From 3411dd12739826c191c5b97c2a607f935011b637 Mon Sep 17 00:00:00 2001
From: Fanyou Wu <fanyou.wu@outlook.com>
Date: Sun, 16 Jun 2024 01:55:21 -0400
Subject: [PATCH] update

---
 models/v4.py | 223 +++++++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 197 insertions(+), 26 deletions(-)

diff --git a/models/v4.py b/models/v4.py
index a2574d8..8700956 100644
--- a/models/v4.py
+++ b/models/v4.py
@@ -10,10 +10,32 @@ from models.rust_content_retriever import (
 from datetime import datetime
 from models.prompt_template import TemplateRouter
 from models.rust_query_classifier import QueryClassifier, QueryClassifierConfig
-
+from models.kg import CRAG
 from vllm.outputs import CompletionOutput
-from math import exp
 from urllib.parse import urlparse
+from math import exp
+from json import JSONDecoder
+import json
+import os
+
+CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000")
+
+
+def extract_json_objects(text, decoder=JSONDecoder()):
+    """Find JSON objects in text, and yield the decoded JSON data"""
+    pos = 0
+    results = []
+    while True:
+        match = text.find("{", pos)
+        if match == -1:
+            break
+        try:
+            result, index = decoder.raw_decode(text[match:])
+            results.append(result)
+            pos = match + index
+        except ValueError:
+            pos = match + 1
+    return results
 
 
 def extract_score(completion_output: CompletionOutput) -> Tuple[float, float]:
@@ -45,22 +67,19 @@ class V4Config:
     )
     eval_batch_size: int = 8
     max_num_doc: int = 10
-    false_premise_threshold: float = 0.1
+    false_premise_threshold: float = 0.97
     trusted_source: list = field(
         default_factory=lambda: [
-            "wikipedia",
+            "wiki",
             "imdb.com",
             "grammy.com",
             "fandom.com",
         ]
     )
-    # blacklist: list = field(
-    #     default_factory=lambda: ["investopedia.com", "quora.com", "amazon.com"]
-    # )
 
 
 class V4(Base):
-    method_version: str = "v4"
+    method_version: str = "v4_simple_qa"
 
     def __init__(self, config: V4Config, init_class: bool = False):
         self.config = config
@@ -75,6 +94,7 @@ class V4(Base):
         )
         self.model = VLLMLamaLoader(config=self.config.llama_loader_config).load()
         self.template = TemplateRouter()
+        self.kg_api = CRAG(server=CRAG_MOCK_API_URL)
 
     def post_processing(
         self, answers: List[str], cond: List[bool], replace_text: str = "i don't know"
@@ -120,15 +140,36 @@ class V4(Base):
                 is_not_other, final_answers_flag, final_answers, "i don't know"
             )
 
+            # print("after question_type")
+            # print("final_answers_flag", final_answers_flag)
+            # print("final_answers", final_answers)
+            # print("-----")
+
             # Check the domain of the query
-            query_classes = self.query_classifier(queries, final_answers_flag)
-            is_finance = [x == "finance" for x in query_classes["domain"]]
+            domains = self.query_classifier(queries, final_answers_flag)["domain"]
+            is_finance = [x == "finance" for x in domains]
             final_answers_flag, final_answers = self.fillin_answer_with_flag(
                 is_finance, final_answers_flag, final_answers, "i don't know"
             )
 
+            # print("after is_finance")
+            # print("final_answers_flag", final_answers_flag)
+            # print("final_answers", final_answers)
+            # print("-----")
+
+            simple_qa = self.simple_qa(queries, domains, final_answers_flag)
+            final_answers = [
+                x if len(x) > 0 else y for x, y in zip(simple_qa, final_answers)
+            ]
+            final_answers_flag = [(len(x) > 0) or (y) for x, y in zip(simple_qa,final_answers_flag) ]
+
+            # print("after simple_qa")
+            # print("final_answers_flag", final_answers_flag)
+            # print("final_answers", final_answers)
+            # print("-----")
+
             search_results = self.remove_duplicated_and_sort_trusted(
-                batch["search_results"]
+                batch["search_results"], domains=domains
             )
             search_results = [x[: self.config.max_num_doc] for x in search_results]
             contents = self.content_retriever(
@@ -139,7 +180,7 @@ class V4(Base):
                 queries,
                 contents,
                 query_times,
-                query_classes["domain"],
+                domains,
                 final_answers_flag,
             )
 
@@ -167,21 +208,43 @@ class V4(Base):
             final_answers = [x.lower() for x in final_answers]
             return final_answers
         else:
+            # queries = batch["query"]
+            # skip_flags = [False] * len(queries)
+
+            # domains = self.query_classifier(queries, skip_flags)["domain"]
+
+            # return {
+            #     "queries": queries,
+            #     "prediction": self.simple_qa(queries, domains, skip_flags),
+            #     # "pred_domain": domains,
+            # }
+
             queries = batch["query"]
             skip_flags = [False] * len(queries)
             question_type = self.query_type_by_llm(queries, skip_flags)
+
             query_classes = self.query_classifier(queries, skip_flags)
+            domains = query_classes["domain"]
             search_results = self.remove_duplicated_and_sort_trusted(
-                batch["search_results"]
+                batch["search_results"], domains
             )
             contents = self.content_retriever(queries, search_results, skip_flags)
             query_times = batch["query_time"]
-            domains = query_classes["domain"]
+
             answers = self.qa_with_ref(
                 queries, contents, query_times, domains, skip_flags
             )
+            simple_qa_answer = self.simple_qa(queries, domains, skip_flags)
             prediction = [self.get_answer_output(x) for x in answers]
-            false_premise = self.false_premise_by_llm(queries, answers, skip_flags, return_bool=False)
+            prediction = [
+                y if len(y) > 0 else x for x, y in zip(prediction, simple_qa_answer)
+            ]
+
+            false_premise = self.false_premise_by_llm(
+                queries, answers, skip_flags, return_bool=False
+            )
+            # entities = self.kg_entity_extract(queries, query_times, domains, skip_flags)
+            # kg_results = self.get_kg_results(entities, domains, skip_flags)
             return {
                 "queries": queries,
                 "prediction": prediction,
@@ -190,6 +253,9 @@ class V4(Base):
                 "pred_domain": domains,
                 "confidence": [self.get_confidence_output(x) for x in answers],
                 "question_type": question_type,
+                # "kd_entity": entities,
+                # "kg_result": kg_results,
+                # "simple_qa": self.simple_qa(queries, skip_flags),
             }
 
     def qa_with_ref(
@@ -250,6 +316,25 @@ class V4(Base):
         input = self.template.format(context, template=domain)
         return input
 
+    def simple_qa(
+        self, queries: List[str], domains: List[str], skip_flag: List[bool]
+    ) -> List[str]:
+        skip = [s | (d != "movie") for d, s in zip(domains, skip_flag)]
+        skip = [
+            s | (not (("oscar" in q) | ("academy" in q))) for q, s in zip(queries, skip)
+        ]
+        inputs = [
+            self.template.format(context={"query": query}, template="simple_qa")
+            for query in queries
+        ]
+        responses = self.llm_inference_with_skip_flag(
+            inputs, skip, max_tokens=128, include_stop_str_in_output=False
+        )
+        answers = [
+            response.outputs[0].text if response else "" for response in responses
+        ]
+        answers = [x.lower().strip() for x in answers]
+        return answers
 
     def query_type_by_llm(self, queries: List[str], skip_flag: List[bool]) -> List[str]:
         inputs = [
@@ -287,7 +372,11 @@ class V4(Base):
         return output
 
     def false_premise_by_llm(
-        self, queries: List[str], answers: List[str], skip_flag: List[bool], return_bool:bool=True
+        self,
+        queries: List[str],
+        answers: List[str],
+        skip_flag: List[bool],
+        return_bool: bool = True,
     ) -> List:
         inputs = [
             self.template.format(
@@ -303,8 +392,9 @@ class V4(Base):
             for response in responses
         ]
         if return_bool:
-
-            output = [(x[0] - x[1]) > self.config.false_premise_threshold for x in scores]
+            output = [
+                (x[0] - x[1]) > self.config.false_premise_threshold for x in scores
+            ]
             return output
         return scores
 
@@ -319,17 +409,98 @@ class V4(Base):
     def get_batch_size(self) -> int:
         return self.config.eval_batch_size
 
-    def remove_duplicated_and_sort_trusted(self, search_results: List[List[Dict]]):
+    def remove_duplicated_and_sort_trusted(
+        self, search_results: List[List[Dict]], domains: List[str]
+    ):
+        output = []
+        for row, domain in zip(search_results, domains):
+            if domain != "open":
+                output.append(self.remove_duplicated_and_sort_trusted_other(row))
+            else:
+                output.append(self.remove_duplicated_and_sort_trusted_open(row))
+        return output
+
+    def remove_duplicated_and_sort_trusted_other(self, search_results: List[Dict]):
+        output = []
+        added_url = set()
+        for web in search_results:
+            url = web["page_url"]
+            added_url.add(url)
+            if web["page_result"] != "":
+                output.append(web)
+        if len(output) == 0:
+            output = search_results
+        return output
+
+    def remove_duplicated_and_sort_trusted_open(self, search_results: List[Dict]):
         output = []
-        for row in search_results:
-            temp = []
+
+        trust_url = set()
+        added_url = set()
+        for keyword in self.config.trusted_source:
+            for r in search_results:
+                url = r["page_url"]
+                domain = urlparse(url).netloc
+                if keyword in domain:
+                    if r["page_result"] != "":
+                        trust_url.add(url)
+
+        if len(trust_url) == 0:
+            trust_url = set([x["page_url"] for x in search_results])
+
+        for web in search_results:
+            url = web["page_url"]
+            if (url in added_url) | (url not in trust_url):
+                continue
+            added_url.add(url)
+            if web["page_result"] != "":
+                output.append(web)
+
+        if len(output) == 0:
             added_url = set()
-            for web in row:
+            for web in search_results:
                 url = web["page_url"]
                 added_url.add(url)
                 if web["page_result"] != "":
-                    temp.append(web)
-            if len(temp) == 0:
-                temp = row
-            output.append(temp)
+                    output.append(web)
+        if len(output) == 0:
+            output = search_results
         return output
+
+    def get_kg_results(self, entities: List, domains: List[str], skip_flag: List[bool]):
+        results = [
+            self.kg_api(e, d) if ~f else ""
+            for e, d, f in zip(entities, domains, skip_flag)
+        ]
+        return results
+
+    def kg_entity_extract(
+        self,
+        queries: List[str],
+        query_times: List[str],
+        domains: List[str],
+        skip_flag: List[bool],
+    ) -> List[str]:
+        skip_flag = [(d == "open") | s for d, s in zip(domains, skip_flag)]
+
+        inputs = [
+            self.template.format(
+                context={"query": query, "query_time": query_time},
+                template=f"entity_{domain}",
+            )
+            for query, domain, query_time in zip(queries, domains, query_times)
+        ]
+
+        responses = self.llm_inference_with_skip_flag(inputs, skip_flag, stop=["}"])
+        answers = [
+            response.outputs[0].text if response else "" for response in responses
+        ]
+
+        entities = []
+        for res in answers:
+            try:
+                res = json.loads(res)
+            except:
+                res = extract_json_objects(res)
+            entities.append(res)
+        return entities
-- 
GitLab