From ff801756628744c0635a2354497f7216019b143b Mon Sep 17 00:00:00 2001 From: Fanyou Wu <fanyou.wu@outlook.com> Date: Sun, 16 Jun 2024 22:20:26 -0400 Subject: [PATCH] convert query to string --- models/llama_loader.py | 4 ++++ models/v4.py | 49 +++++++++++++++++++++++++++--------------- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/models/llama_loader.py b/models/llama_loader.py index 11eee30..270a1e4 100644 --- a/models/llama_loader.py +++ b/models/llama_loader.py @@ -45,6 +45,7 @@ class VLLMConfig: temperature: float = 0 # less than 1e-5 is greedy, 0.1 skip_special_tokens = True max_tokens = 256 + truncate_prompt_tokens:int = 7936 stop_token_ids: list = field(default_factory=lambda: [128009]) include_stop_str_in_output: bool = True stop: List[str] = field(default_factory=lambda: ["</answer>"]) @@ -59,6 +60,8 @@ class VLLMConfig: "enforce_eager": self.enforce_eager, } + # temperature=0.6, + # top_p=0.9, def sampling_args(self, override_kwargs: dict = {}) -> vllm.SamplingParams: kwargs = { "n": self.n, @@ -69,6 +72,7 @@ class VLLMConfig: "stop_token_ids": self.stop_token_ids, "stop": self.stop, "include_stop_str_in_output": self.include_stop_str_in_output, + "truncate_prompt_tokens": self.truncate_prompt_tokens, } kwargs.update(override_kwargs) return vllm.SamplingParams(**kwargs) diff --git a/models/v4.py b/models/v4.py index 104bf61..3e3de6c 100644 --- a/models/v4.py +++ b/models/v4.py @@ -20,6 +20,15 @@ import os CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000") +import re + + +def check_keywords(text, keywords): + for keyword in keywords: + if re.search(keyword, text, re.IGNORECASE): + return True + return False + def extract_json_objects(text, decoder=JSONDecoder()): """Find JSON objects in text, and yield the decoded JSON data""" @@ -128,8 +137,7 @@ class V4(Base): def batch_generate_answer(self, batch: Dict, return_dict: bool = False): if not return_dict: - queries = batch["query"] - # print(queries) + queries = [str(x) for x in batch["query"]] final_answers_flag = [False] * len(queries) final_answers = [""] * len(queries) @@ -148,13 +156,11 @@ class V4(Base): ) simple_qa = self.simple_qa(queries, domains, final_answers_flag) - final_answers = [ - x if len(x) > 0 else y for x, y in zip(simple_qa, final_answers) + (x if x != "" else y) for x, y in zip(simple_qa, final_answers) ] - final_answers_flag = [ - (len(x) > 0) or (y) for x, y in zip(simple_qa, final_answers_flag) + ((x != "") or (y)) for x, y in zip(simple_qa, final_answers_flag) ] search_results = self.remove_duplicated_and_sort_trusted( @@ -165,7 +171,7 @@ class V4(Base): contents = self.content_retriever( queries, search_results, final_answers_flag ) - + query_times = batch["query_time"] answers = self.qa_with_ref( queries, @@ -194,18 +200,20 @@ class V4(Base): ) answers = [self.get_answer_output(x) for x in answers] - final_answers = [ + + output = [ final_answers[idx] if flag else answers[idx] for idx, flag in enumerate(final_answers_flag) ] - final_answers = [x.lower() for x in final_answers] + + output = [x.lower() for x in output] # print("after final_answers") # print("final_answers_flag", final_answers_flag) # print("final_answers", final_answers) # print("-----") - return final_answers + return output else: # queries = batch["query"] @@ -219,7 +227,7 @@ class V4(Base): # # "pred_domain": domains, # } - queries = batch["query"] + queries = [str(x) for x in batch["query"]] skip_flags = [False] * len(queries) question_type = self.query_type_by_llm(queries, skip_flags) @@ -319,23 +327,30 @@ class V4(Base): def simple_qa( self, queries: List[str], domains: List[str], skip_flag: List[bool] ) -> List[str]: + not_skip = [(d == "movie") and (not s) for d, s in zip(domains, skip_flag)] - not_skip = [ - (("oscar" in q) or ("academy" in q)) and s - for q, s in zip(queries, not_skip) + contain_keywords = [ + check_keywords(x, ["oscar", "academy", "award"]) for x in queries ] - skip = [not x for x in not_skip] + skip = [(not x) or (not y) for x, y in zip(not_skip, contain_keywords)] inputs = [ self.template.format(context={"query": query}, template="simple_qa") for query in queries ] + responses = self.llm_inference_with_skip_flag( - inputs, skip, max_tokens=128, include_stop_str_in_output=False + inputs, + skip, + max_tokens=96, + stop=["<|eot_id|>"], + include_stop_str_in_output=False, ) + answers = [ response.outputs[0].text if response else "" for response in responses ] + answers = [x.lower().strip() for x in answers] return answers @@ -357,6 +372,7 @@ class V4(Base): self, inputs: List[str], skip_flag: List[bool], **kwargs ) -> List: llm_inputs = [] + for input, flag in zip(inputs, skip_flag): if not flag: llm_inputs.append(input) @@ -394,7 +410,6 @@ class V4(Base): extract_score(response.outputs[0]) if response else (0.0, 0.0) for response in responses ] - # print("false premise score", scores) if return_bool: output = [ (x[0] - x[1]) > self.config.false_premise_threshold for x in scores -- GitLab