# Copyright (c) Sony Group Corporation. # Released under the MIT license # Adapted from - https://github.com/Silin159/PersonaChat-BART-PeaCoK/blob/main/eval_parlai.py from typing import List, Dict import torch from transformers import BartTokenizer from tqdm.auto import tqdm from agents.bart_model.eval_utils import create_encoder_input from agents.bart_model.modeling_bart import BartForConditionalGeneration class BARTResponseAgent(object): def __init__(self): """ Load your model(s) here """ torch.random.manual_seed(0) torch.cuda.manual_seed(0) tokenizer_path = 'agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random' checkpoint_path = 'agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random/checkpoint_epoch_30/' self.tokenizer = BartTokenizer.from_pretrained(tokenizer_path) self.model = BartForConditionalGeneration.from_pretrained(checkpoint_path) self.query_id, self.res_id, self.latent_id, self.persona_id, self.partner_id = \ self.tokenizer.convert_tokens_to_ids([ '<query>', '<response>', '<latent>', '<persona>', '<partner>' ]) self.bos_id = self.tokenizer.bos_token_id self.eos_id = self.tokenizer.eos_token_id self.pad_id = self.tokenizer.pad_token_id self.sep_id = self.tokenizer.sep_token_id self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model = self.model.to(self.device) self.model.eval() self.turn_id = 1 print("Model loaded!!") self.max_input_tokens = 1024 def tokenize_conversation(self, conversation): def tokenize(text): return self.tokenizer.convert_tokens_to_ids( self.tokenizer.tokenize(text.strip(), add_prefix_space=True) ) persona = [tokenize(line.strip()) for line in conversation['persona B']] partner = [] history = [tokenize(line['text'].strip()) for line in conversation['dialogue']] return persona, partner, history def prepare_tensors(self, conversation): persona, partner, history = self.tokenize_conversation(conversation) input_ids, attention_mask, _, _ = create_encoder_input(persona, partner, history, self.query_id, self.res_id, self.latent_id, self.persona_id, self.partner_id, self.sep_id, self.eos_id ) tensor_input_ids = torch.tensor(input_ids, device=self.device)[-self.max_input_tokens:].unsqueeze(0) tensor_attention_mask = torch.tensor(attention_mask, device=self.device)[-self.max_input_tokens:].unsqueeze(0) return tensor_input_ids, tensor_attention_mask def generate_responses(self, test_data: List[Dict], api_responses: List[str]) -> List[str]: """ You will be provided with a batch of upto 50 independent conversations Input 1 [ {"persona B": ... "dialogue": ... }, # conversation 1 Turn 1 ... {"persona B": ... "dialogue": ... } # conversation 50 Turn 1 ] Model should return 50 responses for Turn 1 ... Input 7 (test_data) [ {"persona B": ... "dialogue": ... }, # conversation 1 Turn 7 ... {"persona B": ... "dialogue": ... } # conversation 50 Turn 7 ] Model should return 50 responses for Turn 7 api_responses - A list of output strings by the api call for each previous prompt response, Will be a list of blank strings on the first call Note: Turn numbers will NOT be provided as input Return a dictionary with the following format "use_api": True/False - Note that this cannot be used when using GPU "prompts": [ <list of the prompts that go as "messages" to the api > ] - Note that every api call is independent and we don't use threads "max_generated_tokens": [ list of ints for the max generation limit on each call] - Note that the submission will fail if the total generation limit is exceeded "final_responses: [ <list of strings with the final responses> ] - Only used when use_api is set to False """ final_responses = [] for conversation in tqdm(test_data): tensor_input_ids, tensor_attention_mask = self.prepare_tensors(conversation) with torch.no_grad(): out_ids = self.model.generate( input_ids=tensor_input_ids, attention_mask=tensor_attention_mask, max_length=50, num_beams=2 ) response = self.tokenizer.batch_decode( out_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=False )[0].strip() final_responses.append(response) self.turn_id = self.turn_id % 7 + 1 # Turn id goes from 1 to 7 response = { "use_api": False, # Cannot use API if GPU true is set in aicrowd.json "prompts": ["" for _ in test_data], # Cannot use API if GPU true is set in aicrowd.json "max_generated_tokens": [0 for _ in test_data], "final_responses": final_responses } return response