From ccb8561a1d47448cb79598a31037820e13c37f14 Mon Sep 17 00:00:00 2001
From: Xiao Yang <xiaoyangfb@meta.com>
Date: Sat, 30 Mar 2024 20:01:45 +0000
Subject: [PATCH] Update local_evaluation.py

---
 local_evaluation.py | 49 ++++++++++++++++++++++++++++++++++-----------
 1 file changed, 37 insertions(+), 12 deletions(-)

diff --git a/local_evaluation.py b/local_evaluation.py
index 6b6dcdb..b22f8c5 100644
--- a/local_evaluation.py
+++ b/local_evaluation.py
@@ -39,6 +39,21 @@ def log_response(messages, response):
     with open(f"api_responses/{file_name}", 'w') as f:
         json.dump({"messages": messages, "response": response}, f)
 
+def parse_response(resp: str):
+    """Pass auto-eval output from the evaluator."""
+    try:
+        resp = resp.lower()
+        model_resp = json.loads(resp)
+        answer = -1
+        if "accuracy" in model_resp and ((model_resp["accuracy"] is True) or (isinstance(model_resp["accuracy"], str) and model_resp["accuracy"].lower() == "true")):
+            answer = 1
+        else:
+            raise ValueError(f"Could not parse answer from response: {model_resp}")
+
+        return answer
+    except:
+        return -1
+
 def evaluate_response(response):
     """Evaluate the response to determine if it's missing or correct."""
     is_missing = "Missing: True" in response
@@ -52,33 +67,43 @@ def evaluate(dataset_path, model_name):
     participant_model = UserModel()
     character_limit = 50  # todo: Make character limit dynamic
 
-    n_miss, n_correct, n_exact = 0, 0, 0
+    n_miss, n_correct, n_correct_exact = 0, 0, 0
     system_message = get_system_message()
 
     for query_dict, query_web_search_results in tqdm(zip(qa, web_results), total=len(qa)):
-        query, ground_truth = query_dict['q'], query_dict['fact_ans']
-        prediction = participant_model.generate_answer(query, query_web_search_results, character_limit=character_limit)[:character_limit]
+        query, ground_truth = query_dict['q'], query_dict['fact_ans'].strip().lower()
+        prediction = participant_model.generate_answer(query, query_web_search_results, character_limit=character_limit)[:character_limit].strip().lower()
         messages = [
             {"role": "system", "content": system_message},
             {"role": "user", "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n"},
         ]
+        if prediction == "i don't know":
+            n_miss += 1
+            continue
+        if row["prediction"] == row["gold_ans"]:
+            n_correct_exact += 1
+            n_correct += 1
+            continue
+
         response = attempt_api_call(openai_client, model_name, messages)
         if response:
             log_response(messages, response)
-            miss, correct = evaluate_response(response)
-            n_miss += miss
-            n_correct += correct
-            n_exact += (prediction.strip() == ground_truth.strip())
+            eval_res = parse_response(response)
+            if eval_res == 1:
+                n_correct += 1
 
+    n = len(qa)
     results = {
-        "Exact Accuracy": n_exact / len(qa),
-        "Accuracy": n_correct / len(qa),
-        "Missing": n_miss / len(qa),
-        "Total": len(qa)
+        "Exact Accuracy": n_exact / n,
+        "Accuracy": n_correct / n,
+        "Hallucination": (n - n_correct - n_miss) / n
+        "Missing": n_miss / n,
+        "Total": n
     }
     logger.info(results)
+    return (2*n_correct + n_miss) / n - 1
 
 if __name__ == '__main__':
-    DATASET_PATH = "example_data/"    
+    DATASET_PATH = "example_data/"
     MODEL_NAME = os.getenv("EVALUATION_MODEL_NAME", "gpt-4-0125-preview")
     evaluate(DATASET_PATH, MODEL_NAME)
-- 
GitLab