From 8679e7d4794eadf5c2edc87c03e76c8e4824bafb Mon Sep 17 00:00:00 2001 From: "S.P. Mohanty" <spmohanty91@gmail.com> Date: Wed, 24 Apr 2024 17:59:48 +0000 Subject: [PATCH] Add query_time as a parameter to the generate_answer interface --- local_evaluation.py | 23 +++++++++++++++-------- models/README.md | 1 + models/dummy_model.py | 5 ++++- models/rag_llama_baseline.py | 5 ++++- models/vanilla_llama_baseline.py | 5 ++++- 5 files changed, 28 insertions(+), 11 deletions(-) diff --git a/local_evaluation.py b/local_evaluation.py index b38feeb..bacda0f 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -77,27 +77,32 @@ def parse_response(resp: str): except: return -1 + def trim_predictions_to_max_token_length(prediction): """Trims prediction output to 75 tokens""" max_token_length = 75 tokenized_prediction = tokenizer.encode(prediction) - trimmed_tokenized_prediction = tokenized_prediction[1: max_token_length+1] + trimmed_tokenized_prediction = tokenized_prediction[ + 1 : max_token_length + 1 + ] trimmed_prediction = tokenizer.decode(trimmed_tokenized_prediction) return trimmed_prediction -def generate_predictions(dataset_path, participant_model): + +def generate_predictions(dataset_path, participant_model): predictions = [] with bz2.open(DATASET_PATH, "rt") as bz2_file: for line in tqdm(bz2_file, desc="Generating Predictions"): data = json.loads(line) - + query = data["query"] web_search_results = data["search_results"] - + query_time = data["query_time"] + prediction = participant_model.generate_answer( - query, web_search_results + query, web_search_results, query_time ) - + # trim prediction to 75 tokens prediction = trim_predictions_to_max_token_length(prediction) predictions.append( @@ -106,7 +111,7 @@ def generate_predictions(dataset_path, participant_model): "ground_truth": str(data["answer"]).strip().lower(), "prediction": str(prediction).strip().lower(), } - ) + ) return predictions @@ -115,7 +120,9 @@ def evaluate_predictions(predictions, evaluation_model_name, openai_client): n_miss, n_correct, n_correct_exact = 0, 0, 0 system_message = get_system_message() - for prediction_dict in tqdm(predictions, total=len(predictions), desc="Evaluating Predictions"): + for prediction_dict in tqdm( + predictions, total=len(predictions), desc="Evaluating Predictions" + ): query, ground_truth, prediction = ( prediction_dict["query"], prediction_dict["ground_truth"], diff --git a/models/README.md b/models/README.md index e72366f..f9b9788 100644 --- a/models/README.md +++ b/models/README.md @@ -15,6 +15,7 @@ To ensure your model is recognized and utilized correctly, please specify your m Your model will receive two pieces of information for every task: - `query`: String representing the input query - `search_results`: List of strings, each comes from scraped HTML text of the search query. +- `query_time`: The time at which the query was made, represented as a string. ### Outputs The output from your model's `generate_answer` function should always be a string. diff --git a/models/dummy_model.py b/models/dummy_model.py index 2ea9412..ac3919f 100644 --- a/models/dummy_model.py +++ b/models/dummy_model.py @@ -24,7 +24,9 @@ class DummyModel: """ pass - def generate_answer(self, query: str, search_results: List[Dict]) -> str: + def generate_answer( + self, query: str, search_results: List[Dict], query_time: str + ) -> str: """ Generate an answer based on a provided query and a list of pre-cached search results. @@ -33,6 +35,7 @@ class DummyModel: - search_results (List[Dict]): A list containing the search result objects, as described here: https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail + - query_time (str): The time at which the query was made, represented as a string. Returns: - (str): A plain text response that answers the query. This response is limited to 75 tokens. diff --git a/models/rag_llama_baseline.py b/models/rag_llama_baseline.py index 0f16d3a..177c8ab 100644 --- a/models/rag_llama_baseline.py +++ b/models/rag_llama_baseline.py @@ -101,7 +101,9 @@ class RAGModel: max_new_tokens=10, ) - def generate_answer(self, query: str, search_results: List[Dict]) -> str: + def generate_answer( + self, query: str, search_results: List[Dict], query_time: str + ) -> str: """ Generate an answer based on the provided query and a list of pre-cached search results. @@ -110,6 +112,7 @@ class RAGModel: - search_results (List[Dict]): A list containing the search result objects, as described here: https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail + - query_time (str): The time at which the query was made, represented as a string. Returns: - str: A text response that answers the query. Limited to 75 tokens. diff --git a/models/vanilla_llama_baseline.py b/models/vanilla_llama_baseline.py index a3d2e83..314f8b0 100644 --- a/models/vanilla_llama_baseline.py +++ b/models/vanilla_llama_baseline.py @@ -88,7 +88,9 @@ class ChatModel: max_new_tokens=75, ) - def generate_answer(self, query: str, search_results: List[Dict]) -> str: + def generate_answer( + self, query: str, search_results: List[Dict], query_time: str + ) -> str: """ Generate an answer based on a provided query and a list of pre-cached search results. @@ -97,6 +99,7 @@ class ChatModel: - search_results (List[Dict]): A list containing the search result objects, as described here: https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail + - query_time (str): The time at which the query was made, represented as a string. Returns: - (str): A plain text response that answers the query. This response is limited to 75 tokens. -- GitLab