From 3ac61cd2295b95b02879e99424f9b96f78462bbd Mon Sep 17 00:00:00 2001 From: Silin <silin.gao@epfl.ch> Date: Sun, 10 Dec 2023 08:37:50 +0000 Subject: [PATCH] Update local_evaluation_with_api.py --- local_evaluation_with_api.py | 46 +++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/local_evaluation_with_api.py b/local_evaluation_with_api.py index be284f6..e4c323e 100644 --- a/local_evaluation_with_api.py +++ b/local_evaluation_with_api.py @@ -38,24 +38,27 @@ class LLM_API: self.model = "gpt-3.5-turbo-1106" def api_call(self, prompt, max_tokens): - if isinstance(prompt, str): # Single-message prompt - response = self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - max_tokens=max_tokens, - ) - elif isinstance(prompt, list): # Multi-message prompt - response = self.client.chat.completions.create( - model=self.model, - messages=prompt, - max_tokens=max_tokens, - ) - else: - raise TypeError - response_text = response.choices[0].message.content - input_tokens = response.usage.prompt_tokens - output_tokens = response.usage.completion_tokens - return response_text, input_tokens, output_tokens + try: + if isinstance(prompt, str): # Single-message prompt + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, + ) + elif isinstance(prompt, list): # Multi-message prompt + response = self.client.chat.completions.create( + model=self.model, + messages=prompt, + max_tokens=max_tokens, + ) + else: + raise TypeError + response_text = response.choices[0].message.content + input_tokens = response.usage.prompt_tokens + output_tokens = response.usage.completion_tokens + return response_text, input_tokens, output_tokens, True + except: + return "", [], [], False llm_api = LLM_API() @@ -71,10 +74,15 @@ def get_responses(agent, test_data, BATCH_SIZE): if not agent_response['use_api']: break api_responses = [] + batch_success = [] for prompt, max_tokens in zip(agent_response['prompts'], agent_response['max_generated_tokens']): - api_resp, _, _ = llm_api.api_call(prompt, max_tokens) + api_resp, _, _, success = llm_api.api_call(prompt, max_tokens) print("Prompt:", prompt, "\nResponse:", api_resp) api_responses.append(api_resp) + batch_success.append(success) + if all(batch_success): + agent_response = agent.generate_responses(batch_inputs, api_responses, final=True) + break responses = agent_response['final_responses'] for bi, resp in zip(batch_idx, responses): -- GitLab