Skip to content
Snippets Groups Projects
phi2_agent.py 4.74 KiB
Newer Older
kky84176's avatar
kky84176 committed
# Copyright (c) Sony Group Corporation.
# Released under the MIT license
kky84176's avatar
kky84176 committed

from typing import List, Dict

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm

class Phi2ResponseAgent(object):
    def __init__(self):
        """ Load your model(s) here """
        torch.random.manual_seed(0)
        torch.cuda.manual_seed(0)

        model_path = './agents/phi2_model/microsoft--phi-2'
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.model = self.model.to(self.device)
        self.model.eval()

        self.turn_id = 1

        print("Model loaded!!")

        self.max_input_tokens = 1024

    def generate_responses(self, test_data: List[Dict], api_responses: List[str]) -> List[str]:
        """
        You will be provided with a batch of upto 50 independent conversations

        Input 1
        [
            {"persona B": ... "dialogue": ... }, # conversation 1  Turn 1
            ...
            {"persona B": ... "dialogue": ... }  # conversation 50 Turn 1
        ]

        Model should return 50 responses for Turn 1

        ...
        Input 7
        [
            {"persona B": ... "dialogue": ... }, # conversation 1  Turn 7
            ...
            {"persona B": ... "dialogue": ... }  # conversation 50 Turn 7
        ]
        Model should return 50 responses for Turn 7

        Note: Turn numbers will NOT be provided as input

        Return a dictionary with the following format

        "use_api": True/False                                                               - Note that this cannot be used when using GPU
        "prompts": [ <list of the prompts that go as "content" to the api > ]               - Note that every call is independent and we don't use threads
        "max_generated_tokens": [ list of ints for the max generation limit on each call]   - Note that the submission will fail if the total generation limit is exceeded
        "final_responses: [ <list of strings with the final responses> ]                    - Only used when use_api is set to False

        """

        final_responses = []

        for conversation in tqdm(test_data):

            input = conversation['dialogue'][-1]['text']
            input = input.lower()
            input = input.replace(' .', '.')
            input = input.replace(' ,', ',')
            input = input.replace(' ?', '?')
            input = input.replace(' !', '!')
            input = input.replace(' )', ')')
            prompt_persona = f'''
Person B has the following Persona information.
'''
            for ipersona in conversation['persona B']:
                prompt_persona += f'''Persona of Person B: {ipersona}\n'''
            prompt = f'''{prompt_persona}
Instruct: Person A and Person B are now having a conversation.  Following the conversation below, write a response that Person B would say base on the above Persona information.  Please carefully consider the flow and context of the conversation below, and use the Person B's Persona information appropriately to generate a response that you think are the most appropriate replying for Person B.
'''
            for iturn in conversation['dialogue']:
                iperson = iturn['persona_id']
                iinput = iturn['text']
                iinput = iinput.lower().replace(' .', '.').replace(' ,', ',').replace(' ?', '?').replace(' !', '!').replace(' )', ')')
                prompt += f'''Person {iperson}: {iinput}\n'''
            prompt += f'''Output: 
'''
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)

            generation_output = self.model.generate(input_ids=input_ids, max_new_tokens=20, pad_token_id=self.tokenizer.eos_token_id)

            output = self.tokenizer.decode(generation_output[0])
            output = output.replace(prompt, '')
            output = output.replace('Person B: ', '')
            response = output.split('\n')[0]
            if response == '':
                response = output.split('\n')[-1]
            print('\n')
            print('input:', input)
            print('response:', response)

            final_responses.append(response)

        self.turn_id = self.turn_id % 7 + 1 # Turn id goes from 1 to 7

        response = {
            "use_api": False,                                    # Cannot use API if GPU true is set in aicrowd.json
            "prompts": ["" for _ in test_data],                  # Cannot use API if GPU true is set in aicrowd.json
            "max_generated_tokens": [0 for _ in test_data],
            "final_responses": final_responses
        }
        return response