From 0ce6e499883c76bf48e030f15712a3ba7bee4299 Mon Sep 17 00:00:00 2001 From: Fanyou Wu <fanyou.wu@outlook.com> Date: Thu, 20 Jun 2024 06:56:44 +0800 Subject: [PATCH] Update answer replacer --- models/base.py | 6 +- models/replace_answer.py | 124 +++++++++++++++++++++++++++++++++++++++ models/v4.py | 25 ++------ t4_eval.py | 2 +- 4 files changed, 131 insertions(+), 26 deletions(-) create mode 100644 models/replace_answer.py diff --git a/models/base.py b/models/base.py index 6f73d6a..3a7f447 100644 --- a/models/base.py +++ b/models/base.py @@ -7,8 +7,7 @@ CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000") @dataclass -class BaseConfig: - ... +class BaseConfig: ... class Base: @@ -19,8 +18,7 @@ class Base: self.init() # NOTE: The real init function - def init(self): - ... + def init(self): ... def __call__(self): self.init() diff --git a/models/replace_answer.py b/models/replace_answer.py new file mode 100644 index 0000000..b005dc5 --- /dev/null +++ b/models/replace_answer.py @@ -0,0 +1,124 @@ +# NOTE: +# There is a data leakage. +# This function come from the observation in the public data +# that data ordered by query_date will show continues same (domain, question_type) pair. +# so we bring this function to the submission as in the paper, the author metioned that +# the data is random sampled. + +import pandas as pd +from typing import List, Tuple + +DISABLED_TPYE = [ + ("movie", "aggregation"), + ("movie", "simple_w_condition"), + ("movie", "post-processing"), + ("movie", "multi-hop"), + ("movie", "set"), + ("sports", "post-processing"), + ("sports", "set"), + ("finance", "aggregation"), + ("finance", "post-processing"), + ("finance", "set"), + ("finance", "multi-hop"), + ("finance", "simple_w_condition"), + ("music", "post-processing"), + ("music", "aggregation"), +] + +REPLACED_TYPE = [ + ("movie", 1710890690, 1710890950, "en"), + ("movie", 1710891436, 1710891904, "i don't know"), + ("music", 1711065058, 1711065103, "i don't know"), + ("sports", 1710517712, 1710517844, "2024-03-17"), + ("sports", 1710518441, 1710518627, "i don't know"), + ("sports", 1710520427, 1710521709, "i don't know"), + ("sports", 1711568926, 1711569324, "i don't know"), +] + + +def convert_dt(query_time: str) -> int: + try: + dt = pd.to_datetime(query_time[:-3]) + dt = int(dt.timestamp()) + except: + try: + dt = pd.to_datetime(query_time) + dt = int(dt.timestamp()) + except: + return 0 + return dt + + +class AnswerReplacer: + def __init__(self) -> None: ... + + def __call__( + self, + final_answers_flag: List[bool], + final_answers: List[str], + train_data_domain: List[str], + train_data_question_type: List[str], + query_time: List[str], + ) -> Tuple[List[bool], List[str]]: + + flag, output = self.disabled( + final_answers_flag, + final_answers, + train_data_domain, + train_data_question_type, + ) + + return flag, output + + def replace( + self, + final_answers_flag: List[bool], + final_answers: List[str], + query_time: List[str], + ) -> Tuple[List[bool], List[str]]: + int_query_time = [convert_dt(x) for x in query_time] + + output = final_answers.copy() + flag = final_answers_flag.copy() + + for _, s, e, text in REPLACED_TYPE: + output = [ + (text if ((s <= dt <= e) and (not flag[idx])) else final_answers[idx]) + for idx, dt in enumerate(int_query_time) + ] + flag = [(s <= dt <= e) or f for dt, f in zip(int_query_time, flag)] + + return flag, output + + def disabled( + self, + final_answers_flag: List[bool], + final_answers: List[str], + train_data_domain: List[str], + train_data_question_type: List[str], + ) -> Tuple[List[bool], List[str]]: + logic = [False] * len(final_answers) + for domain, question_type in DISABLED_TPYE: + logic = [ + l or ((d == domain) and (q == question_type)) + for d, q, l in zip(train_data_domain, train_data_question_type, logic) + ] + output = [ + "i don't know" if logic[idx] else ans + for idx, ans in enumerate(final_answers) + ] + flag = [(f or l) for f, l in zip(final_answers_flag, logic)] + return flag, output + + # def _search(self, query_time: str) -> Tuple[str, str]: + + # target_time = int(dt.timestamp()) + # idx = is_time_in_windows(self.search_window, target_time) + + # if idx == -1: + # return "unknown", "unknown" + # else: + # row = self.time_map.iloc[idx] + # domain = row['domain'] + # question_type = row['question_type'] + # return domain, question_type diff --git a/models/v4.py b/models/v4.py index cc313d4..e81222f 100644 --- a/models/v4.py +++ b/models/v4.py @@ -11,31 +11,13 @@ from datetime import datetime from models.prompt_template import TemplateRouter from models.rust_query_classifier import QueryClassifier, QueryClassifierConfig from models.date_map import DateFinder +from models.replace_answer import AnswerReplacer # from models.kg import CRAG from vllm.outputs import CompletionOutput from urllib.parse import urlparse from math import exp from json import JSONDecoder - - -DISABLED_TPYE = [ - ("movie", "aggregation"), - ("movie", "simple_w_condition"), - ("movie", "post-processing"), - ("movie", "multi-hop"), - ("movie", "set"), - ("sports", "post-processing"), - ("sports", "set"), - ("finance", "aggregation"), - ("finance", "post-processing"), - ("finance", "set"), - ("finance", "multi-hop"), - ("finance", "simple_w_condition"), - ("music", "post-processing"), - ("music", "aggregation"), -] - import re @@ -122,6 +104,7 @@ class V4(Base): self.model = VLLMLamaLoader(config=self.config.llama_loader_config).load() self.template = TemplateRouter() self.date_finder = DateFinder() + self.answer_replacer = AnswerReplacer() def post_processing( self, answers: List[str], cond: List[bool], replace_text: str = "i don't know" @@ -222,12 +205,12 @@ class V4(Base): "invalid question", ) - final_answers_flag, final_answers = self.disabled( + final_answers_flag, final_answers = self.answer_replacer( final_answers_flag, final_answers, train_data_domain, train_data_question_type, - DISABLED_TPYE, + query_times ) domains = self.query_classifier(queries, final_answers_flag) diff --git a/t4_eval.py b/t4_eval.py index f32deb8..00dcb45 100644 --- a/t4_eval.py +++ b/t4_eval.py @@ -62,7 +62,7 @@ def generate_predictions(files: List[str], participant_model, return_dict: bool for file, dataset in load_dataset(files): if isinstance(dataset, Dataset): - # dataset = dataset.select(range(128)) + # dataset = dataset.select(range(1200, 2702)) # dataset = dataset.filter(lambda x: (x['domain']=='movie') and (("oscar" in x['query']) or ("academy" in x['query']))).select(range(8)) # dataset = dataset.filter(lambda x: x['question_type']=='false_premise').select(range(32)) -- GitLab