Skip to content
Snippets Groups Projects
Commit 3ac61cd2 authored by Silin's avatar Silin
Browse files

Update local_evaluation_with_api.py

parent 21bd06d3
No related branches found
No related tags found
No related merge requests found
...@@ -38,24 +38,27 @@ class LLM_API: ...@@ -38,24 +38,27 @@ class LLM_API:
self.model = "gpt-3.5-turbo-1106" self.model = "gpt-3.5-turbo-1106"
def api_call(self, prompt, max_tokens): def api_call(self, prompt, max_tokens):
if isinstance(prompt, str): # Single-message prompt try:
response = self.client.chat.completions.create( if isinstance(prompt, str): # Single-message prompt
model=self.model, response = self.client.chat.completions.create(
messages=[{"role": "user", "content": prompt}], model=self.model,
max_tokens=max_tokens, messages=[{"role": "user", "content": prompt}],
) max_tokens=max_tokens,
elif isinstance(prompt, list): # Multi-message prompt )
response = self.client.chat.completions.create( elif isinstance(prompt, list): # Multi-message prompt
model=self.model, response = self.client.chat.completions.create(
messages=prompt, model=self.model,
max_tokens=max_tokens, messages=prompt,
) max_tokens=max_tokens,
else: )
raise TypeError else:
response_text = response.choices[0].message.content raise TypeError
input_tokens = response.usage.prompt_tokens response_text = response.choices[0].message.content
output_tokens = response.usage.completion_tokens input_tokens = response.usage.prompt_tokens
return response_text, input_tokens, output_tokens output_tokens = response.usage.completion_tokens
return response_text, input_tokens, output_tokens, True
except:
return "", [], [], False
llm_api = LLM_API() llm_api = LLM_API()
...@@ -71,10 +74,15 @@ def get_responses(agent, test_data, BATCH_SIZE): ...@@ -71,10 +74,15 @@ def get_responses(agent, test_data, BATCH_SIZE):
if not agent_response['use_api']: if not agent_response['use_api']:
break break
api_responses = [] api_responses = []
batch_success = []
for prompt, max_tokens in zip(agent_response['prompts'], agent_response['max_generated_tokens']): for prompt, max_tokens in zip(agent_response['prompts'], agent_response['max_generated_tokens']):
api_resp, _, _ = llm_api.api_call(prompt, max_tokens) api_resp, _, _, success = llm_api.api_call(prompt, max_tokens)
print("Prompt:", prompt, "\nResponse:", api_resp) print("Prompt:", prompt, "\nResponse:", api_resp)
api_responses.append(api_resp) api_responses.append(api_resp)
batch_success.append(success)
if all(batch_success):
agent_response = agent.generate_responses(batch_inputs, api_responses, final=True)
break
responses = agent_response['final_responses'] responses = agent_response['final_responses']
for bi, resp in zip(batch_idx, responses): for bi, resp in zip(batch_idx, responses):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment