Skip to content
Snippets Groups Projects
Commit 8d679a75 authored by spmohanty's avatar spmohanty
Browse files

refactor local evaluation

parent 310d1bf1
No related branches found
No related tags found
No related merge requests found
import json
import os
from tqdm.auto import tqdm
from openai import OpenAI, APIConnectionError, RateLimitError
from datetime import datetime
from loguru import logger
from models.user_config import UserModel
from openai import APIConnectionError, OpenAI, RateLimitError
from prompts.templates import IN_CONTEXT_EXAMPLES, INSTRUCTIONS
from tqdm.auto import tqdm
def get_system_message():
INSTRUCTIONS = """You are given a question and the ground truth prediction is correct by comparing to the list of ground truth answers. You should evaluate for Accuracy and Missing.
- For Missing, check whether the prediction returns any concrete answer. If the prediction is "I don't know", "I don't have enough information to answer", or similar responses, Missing should be True, otherwise Missing should be False.
- For Accuracy, check whether a prediction is "correct" according to the ground truth answers. If the prediction is correct, Accuracy should be "True"; if the prediction is wrong, Accuracy should be "False". If the ground truth answer contains a number, the prediction needs to predict a number that matches the ground truth answer for the accuracy to be True.\n
"""
IN_CONTEXT_EXAMPLES = """You need to check whether the prediction of a question-answering system to a question is Accurate or Missing. You should make the judgment based on a list of ground truth answers provided to you. Your response should be "correct" if the prediction is correct or "incorrect" if the prediction is wrong.
Examples:
Question: Who authored The Taming of the Shrew (published in 2002)?
Ground truth: ["William Shakespeare", "Roma Gill"]
Prediction: W Shakespeare
Accuracy: True
Missing: False
Question: how many seconds is 3 minutes 15 seconds?
Ground truth: ["195 seconds"]
Prediction: 3 minutes 15 seconds is 195 seconds.
Accuracy: True
Missing: False
Question: Who authored The Taming of the Shrew (published in 2002)?
Ground truth: ["William Shakespeare", "Roma Gill"]
Prediction: The author to The Taming of the Shrew is Roma Shakespeare.
Accuracy: False
Missing: False
Question: Who played Sheldon in Big Bang Theory?
Ground truth: ["Jim Parsons", "Iain Armitage"]
Prediction: I am sorry I don't know.
Accuracy: False
Missing: True
Question: What country is Maharashtra Metro Rail Corporation Limited
located in?
Ground truth: ["India"]
Prediction: Maharashtra
Accuracy: False
Missing: False
Question: What’s the job of Song Kang-ho in Parasite (2019)?
Ground truth: ["actor"]
Prediction: He plays the role of Kim Ki-taek, the patriarch of the Kim
family. He also plays in the TV Show Friends.
Accuracy: True
Missing: False
Question: Which era did Michael Oakeshott belong to?
Ground truth: ["20th-century philosophy"]
Prediction: 20th century.
Accuracy: True
Missing: False
Question: What the Meta stock price now?
Ground truth: ["382.63 dollars"]
Prediction: I apologize, but I couldn't find any information on this.
Accuracy: False
Missing: True
Question: How long is the great wall of China?
Ground truth: ["13171 miles"]
Prediction: The great wall of China is 13271 miles.
Accuracy: False
Missing: False
Question: When did Brian Fleming release his last song?
Ground truth: ["2023"]
Prediction: Brian Fleming has not released any songs as a musician. He is a playwright and actor, not a musician. My apologies for the confusion.
Accuracy: False
Missing: False
Question: How old is Maria Carrie?
Ground truth: ["53 year old"]
Prediction: Could you please provide more context or information about who or what this person or thing is?
Accuracy: False
Missing: True
Question: What is the best role of Mayim Bialik?
Ground truth: ["Amy Farrah Fowler in Big Bang Theory"]
Prediction: Mayim Bialik is not a known artist. It's possible that you intended to ask about Meryl Streep.
Accuracy: False
Missing: True
"""
def load_json_file(file_path):
"""Load and return the content of a JSON file."""
logger.info(f"Loading JSON from {file_path}")
with open(file_path) as f:
return json.load(f)
def get_system_message():
"""Returns the system message containing instructions and in context examples."""
return INSTRUCTIONS + IN_CONTEXT_EXAMPLES
def gpt_eval(client, model_name, query, gt, pred):
messages = []
messages.append({
"role": "system",
"content": get_system_message()
})
messages.append({
"role": "user",
"content": f"Question: {query}\n Ground truth: {gt}\n Prediction: {pred}\n",
})
MAX_RETRIES = 10
for _ in range(MAX_RETRIES):
def attempt_api_call(client, model_name, messages, max_retries=10):
"""Attempt an API call with retries upon encountering specific errors."""
#todo: add default response when all efforts fail
for attempt in range(max_retries):
try:
response = client.chat.completions.create(model=model_name, messages=messages).choices[0].message.content
response = client.chat.completions.create(model=model_name, messages=messages)
return response.choices[0].message.content
except (APIConnectionError, RateLimitError):
logger.warning(f"API call failed on attempt {attempt + 1}, retrying...")
except Exception as e:
logger.error(f"Unexpected error: {e}")
break
except APIConnectionError:
import traceback; print(traceback.format_exc())
continue
except RateLimitError:
import traceback; print(traceback.format_exc())
continue
except:
import traceback; print(traceback.format_exc())
fname = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
with open(f"api_responses/{fname}.json", 'w') as f:
json.dump({
"messages": messages,
"response": response,
}, f)
return None
miss = "Missing: True" in response
correct = "Accuracy: True" in response
return miss, correct
def log_response(messages, response):
"""Save the response from the API to a file."""
file_name = datetime.now().strftime("%d-%m-%Y-%H-%M-%S.json")
with open(f"api_responses/{file_name}", 'w') as f:
json.dump({"messages": messages, "response": response}, f)
def evaluate(dataset_path, model_name):
# Load dataset
with open(f'{dataset_path}/qa.json') as f:
qa = json.load(f)
with open(f'{dataset_path}/web.json') as f:
web_results = json.load(f)
def evaluate_response(response):
"""Evaluate the response to determine if it's missing or correct."""
is_missing = "Missing: True" in response
is_correct = "Accuracy: True" in response
return is_missing, is_correct
# Setup
def evaluate(dataset_path, model_name):
qa = load_json_file(os.path.join(dataset_path, "qa.json"))
web_results = load_json_file(os.path.join(dataset_path, "web.json"))
openai_client = OpenAI()
participant_model = UserModel()
char_lim = 50 # TODO: Set actual character limit based on query
character_limit = 50 # todo: Make character limit dynamic
n_miss, n_correct, n_exact = 0, 0, 0
# Eval loop
for i, qdict in tqdm(enumerate(qa), total=len(qa)):
query = qdict['q']
gt = qdict['fact_ans']
query_web_res = web_results[i]
prediction = participant_model.generate_answer(query, query_web_res, character_limit=char_lim)
prediction_trimmed = prediction[:char_lim]
miss, correct = gpt_eval(openai_client, model_name, query, gt, prediction_trimmed)
n_exact = (prediction.strip() == gt.strip())
n_miss += miss
n_correct += correct
# Scores
miss = n_miss / len(qa)
acc = n_correct / len(qa)
acc_exact = n_exact / len(qa)
system_message = get_system_message()
for query_dict, query_web_search_results in tqdm(zip(qa, web_results), total=len(qa)):
query, ground_truth = query_dict['q'], query_dict['fact_ans']
prediction = participant_model.generate_answer(query, query_web_search_results, character_limit=character_limit)[:character_limit]
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n"},
]
response = attempt_api_call(openai_client, model_name, messages)
if response:
log_response(messages, response)
miss, correct = evaluate_response(response)
n_miss += miss
n_correct += correct
n_exact += (prediction.strip() == ground_truth.strip())
results = {
"Exact Accuracy": acc_exact,
"Accuracy": acc,
"Missing": miss,
"Exact Accuracy": n_exact / len(qa),
"Accuracy": n_correct / len(qa),
"Missing": n_miss / len(qa),
"Total": len(qa)
}
print(results)
logger.info(results)
if __name__ == '__main__':
DATASET_PATH = "example_data/"
MODEL_NAME = "gpt-4-0125-preview"
evaluate(DATASET_PATH, MODEL_NAME)
\ No newline at end of file
DATASET_PATH = "example_data/"
MODEL_NAME = os.getenv("EVALUATION_MODEL_NAME", "gpt-4-0125-preview")
evaluate(DATASET_PATH, MODEL_NAME)
#!/usr/bin/env python3
INSTRUCTIONS = """You are given a question and the ground truth prediction is correct by comparing to the list of ground truth answers. You should evaluate for Accuracy and Missing.
- For Missing, check whether the prediction returns any concrete answer. If the prediction is "I don't know", "I don't have enough information to answer", or similar responses, Missing should be True, otherwise Missing should be False.
- For Accuracy, check whether a prediction is "correct" according to the ground truth answers. If the prediction is correct, Accuracy should be "True"; if the prediction is wrong, Accuracy should be "False". If the ground truth answer contains a number, the prediction needs to predict a number that matches the ground truth answer for the accuracy to be True.\n
"""
IN_CONTEXT_EXAMPLES = """You need to check whether the prediction of a question-answering system to a question is Accurate or Missing. You should make the judgment based on a list of ground truth answers provided to you. Your response should be "correct" if the prediction is correct or "incorrect" if the prediction is wrong.
Examples:
Question: Who authored The Taming of the Shrew (published in 2002)?
Ground truth: ["William Shakespeare", "Roma Gill"]
Prediction: W Shakespeare
Accuracy: True
Missing: False
Question: how many seconds is 3 minutes 15 seconds?
Ground truth: ["195 seconds"]
Prediction: 3 minutes 15 seconds is 195 seconds.
Accuracy: True
Missing: False
Question: Who authored The Taming of the Shrew (published in 2002)?
Ground truth: ["William Shakespeare", "Roma Gill"]
Prediction: The author to The Taming of the Shrew is Roma Shakespeare.
Accuracy: False
Missing: False
Question: Who played Sheldon in Big Bang Theory?
Ground truth: ["Jim Parsons", "Iain Armitage"]
Prediction: I am sorry I don't know.
Accuracy: False
Missing: True
Question: What country is Maharashtra Metro Rail Corporation Limited
located in?
Ground truth: ["India"]
Prediction: Maharashtra
Accuracy: False
Missing: False
Question: What’s the job of Song Kang-ho in Parasite (2019)?
Ground truth: ["actor"]
Prediction: He plays the role of Kim Ki-taek, the patriarch of the Kim
family. He also plays in the TV Show Friends.
Accuracy: True
Missing: False
Question: Which era did Michael Oakeshott belong to?
Ground truth: ["20th-century philosophy"]
Prediction: 20th century.
Accuracy: True
Missing: False
Question: What the Meta stock price now?
Ground truth: ["382.63 dollars"]
Prediction: I apologize, but I couldn't find any information on this.
Accuracy: False
Missing: True
Question: How long is the great wall of China?
Ground truth: ["13171 miles"]
Prediction: The great wall of China is 13271 miles.
Accuracy: False
Missing: False
Question: When did Brian Fleming release his last song?
Ground truth: ["2023"]
Prediction: Brian Fleming has not released any songs as a musician. He is a playwright and actor, not a musician. My apologies for the confusion.
Accuracy: False
Missing: False
Question: How old is Maria Carrie?
Ground truth: ["53 year old"]
Prediction: Could you please provide more context or information about who or what this person or thing is?
Accuracy: False
Missing: True
Question: What is the best role of Mayim Bialik?
Ground truth: ["Amy Farrah Fowler in Big Bang Theory"]
Prediction: Mayim Bialik is not a known artist. It's possible that you intended to ask about Meryl Streep.
Accuracy: False
Missing: True
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment