From df8fc20bade2d5e469344c15ed9d10bc0c44fcef Mon Sep 17 00:00:00 2001
From: "S.P. Mohanty" <spmohanty91@gmail.com>
Date: Mon, 1 Apr 2024 14:11:40 +0000
Subject: [PATCH] split pred generation and evaluation

---
 local_evaluation.py | 60 +++++++++++++++++++++++++++++++--------------
 1 file changed, 42 insertions(+), 18 deletions(-)

diff --git a/local_evaluation.py b/local_evaluation.py
index b54511a..2f01cbf 100644
--- a/local_evaluation.py
+++ b/local_evaluation.py
@@ -73,26 +73,37 @@ def parse_response(resp: str):
         return -1
 
 
-def evaluate(dataset_path, model_name):
+def generate_predictions(dataset_path, participant_model):
     qa = load_json_file(os.path.join(dataset_path, "qa.json"))
     web_results = load_json_file(os.path.join(dataset_path, "web.json"))
-    openai_client = OpenAI()
-    participant_model = UserModel()
 
+    predictions = []
+    for query_dict, query_web_search_results in tqdm(zip(qa, web_results), total=len(qa), desc="Generating Predictions"):
+        query = query_dict["query"]
+        prediction = participant_model.generate_answer(
+            query, query_web_search_results
+        )
+        predictions.append(
+            {
+                "query": query,
+                "ground_truth": query_dict["answer"].strip().lower(),
+                "prediction": prediction.strip().lower(),
+            }
+        )
+
+    return predictions
+
+
+def evaluate_predictions(predictions, evaluation_model_name, openai_client):
     n_miss, n_correct, n_correct_exact = 0, 0, 0
     system_message = get_system_message()
 
-    for query_dict, query_web_search_results in tqdm(
-        zip(qa, web_results), total=len(qa)
-    ):
-        query, ground_truth = (
-            query_dict["query"],
-            query_dict["answer"].strip().lower(),
+    for prediction_dict in tqdm(predictions, total=len(predictions), desc="Evaluating Predictions"):
+        query, ground_truth, prediction = (
+            prediction_dict["query"],
+            prediction_dict["ground_truth"],
+            prediction_dict["prediction"],
         )
-        prediction = participant_model.generate_answer(
-            query, query_web_search_results
-        )
-        prediction = prediction.strip().lower()
 
         messages = [
             {"role": "system", "content": system_message},
@@ -101,7 +112,7 @@ def evaluate(dataset_path, model_name):
                 "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n",
             },
         ]
-        if prediction == "i don't know":
+        if prediction == "i don't know..":
             n_miss += 1
             continue
         if prediction == ground_truth:
@@ -109,14 +120,16 @@ def evaluate(dataset_path, model_name):
             n_correct += 1
             continue
 
-        response = attempt_api_call(openai_client, model_name, messages)
+        response = attempt_api_call(
+            openai_client, evaluation_model_name, messages
+        )
         if response:
             log_response(messages, response)
             eval_res = parse_response(response)
             if eval_res == 1:
                 n_correct += 1
 
-    n = len(qa)
+    n = len(predictions)
     results = {
         "score": (2 * n_correct + n_miss) / n - 1,
         "exact_accuracy": n_correct_exact / n,
@@ -134,5 +147,16 @@ def evaluate(dataset_path, model_name):
 
 if __name__ == "__main__":
     DATASET_PATH = "example_data/"
-    MODEL_NAME = os.getenv("EVALUATION_MODEL_NAME", "gpt-4-0125-preview")
-    evaluate(DATASET_PATH, MODEL_NAME)
+    EVALUATION_MODEL_NAME = os.getenv(
+        "EVALUATION_MODEL_NAME", "gpt-4-0125-preview"
+    )
+
+    # Generate predictions
+    participant_model = UserModel()
+    predictions = generate_predictions(DATASET_PATH, participant_model)
+
+    # Evaluate Predictions
+    openai_client = OpenAI()
+    evaluation_results = evaluate_predictions(
+        predictions, EVALUATION_MODEL_NAME, openai_client
+    )
-- 
GitLab