Skip to content
Snippets Groups Projects
prompt_agent.py 3.38 KiB
Newer Older
from typing import List, Dict

class DummyPromptAgent(object):
    def __init__(self):
        """ Can initialize any retrieval models etc here """
        self.api_limit = 2                          # Max number of api calls per utterance
        self.input_token_limit = 10_000             # Max number of input tokens per dialogue (combined token usage of all 7 utterances)
        self.output_token_limit = 1_000             # Max number of output tokens per dialogue (combined token usage of all 7 utterances)
        
        self.max_generated_token_per_call = 20      # Can be set by user as needed, can be different for each utterance and dialogue
        self.api_usage_count = 0 

    def generate_responses(self, test_data: List[Dict], api_responses: List[str]) -> List[str]:
        """
        You will be provided with a batch of upto 50 independent conversations
        
        Input 1 (test_data)
        [
            {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 1
            ...
            {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 1
        ]

        Model should return 50 responses for Turn 1

        ...
        Input 7 (test_data)
        [
            {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 7
            ...
            {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 7
        ]
        Model should return 50 responses for Turn 7

        api_responses - A list of strings output by the api call for each previous prompt response, 
                               Will be a list of blank strings on the first call

        Note: Turn numbers will NOT be provided as input

        Return a dictionary with the following format

        "use_api": True/False                                                               - Note that this cannot be used when using GPU
        "prompts": [ <list of the prompts that go as "content" to the api > ]               - Note that every api call is independent and we don't use threads
        "max_generated_tokens": [ list of ints for the max generation limit on each call]   - Note that the submission will fail if the total generation limit is exceeded
        "final_responses: [ <list of strings with the final responses> ]                    - Only used when use_api is set to False

        """
        # print(f"{len(test_data)=}, {test_data[0].keys()=}, {len(test_data[-1]['dialogue'])}")

        if self.api_usage_count < self.api_limit:
            self.api_usage_count += 1
            response = {
                "use_api": True,                                    
                "prompts": ["You're a helpful assistant, say this is a test" for _ in test_data],
                "max_generated_tokens": [self.max_generated_token_per_call for _ in test_data], 
                "final_responses": ["not used" for _ in test_data]
            }
        else:   # After 2 calls of the api, must return the final responses
            self.api_usage_count = 0
            response = {
                "use_api": False,                                    
                "prompts": ["" for _ in test_data],
                "max_generated_tokens": [self.max_generated_token_per_call for _ in test_data],
                "final_responses": api_responses # Can preprocess in between calls if needed.
            }
        return response