Newer
Older
# Copyright (c) Sony Group Corporation.
# Released under the MIT license
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Taken as is from - https://github.com/Silin159/PersonaChat-BART-PeaCoK/blob/main/eval_utils.py
from torch.nn.utils.rnn import pad_sequence
import torch
import numpy as np
def create_encoder_input(
per,
partner,
history,
query_id, res_id, latent_id, persona_id, partner_id, sep_id, eos_id
):
encoder_input_ids = []
per_input_ids = [latent_id] + [persona_id]
for x in per:
per_input_ids += x + [sep_id]
partner_input_ids = [partner_id]
for x in partner:
partner_input_ids += x + [sep_id]
encoder_input_ids += per_input_ids + partner_input_ids
for i in range(len(history)):
if i % 2 == 0:
encoder_input_ids += [query_id] + history[i] + [eos_id]
else:
encoder_input_ids += [res_id] + history[i] + [eos_id]
attention_mask = [1] * len(encoder_input_ids)
per_attention_mask = [1] * len(per_input_ids)
return encoder_input_ids, attention_mask, per_input_ids, per_attention_mask
def create_decoder_input(response_ids, res_id, eos_id, golden=None):
assert golden != None
decoder_lmlabel= response_ids + [eos_id]
decoder_input_ids = [res_id] + response_ids
decoder_cls_index = [-100] * (len(decoder_lmlabel) - 1) + [eos_id]
decoder_attention_mask = [1] * len(decoder_input_ids)
if golden == False:
decoder_lmlabel = [-100] * len(decoder_lmlabel)
assert len(decoder_lmlabel) == len(decoder_input_ids)
return decoder_lmlabel, decoder_input_ids, decoder_cls_index, decoder_attention_mask
def pad_dataset(dataset, pad_id):
for item_name, item in dataset.items():
if item_name == "input_ids" or item_name == "per_input_ids":
item = pad_sequence([torch.from_numpy(np.array(x)) for x in item],
batch_first=True, padding_value=pad_id)
dataset[item_name] = item
elif item_name == "lmlabels":
item = pad_sequence([torch.from_numpy(np.array(x)) for x in item],
batch_first=True, padding_value=-100)
dataset[item_name] = item
elif item_name == "attention_mask" or item_name == "decoder_attention_mask" or item_name == "per_attention_mask":
item = pad_sequence([torch.from_numpy(np.array(x)) for x in item],
batch_first=True, padding_value=0)
dataset[item_name] = item
elif item_name == "decoder_input_ids":
item = pad_sequence([torch.from_numpy(np.array(x)) for x in item],
batch_first=True, padding_value=pad_id)
dataset[item_name] = item
elif item_name == "clslabel":
dataset[item_name] = torch.tensor(item).view(-1,1)
elif item_name == "cls_index":
item = pad_sequence([torch.from_numpy(np.array(x)) for x in item],
batch_first=True, padding_value=-100)
dataset[item_name] = item