From 4c13a84604a2b3511d7950e8e6bd4d02537eb6c6 Mon Sep 17 00:00:00 2001 From: "S.P. Mohanty" <spmohanty91@gmail.com> Date: Mon, 1 Apr 2024 13:44:21 +0000 Subject: [PATCH] update example data, local eval, dummy model --- example_data/qa.json | 4 ++-- local_evaluation.py | 31 +++++++++++++++---------------- models/dummy_model.py | 2 +- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/example_data/qa.json b/example_data/qa.json index 8007f4f..147fa65 100644 --- a/example_data/qa.json +++ b/example_data/qa.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4257e05de6242a319640bddf017391e2f8d9d347d3c1c440d85d23c1be13e6bd -size 1283 +oid sha256:6b9555060dd6c7e9e5e76e95e87d47fe0155a64cc93ecf94fb6cba337bd7d3a3 +size 1413 diff --git a/local_evaluation.py b/local_evaluation.py index b22f8c5..d3b6de6 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -54,25 +54,20 @@ def parse_response(resp: str): except: return -1 -def evaluate_response(response): - """Evaluate the response to determine if it's missing or correct.""" - is_missing = "Missing: True" in response - is_correct = "Accuracy: True" in response - return is_missing, is_correct - def evaluate(dataset_path, model_name): qa = load_json_file(os.path.join(dataset_path, "qa.json")) web_results = load_json_file(os.path.join(dataset_path, "web.json")) openai_client = OpenAI() participant_model = UserModel() - character_limit = 50 # todo: Make character limit dynamic n_miss, n_correct, n_correct_exact = 0, 0, 0 system_message = get_system_message() for query_dict, query_web_search_results in tqdm(zip(qa, web_results), total=len(qa)): - query, ground_truth = query_dict['q'], query_dict['fact_ans'].strip().lower() - prediction = participant_model.generate_answer(query, query_web_search_results, character_limit=character_limit)[:character_limit].strip().lower() + query, ground_truth = query_dict['query'], query_dict['answer'].strip().lower() + prediction = participant_model.generate_answer(query, query_web_search_results) + prediction = prediction.strip().lower() + messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n"}, @@ -80,7 +75,7 @@ def evaluate(dataset_path, model_name): if prediction == "i don't know": n_miss += 1 continue - if row["prediction"] == row["gold_ans"]: + if prediction == ground_truth: n_correct_exact += 1 n_correct += 1 continue @@ -94,14 +89,18 @@ def evaluate(dataset_path, model_name): n = len(qa) results = { - "Exact Accuracy": n_exact / n, - "Accuracy": n_correct / n, - "Hallucination": (n - n_correct - n_miss) / n - "Missing": n_miss / n, - "Total": n + "score": (2*n_correct + n_miss) / n - 1, + "exact_accuracy": n_correct_exact / n, + "accuracy": n_correct / n, + "hallucination": (n - n_correct - n_miss) / n, + "missing": n_miss / n, + "n_miss": n_miss, + "n_correct": n_correct, + "n_correct_exact": n_correct_exact, + "total": n, } logger.info(results) - return (2*n_correct + n_miss) / n - 1 + return results if __name__ == '__main__': DATASET_PATH = "example_data/" diff --git a/models/dummy_model.py b/models/dummy_model.py index 0927299..8dcb9a5 100644 --- a/models/dummy_model.py +++ b/models/dummy_model.py @@ -17,5 +17,5 @@ class DummyModel: string response - Your answer in plain text, should be limited to the character limit, Any longer responses will be trimmed to meet the character limit """ - answer = "I'm sorry, I can't help with that." + answer = "i don't know" return answer \ No newline at end of file -- GitLab