From 9dbe6d34467356479d9436ac465f058495e2a35e Mon Sep 17 00:00:00 2001 From: Fanyou Wu <fanyou.wu@outlook.com> Date: Sun, 16 Jun 2024 20:10:31 -0400 Subject: [PATCH] Remove simple qa and change to t1 --- aicrowd.json | 2 +- models/v4.py | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/aicrowd.json b/aicrowd.json index 6fec5d7..d0c3347 100644 --- a/aicrowd.json +++ b/aicrowd.json @@ -1,5 +1,5 @@ { - "challenge_id": "meta-kdd-cup-24-crag-end-to-end-retrieval-augmented-generation", + "challenge_id": "meta-kdd-cup-24-crag-retrieval-summarization", "authors": [ "Fanyou Wu" ], diff --git a/models/v4.py b/models/v4.py index 81c4ce4..273a389 100644 --- a/models/v4.py +++ b/models/v4.py @@ -130,6 +130,7 @@ class V4(Base): def batch_generate_answer(self, batch: Dict, return_dict: bool = False): if not return_dict: queries = batch["query"] + # print(queries) final_answers_flag = [False] * len(queries) final_answers = [""] * len(queries) @@ -148,6 +149,7 @@ class V4(Base): # Check the domain of the query domains = self.query_classifier(queries, final_answers_flag)["domain"] is_finance = [x == "finance" for x in domains] + final_answers_flag, final_answers = self.fillin_answer_with_flag( is_finance, final_answers_flag, final_answers, "i don't know" ) @@ -156,12 +158,13 @@ class V4(Base): # print("final_answers_flag", final_answers_flag) # print("final_answers", final_answers) # print("-----") + # final_answers_flag = [True] * len(queries) - simple_qa = self.simple_qa(queries, domains, final_answers_flag) - final_answers = [ - x if len(x) > 0 else y for x, y in zip(simple_qa, final_answers) - ] - final_answers_flag = [(len(x) > 0) or (y) for x, y in zip(simple_qa, final_answers_flag) ] + # simple_qa = self.simple_qa(queries, domains, final_answers_flag) + # final_answers = [ + # x if len(x) > 0 else y for x, y in zip(simple_qa, final_answers) + # ] + # final_answers_flag = [(len(x) > 0) or (y) for x, y in zip(simple_qa, final_answers_flag) ] # print("after simple_qa") # print("final_answers_flag", final_answers_flag) @@ -200,9 +203,9 @@ class V4(Base): false_premise, final_answers_flag, final_answers, "invalid question" ) - print("after false_premise") - print("false_premise", false_premise) - print("-----") + # print("after false_premise") + # print("false_premise", false_premise) + # print("-----") confidence = [self.get_confidence_output(x) != "high" for x in answers] final_answers_flag, final_answers = self.fillin_answer_with_flag( @@ -335,10 +338,13 @@ class V4(Base): def simple_qa( self, queries: List[str], domains: List[str], skip_flag: List[bool] ) -> List[str]: + + skip = [s | (d != "movie") for d, s in zip(domains, skip_flag)] skip = [ s | (not (("oscar" in q) | ("academy" in q))) for q, s in zip(queries, skip) ] + inputs = [ self.template.format(context={"query": query}, template="simple_qa") for query in queries @@ -346,9 +352,11 @@ class V4(Base): responses = self.llm_inference_with_skip_flag( inputs, skip, max_tokens=128, include_stop_str_in_output=False ) + answers = [ response.outputs[0].text if response else "" for response in responses ] + answers = [x.lower().strip() for x in answers] return answers -- GitLab