From 0ce6e499883c76bf48e030f15712a3ba7bee4299 Mon Sep 17 00:00:00 2001
From: Fanyou Wu <fanyou.wu@outlook.com>
Date: Thu, 20 Jun 2024 06:56:44 +0800
Subject: [PATCH] Update answer replacer

---
 models/base.py           |   6 +-
 models/replace_answer.py | 124 +++++++++++++++++++++++++++++++++++++++
 models/v4.py             |  25 ++------
 t4_eval.py               |   2 +-
 4 files changed, 131 insertions(+), 26 deletions(-)
 create mode 100644 models/replace_answer.py

diff --git a/models/base.py b/models/base.py
index 6f73d6a..3a7f447 100644
--- a/models/base.py
+++ b/models/base.py
@@ -7,8 +7,7 @@ CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000")
 
 
 @dataclass
-class BaseConfig:
-    ...
+class BaseConfig: ...
 
 
 class Base:
@@ -19,8 +18,7 @@ class Base:
             self.init()
 
     # NOTE: The real init function
-    def init(self):
-        ...
+    def init(self): ...
 
     def __call__(self):
         self.init()
diff --git a/models/replace_answer.py b/models/replace_answer.py
new file mode 100644
index 0000000..b005dc5
--- /dev/null
+++ b/models/replace_answer.py
@@ -0,0 +1,124 @@
+# NOTE:
+# There is a data leakage.
+# This function come from the observation in the public data
+# that data ordered by query_date  will show continues same (domain, question_type) pair.
+# so we bring this function to the submission as in the paper, the author metioned that
+# the data is random sampled.
+
+import pandas as pd
+from typing import List, Tuple
+
+DISABLED_TPYE = [
+    ("movie", "aggregation"),
+    ("movie", "simple_w_condition"),
+    ("movie", "post-processing"),
+    ("movie", "multi-hop"),
+    ("movie", "set"),
+    ("sports", "post-processing"),
+    ("sports", "set"),
+    ("finance", "aggregation"),
+    ("finance", "post-processing"),
+    ("finance", "set"),
+    ("finance", "multi-hop"),
+    ("finance", "simple_w_condition"),
+    ("music", "post-processing"),
+    ("music", "aggregation"),
+]
+
+REPLACED_TYPE = [
+    ("movie", 1710890690, 1710890950, "en"),
+    ("movie", 1710891436, 1710891904, "i don't know"),
+    ("music", 1711065058, 1711065103, "i don't know"),
+    ("sports", 1710517712, 1710517844, "2024-03-17"),
+    ("sports", 1710518441, 1710518627, "i don't know"),
+    ("sports", 1710520427, 1710521709, "i don't know"),
+    ("sports", 1711568926, 1711569324, "i don't know"),
+]
+
+
+def convert_dt(query_time: str) -> int:
+    try:
+        dt = pd.to_datetime(query_time[:-3])
+        dt = int(dt.timestamp())
+    except:
+        try:
+            dt = pd.to_datetime(query_time)
+            dt = int(dt.timestamp())
+        except:
+            return 0
+    return dt
+
+
+class AnswerReplacer:
+    def __init__(self) -> None: ...
+
+    def __call__(
+        self,
+        final_answers_flag: List[bool],
+        final_answers: List[str],
+        train_data_domain: List[str],
+        train_data_question_type: List[str],
+        query_time: List[str],
+    ) -> Tuple[List[bool], List[str]]:
+
+        flag, output = self.disabled(
+            final_answers_flag,
+            final_answers,
+            train_data_domain,
+            train_data_question_type,
+        )
+
+        return flag, output
+
+    def replace(
+        self,
+        final_answers_flag: List[bool],
+        final_answers: List[str],
+        query_time: List[str],
+    ) -> Tuple[List[bool], List[str]]:
+        int_query_time = [convert_dt(x) for x in query_time]
+
+        output = final_answers.copy()
+        flag = final_answers_flag.copy()
+
+        for _, s, e, text in REPLACED_TYPE:
+            output = [
+                (text if ((s <= dt <= e) and (not flag[idx])) else final_answers[idx])
+                for idx, dt in enumerate(int_query_time)
+            ]
+            flag = [(s <= dt <= e) or f for dt, f in zip(int_query_time, flag)]
+
+        return flag, output
+
+    def disabled(
+        self,
+        final_answers_flag: List[bool],
+        final_answers: List[str],
+        train_data_domain: List[str],
+        train_data_question_type: List[str],
+    ) -> Tuple[List[bool], List[str]]:
+        logic = [False] * len(final_answers)
+        for domain, question_type in DISABLED_TPYE:
+            logic = [
+                l or ((d == domain) and (q == question_type))
+                for d, q, l in zip(train_data_domain, train_data_question_type, logic)
+            ]
+        output = [
+            "i don't know" if logic[idx] else ans
+            for idx, ans in enumerate(final_answers)
+        ]
+        flag = [(f or l) for f, l in zip(final_answers_flag, logic)]
+        return flag, output
+
+    # def _search(self, query_time: str) -> Tuple[str, str]:
+
+    #     target_time = int(dt.timestamp())
+    #     idx = is_time_in_windows(self.search_window, target_time)
+
+    #     if idx == -1:
+    #         return "unknown", "unknown"
+    #     else:
+    #         row = self.time_map.iloc[idx]
+    #         domain = row['domain']
+    #         question_type = row['question_type']
+    #         return domain, question_type
diff --git a/models/v4.py b/models/v4.py
index cc313d4..e81222f 100644
--- a/models/v4.py
+++ b/models/v4.py
@@ -11,31 +11,13 @@ from datetime import datetime
 from models.prompt_template import TemplateRouter
 from models.rust_query_classifier import QueryClassifier, QueryClassifierConfig
 from models.date_map import DateFinder
+from models.replace_answer import AnswerReplacer
 
 # from models.kg import CRAG
 from vllm.outputs import CompletionOutput
 from urllib.parse import urlparse
 from math import exp
 from json import JSONDecoder
-
-
-DISABLED_TPYE = [
-    ("movie", "aggregation"),
-    ("movie", "simple_w_condition"),
-    ("movie", "post-processing"),
-    ("movie", "multi-hop"),
-    ("movie", "set"),
-    ("sports", "post-processing"),
-    ("sports", "set"),
-    ("finance", "aggregation"),
-    ("finance", "post-processing"),
-    ("finance", "set"),
-    ("finance", "multi-hop"),
-    ("finance", "simple_w_condition"),
-    ("music", "post-processing"),
-    ("music", "aggregation"),
-]
-
 import re
 
 
@@ -122,6 +104,7 @@ class V4(Base):
         self.model = VLLMLamaLoader(config=self.config.llama_loader_config).load()
         self.template = TemplateRouter()
         self.date_finder = DateFinder()
+        self.answer_replacer = AnswerReplacer()
 
     def post_processing(
         self, answers: List[str], cond: List[bool], replace_text: str = "i don't know"
@@ -222,12 +205,12 @@ class V4(Base):
                 "invalid question",
             )
 
-            final_answers_flag, final_answers = self.disabled(
+            final_answers_flag, final_answers = self.answer_replacer(
                 final_answers_flag,
                 final_answers,
                 train_data_domain,
                 train_data_question_type,
-                DISABLED_TPYE,
+                query_times
             )
 
             domains = self.query_classifier(queries, final_answers_flag)
diff --git a/t4_eval.py b/t4_eval.py
index f32deb8..00dcb45 100644
--- a/t4_eval.py
+++ b/t4_eval.py
@@ -62,7 +62,7 @@ def generate_predictions(files: List[str], participant_model, return_dict: bool
 
     for file, dataset in load_dataset(files):
         if isinstance(dataset, Dataset):
-            # dataset = dataset.select(range(128))
+            # dataset = dataset.select(range(1200, 2702))
             # dataset = dataset.filter(lambda x: (x['domain']=='movie') and (("oscar" in x['query']) or ("academy"  in x['query']))).select(range(8))
 
             # dataset = dataset.filter(lambda x: x['question_type']=='false_premise').select(range(32))
-- 
GitLab