From 663d93aef94ac3ce589971b2db0f20660e73190e Mon Sep 17 00:00:00 2001 From: Fanyou Wu <fanyou.wu@outlook.com> Date: Sun, 16 Jun 2024 14:49:20 -0400 Subject: [PATCH] Update --- models/base.py | 3 - models/v4.py | 221 +++++++++++++++++++++++++++++-------------------- 2 files changed, 129 insertions(+), 95 deletions(-) diff --git a/models/base.py b/models/base.py index 0820e0f..6f73d6a 100644 --- a/models/base.py +++ b/models/base.py @@ -25,6 +25,3 @@ class Base: def __call__(self): self.init() return self - - def generate_answer(self, query: str, search_results: List[str]) -> str: - return "i don't know" diff --git a/models/v4.py b/models/v4.py index 8700956..f3ed71d 100644 --- a/models/v4.py +++ b/models/v4.py @@ -67,7 +67,7 @@ class V4Config: ) eval_batch_size: int = 8 max_num_doc: int = 10 - false_premise_threshold: float = 0.97 + false_premise_threshold: float = 0.8 #0.97 offline best trusted_source: list = field( default_factory=lambda: [ "wiki", @@ -94,7 +94,7 @@ class V4(Base): ) self.model = VLLMLamaLoader(config=self.config.llama_loader_config).load() self.template = TemplateRouter() - self.kg_api = CRAG(server=CRAG_MOCK_API_URL) + def post_processing( self, answers: List[str], cond: List[bool], replace_text: str = "i don't know" @@ -161,7 +161,7 @@ class V4(Base): final_answers = [ x if len(x) > 0 else y for x, y in zip(simple_qa, final_answers) ] - final_answers_flag = [(len(x) > 0) or (y) for x, y in zip(simple_qa,final_answers_flag) ] + final_answers_flag = [(len(x) > 0) or (y) for x, y in zip(simple_qa, final_answers_flag) ] # print("after simple_qa") # print("final_answers_flag", final_answers_flag) @@ -169,8 +169,11 @@ class V4(Base): # print("-----") search_results = self.remove_duplicated_and_sort_trusted( - batch["search_results"], domains=domains + batch["search_results"] ) + # search_results = self.remove_duplicated_and_sort_trusted( + # batch["search_results"], domains=domains + # ) search_results = [x[: self.config.max_num_doc] for x in search_results] contents = self.content_retriever( queries, search_results, final_answers_flag @@ -188,13 +191,19 @@ class V4(Base): false_premise = self.false_premise_by_llm( queries, answers, final_answers_flag ) + low_confidence = [self.get_confidence_output(x) == "low" for x in answers] + false_premise = [a & b for a, b in zip(false_premise, low_confidence)] final_answers_flag, final_answers = self.fillin_answer_with_flag( false_premise, final_answers_flag, final_answers, "invalid question" ) + print("after false_premise") + print("false_premise", false_premise) + print("-----") + confidence = [self.get_confidence_output(x) != "high" for x in answers] final_answers_flag, final_answers = self.fillin_answer_with_flag( confidence, final_answers_flag, final_answers, "i don't know" @@ -206,7 +215,14 @@ class V4(Base): for idx, flag in enumerate(final_answers_flag) ] final_answers = [x.lower() for x in final_answers] + + # print("after final_answers") + # print("final_answers_flag", final_answers_flag) + # print("final_answers", final_answers) + # print("-----") + return final_answers + else: # queries = batch["query"] # skip_flags = [False] * len(queries) @@ -226,7 +242,7 @@ class V4(Base): query_classes = self.query_classifier(queries, skip_flags) domains = query_classes["domain"] search_results = self.remove_duplicated_and_sort_trusted( - batch["search_results"], domains + batch["search_results"] ) contents = self.content_retriever(queries, search_results, skip_flags) query_times = batch["query_time"] @@ -391,6 +407,7 @@ class V4(Base): extract_score(response.outputs[0]) if response else (0.0, 0.0) for response in responses ] + # print("false premise score", scores) if return_bool: output = [ (x[0] - x[1]) > self.config.false_premise_threshold for x in scores @@ -409,98 +426,118 @@ class V4(Base): def get_batch_size(self) -> int: return self.config.eval_batch_size - def remove_duplicated_and_sort_trusted( - self, search_results: List[List[Dict]], domains: List[str] - ): - output = [] - for row, domain in zip(search_results, domains): - if domain != "open": - output.append(self.remove_duplicated_and_sort_trusted_other(row)) - else: - output.append(self.remove_duplicated_and_sort_trusted_open(row)) - return output - - def remove_duplicated_and_sort_trusted_other(self, search_results: List[Dict]): - output = [] - added_url = set() - for web in search_results: - url = web["page_url"] - added_url.add(url) - if web["page_result"] != "": - output.append(web) - if len(output) == 0: - output = search_results - return output - - def remove_duplicated_and_sort_trusted_open(self, search_results: List[Dict]): + def remove_duplicated_and_sort_trusted(self, search_results: List[List[Dict]]): output = [] - - trust_url = set() - added_url = set() - for keyword in self.config.trusted_source: - for r in search_results: - url = r["page_url"] - domain = urlparse(url).netloc - if keyword in domain: - if r["page_result"] != "": - trust_url.add(url) - - if len(trust_url) == 0: - trust_url = set([x["page_url"] for x in search_results]) - - for web in search_results: - url = web["page_url"] - if (url in added_url) | (url not in trust_url): - continue - added_url.add(url) - if web["page_result"] != "": - output.append(web) - - if len(output) == 0: + for row in search_results: + temp = [] added_url = set() - for web in search_results: + for web in row: url = web["page_url"] added_url.add(url) if web["page_result"] != "": - output.append(web) - if len(output) == 0: - output = search_results + temp.append(web) + if len(temp) == 0: + temp = row + output.append(temp) return output - def get_kg_results(self, entities: List, domains: List[str], skip_flag: List[bool]): - results = [ - self.kg_api(e, d) if ~f else "" - for e, d, f in zip(entities, domains, skip_flag) - ] - return results - - def kg_entity_extract( - self, - queries: List[str], - query_times: List[str], - domains: List[str], - skip_flag: List[bool], - ) -> List[str]: - skip_flag = [(d == "open") | s for d, s in zip(domains, skip_flag)] - - inputs = [ - self.template.format( - context={"query": query, "query_time": query_time}, - template=f"entity_{domain}", - ) - for query, domain, query_time in zip(queries, domains, query_times) - ] - - responses = self.llm_inference_with_skip_flag(inputs, skip_flag, stop=["}"]) - answers = [ - response.outputs[0].text if response else "" for response in responses - ] - entities = [] - for res in answers: - try: - res = json.loads(res) - except: - res = extract_json_objects(res) - entities.append(res) - return entities + # def remove_duplicated_and_sort_trusted( + # self, search_results: List[List[Dict]], domains: List[str] + # ): + # output = [] + # for row, domain in zip(search_results, domains): + # if domain != "open": + # output.append(self.remove_duplicated_and_sort_trusted_other(row)) + # else: + # output.append(self.remove_duplicated_and_sort_trusted_open(row)) + # return output + + # def remove_duplicated_and_sort_trusted_other(self, search_results: List[Dict]): + # output = [] + # added_url = set() + # for web in search_results: + # url = web["page_url"] + # added_url.add(url) + # if web["page_result"] != "": + # output.append(web) + # if len(output) == 0: + # output = search_results + # return output + + # def remove_duplicated_and_sort_trusted_open(self, search_results: List[Dict]): + # output = [] + + # trust_url = set() + # added_url = set() + + # for keyword in self.config.trusted_source: + # for r in search_results: + # url = r["page_url"] + # domain = urlparse(url).netloc + # if keyword in domain: + # if r["page_result"] != "": + # trust_url.add(url) + + # if len(trust_url) == 0: + # trust_url = set([x["page_url"] for x in search_results]) + + # for web in search_results: + # url = web["page_url"] + # if (url in added_url) | (url not in trust_url): + # continue + # added_url.add(url) + # if web["page_result"] != "": + # output.append(web) + + # if len(output) == 0: + # added_url = set() + # for web in search_results: + # url = web["page_url"] + # added_url.add(url) + # if web["page_result"] != "": + # output.append(web) + + # if len(output) == 0: + # output = search_results + + # return output + + # def get_kg_results(self, entities: List, domains: List[str], skip_flag: List[bool]): + # kg_api = CRAG(server=CRAG_MOCK_API_URL) + # results = [ + # kg_api(e, d) if ~f else "" + # for e, d, f in zip(entities, domains, skip_flag) + # ] + # return results + + # def kg_entity_extract( + # self, + # queries: List[str], + # query_times: List[str], + # domains: List[str], + # skip_flag: List[bool], + # ) -> List[str]: + # skip_flag = [(d == "open") | s for d, s in zip(domains, skip_flag)] + + # inputs = [ + # self.template.format( + # context={"query": query, "query_time": query_time}, + # template=f"entity_{domain}", + # ) + # for query, domain, query_time in zip(queries, domains, query_times) + # ] + + # responses = self.llm_inference_with_skip_flag(inputs, skip_flag, stop=["}"]) + # answers = [ + # response.outputs[0].text if response else "" for response in responses + # ] + + # entities = [] + # for res in answers: + # try: + # res = json.loads(res) + # except: + # res = extract_json_objects(res) + # entities.append(res) + # return entities -- GitLab