diff --git a/local_evaluation_with_api.py b/local_evaluation_with_api.py index a339e3c25b3ee994301b072083245406fac39fa6..be284f6556769189f1671d6ad2d6c71f842b37fe 100644 --- a/local_evaluation_with_api.py +++ b/local_evaluation_with_api.py @@ -38,12 +38,20 @@ class LLM_API: self.model = "gpt-3.5-turbo-1106" def api_call(self, prompt, max_tokens): - """ Simple single prompt api call """ - response = self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - max_tokens=max_tokens, - ) + if isinstance(prompt, str): # Single-message prompt + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, + ) + elif isinstance(prompt, list): # Multi-message prompt + response = self.client.chat.completions.create( + model=self.model, + messages=prompt, + max_tokens=max_tokens, + ) + else: + raise TypeError response_text = response.choices[0].message.content input_tokens = response.usage.prompt_tokens output_tokens = response.usage.completion_tokens