From a21d388af3be3c38c8cda4a4325a6db6ad81aaa8 Mon Sep 17 00:00:00 2001
From: Fanyou Wu <fanyou.wu@outlook.com>
Date: Mon, 17 Jun 2024 02:13:37 -0400
Subject: [PATCH] Update

---
 models/llama_loader.py | 2 --
 models/v4.py           | 6 ++++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/models/llama_loader.py b/models/llama_loader.py
index ae01055..8fff3f3 100644
--- a/models/llama_loader.py
+++ b/models/llama_loader.py
@@ -63,8 +63,6 @@ class VLLMConfig:
             "disable_custom_all_reduce": self.disable_custom_all_reduce
         }
 
-    # temperature=0.6,
-    # top_p=0.9,
     def sampling_args(self, override_kwargs: dict = {}) -> vllm.SamplingParams:
         kwargs = {
             "n": self.n,
diff --git a/models/v4.py b/models/v4.py
index c67a991..dd0b2b6 100644
--- a/models/v4.py
+++ b/models/v4.py
@@ -162,6 +162,7 @@ class V4(Base):
             )
 
             simple_qa = self.simple_qa(queries, domains, final_answers_flag)
+
             final_answers = [
                 (x if x != "" else y) for x, y in zip(simple_qa, final_answers)
             ]
@@ -328,7 +329,6 @@ class V4(Base):
         self, queries: List[str], domains: List[str], skip_flag: List[bool]
     ) -> List[str]:
         
-        skip = [False] * len(queries)
         inputs = [
             self.template.format(context={"query": query}, template="simple_qa")
             for query in queries
@@ -336,7 +336,7 @@ class V4(Base):
 
         responses = self.llm_inference_with_skip_flag(
             inputs,
-            skip,
+            [False] * len(queries),
             max_tokens=64,
             stop=["<|eot_id|>"],
             include_stop_str_in_output=False,
@@ -347,6 +347,8 @@ class V4(Base):
         ]
         answers = [x.lower().strip() for x in answers]
 
+        print(answers)
+
         not_skip = [(d == "movie") and (not s) for d, s in zip(domains, skip_flag)]
         contain_keywords = [
             check_keywords(x, ["oscar", "academy", "award"]) for x in queries
-- 
GitLab