diff --git a/local_evaluation.py b/local_evaluation.py index b54511a7d000a589e0a3787a9aa2e8427124daea..2f01cbf26c2644b18fbf702305b7804b939c1378 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -73,26 +73,37 @@ def parse_response(resp: str): return -1 -def evaluate(dataset_path, model_name): +def generate_predictions(dataset_path, participant_model): qa = load_json_file(os.path.join(dataset_path, "qa.json")) web_results = load_json_file(os.path.join(dataset_path, "web.json")) - openai_client = OpenAI() - participant_model = UserModel() + predictions = [] + for query_dict, query_web_search_results in tqdm(zip(qa, web_results), total=len(qa), desc="Generating Predictions"): + query = query_dict["query"] + prediction = participant_model.generate_answer( + query, query_web_search_results + ) + predictions.append( + { + "query": query, + "ground_truth": query_dict["answer"].strip().lower(), + "prediction": prediction.strip().lower(), + } + ) + + return predictions + + +def evaluate_predictions(predictions, evaluation_model_name, openai_client): n_miss, n_correct, n_correct_exact = 0, 0, 0 system_message = get_system_message() - for query_dict, query_web_search_results in tqdm( - zip(qa, web_results), total=len(qa) - ): - query, ground_truth = ( - query_dict["query"], - query_dict["answer"].strip().lower(), + for prediction_dict in tqdm(predictions, total=len(predictions), desc="Evaluating Predictions"): + query, ground_truth, prediction = ( + prediction_dict["query"], + prediction_dict["ground_truth"], + prediction_dict["prediction"], ) - prediction = participant_model.generate_answer( - query, query_web_search_results - ) - prediction = prediction.strip().lower() messages = [ {"role": "system", "content": system_message}, @@ -101,7 +112,7 @@ def evaluate(dataset_path, model_name): "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n", }, ] - if prediction == "i don't know": + if prediction == "i don't know..": n_miss += 1 continue if prediction == ground_truth: @@ -109,14 +120,16 @@ def evaluate(dataset_path, model_name): n_correct += 1 continue - response = attempt_api_call(openai_client, model_name, messages) + response = attempt_api_call( + openai_client, evaluation_model_name, messages + ) if response: log_response(messages, response) eval_res = parse_response(response) if eval_res == 1: n_correct += 1 - n = len(qa) + n = len(predictions) results = { "score": (2 * n_correct + n_miss) / n - 1, "exact_accuracy": n_correct_exact / n, @@ -134,5 +147,16 @@ def evaluate(dataset_path, model_name): if __name__ == "__main__": DATASET_PATH = "example_data/" - MODEL_NAME = os.getenv("EVALUATION_MODEL_NAME", "gpt-4-0125-preview") - evaluate(DATASET_PATH, MODEL_NAME) + EVALUATION_MODEL_NAME = os.getenv( + "EVALUATION_MODEL_NAME", "gpt-4-0125-preview" + ) + + # Generate predictions + participant_model = UserModel() + predictions = generate_predictions(DATASET_PATH, participant_model) + + # Evaluate Predictions + openai_client = OpenAI() + evaluation_results = evaluate_predictions( + predictions, EVALUATION_MODEL_NAME, openai_client + )