from typing import List, Union
import random
import os
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
from transformers import GenerationConfig
import torch
import json
from torch import Tensor
from tqdm.auto import tqdm
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from .base_model import ShopBenchBaseModel

# Set a consistent seed for reproducibility
AICROWD_RUN_SEED = int(os.getenv("AICROWD_RUN_SEED", 3142))


class DummyModel(ShopBenchBaseModel):
    """
    A dummy model implementation for ShopBench, illustrating how to handle both
    multiple choice and other types of tasks like Ranking, Retrieval, and Named Entity Recognition.
    This model uses a consistent random seed for reproducible results.
    """

    def __init__(self):
        """Initializes the model and sets the random seed for consistency."""
        random.seed(AICROWD_RUN_SEED)

    def predict(self, prompt: str, is_multiple_choice: bool) -> str:
        """
        Generates a prediction based on the input prompt and task type.

        For multiple choice tasks, it randomly selects a choice.
        For other tasks, it returns a list of integers as a string,
        representing the model's prediction in a format compatible with task-specific parsers.

        Args:
            prompt (str): The input prompt for the model.
            is_multiple_choice (bool): Indicates whether the task is a multiple choice question.

        Returns:
            str: The prediction as a string representing a single integer[0, 3] for multiple choice tasks,
                        or a string representing a comma separated list of integers for Ranking, Retrieval tasks,
                        or a string representing a comma separated list of named entities for Named Entity Recognition tasks.
                        or a string representing the (unconstrained) generated response for the generation tasks
                        Please refer to parsers.py for more details on how these responses will be parsed by the evaluator.
        """
        possible_responses = [1, 2, 3, 4]

        if is_multiple_choice:
            # Randomly select one of the possible responses for multiple choice tasks
            return str(random.choice(possible_responses))
        else:
            # For other tasks, shuffle the possible responses and return as a string
            random.shuffle(possible_responses)
            return str(possible_responses)
            # Note: As this is dummy model, we are returning random responses for non-multiple choice tasks.
            # For generation tasks, this should ideally return an unconstrained string.


class llama3_8b_FewShot(ShopBenchBaseModel):
    def __init__(self):
        random.seed(AICROWD_RUN_SEED)

        self.build_vector_database()

        model_path = './models/Meta-Llama-3-8B-Instruct'
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', trust_remote_code=True)
        self.system_prompt =  "You are a helpful and multilingual online shopping assistant. You can understand and respond to user queries in English, German, Italian, French, Japanese, Spanish, Chinese. You are knowledgeable about various products. NOTE:ONLY OUTPUT THE ANSWER!!\n" 
        self.terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
            # self.tokenizer.convert_tokens_to_ids("\\n"),
        ]

    def average_pool(self,last_hidden_states: Tensor,
                    attention_mask: Tensor) -> Tensor:
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

    def build_vector_database(self, ):
        # few shot preprocess
        dim = 1024  # Embedding dimension for intfloat/multilingual-e5-large
        nlist = 1024 # Number of cluster centroids
        quantizer = faiss.IndexFlatIP(dim)
        self.index = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss.METRIC_INNER_PRODUCT)
        self.index.nprobe = 3
        self.embed_model = SentenceTransformer("./models/intfloat-multilingual-e5-large")
        self.few_shot_example_text = []
        self.fewshot_embeddings = []
        with open('./models/sample_example1.jsonl','r',encoding='utf8') as f:
            for i in f.readlines():
                passage = ''
                t_data = json.loads(i.strip())
                if "input" in t_data:
                    passage = t_data['instruction'] + t_data['input'] + '\nOutput:' + str( t_data['output']) + '\n'
                else:
                    passage = t_data['instruction'] + str(t_data['output']) + '\n'
                passage = passage.replace('\\n','\n')
                self.few_shot_example_text.append('passage: ' + passage)

        self.index = faiss.read_index("./models/index.ivf")
        self.metadata = [{"fewshot_examaple": fewshot_examaple} for fewshot_examaple in self.few_shot_example_text]


    def predict(self, prompt: str, is_multiple_choice: bool) -> str:
        query_text = 'query: ' + prompt
        query_embed = self.embed_model.encode([query_text])[0]
        topk = 3
        scores, indices = self.index.search(np.array([query_embed]).astype(np.float32), topk)

        # Retrieve and process results
        exmaple_prompt = []
        if not is_multiple_choice:
            for score, idx in zip(scores[0], indices[0]):
                if score>=0.85:
                    fewshot_examaple = self.metadata[idx]["fewshot_examaple"]
                    exmaple_prompt.append(fewshot_examaple[9:])


        if len(exmaple_prompt) > 0:
            prompt_example = self.system_prompt + 'Here are some similar questions and answers you can refer to:\n' 
            for i in exmaple_prompt:
                prompt_example += i+'\n'
            prompt_example += '\nQuestion:' + prompt
        else:
            prompt_example = self.system_prompt + '\n' + prompt
        print(prompt_example)


        messages = [
            {"role": "system", "content": prompt_example[:len(self.system_prompt)]},
            {"role": "user", "content": prompt_example[len(self.system_prompt):]},
        ]
        input_ids = self.tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(self.model.device)
        if is_multiple_choice:
            outputs = self.model.generate(
                input_ids,
                max_new_tokens=1,
                eos_token_id=self.terminators,
                do_sample=False,
            )
        else:
            outputs = self.model.generate(
                input_ids,
                max_new_tokens=128,
                eos_token_id=self.terminators,
                do_sample=False,
            )
        response = outputs[0][input_ids.shape[-1]:]
        response = self.tokenizer.decode(response, skip_special_tokens=True)
        print(response)
        return response