From e4747e192fd3291e38bf01beea48a1a41c2ea8e8 Mon Sep 17 00:00:00 2001 From: Fanyou Wu <fanyou.wu@outlook.com> Date: Thu, 20 Jun 2024 14:48:07 -0400 Subject: [PATCH] update --- models/v4.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/models/v4.py b/models/v4.py index 8fe8d40..66d2858 100644 --- a/models/v4.py +++ b/models/v4.py @@ -231,21 +231,21 @@ class V4(Base): is_not_static, final_answers_flag, final_answers, "i don't know" ) - question_type = self.query_type_by_llm(queries, final_answers_flag) - question_type = [ - x if x != "unknown" else y - for x, y in zip(train_data_question_type, question_type) - ] - - is_aggregation = [x == "aggregation" for x in question_type] - final_answers_flag, final_answers = self.fillin_answer_with_flag( - is_aggregation, final_answers_flag, final_answers, "i don't know" - ) + # question_type = self.query_type_by_llm(queries, final_answers_flag) + # question_type = [ + # x if x != "unknown" else y + # for x, y in zip(train_data_question_type, question_type) + # ] + + # is_aggregation = [x == "aggregation" for x in question_type] + # final_answers_flag, final_answers = self.fillin_answer_with_flag( + # is_aggregation, final_answers_flag, final_answers, "i don't know" + # ) - is_post_processing = [x == "post-processing" for x in question_type] - final_answers_flag, final_answers = self.fillin_answer_with_flag( - is_post_processing, final_answers_flag, final_answers, "i don't know" - ) + # is_post_processing = [x == "post-processing" for x in question_type] + # final_answers_flag, final_answers = self.fillin_answer_with_flag( + # is_post_processing, final_answers_flag, final_answers, "i don't know" + # ) simple_qa, simple_not_skip = self.simple_qa( queries, domains, final_answers_flag @@ -271,6 +271,7 @@ class V4(Base): contents = self.content_retriever( queries, search_results, final_answers_flag ) + answers = self.qa_with_ref( queries, contents, @@ -278,6 +279,7 @@ class V4(Base): domains, final_answers_flag, ) + false_premise_by_rag = self.false_premise_by_llm( queries, answers, final_answers_flag, return_bool=True ) @@ -346,7 +348,7 @@ class V4(Base): return { "queries": queries, "prediction": [self.get_answer_output(x) for x in answers], - "reason": [find_by_xml_tag(x, "reason", "") for x in answers], + #"reason": [find_by_xml_tag(x, "reason", "") for x in answers], "confidence": [self.get_confidence_output(x) for x in answers], } -- GitLab