Skip to content
Snippets Groups Projects
Commit 03955d20 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

add bart baseline

parent 67d2c707
No related branches found
No related tags found
No related merge requests found
Showing
with 51471 additions and 2 deletions
agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random/checkpoint_epoch_30/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
# Adapted from - https://github.com/Silin159/PersonaChat-BART-PeaCoK/blob/main/eval_parlai.py
from typing import List, Dict
import torch
from transformers import BartTokenizer
from tqdm.auto import tqdm
from agents.bart_model.eval_utils import create_encoder_input
from agents.bart_model.modeling_bart import BartForConditionalGeneration
class BARTResponseAgent(object):
def __init__(self):
""" Load your model(s) here """
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
tokenizer_path = 'agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random'
checkpoint_path = 'agents/bart_model/checkpoints/checkpoints_persona_chat_peacok_random/checkpoint_epoch_30/'
self.tokenizer = BartTokenizer.from_pretrained(tokenizer_path)
self.model = BartForConditionalGeneration.from_pretrained(checkpoint_path)
self.query_id, self.res_id, self.latent_id, self.persona_id, self.partner_id = \
self.tokenizer.convert_tokens_to_ids([
'<query>', '<response>', '<latent>', '<persona>', '<partner>'
])
self.bos_id = self.tokenizer.bos_token_id
self.eos_id = self.tokenizer.eos_token_id
self.pad_id = self.tokenizer.pad_token_id
self.sep_id = self.tokenizer.sep_token_id
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = self.model.to(self.device)
self.model.eval()
self.turn_id = 1
print("Model loaded!!")
self.max_input_tokens = 1024
def tokenize_conversation(self, conversation):
def tokenize(text):
return self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text.strip(), add_prefix_space=True)
)
persona = [tokenize(line.strip()) for line in conversation['persona B']]
# partner = [tokenize(line.strip()) for line in conversation['persona A']]
partner = [] # Baseline not trained with the partner personaj
history = [tokenize(line['text'].strip()) for line in conversation['dialogue'][:-1]]
return persona, partner, history
def prepare_tensors(self, conversation):
persona, partner, history = self.tokenize_conversation(conversation)
input_ids, attention_mask, _, _ = create_encoder_input(persona,
partner,
history,
self.query_id,
self.res_id,
self.latent_id,
self.persona_id,
self.partner_id,
self.sep_id,
self.eos_id
)
tensor_input_ids = torch.tensor(input_ids, device=self.device)[-self.max_input_tokens:].unsqueeze(0)
tensor_attention_mask = torch.tensor(attention_mask, device=self.device)[-self.max_input_tokens:].unsqueeze(0)
return tensor_input_ids, tensor_attention_mask
def generate_responses(self, test_data: List[Dict]) -> List[str]:
"""
You will be provided with a batch of upto 50 independent conversations
Return a string for every conversation
Input 1
[
{"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1
...
{"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1
]
Model should return 50 responses for Turn 1
...
Input 7
[
{"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7
...
{"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7
]
Model should return 50 responses for Turn 7
Note: Turn numbers will NOT be provided as input
"""
all_responses = []
for conversation in tqdm(test_data):
tensor_input_ids, tensor_attention_mask = self.prepare_tensors(conversation)
with torch.no_grad():
out_ids = self.model.generate(
input_ids=tensor_input_ids,
attention_mask=tensor_attention_mask,
max_length=50,
num_beams=2
)
response = self.tokenizer.batch_decode(
out_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
clean_up_tokenization_spaces=False
)[0].strip()
all_responses.append(response)
self.turn_id = self.turn_id % 7 + 1 # Turn id goes from 1 to 7
return all_responses
File added
{"<query>": 50265, "<response>": 50266, "<latent>": 50267, "<persona>": 50268}
\ No newline at end of file
{
"_name_or_path": "facebook/bart-large",
"activation_dropout": 0.1,
"activation_function": "gelu",
"add_bias_logits": false,
"add_final_layer_norm": false,
"architectures": [
"BartForConditionalGeneration"
],
"attention_dropout": 0.1,
"bos_token_id": 0,
"classif_dropout": 0.1,
"classifier_dropout": 0.0,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 50266,
"dropout": 0.1,
"early_stopping": true,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 12,
"eos_token_id": 2,
"forced_eos_token_id": 2,
"gradient_checkpointing": false,
"id2label": {
"0": "LABEL_0"
},
"init_std": 0.02,
"is_encoder_decoder": true,
"label2id": {
"LABEL_0": 0
},
"max_position_embeddings": 1024,
"model_type": "bart",
"no_repeat_ngram_size": 3,
"normalize_before": false,
"num_beams": 4,
"num_hidden_layers": 12,
"pad_token_id": 1,
"scale_embedding": false,
"task_specific_params": {
"summarization": {
"length_penalty": 1.0,
"max_length": 128,
"min_length": 12,
"num_beams": 4
},
"summarization_cnn": {
"length_penalty": 2.0,
"max_length": 142,
"min_length": 56,
"num_beams": 4
},
"summarization_xsum": {
"length_penalty": 1.0,
"max_length": 62,
"min_length": 11,
"num_beams": 6
}
},
"torch_dtype": "float32",
"transformers_version": "4.14.1",
"use_cache": true,
"vocab_size": 50269
}
FINAL F1: 0.2186FINAL BLEU: 0.01491
\ No newline at end of file
{
"_name_or_path": "facebook/bart-large",
"activation_dropout": 0.1,
"activation_function": "gelu",
"add_bias_logits": false,
"add_final_layer_norm": false,
"architectures": [
"BartModel"
],
"attention_dropout": 0.1,
"bos_token_id": 0,
"classif_dropout": 0.1,
"classifier_dropout": 0.0,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 50266,
"dropout": 0.1,
"early_stopping": true,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 12,
"eos_token_id": 2,
"forced_eos_token_id": 2,
"gradient_checkpointing": false,
"id2label": {
"0": "LABEL_0"
},
"init_std": 0.02,
"is_encoder_decoder": true,
"label2id": {
"LABEL_0": 0
},
"max_position_embeddings": 1024,
"model_type": "bart",
"no_repeat_ngram_size": 3,
"normalize_before": false,
"num_beams": 4,
"num_hidden_layers": 12,
"pad_token_id": 1,
"scale_embedding": false,
"task_specific_params": {
"summarization": {
"length_penalty": 1.0,
"max_length": 128,
"min_length": 12,
"num_beams": 4
},
"summarization_cnn": {
"length_penalty": 2.0,
"max_length": 142,
"min_length": 56,
"num_beams": 4
},
"summarization_xsum": {
"length_penalty": 1.0,
"max_length": 62,
"min_length": 11,
"num_beams": 6
}
},
"transformers_version": "4.14.1",
"use_cache": true,
"vocab_size": 50269
}
{"bos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "sep_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "cls_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true}, "additional_special_tokens": ["<query>", "<response>", "<latent>", "<persona>"]}
\ No newline at end of file
{"errors": "replace", "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "sep_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "cls_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "model_max_length": 1024, "special_tokens_map_file": null, "tokenizer_file": "/home/cutura/.cache/huggingface/transformers/d94f53c8851dcda40774f97280e634b94b721a58e71bcc152b5f51d0d49a046a.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730", "name_or_path": "facebook/bart-large", "tokenizer_class": "BartTokenizer"}
\ No newline at end of file
File added
# Taken as is from - https://github.com/Silin159/PersonaChat-BART-PeaCoK/blob/main/eval_utils.py
from torch.nn.utils.rnn import pad_sequence
import torch
import numpy as np
def create_encoder_input(
per,
partner,
history,
query_id, res_id, latent_id, persona_id, partner_id, sep_id, eos_id
):
encoder_input_ids = []
per_input_ids = [latent_id] + [persona_id]
for x in per:
per_input_ids += x + [sep_id]
partner_input_ids = [partner_id]
for x in partner:
partner_input_ids += x + [sep_id]
encoder_input_ids += per_input_ids + partner_input_ids
for i in range(len(history)):
if i % 2 == 0:
encoder_input_ids += [query_id] + history[i] + [eos_id]
else:
encoder_input_ids += [res_id] + history[i] + [eos_id]
attention_mask = [1] * len(encoder_input_ids)
per_attention_mask = [1] * len(per_input_ids)
return encoder_input_ids, attention_mask, per_input_ids, per_attention_mask
def create_decoder_input(response_ids, res_id, eos_id, golden=None):
assert golden != None
decoder_lmlabel= response_ids + [eos_id]
decoder_input_ids = [res_id] + response_ids
decoder_cls_index = [-100] * (len(decoder_lmlabel) - 1) + [eos_id]
decoder_attention_mask = [1] * len(decoder_input_ids)
if golden == False:
decoder_lmlabel = [-100] * len(decoder_lmlabel)
assert len(decoder_lmlabel) == len(decoder_input_ids)
return decoder_lmlabel, decoder_input_ids, decoder_cls_index, decoder_attention_mask
def pad_dataset(dataset, pad_id):
for item_name, item in dataset.items():
if item_name == "input_ids" or item_name == "per_input_ids":
item = pad_sequence([torch.from_numpy(np.array(x)) for x in item],
batch_first=True, padding_value=pad_id)
dataset[item_name] = item
elif item_name == "lmlabels":
item = pad_sequence([torch.from_numpy(np.array(x)) for x in item],
batch_first=True, padding_value=-100)
dataset[item_name] = item
elif item_name == "attention_mask" or item_name == "decoder_attention_mask" or item_name == "per_attention_mask":
item = pad_sequence([torch.from_numpy(np.array(x)) for x in item],
batch_first=True, padding_value=0)
dataset[item_name] = item
elif item_name == "decoder_input_ids":
item = pad_sequence([torch.from_numpy(np.array(x)) for x in item],
batch_first=True, padding_value=pad_id)
dataset[item_name] = item
elif item_name == "clslabel":
dataset[item_name] = torch.tensor(item).view(-1,1)
elif item_name == "cls_index":
item = pad_sequence([torch.from_numpy(np.array(x)) for x in item],
batch_first=True, padding_value=-100)
dataset[item_name] = item
return dataset
\ No newline at end of file
This diff is collapsed.
from agents.dummy_agent import DummyResponseAgent from agents.dummy_agent import DummyResponseAgent
from agents.bart_agent import BARTResponseAgent
UserAgent = DummyResponseAgent # UserAgent = DummyResponseAgent
\ No newline at end of file UserAgent = BARTResponseAgent
\ No newline at end of file
numpy numpy
\ No newline at end of file transformers
torch
accelerate
nltk
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment