Skip to content
Snippets Groups Projects
Commit 3ac61cd2 authored by Silin's avatar Silin
Browse files

Update local_evaluation_with_api.py

parent 21bd06d3
No related branches found
No related tags found
No related merge requests found
......@@ -38,24 +38,27 @@ class LLM_API:
self.model = "gpt-3.5-turbo-1106"
def api_call(self, prompt, max_tokens):
if isinstance(prompt, str): # Single-message prompt
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
)
elif isinstance(prompt, list): # Multi-message prompt
response = self.client.chat.completions.create(
model=self.model,
messages=prompt,
max_tokens=max_tokens,
)
else:
raise TypeError
response_text = response.choices[0].message.content
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
return response_text, input_tokens, output_tokens
try:
if isinstance(prompt, str): # Single-message prompt
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
)
elif isinstance(prompt, list): # Multi-message prompt
response = self.client.chat.completions.create(
model=self.model,
messages=prompt,
max_tokens=max_tokens,
)
else:
raise TypeError
response_text = response.choices[0].message.content
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
return response_text, input_tokens, output_tokens, True
except:
return "", [], [], False
llm_api = LLM_API()
......@@ -71,10 +74,15 @@ def get_responses(agent, test_data, BATCH_SIZE):
if not agent_response['use_api']:
break
api_responses = []
batch_success = []
for prompt, max_tokens in zip(agent_response['prompts'], agent_response['max_generated_tokens']):
api_resp, _, _ = llm_api.api_call(prompt, max_tokens)
api_resp, _, _, success = llm_api.api_call(prompt, max_tokens)
print("Prompt:", prompt, "\nResponse:", api_resp)
api_responses.append(api_resp)
batch_success.append(success)
if all(batch_success):
agent_response = agent.generate_responses(batch_inputs, api_responses, final=True)
break
responses = agent_response['final_responses']
for bi, resp in zip(batch_idx, responses):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment