From 1b65447a0f892eb2260ac9e5da24beb67d3196b6 Mon Sep 17 00:00:00 2001 From: yilun_jin <jyl.jal123@gmail.com> Date: Sun, 24 Mar 2024 17:29:45 +0000 Subject: [PATCH] Update models/dummy_model.py --- models/dummy_model.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/models/dummy_model.py b/models/dummy_model.py index 8c15b55..5126746 100644 --- a/models/dummy_model.py +++ b/models/dummy_model.py @@ -50,29 +50,3 @@ class DummyModel(ShopBenchBaseModel): # Note: As this is dummy model, we are returning random responses for non-multiple choice tasks. # For generation tasks, this should ideally return an unconstrained string. -class Vicuna2ZeroShot(ShopBenchBaseModel): - """ - A baseline solution that uses Vicuna-7B to generate answers with zero-shot prompting. - """ - def __init__(self): - random.seed(AICROWD_RUN_SEED) - ### model_path = 'lmsys/vicuna-7b-v1.5' - ### Before submitting, please put necessary files to run Vicuna-7B at the corresponding path, and submit them with `git lfs`. - self.tokenizer = AutoTokenizer.from_pretrained('./models/vicuna-7b-v1.5/', trust_remote_code=True) - self.model = AutoModelForCausalLM.from_pretrained('./models/vicuna-7b-v1.5/', device_map='auto', torch_dtype='auto', trust_remote_code=True, do_sample=True) - self.system_prompt = "You are a helpful online shopping assistant. Please answer the following question about online shopping and follow the given instructions.\n\n" - - - def predict(self, prompt: str, is_multiple_choice: bool) -> str: - prompt = self.system_prompt + prompt - inputs = self.tokenizer(prompt, return_tensors='pt') - inputs.input_ids = inputs.input_ids.cuda() - if is_multiple_choice: - # only one token for multiple choice questions. - generate_ids = self.model.generate(inputs.input_ids, max_new_tokens=1, temperature=0) - else: - # 100 tokens for non-multiple choice questions. - generate_ids = self.model.generate(inputs.input_ids, max_new_tokens=100, temperature=0) - result = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - generation = result[len(prompt):] - return generation -- GitLab