Skip to content
Snippets Groups Projects
Commit 1b510bb5 authored by spmohanty's avatar spmohanty
Browse files

add tokenizer + response trimming

parent 9a7c5b04
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,9 @@ from loguru import logger
from openai import APIConnectionError, OpenAI, RateLimitError
from prompts.templates import IN_CONTEXT_EXAMPLES, INSTRUCTIONS
from tqdm.auto import tqdm
from transformers import LlamaTokenizerFast
tokenizer = LlamaTokenizerFast.from_pretrained("tokenizer")
def load_json_file(file_path):
......@@ -41,10 +44,12 @@ def attempt_api_call(client, model_name, messages, max_retries=10):
return None
def log_response(messages, response):
def log_response(messages, response, output_directory="api_responses"):
"""Save the response from the API to a file."""
os.makedirs(output_directory, exist_ok=True)
file_name = datetime.now().strftime("%d-%m-%Y-%H-%M-%S.json")
with open(f"api_responses/{file_name}", "w") as f:
file_path = os.path.join(output_directory, file_name)
with open(file_path, "w") as f:
json.dump({"messages": messages, "response": response}, f)
......@@ -71,6 +76,13 @@ def parse_response(resp: str):
except:
return -1
def trim_predictions_to_max_token_length(prediction):
"""Trims prediction output to 75 tokens"""
max_token_length = 75
tokenized_prediction = tokenizer.encode(prediction)
trimmed_tokenized_prediction = tokenized_prediction[1: max_token_length+1]
trimmed_prediction = tokenizer.decode(trimmed_tokenized_prediction)
return trimmed_prediction
def generate_predictions(dataset_path, participant_model):
qa = load_json_file(os.path.join(dataset_path, "qa.json"))
......@@ -82,6 +94,8 @@ def generate_predictions(dataset_path, participant_model):
prediction = participant_model.generate_answer(
query, query_web_search_results
)
# trim prediction to 75 tokens
prediction = trim_predictions_to_max_token_length(prediction)
predictions.append(
{
"query": query,
......
import os
from typing import List
from models.utils import trim_predictions_to_max_token_length
# Load the environment variable that specifies the URL of the MockAPI. This URL is essential
# for accessing the correct API endpoint in Task 2 and Task 3. The value of this environment variable
# may vary across different evaluation settings, emphasizing the importance of dynamically obtaining
......@@ -44,4 +46,7 @@ class DummyModel:
# Default response when unsure about the answer
answer = "i don't know"
return answer
# Trim prediction to a max of 75 tokens
trimmed_answer = trim_predictions_to_max_token_length(answer)
return trimmed_answer
#!/usr/bin/env python
import os
from transformers import LlamaTokenizerFast
tokenizer_path = os.path.join(os.path.dirname(__file__), "..", "tokenizer")
tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_path)
def trim_predictions_to_max_token_length(prediction):
"""Trims prediction output to 75 tokens"""
max_token_length = 75
tokenized_prediction = tokenizer.encode(prediction)
trimmed_tokenized_prediction = tokenized_prediction[1: max_token_length+1]
trimmed_prediction = tokenizer.decode(trimmed_tokenized_prediction)
return trimmed_prediction
\ No newline at end of file
# hf-internal-testing/llama-tokenizer
This tokenizer has been obtained from: https://huggingface.co/hf-internal-testing/llama-tokenizer
\ No newline at end of file
{
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}
source diff could not be displayed: it is too large. Options to address this: view the blob.
File added
{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"clean_up_tokenization_spaces": false,
"eos_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"model_max_length": 2048,
"pad_token": null,
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"unk_token": {
"__type": "AddedToken",
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment