diff --git a/local_evaluation.py b/local_evaluation.py index d3b6de6c152377d9945eef0dd1d748d34ccf6646..b54511a7d000a589e0a3787a9aa2e8427124daea 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -15,45 +15,64 @@ def load_json_file(file_path): with open(file_path) as f: return json.load(f) + def get_system_message(): """Returns the system message containing instructions and in context examples.""" return INSTRUCTIONS + IN_CONTEXT_EXAMPLES + def attempt_api_call(client, model_name, messages, max_retries=10): """Attempt an API call with retries upon encountering specific errors.""" - #todo: add default response when all efforts fail + # todo: add default response when all efforts fail for attempt in range(max_retries): try: - response = client.chat.completions.create(model=model_name, messages=messages) + response = client.chat.completions.create( + model=model_name, + messages=messages, + response_format={"type": "json_object"}, + ) return response.choices[0].message.content except (APIConnectionError, RateLimitError): - logger.warning(f"API call failed on attempt {attempt + 1}, retrying...") + logger.warning( + f"API call failed on attempt {attempt + 1}, retrying..." + ) except Exception as e: logger.error(f"Unexpected error: {e}") break return None + def log_response(messages, response): """Save the response from the API to a file.""" file_name = datetime.now().strftime("%d-%m-%Y-%H-%M-%S.json") - with open(f"api_responses/{file_name}", 'w') as f: + with open(f"api_responses/{file_name}", "w") as f: json.dump({"messages": messages, "response": response}, f) + def parse_response(resp: str): """Pass auto-eval output from the evaluator.""" try: resp = resp.lower() model_resp = json.loads(resp) answer = -1 - if "accuracy" in model_resp and ((model_resp["accuracy"] is True) or (isinstance(model_resp["accuracy"], str) and model_resp["accuracy"].lower() == "true")): + if "accuracy" in model_resp and ( + (model_resp["accuracy"] is True) + or ( + isinstance(model_resp["accuracy"], str) + and model_resp["accuracy"].lower() == "true" + ) + ): answer = 1 else: - raise ValueError(f"Could not parse answer from response: {model_resp}") + raise ValueError( + f"Could not parse answer from response: {model_resp}" + ) return answer except: return -1 + def evaluate(dataset_path, model_name): qa = load_json_file(os.path.join(dataset_path, "qa.json")) web_results = load_json_file(os.path.join(dataset_path, "web.json")) @@ -63,14 +82,24 @@ def evaluate(dataset_path, model_name): n_miss, n_correct, n_correct_exact = 0, 0, 0 system_message = get_system_message() - for query_dict, query_web_search_results in tqdm(zip(qa, web_results), total=len(qa)): - query, ground_truth = query_dict['query'], query_dict['answer'].strip().lower() - prediction = participant_model.generate_answer(query, query_web_search_results) + for query_dict, query_web_search_results in tqdm( + zip(qa, web_results), total=len(qa) + ): + query, ground_truth = ( + query_dict["query"], + query_dict["answer"].strip().lower(), + ) + prediction = participant_model.generate_answer( + query, query_web_search_results + ) prediction = prediction.strip().lower() - + messages = [ {"role": "system", "content": system_message}, - {"role": "user", "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n"}, + { + "role": "user", + "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n", + }, ] if prediction == "i don't know": n_miss += 1 @@ -89,7 +118,7 @@ def evaluate(dataset_path, model_name): n = len(qa) results = { - "score": (2*n_correct + n_miss) / n - 1, + "score": (2 * n_correct + n_miss) / n - 1, "exact_accuracy": n_correct_exact / n, "accuracy": n_correct / n, "hallucination": (n - n_correct - n_miss) / n, @@ -97,12 +126,13 @@ def evaluate(dataset_path, model_name): "n_miss": n_miss, "n_correct": n_correct, "n_correct_exact": n_correct_exact, - "total": n, + "total": n, } logger.info(results) return results -if __name__ == '__main__': + +if __name__ == "__main__": DATASET_PATH = "example_data/" MODEL_NAME = os.getenv("EVALUATION_MODEL_NAME", "gpt-4-0125-preview") evaluate(DATASET_PATH, MODEL_NAME)