Newer
Older
# Copyright (c) Sony Group Corporation.
# Released under the MIT license
from typing import List, Dict
class PromptAgent(object):
def __init__(self):
""" Can initialize any retrieval models etc here """
self.api_used = False
self.max_generated_token_per_call = 20 # Can be changed by participant
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_prompt(self, test_sample: Dict):
profile = "\n".join(test_sample["persona B"])
system = {'role': 'system', 'content': 'You are a chit-chat agent who is playing a role described by the following profile:\n'+profile}
prompt = [system]
for turn in test_sample["dialogue"]:
if turn["persona_id"] == "A":
prompt.append({'role': 'user', 'content': turn["text"]})
else: # turn["persona_id"] == "B":
prompt.append({'role': 'assistant', 'content': turn["text"]})
return prompt
def generate_responses(self, test_data: List[Dict], api_responses: List[str], final=False) -> List[str]:
"""
You will be provided with a batch of upto 50 independent conversations
Input 1 (test_data)
[
{"persona B": ... "dialogue": ... }, # conversation 1 Turn 1
...
{"persona B": ... "dialogue": ... } # conversation 50 Turn 1
]
Model should return 50 responses for Turn 1
...
Input 7 (test_data)
[
{"persona B": ... "dialogue": ... }, # conversation 1 Turn 7
...
{"persona B": ... "dialogue": ... } # conversation 50 Turn 7
]
Model should return 50 responses for Turn 7
api_responses - A list of output strings by the api call for each previous prompt response,
Will be a list of blank strings on the first call
Note: Turn numbers will NOT be provided as input
Return a dictionary with the following format
"use_api": True/False - Note that this cannot be used when using GPU
"prompts": [ <list of the prompts that go as "messages" to the api > ] - Note that every api call is independent and we don't use threads
"max_generated_tokens": [ list of ints for the max generation limit on each call] - Note that the submission will fail if the total generation limit is exceeded
"final_responses: [ <list of strings with the final responses> ] - Only used when use_api is set to False
"""
# print(f"{len(test_data)=}, {test_data[0].keys()=}, {len(test_data[-1]['dialogue'])}")
"prompts": [self.get_prompt(td) for td in test_data],
"max_generated_tokens": [self.max_generated_token_per_call for _ in test_data],
"final_responses": ["not used" for _ in test_data]
}
"max_generated_tokens": [self.max_generated_token_per_call for _ in test_data],
"final_responses": api_responses # Can preprocess in between calls if needed.
}