From 3411dd12739826c191c5b97c2a607f935011b637 Mon Sep 17 00:00:00 2001 From: Fanyou Wu <fanyou.wu@outlook.com> Date: Sun, 16 Jun 2024 01:55:21 -0400 Subject: [PATCH] update --- models/v4.py | 223 +++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 197 insertions(+), 26 deletions(-) diff --git a/models/v4.py b/models/v4.py index a2574d8..8700956 100644 --- a/models/v4.py +++ b/models/v4.py @@ -10,10 +10,32 @@ from models.rust_content_retriever import ( from datetime import datetime from models.prompt_template import TemplateRouter from models.rust_query_classifier import QueryClassifier, QueryClassifierConfig - +from models.kg import CRAG from vllm.outputs import CompletionOutput -from math import exp from urllib.parse import urlparse +from math import exp +from json import JSONDecoder +import json +import os + +CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000") + + +def extract_json_objects(text, decoder=JSONDecoder()): + """Find JSON objects in text, and yield the decoded JSON data""" + pos = 0 + results = [] + while True: + match = text.find("{", pos) + if match == -1: + break + try: + result, index = decoder.raw_decode(text[match:]) + results.append(result) + pos = match + index + except ValueError: + pos = match + 1 + return results def extract_score(completion_output: CompletionOutput) -> Tuple[float, float]: @@ -45,22 +67,19 @@ class V4Config: ) eval_batch_size: int = 8 max_num_doc: int = 10 - false_premise_threshold: float = 0.1 + false_premise_threshold: float = 0.97 trusted_source: list = field( default_factory=lambda: [ - "wikipedia", + "wiki", "imdb.com", "grammy.com", "fandom.com", ] ) - # blacklist: list = field( - # default_factory=lambda: ["investopedia.com", "quora.com", "amazon.com"] - # ) class V4(Base): - method_version: str = "v4" + method_version: str = "v4_simple_qa" def __init__(self, config: V4Config, init_class: bool = False): self.config = config @@ -75,6 +94,7 @@ class V4(Base): ) self.model = VLLMLamaLoader(config=self.config.llama_loader_config).load() self.template = TemplateRouter() + self.kg_api = CRAG(server=CRAG_MOCK_API_URL) def post_processing( self, answers: List[str], cond: List[bool], replace_text: str = "i don't know" @@ -120,15 +140,36 @@ class V4(Base): is_not_other, final_answers_flag, final_answers, "i don't know" ) + # print("after question_type") + # print("final_answers_flag", final_answers_flag) + # print("final_answers", final_answers) + # print("-----") + # Check the domain of the query - query_classes = self.query_classifier(queries, final_answers_flag) - is_finance = [x == "finance" for x in query_classes["domain"]] + domains = self.query_classifier(queries, final_answers_flag)["domain"] + is_finance = [x == "finance" for x in domains] final_answers_flag, final_answers = self.fillin_answer_with_flag( is_finance, final_answers_flag, final_answers, "i don't know" ) + # print("after is_finance") + # print("final_answers_flag", final_answers_flag) + # print("final_answers", final_answers) + # print("-----") + + simple_qa = self.simple_qa(queries, domains, final_answers_flag) + final_answers = [ + x if len(x) > 0 else y for x, y in zip(simple_qa, final_answers) + ] + final_answers_flag = [(len(x) > 0) or (y) for x, y in zip(simple_qa,final_answers_flag) ] + + # print("after simple_qa") + # print("final_answers_flag", final_answers_flag) + # print("final_answers", final_answers) + # print("-----") + search_results = self.remove_duplicated_and_sort_trusted( - batch["search_results"] + batch["search_results"], domains=domains ) search_results = [x[: self.config.max_num_doc] for x in search_results] contents = self.content_retriever( @@ -139,7 +180,7 @@ class V4(Base): queries, contents, query_times, - query_classes["domain"], + domains, final_answers_flag, ) @@ -167,21 +208,43 @@ class V4(Base): final_answers = [x.lower() for x in final_answers] return final_answers else: + # queries = batch["query"] + # skip_flags = [False] * len(queries) + + # domains = self.query_classifier(queries, skip_flags)["domain"] + + # return { + # "queries": queries, + # "prediction": self.simple_qa(queries, domains, skip_flags), + # # "pred_domain": domains, + # } + queries = batch["query"] skip_flags = [False] * len(queries) question_type = self.query_type_by_llm(queries, skip_flags) + query_classes = self.query_classifier(queries, skip_flags) + domains = query_classes["domain"] search_results = self.remove_duplicated_and_sort_trusted( - batch["search_results"] + batch["search_results"], domains ) contents = self.content_retriever(queries, search_results, skip_flags) query_times = batch["query_time"] - domains = query_classes["domain"] + answers = self.qa_with_ref( queries, contents, query_times, domains, skip_flags ) + simple_qa_answer = self.simple_qa(queries, domains, skip_flags) prediction = [self.get_answer_output(x) for x in answers] - false_premise = self.false_premise_by_llm(queries, answers, skip_flags, return_bool=False) + prediction = [ + y if len(y) > 0 else x for x, y in zip(prediction, simple_qa_answer) + ] + + false_premise = self.false_premise_by_llm( + queries, answers, skip_flags, return_bool=False + ) + # entities = self.kg_entity_extract(queries, query_times, domains, skip_flags) + # kg_results = self.get_kg_results(entities, domains, skip_flags) return { "queries": queries, "prediction": prediction, @@ -190,6 +253,9 @@ class V4(Base): "pred_domain": domains, "confidence": [self.get_confidence_output(x) for x in answers], "question_type": question_type, + # "kd_entity": entities, + # "kg_result": kg_results, + # "simple_qa": self.simple_qa(queries, skip_flags), } def qa_with_ref( @@ -250,6 +316,25 @@ class V4(Base): input = self.template.format(context, template=domain) return input + def simple_qa( + self, queries: List[str], domains: List[str], skip_flag: List[bool] + ) -> List[str]: + skip = [s | (d != "movie") for d, s in zip(domains, skip_flag)] + skip = [ + s | (not (("oscar" in q) | ("academy" in q))) for q, s in zip(queries, skip) + ] + inputs = [ + self.template.format(context={"query": query}, template="simple_qa") + for query in queries + ] + responses = self.llm_inference_with_skip_flag( + inputs, skip, max_tokens=128, include_stop_str_in_output=False + ) + answers = [ + response.outputs[0].text if response else "" for response in responses + ] + answers = [x.lower().strip() for x in answers] + return answers def query_type_by_llm(self, queries: List[str], skip_flag: List[bool]) -> List[str]: inputs = [ @@ -287,7 +372,11 @@ class V4(Base): return output def false_premise_by_llm( - self, queries: List[str], answers: List[str], skip_flag: List[bool], return_bool:bool=True + self, + queries: List[str], + answers: List[str], + skip_flag: List[bool], + return_bool: bool = True, ) -> List: inputs = [ self.template.format( @@ -303,8 +392,9 @@ class V4(Base): for response in responses ] if return_bool: - - output = [(x[0] - x[1]) > self.config.false_premise_threshold for x in scores] + output = [ + (x[0] - x[1]) > self.config.false_premise_threshold for x in scores + ] return output return scores @@ -319,17 +409,98 @@ class V4(Base): def get_batch_size(self) -> int: return self.config.eval_batch_size - def remove_duplicated_and_sort_trusted(self, search_results: List[List[Dict]]): + def remove_duplicated_and_sort_trusted( + self, search_results: List[List[Dict]], domains: List[str] + ): + output = [] + for row, domain in zip(search_results, domains): + if domain != "open": + output.append(self.remove_duplicated_and_sort_trusted_other(row)) + else: + output.append(self.remove_duplicated_and_sort_trusted_open(row)) + return output + + def remove_duplicated_and_sort_trusted_other(self, search_results: List[Dict]): + output = [] + added_url = set() + for web in search_results: + url = web["page_url"] + added_url.add(url) + if web["page_result"] != "": + output.append(web) + if len(output) == 0: + output = search_results + return output + + def remove_duplicated_and_sort_trusted_open(self, search_results: List[Dict]): output = [] - for row in search_results: - temp = [] + + trust_url = set() + added_url = set() + for keyword in self.config.trusted_source: + for r in search_results: + url = r["page_url"] + domain = urlparse(url).netloc + if keyword in domain: + if r["page_result"] != "": + trust_url.add(url) + + if len(trust_url) == 0: + trust_url = set([x["page_url"] for x in search_results]) + + for web in search_results: + url = web["page_url"] + if (url in added_url) | (url not in trust_url): + continue + added_url.add(url) + if web["page_result"] != "": + output.append(web) + + if len(output) == 0: added_url = set() - for web in row: + for web in search_results: url = web["page_url"] added_url.add(url) if web["page_result"] != "": - temp.append(web) - if len(temp) == 0: - temp = row - output.append(temp) + output.append(web) + if len(output) == 0: + output = search_results return output + + def get_kg_results(self, entities: List, domains: List[str], skip_flag: List[bool]): + results = [ + self.kg_api(e, d) if ~f else "" + for e, d, f in zip(entities, domains, skip_flag) + ] + return results + + def kg_entity_extract( + self, + queries: List[str], + query_times: List[str], + domains: List[str], + skip_flag: List[bool], + ) -> List[str]: + skip_flag = [(d == "open") | s for d, s in zip(domains, skip_flag)] + + inputs = [ + self.template.format( + context={"query": query, "query_time": query_time}, + template=f"entity_{domain}", + ) + for query, domain, query_time in zip(queries, domains, query_times) + ] + + responses = self.llm_inference_with_skip_flag(inputs, skip_flag, stop=["}"]) + answers = [ + response.outputs[0].text if response else "" for response in responses + ] + + entities = [] + for res in answers: + try: + res = json.loads(res) + except: + res = extract_json_objects(res) + entities.append(res) + return entities -- GitLab