# Adapted from - https://github.com/Silin159/PersonaChat-BART-PeaCoK/blob/main/eval_parlai.py from typing import List, Dict import torch from transformers import BartTokenizer from tqdm.auto import tqdm from agents.bart_model.eval_utils import create_encoder_input from agents.bart_model.modeling_bart import BartForConditionalGeneration class BARTResponseAgent(object): def __init__(self): """ Load your model(s) here """ torch.random.manual_seed(0) torch.cuda.manual_seed(0) tokenizer_path = 'agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random' checkpoint_path = 'agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random/checkpoint_epoch_30/' self.tokenizer = BartTokenizer.from_pretrained(tokenizer_path) self.model = BartForConditionalGeneration.from_pretrained(checkpoint_path) self.query_id, self.res_id, self.latent_id, self.persona_id, self.partner_id = \ self.tokenizer.convert_tokens_to_ids([ '<query>', '<response>', '<latent>', '<persona>', '<partner>' ]) self.bos_id = self.tokenizer.bos_token_id self.eos_id = self.tokenizer.eos_token_id self.pad_id = self.tokenizer.pad_token_id self.sep_id = self.tokenizer.sep_token_id self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model = self.model.to(self.device) self.model.eval() self.turn_id = 1 print("Model loaded!!") self.max_input_tokens = 1024 def tokenize_conversation(self, conversation): def tokenize(text): return self.tokenizer.convert_tokens_to_ids( self.tokenizer.tokenize(text.strip(), add_prefix_space=True) ) persona = [tokenize(line.strip()) for line in conversation['persona B']] # partner = [tokenize(line.strip()) for line in conversation['persona A']] partner = [] # Baseline not trained with the partner personaj history = [tokenize(line['text'].strip()) for line in conversation['dialogue']] return persona, partner, history def prepare_tensors(self, conversation): persona, partner, history = self.tokenize_conversation(conversation) input_ids, attention_mask, _, _ = create_encoder_input(persona, partner, history, self.query_id, self.res_id, self.latent_id, self.persona_id, self.partner_id, self.sep_id, self.eos_id ) tensor_input_ids = torch.tensor(input_ids, device=self.device)[-self.max_input_tokens:].unsqueeze(0) tensor_attention_mask = torch.tensor(attention_mask, device=self.device)[-self.max_input_tokens:].unsqueeze(0) return tensor_input_ids, tensor_attention_mask def generate_responses(self, test_data: List[Dict]) -> List[str]: """ You will be provided with a batch of upto 50 independent conversations Return a string for every conversation Input 1 [ {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1 ... {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1 ] Model should return 50 responses for Turn 1 ... Input 7 [ {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7 ... {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7 ] Model should return 50 responses for Turn 7 Note: Turn numbers will NOT be provided as input """ all_responses = [] for conversation in tqdm(test_data): tensor_input_ids, tensor_attention_mask = self.prepare_tensors(conversation) with torch.no_grad(): out_ids = self.model.generate( input_ids=tensor_input_ids, attention_mask=tensor_attention_mask, max_length=50, num_beams=2 ) response = self.tokenizer.batch_decode( out_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=False )[0].strip() all_responses.append(response) self.turn_id = self.turn_id % 7 + 1 # Turn id goes from 1 to 7 return all_responses