From ff801756628744c0635a2354497f7216019b143b Mon Sep 17 00:00:00 2001
From: Fanyou Wu <fanyou.wu@outlook.com>
Date: Sun, 16 Jun 2024 22:20:26 -0400
Subject: [PATCH] convert query to string

---
 models/llama_loader.py |  4 ++++
 models/v4.py           | 49 +++++++++++++++++++++++++++---------------
 2 files changed, 36 insertions(+), 17 deletions(-)

diff --git a/models/llama_loader.py b/models/llama_loader.py
index 11eee30..270a1e4 100644
--- a/models/llama_loader.py
+++ b/models/llama_loader.py
@@ -45,6 +45,7 @@ class VLLMConfig:
     temperature: float = 0  # less than 1e-5 is greedy, 0.1
     skip_special_tokens = True
     max_tokens = 256
+    truncate_prompt_tokens:int = 7936
     stop_token_ids: list = field(default_factory=lambda: [128009])
     include_stop_str_in_output: bool = True
     stop: List[str] = field(default_factory=lambda: ["</answer>"])
@@ -59,6 +60,8 @@ class VLLMConfig:
             "enforce_eager": self.enforce_eager,
         }
 
+    # temperature=0.6,
+    # top_p=0.9,
     def sampling_args(self, override_kwargs: dict = {}) -> vllm.SamplingParams:
         kwargs = {
             "n": self.n,
@@ -69,6 +72,7 @@ class VLLMConfig:
             "stop_token_ids": self.stop_token_ids,
             "stop": self.stop,
             "include_stop_str_in_output": self.include_stop_str_in_output,
+            "truncate_prompt_tokens": self.truncate_prompt_tokens,
         }
         kwargs.update(override_kwargs)
         return vllm.SamplingParams(**kwargs)
diff --git a/models/v4.py b/models/v4.py
index 104bf61..3e3de6c 100644
--- a/models/v4.py
+++ b/models/v4.py
@@ -20,6 +20,15 @@ import os
 
 CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000")
 
+import re
+
+
+def check_keywords(text, keywords):
+    for keyword in keywords:
+        if re.search(keyword, text, re.IGNORECASE):
+            return True
+    return False
+
 
 def extract_json_objects(text, decoder=JSONDecoder()):
     """Find JSON objects in text, and yield the decoded JSON data"""
@@ -128,8 +137,7 @@ class V4(Base):
 
     def batch_generate_answer(self, batch: Dict, return_dict: bool = False):
         if not return_dict:
-            queries = batch["query"]
-            # print(queries)
+            queries = [str(x) for x in batch["query"]]
             final_answers_flag = [False] * len(queries)
             final_answers = [""] * len(queries)
 
@@ -148,13 +156,11 @@ class V4(Base):
             )
 
             simple_qa = self.simple_qa(queries, domains, final_answers_flag)
-
             final_answers = [
-                x if len(x) > 0 else y for x, y in zip(simple_qa, final_answers)
+                (x if x != "" else y) for x, y in zip(simple_qa, final_answers)
             ]
-
             final_answers_flag = [
-                (len(x) > 0) or (y) for x, y in zip(simple_qa, final_answers_flag)
+                ((x != "") or (y)) for x, y in zip(simple_qa, final_answers_flag)
             ]
 
             search_results = self.remove_duplicated_and_sort_trusted(
@@ -165,7 +171,7 @@ class V4(Base):
             contents = self.content_retriever(
                 queries, search_results, final_answers_flag
             )
-            
+
             query_times = batch["query_time"]
             answers = self.qa_with_ref(
                 queries,
@@ -194,18 +200,20 @@ class V4(Base):
             )
 
             answers = [self.get_answer_output(x) for x in answers]
-            final_answers = [
+            
+            output = [
                 final_answers[idx] if flag else answers[idx]
                 for idx, flag in enumerate(final_answers_flag)
             ]
-            final_answers = [x.lower() for x in final_answers]
+            
+            output = [x.lower() for x in output]
 
             # print("after final_answers")
             # print("final_answers_flag", final_answers_flag)
             # print("final_answers", final_answers)
             # print("-----")
 
-            return final_answers
+            return output
 
         else:
             # queries = batch["query"]
@@ -219,7 +227,7 @@ class V4(Base):
             #     # "pred_domain": domains,
             # }
 
-            queries = batch["query"]
+            queries = [str(x) for x in batch["query"]]
             skip_flags = [False] * len(queries)
             question_type = self.query_type_by_llm(queries, skip_flags)
 
@@ -319,23 +327,30 @@ class V4(Base):
     def simple_qa(
         self, queries: List[str], domains: List[str], skip_flag: List[bool]
     ) -> List[str]:
+        
         not_skip = [(d == "movie") and (not s) for d, s in zip(domains, skip_flag)]
-        not_skip = [
-            (("oscar" in q) or ("academy" in q)) and s
-            for q, s in zip(queries, not_skip)
+        contain_keywords = [
+            check_keywords(x, ["oscar", "academy", "award"]) for x in queries
         ]
-        skip = [not x for x in not_skip]
+        skip = [(not x) or (not y) for x, y in zip(not_skip, contain_keywords)]
 
         inputs = [
             self.template.format(context={"query": query}, template="simple_qa")
             for query in queries
         ]
+
         responses = self.llm_inference_with_skip_flag(
-            inputs, skip, max_tokens=128, include_stop_str_in_output=False
+            inputs,
+            skip,
+            max_tokens=96,
+            stop=["<|eot_id|>"],
+            include_stop_str_in_output=False,
         )
+
         answers = [
             response.outputs[0].text if response else "" for response in responses
         ]
+
         answers = [x.lower().strip() for x in answers]
         return answers
 
@@ -357,6 +372,7 @@ class V4(Base):
         self, inputs: List[str], skip_flag: List[bool], **kwargs
     ) -> List:
         llm_inputs = []
+
         for input, flag in zip(inputs, skip_flag):
             if not flag:
                 llm_inputs.append(input)
@@ -394,7 +410,6 @@ class V4(Base):
             extract_score(response.outputs[0]) if response else (0.0, 0.0)
             for response in responses
         ]
-        # print("false premise score", scores)
         if return_bool:
             output = [
                 (x[0] - x[1]) > self.config.false_premise_threshold for x in scores
-- 
GitLab