From 663d93aef94ac3ce589971b2db0f20660e73190e Mon Sep 17 00:00:00 2001
From: Fanyou Wu <fanyou.wu@outlook.com>
Date: Sun, 16 Jun 2024 14:49:20 -0400
Subject: [PATCH] Update

---
 models/base.py |   3 -
 models/v4.py   | 221 +++++++++++++++++++++++++++++--------------------
 2 files changed, 129 insertions(+), 95 deletions(-)

diff --git a/models/base.py b/models/base.py
index 0820e0f..6f73d6a 100644
--- a/models/base.py
+++ b/models/base.py
@@ -25,6 +25,3 @@ class Base:
     def __call__(self):
         self.init()
         return self
-
-    def generate_answer(self, query: str, search_results: List[str]) -> str:
-        return "i don't know"
diff --git a/models/v4.py b/models/v4.py
index 8700956..f3ed71d 100644
--- a/models/v4.py
+++ b/models/v4.py
@@ -67,7 +67,7 @@ class V4Config:
     )
     eval_batch_size: int = 8
     max_num_doc: int = 10
-    false_premise_threshold: float = 0.97
+    false_premise_threshold: float = 0.8 #0.97 offline best
     trusted_source: list = field(
         default_factory=lambda: [
             "wiki",
@@ -94,7 +94,7 @@ class V4(Base):
         )
         self.model = VLLMLamaLoader(config=self.config.llama_loader_config).load()
         self.template = TemplateRouter()
-        self.kg_api = CRAG(server=CRAG_MOCK_API_URL)
+        
 
     def post_processing(
         self, answers: List[str], cond: List[bool], replace_text: str = "i don't know"
@@ -161,7 +161,7 @@ class V4(Base):
             final_answers = [
                 x if len(x) > 0 else y for x, y in zip(simple_qa, final_answers)
             ]
-            final_answers_flag = [(len(x) > 0) or (y) for x, y in zip(simple_qa,final_answers_flag) ]
+            final_answers_flag = [(len(x) > 0) or (y) for x, y in zip(simple_qa, final_answers_flag) ]
 
             # print("after simple_qa")
             # print("final_answers_flag", final_answers_flag)
@@ -169,8 +169,11 @@ class V4(Base):
             # print("-----")
 
             search_results = self.remove_duplicated_and_sort_trusted(
-                batch["search_results"], domains=domains
+                batch["search_results"]
             )
+            # search_results = self.remove_duplicated_and_sort_trusted(
+            #     batch["search_results"], domains=domains
+            # )
             search_results = [x[: self.config.max_num_doc] for x in search_results]
             contents = self.content_retriever(
                 queries, search_results, final_answers_flag
@@ -188,13 +191,19 @@ class V4(Base):
             false_premise = self.false_premise_by_llm(
                 queries, answers, final_answers_flag
             )
+
             low_confidence = [self.get_confidence_output(x) == "low" for x in answers]
+        
             false_premise = [a & b for a, b in zip(false_premise, low_confidence)]
 
             final_answers_flag, final_answers = self.fillin_answer_with_flag(
                 false_premise, final_answers_flag, final_answers, "invalid question"
             )
 
+            print("after false_premise")
+            print("false_premise", false_premise)
+            print("-----")
+
             confidence = [self.get_confidence_output(x) != "high" for x in answers]
             final_answers_flag, final_answers = self.fillin_answer_with_flag(
                 confidence, final_answers_flag, final_answers, "i don't know"
@@ -206,7 +215,14 @@ class V4(Base):
                 for idx, flag in enumerate(final_answers_flag)
             ]
             final_answers = [x.lower() for x in final_answers]
+            
+            # print("after final_answers")
+            # print("final_answers_flag", final_answers_flag)
+            # print("final_answers", final_answers)
+            # print("-----")
+
             return final_answers
+        
         else:
             # queries = batch["query"]
             # skip_flags = [False] * len(queries)
@@ -226,7 +242,7 @@ class V4(Base):
             query_classes = self.query_classifier(queries, skip_flags)
             domains = query_classes["domain"]
             search_results = self.remove_duplicated_and_sort_trusted(
-                batch["search_results"], domains
+                batch["search_results"]
             )
             contents = self.content_retriever(queries, search_results, skip_flags)
             query_times = batch["query_time"]
@@ -391,6 +407,7 @@ class V4(Base):
             extract_score(response.outputs[0]) if response else (0.0, 0.0)
             for response in responses
         ]
+        # print("false premise score", scores)
         if return_bool:
             output = [
                 (x[0] - x[1]) > self.config.false_premise_threshold for x in scores
@@ -409,98 +426,118 @@ class V4(Base):
     def get_batch_size(self) -> int:
         return self.config.eval_batch_size
 
-    def remove_duplicated_and_sort_trusted(
-        self, search_results: List[List[Dict]], domains: List[str]
-    ):
-        output = []
-        for row, domain in zip(search_results, domains):
-            if domain != "open":
-                output.append(self.remove_duplicated_and_sort_trusted_other(row))
-            else:
-                output.append(self.remove_duplicated_and_sort_trusted_open(row))
-        return output
-
-    def remove_duplicated_and_sort_trusted_other(self, search_results: List[Dict]):
-        output = []
-        added_url = set()
-        for web in search_results:
-            url = web["page_url"]
-            added_url.add(url)
-            if web["page_result"] != "":
-                output.append(web)
-        if len(output) == 0:
-            output = search_results
-        return output
-
-    def remove_duplicated_and_sort_trusted_open(self, search_results: List[Dict]):
+    def remove_duplicated_and_sort_trusted(self, search_results: List[List[Dict]]):
         output = []
-
-        trust_url = set()
-        added_url = set()
-        for keyword in self.config.trusted_source:
-            for r in search_results:
-                url = r["page_url"]
-                domain = urlparse(url).netloc
-                if keyword in domain:
-                    if r["page_result"] != "":
-                        trust_url.add(url)
-
-        if len(trust_url) == 0:
-            trust_url = set([x["page_url"] for x in search_results])
-
-        for web in search_results:
-            url = web["page_url"]
-            if (url in added_url) | (url not in trust_url):
-                continue
-            added_url.add(url)
-            if web["page_result"] != "":
-                output.append(web)
-
-        if len(output) == 0:
+        for row in search_results:
+            temp = []
             added_url = set()
-            for web in search_results:
+            for web in row:
                 url = web["page_url"]
                 added_url.add(url)
                 if web["page_result"] != "":
-                    output.append(web)
-        if len(output) == 0:
-            output = search_results
+                    temp.append(web)
+            if len(temp) == 0:
+                temp = row
+            output.append(temp)
         return output
 
-    def get_kg_results(self, entities: List, domains: List[str], skip_flag: List[bool]):
-        results = [
-            self.kg_api(e, d) if ~f else ""
-            for e, d, f in zip(entities, domains, skip_flag)
-        ]
-        return results
-
-    def kg_entity_extract(
-        self,
-        queries: List[str],
-        query_times: List[str],
-        domains: List[str],
-        skip_flag: List[bool],
-    ) -> List[str]:
-        skip_flag = [(d == "open") | s for d, s in zip(domains, skip_flag)]
-
-        inputs = [
-            self.template.format(
-                context={"query": query, "query_time": query_time},
-                template=f"entity_{domain}",
-            )
-            for query, domain, query_time in zip(queries, domains, query_times)
-        ]
-
-        responses = self.llm_inference_with_skip_flag(inputs, skip_flag, stop=["}"])
-        answers = [
-            response.outputs[0].text if response else "" for response in responses
-        ]
 
-        entities = []
-        for res in answers:
-            try:
-                res = json.loads(res)
-            except:
-                res = extract_json_objects(res)
-            entities.append(res)
-        return entities
+    # def remove_duplicated_and_sort_trusted(
+    #     self, search_results: List[List[Dict]], domains: List[str]
+    # ):
+    #     output = []
+    #     for row, domain in zip(search_results, domains):
+    #         if domain != "open":
+    #             output.append(self.remove_duplicated_and_sort_trusted_other(row))
+    #         else:
+    #             output.append(self.remove_duplicated_and_sort_trusted_open(row))
+    #     return output
+
+    # def remove_duplicated_and_sort_trusted_other(self, search_results: List[Dict]):
+    #     output = []
+    #     added_url = set()
+    #     for web in search_results:
+    #         url = web["page_url"]
+    #         added_url.add(url)
+    #         if web["page_result"] != "":
+    #             output.append(web)
+    #     if len(output) == 0:
+    #         output = search_results
+    #     return output
+
+    # def remove_duplicated_and_sort_trusted_open(self, search_results: List[Dict]):
+    #     output = []
+
+    #     trust_url = set()
+    #     added_url = set()
+
+    #     for keyword in self.config.trusted_source:
+    #         for r in search_results:
+    #             url = r["page_url"]
+    #             domain = urlparse(url).netloc
+    #             if keyword in domain:
+    #                 if r["page_result"] != "":
+    #                     trust_url.add(url)
+
+    #     if len(trust_url) == 0:
+    #         trust_url = set([x["page_url"] for x in search_results])
+
+    #     for web in search_results:
+    #         url = web["page_url"]
+    #         if (url in added_url) | (url not in trust_url):
+    #             continue
+    #         added_url.add(url)
+    #         if web["page_result"] != "":
+    #             output.append(web)
+
+    #     if len(output) == 0:
+    #         added_url = set()
+    #         for web in search_results:
+    #             url = web["page_url"]
+    #             added_url.add(url)
+    #             if web["page_result"] != "":
+    #                 output.append(web)
+
+    #     if len(output) == 0:
+    #         output = search_results
+
+    #     return output
+
+    # def get_kg_results(self, entities: List, domains: List[str], skip_flag: List[bool]):
+    #     kg_api = CRAG(server=CRAG_MOCK_API_URL)
+    #     results = [
+    #         kg_api(e, d) if ~f else ""
+    #         for e, d, f in zip(entities, domains, skip_flag)
+    #     ]
+    #     return results
+
+    # def kg_entity_extract(
+    #     self,
+    #     queries: List[str],
+    #     query_times: List[str],
+    #     domains: List[str],
+    #     skip_flag: List[bool],
+    # ) -> List[str]:
+    #     skip_flag = [(d == "open") | s for d, s in zip(domains, skip_flag)]
+
+    #     inputs = [
+    #         self.template.format(
+    #             context={"query": query, "query_time": query_time},
+    #             template=f"entity_{domain}",
+    #         )
+    #         for query, domain, query_time in zip(queries, domains, query_times)
+    #     ]
+
+    #     responses = self.llm_inference_with_skip_flag(inputs, skip_flag, stop=["}"])
+    #     answers = [
+    #         response.outputs[0].text if response else "" for response in responses
+    #     ]
+
+    #     entities = []
+    #     for res in answers:
+    #         try:
+    #             res = json.loads(res)
+    #         except:
+    #             res = extract_json_objects(res)
+    #         entities.append(res)
+    #     return entities
-- 
GitLab