Skip to content
Snippets Groups Projects
Commit 8679e7d4 authored by spmohanty's avatar spmohanty
Browse files

Add query_time as a parameter to the generate_answer interface

parent 048f5f45
No related branches found
No related tags found
No related merge requests found
......@@ -77,27 +77,32 @@ def parse_response(resp: str):
except:
return -1
def trim_predictions_to_max_token_length(prediction):
"""Trims prediction output to 75 tokens"""
max_token_length = 75
tokenized_prediction = tokenizer.encode(prediction)
trimmed_tokenized_prediction = tokenized_prediction[1: max_token_length+1]
trimmed_tokenized_prediction = tokenized_prediction[
1 : max_token_length + 1
]
trimmed_prediction = tokenizer.decode(trimmed_tokenized_prediction)
return trimmed_prediction
def generate_predictions(dataset_path, participant_model):
def generate_predictions(dataset_path, participant_model):
predictions = []
with bz2.open(DATASET_PATH, "rt") as bz2_file:
for line in tqdm(bz2_file, desc="Generating Predictions"):
data = json.loads(line)
query = data["query"]
web_search_results = data["search_results"]
query_time = data["query_time"]
prediction = participant_model.generate_answer(
query, web_search_results
query, web_search_results, query_time
)
# trim prediction to 75 tokens
prediction = trim_predictions_to_max_token_length(prediction)
predictions.append(
......@@ -106,7 +111,7 @@ def generate_predictions(dataset_path, participant_model):
"ground_truth": str(data["answer"]).strip().lower(),
"prediction": str(prediction).strip().lower(),
}
)
)
return predictions
......@@ -115,7 +120,9 @@ def evaluate_predictions(predictions, evaluation_model_name, openai_client):
n_miss, n_correct, n_correct_exact = 0, 0, 0
system_message = get_system_message()
for prediction_dict in tqdm(predictions, total=len(predictions), desc="Evaluating Predictions"):
for prediction_dict in tqdm(
predictions, total=len(predictions), desc="Evaluating Predictions"
):
query, ground_truth, prediction = (
prediction_dict["query"],
prediction_dict["ground_truth"],
......
......@@ -15,6 +15,7 @@ To ensure your model is recognized and utilized correctly, please specify your m
Your model will receive two pieces of information for every task:
- `query`: String representing the input query
- `search_results`: List of strings, each comes from scraped HTML text of the search query.
- `query_time`: The time at which the query was made, represented as a string.
### Outputs
The output from your model's `generate_answer` function should always be a string.
......
......@@ -24,7 +24,9 @@ class DummyModel:
"""
pass
def generate_answer(self, query: str, search_results: List[Dict]) -> str:
def generate_answer(
self, query: str, search_results: List[Dict], query_time: str
) -> str:
"""
Generate an answer based on a provided query and a list of pre-cached search results.
......@@ -33,6 +35,7 @@ class DummyModel:
- search_results (List[Dict]): A list containing the search result objects,
as described here:
https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail
- query_time (str): The time at which the query was made, represented as a string.
Returns:
- (str): A plain text response that answers the query. This response is limited to 75 tokens.
......
......@@ -101,7 +101,9 @@ class RAGModel:
max_new_tokens=10,
)
def generate_answer(self, query: str, search_results: List[Dict]) -> str:
def generate_answer(
self, query: str, search_results: List[Dict], query_time: str
) -> str:
"""
Generate an answer based on the provided query and a list of pre-cached search results.
......@@ -110,6 +112,7 @@ class RAGModel:
- search_results (List[Dict]): A list containing the search result objects,
as described here:
https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail
- query_time (str): The time at which the query was made, represented as a string.
Returns:
- str: A text response that answers the query. Limited to 75 tokens.
......
......@@ -88,7 +88,9 @@ class ChatModel:
max_new_tokens=75,
)
def generate_answer(self, query: str, search_results: List[Dict]) -> str:
def generate_answer(
self, query: str, search_results: List[Dict], query_time: str
) -> str:
"""
Generate an answer based on a provided query and a list of pre-cached search results.
......@@ -97,6 +99,7 @@ class ChatModel:
- search_results (List[Dict]): A list containing the search result objects,
as described here:
https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024/meta-comphrehensive-rag-benchmark-starter-kit/-/blob/master/docs/dataset.md#search-results-detail
- query_time (str): The time at which the query was made, represented as a string.
Returns:
- (str): A plain text response that answers the query. This response is limited to 75 tokens.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment