From c4105a1e6f436a53fc66ee11a1e61419ab1be36e Mon Sep 17 00:00:00 2001
From: Fanyou Wu <fanyou.wu@outlook.com>
Date: Mon, 17 Jun 2024 00:26:34 -0400
Subject: [PATCH] Text simple qa

---
 models/v4.py | 61 +++++++++++++++++++++++++---------------------------
 1 file changed, 29 insertions(+), 32 deletions(-)

diff --git a/models/v4.py b/models/v4.py
index 1c9770b..1f737f5 100644
--- a/models/v4.py
+++ b/models/v4.py
@@ -142,11 +142,12 @@ class V4(Base):
 
     def batch_generate_answer(self, batch: Dict, return_dict: bool = False):
         if not return_dict:
-            queries = [str(x) for x in batch["query"]]
+            queries = batch["query"]
+            print(queries)
+
             final_answers_flag = [False] * len(queries)
             final_answers = [""] * len(queries)
 
-            # Check if the question type is other or not.
             question_type = self.query_type_by_llm(queries, final_answers_flag)
             is_not_other = [x != "other" for x in question_type]
             final_answers_flag, final_answers = self.fillin_answer_with_flag(
@@ -212,12 +213,6 @@ class V4(Base):
             ]
 
             output = [x.lower() for x in output]
-
-            # print("after final_answers")
-            # print("final_answers_flag", final_answers_flag)
-            # print("final_answers", final_answers)
-            # print("-----")
-
             return output
 
         else:
@@ -332,34 +327,36 @@ class V4(Base):
     def simple_qa(
         self, queries: List[str], domains: List[str], skip_flag: List[bool]
     ) -> List[str]:
-        try:
-            not_skip = [(d == "movie") and (not s) for d, s in zip(domains, skip_flag)]
-            contain_keywords = [
-                check_keywords(x, ["oscar", "academy", "award"]) for x in queries
-            ]
-            skip = [(not x) or (not y) for x, y in zip(not_skip, contain_keywords)]
+        
+        # not_skip = [(d == "movie") and (not s) for d, s in zip(domains, skip_flag)]
+        # contain_keywords = [
+        #     check_keywords(x, ["oscar", "academy", "award"]) for x in queries
+        # ]
+        # skip = [(not x) or (not y) for x, y in zip(not_skip, contain_keywords)]
 
-            inputs = [
-                self.template.format(context={"query": query}, template="simple_qa")
-                for query in queries
-            ]
+        skip = [False] * len(queries)
 
-            responses = self.llm_inference_with_skip_flag(
-                inputs,
-                skip,
-                max_tokens=96,
-                stop=["<|eot_id|>"],
-                include_stop_str_in_output=False,
-            )
+        inputs = [
+            self.template.format(context={"query": query}, template="simple_qa")
+            for query in queries
+        ]
 
-            answers = [
-                response.outputs[0].text if response else "" for response in responses
-            ]
+        responses = self.llm_inference_with_skip_flag(
+            inputs,
+            skip,
+            max_tokens=64,
+            stop=["<|eot_id|>"],
+            include_stop_str_in_output=False,
+        )
 
-            answers = [x.lower().strip() for x in answers]
-            return answers
-        except:
-            return [""] * len(queries)
+        answers = [
+            response.outputs[0].text if response else "" for response in responses
+        ]
+
+        answers = [x.lower().strip() for x in answers]
+        print(answers)
+        exit()
+        return answers
 
     def query_type_by_llm(self, queries: List[str], skip_flag: List[bool]) -> List[str]:
         inputs = [
-- 
GitLab