From c4105a1e6f436a53fc66ee11a1e61419ab1be36e Mon Sep 17 00:00:00 2001 From: Fanyou Wu <fanyou.wu@outlook.com> Date: Mon, 17 Jun 2024 00:26:34 -0400 Subject: [PATCH] Text simple qa --- models/v4.py | 61 +++++++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/models/v4.py b/models/v4.py index 1c9770b..1f737f5 100644 --- a/models/v4.py +++ b/models/v4.py @@ -142,11 +142,12 @@ class V4(Base): def batch_generate_answer(self, batch: Dict, return_dict: bool = False): if not return_dict: - queries = [str(x) for x in batch["query"]] + queries = batch["query"] + print(queries) + final_answers_flag = [False] * len(queries) final_answers = [""] * len(queries) - # Check if the question type is other or not. question_type = self.query_type_by_llm(queries, final_answers_flag) is_not_other = [x != "other" for x in question_type] final_answers_flag, final_answers = self.fillin_answer_with_flag( @@ -212,12 +213,6 @@ class V4(Base): ] output = [x.lower() for x in output] - - # print("after final_answers") - # print("final_answers_flag", final_answers_flag) - # print("final_answers", final_answers) - # print("-----") - return output else: @@ -332,34 +327,36 @@ class V4(Base): def simple_qa( self, queries: List[str], domains: List[str], skip_flag: List[bool] ) -> List[str]: - try: - not_skip = [(d == "movie") and (not s) for d, s in zip(domains, skip_flag)] - contain_keywords = [ - check_keywords(x, ["oscar", "academy", "award"]) for x in queries - ] - skip = [(not x) or (not y) for x, y in zip(not_skip, contain_keywords)] + + # not_skip = [(d == "movie") and (not s) for d, s in zip(domains, skip_flag)] + # contain_keywords = [ + # check_keywords(x, ["oscar", "academy", "award"]) for x in queries + # ] + # skip = [(not x) or (not y) for x, y in zip(not_skip, contain_keywords)] - inputs = [ - self.template.format(context={"query": query}, template="simple_qa") - for query in queries - ] + skip = [False] * len(queries) - responses = self.llm_inference_with_skip_flag( - inputs, - skip, - max_tokens=96, - stop=["<|eot_id|>"], - include_stop_str_in_output=False, - ) + inputs = [ + self.template.format(context={"query": query}, template="simple_qa") + for query in queries + ] - answers = [ - response.outputs[0].text if response else "" for response in responses - ] + responses = self.llm_inference_with_skip_flag( + inputs, + skip, + max_tokens=64, + stop=["<|eot_id|>"], + include_stop_str_in_output=False, + ) - answers = [x.lower().strip() for x in answers] - return answers - except: - return [""] * len(queries) + answers = [ + response.outputs[0].text if response else "" for response in responses + ] + + answers = [x.lower().strip() for x in answers] + print(answers) + exit() + return answers def query_type_by_llm(self, queries: List[str], skip_flag: List[bool]) -> List[str]: inputs = [ -- GitLab