from typing import List, Union
import random
import os
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
from transformers import GenerationConfig
import torch
import json
from torch import Tensor
from tqdm.auto import tqdm
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from .base_model import ShopBenchBaseModel

# Set a consistent seed for reproducibility
AICROWD_RUN_SEED = int(os.getenv("AICROWD_RUN_SEED", 3142))


class DummyModel(ShopBenchBaseModel):
    """
    A dummy model implementation for ShopBench, illustrating how to handle both
    multiple choice and other types of tasks like Ranking, Retrieval, and Named Entity Recognition.
    This model uses a consistent random seed for reproducible results.
    """

    def __init__(self):
        """Initializes the model and sets the random seed for consistency."""
        random.seed(AICROWD_RUN_SEED)

    def predict(self, prompt: str, is_multiple_choice: bool) -> str:
        """
        Generates a prediction based on the input prompt and task type.

        For multiple choice tasks, it randomly selects a choice.
        For other tasks, it returns a list of integers as a string,
        representing the model's prediction in a format compatible with task-specific parsers.

        Args:
            prompt (str): The input prompt for the model.
            is_multiple_choice (bool): Indicates whether the task is a multiple choice question.

        Returns:
            str: The prediction as a string representing a single integer[0, 3] for multiple choice tasks,
                        or a string representing a comma separated list of integers for Ranking, Retrieval tasks,
                        or a string representing a comma separated list of named entities for Named Entity Recognition tasks.
                        or a string representing the (unconstrained) generated response for the generation tasks
                        Please refer to parsers.py for more details on how these responses will be parsed by the evaluator.
        """
        possible_responses = [1, 2, 3, 4]

        if is_multiple_choice:
            # Randomly select one of the possible responses for multiple choice tasks
            return str(random.choice(possible_responses))
        else:
            # For other tasks, shuffle the possible responses and return as a string
            random.shuffle(possible_responses)
            return str(possible_responses)
            # Note: As this is dummy model, we are returning random responses for non-multiple choice tasks.
            # For generation tasks, this should ideally return an unconstrained string.


class llama3_8b_FewShot(ShopBenchBaseModel):
    def __init__(self):
        random.seed(AICROWD_RUN_SEED)

        self.build_vector_database()

        model_path = './models/Meta-Llama-3-8B-Instruct'
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', trust_remote_code=True)
        self.system_prompt =  "You are a helpful and multilingual online shopping assistant. You can understand and respond to user queries in English, German, Italian, French, Japanese, Spanish, Chinese. You are knowledgeable about various products. NOTE:ONLY OUTPUT THE ANSWER!!\n" 
        self.terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
            # self.tokenizer.convert_tokens_to_ids("\\n"),
        ]

    def get_detailed_instruct(self, task_description: str, query: str) -> str:
        return f'Instruct: {task_description}\nQuery: {query}'
        
    def build_vector_database(self, ):
        self.embed_model = SentenceTransformer("./models/multilingual-e5-large-instruct")

        # few shot preprocess
        dim = 1024  # Embedding dimension for intfloat/multilingual-e5-large
        nlist = 1024 # Number of cluster centroids
        quantizer = faiss.IndexFlatIP(dim)
        self.index = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss.METRIC_INNER_PRODUCT)
        self.index.nprobe = 3

        self.few_shot_example_text = []
        self.fewshot_embeddings = []
        with open('./models/sample_example1.jsonl','r',encoding='utf8') as f:
            for i in f.readlines():
                passage = ''
                t_data = json.loads(i.strip())
                if "input" in t_data:
                    passage = t_data['instruction'] + t_data['input'] + '\nOutput:' + str( t_data['output']) + '\n'
                else:
                    passage = t_data['instruction'] + str(t_data['output']) + '\n'
                passage = passage.replace('\\n','\n')
                self.few_shot_example_text.append(passage)



        # preprocess retriev index and save trained index
        # self.fewshot_embeddings = self.embed_model.encode(self.few_shot_example_text, batch_size=128, show_progress_bar=True)
        # print(f'process few shot example embedding done! {len(self.few_shot_example_text)}')
        # self.index.train(self.fewshot_embeddings.astype(np.float32))
        # self.index.add(self.fewshot_embeddings.astype(np.float32))
        # faiss.write_index(self.index, "./models/index.ivf")


        self.index = faiss.read_index("./models/index.ivf")
        self.metadata = [{"fewshot_examaple": fewshot_examaple} for fewshot_examaple in self.few_shot_example_text]


    def predict(self, prompt: str, is_multiple_choice: bool) -> str:
        task_description = "Given a online shopping user query, retrieve relevant Question-Answer that similar to the query."
        query_text = ' ' + prompt
        query_embed = self.embed_model.encode([self.get_detailed_instruct(task_description, query_text)])[0]
        topk = 3
        scores, indices = self.index.search(np.array([query_embed]).astype(np.float32), topk)

        # Retrieve and process results
        exmaple_prompt = []
        for score, idx in zip(scores[0], indices[0]):
            print(f'score:{score} meta data:{self.metadata[idx]["fewshot_examaple"]}')
            if score>=0.896:
                fewshot_examaple = self.metadata[idx]["fewshot_examaple"]
                exmaple_prompt.append(fewshot_examaple)
        if len(exmaple_prompt) > 0:
            prompt_example = self.system_prompt + 'Here are some similar questions and answers you can refer to:\n' 
            for i in exmaple_prompt:
                prompt_example += i+'\n'
            prompt_example += 'Now answer the Question:' + prompt
        else:
            prompt_example = self.system_prompt + '\nNow answer the Question:' + prompt
        print(prompt_example)


        if is_multiple_choice:
            # todo add retrive few shot example
            inputs = self.tokenizer.encode(prompt_example, add_special_tokens=False, return_tensors="pt")
            inputs = inputs.cuda()
            if is_multiple_choice:
                generate_ids = self.model.generate(inputs, max_new_tokens=1, temperature=0.1, eos_token_id=self.terminators)
            result = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
            generation = result[len(prompt_example):]
        else:
            messages = [
                {"role": "system", "content": prompt_example[:len(self.system_prompt)]},
                {"role": "user", "content": prompt_example[len(self.system_prompt):]},
            ]
            input_ids = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(self.model.device)
            outputs = self.model.generate(
                input_ids,
                max_new_tokens=200,
                eos_token_id=self.terminators,
                do_sample=False,
            )
            generation = outputs[0][input_ids.shape[-1]:]
            generation = self.tokenizer.decode(generation, skip_special_tokens=True)

        print(f'model generate answer : {generation}')
        return generation