From ccb8561a1d47448cb79598a31037820e13c37f14 Mon Sep 17 00:00:00 2001 From: Xiao Yang <xiaoyangfb@meta.com> Date: Sat, 30 Mar 2024 20:01:45 +0000 Subject: [PATCH] Update local_evaluation.py --- local_evaluation.py | 49 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/local_evaluation.py b/local_evaluation.py index 6b6dcdb..b22f8c5 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -39,6 +39,21 @@ def log_response(messages, response): with open(f"api_responses/{file_name}", 'w') as f: json.dump({"messages": messages, "response": response}, f) +def parse_response(resp: str): + """Pass auto-eval output from the evaluator.""" + try: + resp = resp.lower() + model_resp = json.loads(resp) + answer = -1 + if "accuracy" in model_resp and ((model_resp["accuracy"] is True) or (isinstance(model_resp["accuracy"], str) and model_resp["accuracy"].lower() == "true")): + answer = 1 + else: + raise ValueError(f"Could not parse answer from response: {model_resp}") + + return answer + except: + return -1 + def evaluate_response(response): """Evaluate the response to determine if it's missing or correct.""" is_missing = "Missing: True" in response @@ -52,33 +67,43 @@ def evaluate(dataset_path, model_name): participant_model = UserModel() character_limit = 50 # todo: Make character limit dynamic - n_miss, n_correct, n_exact = 0, 0, 0 + n_miss, n_correct, n_correct_exact = 0, 0, 0 system_message = get_system_message() for query_dict, query_web_search_results in tqdm(zip(qa, web_results), total=len(qa)): - query, ground_truth = query_dict['q'], query_dict['fact_ans'] - prediction = participant_model.generate_answer(query, query_web_search_results, character_limit=character_limit)[:character_limit] + query, ground_truth = query_dict['q'], query_dict['fact_ans'].strip().lower() + prediction = participant_model.generate_answer(query, query_web_search_results, character_limit=character_limit)[:character_limit].strip().lower() messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n"}, ] + if prediction == "i don't know": + n_miss += 1 + continue + if row["prediction"] == row["gold_ans"]: + n_correct_exact += 1 + n_correct += 1 + continue + response = attempt_api_call(openai_client, model_name, messages) if response: log_response(messages, response) - miss, correct = evaluate_response(response) - n_miss += miss - n_correct += correct - n_exact += (prediction.strip() == ground_truth.strip()) + eval_res = parse_response(response) + if eval_res == 1: + n_correct += 1 + n = len(qa) results = { - "Exact Accuracy": n_exact / len(qa), - "Accuracy": n_correct / len(qa), - "Missing": n_miss / len(qa), - "Total": len(qa) + "Exact Accuracy": n_exact / n, + "Accuracy": n_correct / n, + "Hallucination": (n - n_correct - n_miss) / n + "Missing": n_miss / n, + "Total": n } logger.info(results) + return (2*n_correct + n_miss) / n - 1 if __name__ == '__main__': - DATASET_PATH = "example_data/" + DATASET_PATH = "example_data/" MODEL_NAME = os.getenv("EVALUATION_MODEL_NAME", "gpt-4-0125-preview") evaluate(DATASET_PATH, MODEL_NAME) -- GitLab