From 609359d8f2d983043c93603e017c27b3544c744c Mon Sep 17 00:00:00 2001
From: yilun_jin <jyl.jal123@gmail.com>
Date: Fri, 22 Mar 2024 05:07:11 +0000
Subject: [PATCH] Update models/dummy_model.py

---
 models/dummy_model.py | 27 +++++++++++++++++++++++++++
 1 file changed, 27 insertions(+)

diff --git a/models/dummy_model.py b/models/dummy_model.py
index e126fe5..8c15b55 100644
--- a/models/dummy_model.py
+++ b/models/dummy_model.py
@@ -49,3 +49,30 @@ class DummyModel(ShopBenchBaseModel):
             return str(possible_responses)
             # Note: As this is dummy model, we are returning random responses for non-multiple choice tasks.
             # For generation tasks, this should ideally return an unconstrained string.
+
+class Vicuna2ZeroShot(ShopBenchBaseModel):
+    """
+    A baseline solution that uses Vicuna-7B to generate answers with zero-shot prompting. 
+    """
+    def __init__(self):
+        random.seed(AICROWD_RUN_SEED)
+        ### model_path = 'lmsys/vicuna-7b-v1.5'
+        ### Before submitting, please put necessary files to run Vicuna-7B at the corresponding path, and submit them with `git lfs`. 
+        self.tokenizer = AutoTokenizer.from_pretrained('./models/vicuna-7b-v1.5/', trust_remote_code=True)
+        self.model = AutoModelForCausalLM.from_pretrained('./models/vicuna-7b-v1.5/', device_map='auto', torch_dtype='auto', trust_remote_code=True, do_sample=True)
+        self.system_prompt = "You are a helpful online shopping assistant. Please answer the following question about online shopping and follow the given instructions.\n\n"
+
+
+    def predict(self, prompt: str, is_multiple_choice: bool) -> str:
+        prompt = self.system_prompt + prompt
+        inputs = self.tokenizer(prompt, return_tensors='pt')
+        inputs.input_ids = inputs.input_ids.cuda()
+        if is_multiple_choice:
+            # only one token for multiple choice questions. 
+            generate_ids = self.model.generate(inputs.input_ids, max_new_tokens=1, temperature=0)
+        else:
+            # 100 tokens for non-multiple choice questions. 
+            generate_ids = self.model.generate(inputs.input_ids, max_new_tokens=100, temperature=0)
+        result = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        generation = result[len(prompt):]
+        return generation
-- 
GitLab