Skip to content
Snippets Groups Projects
prompt_agent.py 3.41 KiB
Newer Older
kky84176's avatar
kky84176 committed
# Copyright (c) Sony Group Corporation.
# Released under the MIT license
kky84176's avatar
kky84176 committed

Silin's avatar
Silin committed
from typing import List, Dict

class PromptAgent(object):
    def __init__(self):
        """ Can initialize any retrieval models etc here """
        self.api_used = False
        self.max_generated_token_per_call = 20  # Can be changed by participant
Silin's avatar
Silin committed

    def get_prompt(self, test_sample: Dict):
        profile = "\n".join(test_sample["persona B"])
        system = {'role': 'system', 'content': 'You are a chit-chat agent who is playing a role described by the following profile:\n'+profile}
        prompt = [system]
        for turn in test_sample["dialogue"]:
            if turn["persona_id"] == "A":
                prompt.append({'role': 'user', 'content': turn["text"]})
            else:  # turn["persona_id"] == "B":
                prompt.append({'role': 'assistant', 'content': turn["text"]})
        return prompt

    def generate_responses(self, test_data: List[Dict], api_responses: List[str], final=False) -> List[str]:
        """
        You will be provided with a batch of upto 50 independent conversations
        
        Input 1 (test_data)
        [
            {"persona B": ... "dialogue": ... }, # conversation 1  Turn 1
            ...
            {"persona B": ... "dialogue": ... }  # conversation 50 Turn 1
        ]

        Model should return 50 responses for Turn 1

        ...
        Input 7 (test_data)
        [
            {"persona B": ... "dialogue": ... }, # conversation 1  Turn 7
            ...
            {"persona B": ... "dialogue": ... }  # conversation 50 Turn 7
        ]
        Model should return 50 responses for Turn 7

        api_responses - A list of output strings by the api call for each previous prompt response, 
                        Will be a list of blank strings on the first call
Silin's avatar
Silin committed

        Note: Turn numbers will NOT be provided as input

        Return a dictionary with the following format

        "use_api": True/False                                                               - Note that this cannot be used when using GPU
        "prompts": [ <list of the prompts that go as "messages" to the api > ]               - Note that every api call is independent and we don't use threads
Silin's avatar
Silin committed
        "max_generated_tokens": [ list of ints for the max generation limit on each call]   - Note that the submission will fail if the total generation limit is exceeded
        "final_responses: [ <list of strings with the final responses> ]                    - Only used when use_api is set to False

        """
        # print(f"{len(test_data)=}, {test_data[0].keys()=}, {len(test_data[-1]['dialogue'])}")

        if not self.api_used:
Silin's avatar
Silin committed
            response = {
                "use_api": True,                                    
                "prompts": [self.get_prompt(td) for td in test_data],
Silin's avatar
Silin committed
                "max_generated_tokens": [self.max_generated_token_per_call for _ in test_data], 
                "final_responses": ["not used" for _ in test_data]
            }
Dipam Chakraborty's avatar
Dipam Chakraborty committed
            self.api_used = True
Silin's avatar
Silin committed
            response = {
                "use_api": False,                                    
                "prompts": [[] for _ in test_data],
Silin's avatar
Silin committed
                "max_generated_tokens": [self.max_generated_token_per_call for _ in test_data],
                "final_responses": api_responses  # Can preprocess in between calls if needed.
            }
Dipam Chakraborty's avatar
Dipam Chakraborty committed
            self.api_used = False
        return response