Skip to content
Snippets Groups Projects
bart_agent.py 5.38 KiB
Newer Older
Dipam Chakraborty's avatar
Dipam Chakraborty committed
# Adapted from - https://github.com/Silin159/PersonaChat-BART-PeaCoK/blob/main/eval_parlai.py

from typing import List, Dict

import torch
from transformers import BartTokenizer
from tqdm.auto import tqdm

from agents.bart_model.eval_utils import create_encoder_input
from agents.bart_model.modeling_bart import BartForConditionalGeneration

class BARTResponseAgent(object):
    def __init__(self):
        """ Load your model(s) here """
        torch.random.manual_seed(0)
        torch.cuda.manual_seed(0)

        tokenizer_path = 'agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random'
        checkpoint_path = 'agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random/checkpoint_epoch_30/'

        self.tokenizer = BartTokenizer.from_pretrained(tokenizer_path)
        self.model = BartForConditionalGeneration.from_pretrained(checkpoint_path)

        self.query_id, self.res_id, self.latent_id, self.persona_id, self.partner_id = \
            self.tokenizer.convert_tokens_to_ids([
                '<query>', '<response>', '<latent>', '<persona>', '<partner>'
            ])
        self.bos_id = self.tokenizer.bos_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.sep_id = self.tokenizer.sep_token_id

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.model = self.model.to(self.device)
        self.model.eval()

        self.turn_id = 1

        print("Model loaded!!")

        self.max_input_tokens = 1024

    def tokenize_conversation(self, conversation):
        def tokenize(text):
            return self.tokenizer.convert_tokens_to_ids(
                self.tokenizer.tokenize(text.strip(), add_prefix_space=True)
            )
       
        persona = [tokenize(line.strip()) for line in conversation['persona B']]
        # partner = [tokenize(line.strip()) for line in conversation['persona A']]
        partner = [] # Baseline not trained with the partner personaj
        history = [tokenize(line['text'].strip()) for line in conversation['dialogue']]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
        return persona, partner, history
   
    def prepare_tensors(self, conversation):
        persona, partner, history = self.tokenize_conversation(conversation)
        input_ids, attention_mask, _, _ = create_encoder_input(persona,
            partner,
            history,
            self.query_id, 
            self.res_id, 
            self.latent_id,
            self.persona_id, 
            self.partner_id,
            self.sep_id, 
            self.eos_id
        )
        tensor_input_ids = torch.tensor(input_ids, device=self.device)[-self.max_input_tokens:].unsqueeze(0)
        tensor_attention_mask = torch.tensor(attention_mask, device=self.device)[-self.max_input_tokens:].unsqueeze(0)
        return tensor_input_ids, tensor_attention_mask


    def generate_responses(self, test_data: List[Dict]) -> List[str]:
        """
        You will be provided with a batch of upto 50 independent conversations
Dipam Chakraborty's avatar
Dipam Chakraborty committed
        Input 1
        [
            {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 1
            ...
            {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 1
        ]

        Model should return 50 responses for Turn 1

        ...
        Input 7
        [
            {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 7
            ...
            {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 7
        ]
        Model should return 50 responses for Turn 7

        Note: Turn numbers will NOT be provided as input

        Return a dictionary with the following format

        "use_api": True/False                                                               - Note that this cannot be used when using GPU
        "prompts": [ <list of the prompts that go as "content" to the api > ]               - Note that every call is independent and we don't use threads
        "max_generated_tokens": [ list of ints for the max generation limit on each call]   - Note that the submission will fail if the total generation limit is exceeded
        "final_responses: [ <list of strings with the final responses> ]                    - Only used when use_api is set to False

Dipam Chakraborty's avatar
Dipam Chakraborty committed

        for conversation in tqdm(test_data):
            tensor_input_ids, tensor_attention_mask = self.prepare_tensors(conversation)
            with torch.no_grad():
                out_ids = self.model.generate(
                    input_ids=tensor_input_ids,
                    attention_mask=tensor_attention_mask,
                    max_length=50,
                    num_beams=2
                )
            
            response = self.tokenizer.batch_decode(
                    out_ids,
                    skip_special_tokens=True, 
                    spaces_between_special_tokens=False,
                    clean_up_tokenization_spaces=False
            )[0].strip()
            final_responses.append(response)
Dipam Chakraborty's avatar
Dipam Chakraborty committed

        self.turn_id = self.turn_id % 7 + 1 # Turn id goes from 1 to 7

        response = {
            "use_api": False,                                    # Ignored if GPU true is set in aicrowd.json
            "prompts": ["" for _ in test_data],                  # Ignored if GPU true is set in aicrowd.json
            "final_responses": final_responses
        }
        return response