Skip to content
Snippets Groups Projects
Commit 4c13a846 authored by spmohanty's avatar spmohanty
Browse files

update example data, local eval, dummy model

parent ea84b35c
No related branches found
No related tags found
No related merge requests found
source diff could not be displayed: it is stored in LFS. Options to address this: view the blob.
......@@ -54,25 +54,20 @@ def parse_response(resp: str):
except:
return -1
def evaluate_response(response):
"""Evaluate the response to determine if it's missing or correct."""
is_missing = "Missing: True" in response
is_correct = "Accuracy: True" in response
return is_missing, is_correct
def evaluate(dataset_path, model_name):
qa = load_json_file(os.path.join(dataset_path, "qa.json"))
web_results = load_json_file(os.path.join(dataset_path, "web.json"))
openai_client = OpenAI()
participant_model = UserModel()
character_limit = 50 # todo: Make character limit dynamic
n_miss, n_correct, n_correct_exact = 0, 0, 0
system_message = get_system_message()
for query_dict, query_web_search_results in tqdm(zip(qa, web_results), total=len(qa)):
query, ground_truth = query_dict['q'], query_dict['fact_ans'].strip().lower()
prediction = participant_model.generate_answer(query, query_web_search_results, character_limit=character_limit)[:character_limit].strip().lower()
query, ground_truth = query_dict['query'], query_dict['answer'].strip().lower()
prediction = participant_model.generate_answer(query, query_web_search_results)
prediction = prediction.strip().lower()
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n"},
......@@ -80,7 +75,7 @@ def evaluate(dataset_path, model_name):
if prediction == "i don't know":
n_miss += 1
continue
if row["prediction"] == row["gold_ans"]:
if prediction == ground_truth:
n_correct_exact += 1
n_correct += 1
continue
......@@ -94,14 +89,18 @@ def evaluate(dataset_path, model_name):
n = len(qa)
results = {
"Exact Accuracy": n_exact / n,
"Accuracy": n_correct / n,
"Hallucination": (n - n_correct - n_miss) / n
"Missing": n_miss / n,
"Total": n
"score": (2*n_correct + n_miss) / n - 1,
"exact_accuracy": n_correct_exact / n,
"accuracy": n_correct / n,
"hallucination": (n - n_correct - n_miss) / n,
"missing": n_miss / n,
"n_miss": n_miss,
"n_correct": n_correct,
"n_correct_exact": n_correct_exact,
"total": n,
}
logger.info(results)
return (2*n_correct + n_miss) / n - 1
return results
if __name__ == '__main__':
DATASET_PATH = "example_data/"
......
......@@ -17,5 +17,5 @@ class DummyModel:
string response - Your answer in plain text, should be limited to the character limit,
Any longer responses will be trimmed to meet the character limit
"""
answer = "I'm sorry, I can't help with that."
answer = "i don't know"
return answer
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment