From 3ac61cd2295b95b02879e99424f9b96f78462bbd Mon Sep 17 00:00:00 2001
From: Silin <silin.gao@epfl.ch>
Date: Sun, 10 Dec 2023 08:37:50 +0000
Subject: [PATCH] Update local_evaluation_with_api.py

---
 local_evaluation_with_api.py | 46 +++++++++++++++++++++---------------
 1 file changed, 27 insertions(+), 19 deletions(-)

diff --git a/local_evaluation_with_api.py b/local_evaluation_with_api.py
index be284f6..e4c323e 100644
--- a/local_evaluation_with_api.py
+++ b/local_evaluation_with_api.py
@@ -38,24 +38,27 @@ class LLM_API:
         self.model = "gpt-3.5-turbo-1106"
 
     def api_call(self, prompt, max_tokens):
-        if isinstance(prompt, str):  # Single-message prompt
-            response = self.client.chat.completions.create(
-                model=self.model,
-                messages=[{"role": "user", "content": prompt}],
-                max_tokens=max_tokens,
-            )
-        elif isinstance(prompt, list):  # Multi-message prompt
-            response = self.client.chat.completions.create(
-                model=self.model,
-                messages=prompt,
-                max_tokens=max_tokens,
-            )
-        else:
-            raise TypeError
-        response_text = response.choices[0].message.content
-        input_tokens = response.usage.prompt_tokens
-        output_tokens = response.usage.completion_tokens
-        return response_text, input_tokens, output_tokens
+        try:
+            if isinstance(prompt, str):  # Single-message prompt
+                response = self.client.chat.completions.create(
+                    model=self.model,
+                    messages=[{"role": "user", "content": prompt}],
+                    max_tokens=max_tokens,
+                )
+            elif isinstance(prompt, list):  # Multi-message prompt
+                response = self.client.chat.completions.create(
+                    model=self.model,
+                    messages=prompt,
+                    max_tokens=max_tokens,
+                )
+            else:
+                raise TypeError
+            response_text = response.choices[0].message.content
+            input_tokens = response.usage.prompt_tokens
+            output_tokens = response.usage.completion_tokens
+            return response_text, input_tokens, output_tokens, True
+        except:
+            return "", [], [], False
   
 llm_api = LLM_API()
 
@@ -71,10 +74,15 @@ def get_responses(agent, test_data, BATCH_SIZE):
                 if not agent_response['use_api']:
                     break
                 api_responses = []
+                batch_success = []
                 for prompt, max_tokens in zip(agent_response['prompts'], agent_response['max_generated_tokens']):
-                    api_resp, _, _ = llm_api.api_call(prompt, max_tokens)
+                    api_resp, _, _, success = llm_api.api_call(prompt, max_tokens)
                     print("Prompt:", prompt, "\nResponse:", api_resp)
                     api_responses.append(api_resp)
+                    batch_success.append(success)
+                if all(batch_success):
+                    agent_response = agent.generate_responses(batch_inputs, api_responses, final=True)
+                    break
 
             responses = agent_response['final_responses']
             for bi, resp in zip(batch_idx, responses):
-- 
GitLab