From e4747e192fd3291e38bf01beea48a1a41c2ea8e8 Mon Sep 17 00:00:00 2001
From: Fanyou Wu <fanyou.wu@outlook.com>
Date: Thu, 20 Jun 2024 14:48:07 -0400
Subject: [PATCH] update

---
 models/v4.py | 32 +++++++++++++++++---------------
 1 file changed, 17 insertions(+), 15 deletions(-)

diff --git a/models/v4.py b/models/v4.py
index 8fe8d40..66d2858 100644
--- a/models/v4.py
+++ b/models/v4.py
@@ -231,21 +231,21 @@ class V4(Base):
                 is_not_static, final_answers_flag, final_answers, "i don't know"
             )
 
-            question_type = self.query_type_by_llm(queries, final_answers_flag)
-            question_type = [
-                x if x != "unknown" else y
-                for x, y in zip(train_data_question_type, question_type)
-            ]
-
-            is_aggregation = [x == "aggregation" for x in question_type]
-            final_answers_flag, final_answers = self.fillin_answer_with_flag(
-                is_aggregation, final_answers_flag, final_answers, "i don't know"
-            )
+            # question_type = self.query_type_by_llm(queries, final_answers_flag)
+            # question_type = [
+            #     x if x != "unknown" else y
+            #     for x, y in zip(train_data_question_type, question_type)
+            # ]
+
+            # is_aggregation = [x == "aggregation" for x in question_type]
+            # final_answers_flag, final_answers = self.fillin_answer_with_flag(
+            #     is_aggregation, final_answers_flag, final_answers, "i don't know"
+            # )
 
-            is_post_processing = [x == "post-processing" for x in question_type]
-            final_answers_flag, final_answers = self.fillin_answer_with_flag(
-                is_post_processing, final_answers_flag, final_answers, "i don't know"
-            )
+            # is_post_processing = [x == "post-processing" for x in question_type]
+            # final_answers_flag, final_answers = self.fillin_answer_with_flag(
+            #     is_post_processing, final_answers_flag, final_answers, "i don't know"
+            # )
 
             simple_qa, simple_not_skip = self.simple_qa(
                 queries, domains, final_answers_flag
@@ -271,6 +271,7 @@ class V4(Base):
             contents = self.content_retriever(
                 queries, search_results, final_answers_flag
             )
+
             answers = self.qa_with_ref(
                 queries,
                 contents,
@@ -278,6 +279,7 @@ class V4(Base):
                 domains,
                 final_answers_flag,
             )
+            
             false_premise_by_rag = self.false_premise_by_llm(
                 queries, answers, final_answers_flag, return_bool=True
             )
@@ -346,7 +348,7 @@ class V4(Base):
             return {
                 "queries": queries,
                 "prediction": [self.get_answer_output(x) for x in answers],
-                "reason": [find_by_xml_tag(x, "reason", "") for x in answers],
+                #"reason": [find_by_xml_tag(x, "reason", "") for x in answers],
                 "confidence": [self.get_confidence_output(x) for x in answers],
             }
 
-- 
GitLab