From 9dbe6d34467356479d9436ac465f058495e2a35e Mon Sep 17 00:00:00 2001
From: Fanyou Wu <fanyou.wu@outlook.com>
Date: Sun, 16 Jun 2024 20:10:31 -0400
Subject: [PATCH] Remove simple qa and change to t1

---
 aicrowd.json |  2 +-
 models/v4.py | 24 ++++++++++++++++--------
 2 files changed, 17 insertions(+), 9 deletions(-)

diff --git a/aicrowd.json b/aicrowd.json
index 6fec5d7..d0c3347 100644
--- a/aicrowd.json
+++ b/aicrowd.json
@@ -1,5 +1,5 @@
 {
-    "challenge_id": "meta-kdd-cup-24-crag-end-to-end-retrieval-augmented-generation",
+    "challenge_id": "meta-kdd-cup-24-crag-retrieval-summarization",
     "authors": [
         "Fanyou Wu"
     ],
diff --git a/models/v4.py b/models/v4.py
index 81c4ce4..273a389 100644
--- a/models/v4.py
+++ b/models/v4.py
@@ -130,6 +130,7 @@ class V4(Base):
     def batch_generate_answer(self, batch: Dict, return_dict: bool = False):
         if not return_dict:
             queries = batch["query"]
+            # print(queries)
             final_answers_flag = [False] * len(queries)
             final_answers = [""] * len(queries)
 
@@ -148,6 +149,7 @@ class V4(Base):
             # Check the domain of the query
             domains = self.query_classifier(queries, final_answers_flag)["domain"]
             is_finance = [x == "finance" for x in domains]
+            
             final_answers_flag, final_answers = self.fillin_answer_with_flag(
                 is_finance, final_answers_flag, final_answers, "i don't know"
             )
@@ -156,12 +158,13 @@ class V4(Base):
             # print("final_answers_flag", final_answers_flag)
             # print("final_answers", final_answers)
             # print("-----")
+            # final_answers_flag = [True] * len(queries)
 
-            simple_qa = self.simple_qa(queries, domains, final_answers_flag)
-            final_answers = [
-                x if len(x) > 0 else y for x, y in zip(simple_qa, final_answers)
-            ]
-            final_answers_flag = [(len(x) > 0) or (y) for x, y in zip(simple_qa, final_answers_flag) ]
+            # simple_qa = self.simple_qa(queries, domains, final_answers_flag)
+            # final_answers = [
+            #     x if len(x) > 0 else y for x, y in zip(simple_qa, final_answers)
+            # ]
+            # final_answers_flag = [(len(x) > 0) or (y) for x, y in zip(simple_qa, final_answers_flag) ]
 
             # print("after simple_qa")
             # print("final_answers_flag", final_answers_flag)
@@ -200,9 +203,9 @@ class V4(Base):
                 false_premise, final_answers_flag, final_answers, "invalid question"
             )
 
-            print("after false_premise")
-            print("false_premise", false_premise)
-            print("-----")
+            # print("after false_premise")
+            # print("false_premise", false_premise)
+            # print("-----")
 
             confidence = [self.get_confidence_output(x) != "high" for x in answers]
             final_answers_flag, final_answers = self.fillin_answer_with_flag(
@@ -335,10 +338,13 @@ class V4(Base):
     def simple_qa(
         self, queries: List[str], domains: List[str], skip_flag: List[bool]
     ) -> List[str]:
+
+
         skip = [s | (d != "movie") for d, s in zip(domains, skip_flag)]
         skip = [
             s | (not (("oscar" in q) | ("academy" in q))) for q, s in zip(queries, skip)
         ]
+        
         inputs = [
             self.template.format(context={"query": query}, template="simple_qa")
             for query in queries
@@ -346,9 +352,11 @@ class V4(Base):
         responses = self.llm_inference_with_skip_flag(
             inputs, skip, max_tokens=128, include_stop_str_in_output=False
         )
+
         answers = [
             response.outputs[0].text if response else "" for response in responses
         ]
+
         answers = [x.lower().strip() for x in answers]
         return answers
 
-- 
GitLab