Skip to content
Snippets Groups Projects
dummy_model.py 7.07 KiB
Newer Older
spmohanty's avatar
spmohanty committed
from typing import List, Union
spmohanty's avatar
spmohanty committed
import random
import os
xw_g's avatar
xw_g committed
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
from transformers import GenerationConfig
import torch
import json
from torch import Tensor
from tqdm.auto import tqdm
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
spmohanty's avatar
spmohanty committed
from .base_model import ShopBenchBaseModel

# Set a consistent seed for reproducibility
spmohanty's avatar
spmohanty committed
AICROWD_RUN_SEED = int(os.getenv("AICROWD_RUN_SEED", 3142))
spmohanty's avatar
spmohanty committed
class DummyModel(ShopBenchBaseModel):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
    """
spmohanty's avatar
spmohanty committed
    A dummy model implementation for ShopBench, illustrating how to handle both
    multiple choice and other types of tasks like Ranking, Retrieval, and Named Entity Recognition.
    This model uses a consistent random seed for reproducible results.
Dipam Chakraborty's avatar
Dipam Chakraborty committed
    """
spmohanty's avatar
spmohanty committed

Dipam Chakraborty's avatar
Dipam Chakraborty committed
    def __init__(self):
spmohanty's avatar
spmohanty committed
        """Initializes the model and sets the random seed for consistency."""
spmohanty's avatar
spmohanty committed
        random.seed(AICROWD_RUN_SEED)
spmohanty's avatar
spmohanty committed
    def predict(self, prompt: str, is_multiple_choice: bool) -> str:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
        """
spmohanty's avatar
spmohanty committed
        Generates a prediction based on the input prompt and task type.
spmohanty's avatar
spmohanty committed
        For multiple choice tasks, it randomly selects a choice.
        For other tasks, it returns a list of integers as a string,
        representing the model's prediction in a format compatible with task-specific parsers.
spmohanty's avatar
spmohanty committed
        Args:
            prompt (str): The input prompt for the model.
            is_multiple_choice (bool): Indicates whether the task is a multiple choice question.
spmohanty's avatar
spmohanty committed
        Returns:
            str: The prediction as a string representing a single integer[0, 3] for multiple choice tasks,
                        or a string representing a comma separated list of integers for Ranking, Retrieval tasks,
                        or a string representing a comma separated list of named entities for Named Entity Recognition tasks.
                        or a string representing the (unconstrained) generated response for the generation tasks
spmohanty's avatar
spmohanty committed
                        Please refer to parsers.py for more details on how these responses will be parsed by the evaluator.
Dipam Chakraborty's avatar
Dipam Chakraborty committed
        """
spmohanty's avatar
spmohanty committed
        possible_responses = [1, 2, 3, 4]
spmohanty's avatar
spmohanty committed

        if is_multiple_choice:
spmohanty's avatar
spmohanty committed
            # Randomly select one of the possible responses for multiple choice tasks
            return str(random.choice(possible_responses))
spmohanty's avatar
spmohanty committed
        else:
spmohanty's avatar
spmohanty committed
            # For other tasks, shuffle the possible responses and return as a string
            random.shuffle(possible_responses)
            return str(possible_responses)
            # Note: As this is dummy model, we are returning random responses for non-multiple choice tasks.
            # For generation tasks, this should ideally return an unconstrained string.
xw_g's avatar
xw_g committed

class llama3_8b_FewShot(ShopBenchBaseModel):
    def __init__(self):
        random.seed(AICROWD_RUN_SEED)

        self.build_vector_database()

        model_path = './models/Meta-Llama-3-8B-Instruct'
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto', trust_remote_code=True)
        self.system_prompt =  "You are a helpful and multilingual online shopping assistant. You can understand and respond to user queries in English, German, Italian, French, Japanese, Spanish, Chinese. You are knowledgeable about various products. NOTE:ONLY OUTPUT THE ANSWER!!\n" 
        self.terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
            # self.tokenizer.convert_tokens_to_ids("\\n"),
        ]

    def average_pool(self,last_hidden_states: Tensor,
                    attention_mask: Tensor) -> Tensor:
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

    def build_vector_database(self, ):
        # few shot preprocess
        dim = 1024  # Embedding dimension for intfloat/multilingual-e5-large
        nlist = 1024 # Number of cluster centroids
        quantizer = faiss.IndexFlatIP(dim)
        self.index = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss.METRIC_INNER_PRODUCT)
        self.index.nprobe = 3
        self.embed_model = SentenceTransformer("./models/intfloat-multilingual-e5-large")
        self.few_shot_example_text = []
        self.fewshot_embeddings = []
        with open('./models/sample_example1.jsonl','r',encoding='utf8') as f:
            for i in f.readlines():
                passage = ''
                t_data = json.loads(i.strip())
                if "input" in t_data:
                    passage = t_data['instruction'] + t_data['input'] + '\nOutput:' + str( t_data['output']) + '\n'
                else:
                    passage = t_data['instruction'] + str(t_data['output']) + '\n'
                passage = passage.replace('\\n','\n')
                self.few_shot_example_text.append('passage: ' + passage)

        self.index = faiss.read_index("./models/index.ivf")
        self.metadata = [{"fewshot_examaple": fewshot_examaple} for fewshot_examaple in self.few_shot_example_text]


    def predict(self, prompt: str, is_multiple_choice: bool) -> str:
        query_text = 'query: ' + prompt
        query_embed = self.embed_model.encode([query_text])[0]
        topk = 3
        scores, indices = self.index.search(np.array([query_embed]).astype(np.float32), topk)

        # Retrieve and process results
        if not is_multiple_choice:
            exmaple_prompt = []
            for score, idx in zip(scores[0], indices[0]):
                if score>=0.85:
                    fewshot_examaple = self.metadata[idx]["fewshot_examaple"]
                    exmaple_prompt.append(fewshot_examaple[9:])


        if len(exmaple_prompt) > 0:
            prompt_example = self.system_prompt + 'Here are some similar questions and answers you can refer to:\n' 
            for i in exmaple_prompt:
                prompt_example += i+'\n'
            prompt_example += '\nQuestion:' + prompt
        else:
            prompt_example = self.system_prompt + '\n' + prompt
        print(prompt_example)


        messages = [
            {"role": "system", "content": prompt_example[:len(self.system_prompt)]},
            {"role": "user", "content": prompt_example[len(self.system_prompt):]},
        ]
        input_ids = self.tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(self.model.device)
        if is_multiple_choice:
            outputs = self.model.generate(
                input_ids,
                max_new_tokens=1,
                eos_token_id=self.terminators,
                do_sample=False,
            )
        else:
            outputs = self.model.generate(
                input_ids,
                max_new_tokens=128,
                eos_token_id=self.terminators,
                do_sample=False,
            )
        response = outputs[0][input_ids.shape[-1]:]
        response = self.tokenizer.decode(response, skip_special_tokens=True)
        print(response)
        return response