Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Adapted from - https://github.com/Silin159/PersonaChat-BART-PeaCoK/blob/main/eval_parlai.py
from typing import List, Dict
import torch
from transformers import BartTokenizer
from tqdm.auto import tqdm
from agents.bart_model.eval_utils import create_encoder_input
from agents.bart_model.modeling_bart import BartForConditionalGeneration
class BARTResponseAgent(object):
def __init__(self):
""" Load your model(s) here """
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
tokenizer_path = 'agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random'
checkpoint_path = 'agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random/checkpoint_epoch_30/'
self.tokenizer = BartTokenizer.from_pretrained(tokenizer_path)
self.model = BartForConditionalGeneration.from_pretrained(checkpoint_path)
self.query_id, self.res_id, self.latent_id, self.persona_id, self.partner_id = \
self.tokenizer.convert_tokens_to_ids([
'<query>', '<response>', '<latent>', '<persona>', '<partner>'
])
self.bos_id = self.tokenizer.bos_token_id
self.eos_id = self.tokenizer.eos_token_id
self.pad_id = self.tokenizer.pad_token_id
self.sep_id = self.tokenizer.sep_token_id
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = self.model.to(self.device)
self.model.eval()
self.turn_id = 1
print("Model loaded!!")
self.max_input_tokens = 1024
def tokenize_conversation(self, conversation):
def tokenize(text):
return self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text.strip(), add_prefix_space=True)
)
persona = [tokenize(line.strip()) for line in conversation['persona B']]
# partner = [tokenize(line.strip()) for line in conversation['persona A']]
partner = [] # Baseline not trained with the partner personaj
history = [tokenize(line['text'].strip()) for line in conversation['dialogue']]
return persona, partner, history
def prepare_tensors(self, conversation):
persona, partner, history = self.tokenize_conversation(conversation)
input_ids, attention_mask, _, _ = create_encoder_input(persona,
partner,
history,
self.query_id,
self.res_id,
self.latent_id,
self.persona_id,
self.partner_id,
self.sep_id,
self.eos_id
)
tensor_input_ids = torch.tensor(input_ids, device=self.device)[-self.max_input_tokens:].unsqueeze(0)
tensor_attention_mask = torch.tensor(attention_mask, device=self.device)[-self.max_input_tokens:].unsqueeze(0)
return tensor_input_ids, tensor_attention_mask
def generate_responses(self, test_data: List[Dict]) -> List[str]:
"""
You will be provided with a batch of upto 50 independent conversations
Input 1
[
{"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1
...
{"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1
]
Model should return 50 responses for Turn 1
...
Input 7
[
{"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7
...
{"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7
]
Model should return 50 responses for Turn 7
Note: Turn numbers will NOT be provided as input
Return a dictionary with the following format
"use_api": True/False - Note that this cannot be used when using GPU
"prompts": [ <list of the prompts that go as "content" to the api > ] - Note that every call is independent and we don't use threads
"max_generated_tokens": [ list of ints for the max generation limit on each call] - Note that the submission will fail if the total generation limit is exceeded
"final_responses: [ <list of strings with the final responses> ] - Only used when use_api is set to False
for conversation in tqdm(test_data):
tensor_input_ids, tensor_attention_mask = self.prepare_tensors(conversation)
with torch.no_grad():
out_ids = self.model.generate(
input_ids=tensor_input_ids,
attention_mask=tensor_attention_mask,
max_length=50,
num_beams=2
)
response = self.tokenizer.batch_decode(
out_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=False
)[0].strip()
final_responses.append(response)
self.turn_id = self.turn_id % 7 + 1 # Turn id goes from 1 to 7
response = {
"use_api": False, # Ignored if GPU true is set in aicrowd.json
"prompts": ["" for _ in test_data], # Ignored if GPU true is set in aicrowd.json
"final_responses": final_responses
}
return response