From 8679e7d4794eadf5c2edc87c03e76c8e4824bafb Mon Sep 17 00:00:00 2001
From: "S.P. Mohanty" <spmohanty91@gmail.com>
Date: Wed, 24 Apr 2024 17:59:48 +0000
Subject: [PATCH] Add query_time as a parameter to the generate_answer
 interface

---
 local_evaluation.py              | 23 +++++++++++++++--------
 models/README.md                 |  1 +
 models/dummy_model.py            |  5 ++++-
 models/rag_llama_baseline.py     |  5 ++++-
 models/vanilla_llama_baseline.py |  5 ++++-
 5 files changed, 28 insertions(+), 11 deletions(-)

diff --git a/local_evaluation.py b/local_evaluation.py
index b38feeb..bacda0f 100644
--- a/local_evaluation.py
+++ b/local_evaluation.py
@@ -77,27 +77,32 @@ def parse_response(resp: str):
     except:
         return -1
 
+
 def trim_predictions_to_max_token_length(prediction):
     """Trims prediction output to 75 tokens"""
     max_token_length = 75
     tokenized_prediction = tokenizer.encode(prediction)
-    trimmed_tokenized_prediction = tokenized_prediction[1: max_token_length+1]
+    trimmed_tokenized_prediction = tokenized_prediction[
+        1 : max_token_length + 1
+    ]
     trimmed_prediction = tokenizer.decode(trimmed_tokenized_prediction)
     return trimmed_prediction
 
-def generate_predictions(dataset_path, participant_model):    
+
+def generate_predictions(dataset_path, participant_model):
     predictions = []
     with bz2.open(DATASET_PATH, "rt") as bz2_file:
         for line in tqdm(bz2_file, desc="Generating Predictions"):
             data = json.loads(line)
-            
+
             query = data["query"]
             web_search_results = data["search_results"]
-            
+            query_time = data["query_time"]
+
             prediction = participant_model.generate_answer(
-                query, web_search_results
+                query, web_search_results, query_time
             )
-            
+
             # trim prediction to 75 tokens
             prediction = trim_predictions_to_max_token_length(prediction)
             predictions.append(
@@ -106,7 +111,7 @@ def generate_predictions(dataset_path, participant_model):
                     "ground_truth": str(data["answer"]).strip().lower(),
                     "prediction": str(prediction).strip().lower(),
                 }
-            )            
+            )
 
     return predictions
 
@@ -115,7 +120,9 @@ def evaluate_predictions(predictions, evaluation_model_name, openai_client):
     n_miss, n_correct, n_correct_exact = 0, 0, 0
     system_message = get_system_message()
 
-    for prediction_dict in tqdm(predictions, total=len(predictions), desc="Evaluating Predictions"):
+    for prediction_dict in tqdm(
+        predictions, total=len(predictions), desc="Evaluating Predictions"
+    ):
         query, ground_truth, prediction = (
             prediction_dict["query"],
             prediction_dict["ground_truth"],
diff --git a/models/README.md b/models/README.md
index e72366f..f9b9788 100644
--- a/models/README.md
+++ b/models/README.md
@@ -15,6 +15,7 @@ To ensure your model is recognized and utilized correctly, please specify your m
 Your model will receive two pieces of information for every task:
 - `query`: String representing the input query
 - `search_results`: List of strings, each comes from scraped HTML text of the search query.
+- `query_time`: The time at which the query was made, represented as a string.
 
 ### Outputs
 The output from your model's `generate_answer` function should always be a string.
diff --git a/models/dummy_model.py b/models/dummy_model.py
index 2ea9412..ac3919f 100644
--- a/models/dummy_model.py
+++ b/models/dummy_model.py
@@ -24,7 +24,9 @@ class DummyModel:
         """
         pass
 
-    def generate_answer(self, query: str, search_results: List[Dict]) -> str:
+    def generate_answer(
+        self, query: str, search_results: List[Dict], query_time: str
+    ) -> str:
         """
         Generate an answer based on a provided query and a list of pre-cached search results.
 
@@ -33,6 +35,7 @@ class DummyModel:
         - search_results (List[Dict]): A list containing the search result objects,
         as described here:
           https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail
+        - query_time (str): The time at which the query was made, represented as a string.
 
         Returns:
         - (str): A plain text response that answers the query. This response is limited to 75 tokens.
diff --git a/models/rag_llama_baseline.py b/models/rag_llama_baseline.py
index 0f16d3a..177c8ab 100644
--- a/models/rag_llama_baseline.py
+++ b/models/rag_llama_baseline.py
@@ -101,7 +101,9 @@ class RAGModel:
             max_new_tokens=10,
         )
 
-    def generate_answer(self, query: str, search_results: List[Dict]) -> str:
+    def generate_answer(
+        self, query: str, search_results: List[Dict], query_time: str
+    ) -> str:
         """
         Generate an answer based on the provided query and a list of pre-cached search results.
 
@@ -110,6 +112,7 @@ class RAGModel:
         - search_results (List[Dict]): A list containing the search result objects,
         as described here:
           https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail
+        - query_time (str): The time at which the query was made, represented as a string.
 
         Returns:
         - str: A text response that answers the query. Limited to 75 tokens.
diff --git a/models/vanilla_llama_baseline.py b/models/vanilla_llama_baseline.py
index a3d2e83..314f8b0 100644
--- a/models/vanilla_llama_baseline.py
+++ b/models/vanilla_llama_baseline.py
@@ -88,7 +88,9 @@ class ChatModel:
             max_new_tokens=75,
         )
 
-    def generate_answer(self, query: str, search_results: List[Dict]) -> str:
+    def generate_answer(
+        self, query: str, search_results: List[Dict], query_time: str
+    ) -> str:
         """
         Generate an answer based on a provided query and a list of pre-cached search results.
 
@@ -97,6 +99,7 @@ class ChatModel:
         - search_results (List[Dict]): A list containing the search result objects,
           as described here:
           https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail
+        - query_time (str): The time at which the query was made, represented as a string.
 
         Returns:
         - (str): A plain text response that answers the query. This response is limited to 75 tokens.
-- 
GitLab