diff --git a/local_evaluation.py b/local_evaluation.py
index d3b6de6c152377d9945eef0dd1d748d34ccf6646..b54511a7d000a589e0a3787a9aa2e8427124daea 100644
--- a/local_evaluation.py
+++ b/local_evaluation.py
@@ -15,45 +15,64 @@ def load_json_file(file_path):
     with open(file_path) as f:
         return json.load(f)
 
+
 def get_system_message():
     """Returns the system message containing instructions and in context examples."""
     return INSTRUCTIONS + IN_CONTEXT_EXAMPLES
 
+
 def attempt_api_call(client, model_name, messages, max_retries=10):
     """Attempt an API call with retries upon encountering specific errors."""
-    #todo: add default response when all efforts fail
+    # todo: add default response when all efforts fail
     for attempt in range(max_retries):
         try:
-            response = client.chat.completions.create(model=model_name, messages=messages)
+            response = client.chat.completions.create(
+                model=model_name,
+                messages=messages,
+                response_format={"type": "json_object"},
+            )
             return response.choices[0].message.content
         except (APIConnectionError, RateLimitError):
-            logger.warning(f"API call failed on attempt {attempt + 1}, retrying...")
+            logger.warning(
+                f"API call failed on attempt {attempt + 1}, retrying..."
+            )
         except Exception as e:
             logger.error(f"Unexpected error: {e}")
             break
     return None
 
+
 def log_response(messages, response):
     """Save the response from the API to a file."""
     file_name = datetime.now().strftime("%d-%m-%Y-%H-%M-%S.json")
-    with open(f"api_responses/{file_name}", 'w') as f:
+    with open(f"api_responses/{file_name}", "w") as f:
         json.dump({"messages": messages, "response": response}, f)
 
+
 def parse_response(resp: str):
     """Pass auto-eval output from the evaluator."""
     try:
         resp = resp.lower()
         model_resp = json.loads(resp)
         answer = -1
-        if "accuracy" in model_resp and ((model_resp["accuracy"] is True) or (isinstance(model_resp["accuracy"], str) and model_resp["accuracy"].lower() == "true")):
+        if "accuracy" in model_resp and (
+            (model_resp["accuracy"] is True)
+            or (
+                isinstance(model_resp["accuracy"], str)
+                and model_resp["accuracy"].lower() == "true"
+            )
+        ):
             answer = 1
         else:
-            raise ValueError(f"Could not parse answer from response: {model_resp}")
+            raise ValueError(
+                f"Could not parse answer from response: {model_resp}"
+            )
 
         return answer
     except:
         return -1
 
+
 def evaluate(dataset_path, model_name):
     qa = load_json_file(os.path.join(dataset_path, "qa.json"))
     web_results = load_json_file(os.path.join(dataset_path, "web.json"))
@@ -63,14 +82,24 @@ def evaluate(dataset_path, model_name):
     n_miss, n_correct, n_correct_exact = 0, 0, 0
     system_message = get_system_message()
 
-    for query_dict, query_web_search_results in tqdm(zip(qa, web_results), total=len(qa)):
-        query, ground_truth = query_dict['query'], query_dict['answer'].strip().lower()
-        prediction = participant_model.generate_answer(query, query_web_search_results)
+    for query_dict, query_web_search_results in tqdm(
+        zip(qa, web_results), total=len(qa)
+    ):
+        query, ground_truth = (
+            query_dict["query"],
+            query_dict["answer"].strip().lower(),
+        )
+        prediction = participant_model.generate_answer(
+            query, query_web_search_results
+        )
         prediction = prediction.strip().lower()
-        
+
         messages = [
             {"role": "system", "content": system_message},
-            {"role": "user", "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n"},
+            {
+                "role": "user",
+                "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n",
+            },
         ]
         if prediction == "i don't know":
             n_miss += 1
@@ -89,7 +118,7 @@ def evaluate(dataset_path, model_name):
 
     n = len(qa)
     results = {
-        "score": (2*n_correct + n_miss) / n - 1,
+        "score": (2 * n_correct + n_miss) / n - 1,
         "exact_accuracy": n_correct_exact / n,
         "accuracy": n_correct / n,
         "hallucination": (n - n_correct - n_miss) / n,
@@ -97,12 +126,13 @@ def evaluate(dataset_path, model_name):
         "n_miss": n_miss,
         "n_correct": n_correct,
         "n_correct_exact": n_correct_exact,
-        "total": n,        
+        "total": n,
     }
     logger.info(results)
     return results
 
-if __name__ == '__main__':
+
+if __name__ == "__main__":
     DATASET_PATH = "example_data/"
     MODEL_NAME = os.getenv("EVALUATION_MODEL_NAME", "gpt-4-0125-preview")
     evaluate(DATASET_PATH, MODEL_NAME)