diff --git a/agents/README.md b/agents/README.md index c087b330fdec7bb6c57a8d482ca79d68b451c0ee..58debb7b248ee266fe391357f4015a695eefecaa 100644 --- a/agents/README.md +++ b/agents/README.md @@ -23,9 +23,9 @@ Each batch will be of the following format: ```python Input 1 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 1 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 1 ] Model should return 50 responses for Turn 1 @@ -33,9 +33,9 @@ Model should return 50 responses for Turn 1 ... Input 7 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 7 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 7 ] ``` diff --git a/agents/bart_agent.py b/agents/bart_agent.py index 4711b8d32960a864f27112a8a3e166f2b391918e..cd702e3d0c1769955cfaab4cf6a1813825c73f35 100644 --- a/agents/bart_agent.py +++ b/agents/bart_agent.py @@ -48,8 +48,7 @@ class BARTResponseAgent(object): ) persona = [tokenize(line.strip()) for line in conversation['persona B']] - # partner = [tokenize(line.strip()) for line in conversation['persona A']] - partner = [] # Baseline not trained with the partner personaj + partner = [] history = [tokenize(line['text'].strip()) for line in conversation['dialogue']] return persona, partner, history @@ -58,12 +57,12 @@ class BARTResponseAgent(object): input_ids, attention_mask, _, _ = create_encoder_input(persona, partner, history, - self.query_id, - self.res_id, + self.query_id, + self.res_id, self.latent_id, - self.persona_id, + self.persona_id, self.partner_id, - self.sep_id, + self.sep_id, self.eos_id ) tensor_input_ids = torch.tensor(input_ids, device=self.device)[-self.max_input_tokens:].unsqueeze(0) @@ -77,9 +76,9 @@ class BARTResponseAgent(object): Input 1 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 1 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 1 ] Model should return 50 responses for Turn 1 @@ -87,9 +86,9 @@ class BARTResponseAgent(object): ... Input 7 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 7 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 7 ] Model should return 50 responses for Turn 7 diff --git a/agents/dummy_agent.py b/agents/dummy_agent.py index abdd9f28bf999b526566cdb26f5afd03c635aad7..027fb1b2054762d359da76e8bb18f09368877bef 100644 --- a/agents/dummy_agent.py +++ b/agents/dummy_agent.py @@ -11,9 +11,9 @@ class DummyResponseAgent(object): Input 1 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 1 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 1 ] Model should return 50 responses for Turn 1 @@ -21,9 +21,9 @@ class DummyResponseAgent(object): ... Input 7 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 7 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 7 ] Model should return 50 responses for Turn 7 diff --git a/agents/prompt_agent.py b/agents/prompt_agent.py index dd793682ac4cdbeb1d195f0d2e8668d67919e78d..d40e4eed924dc5cad0776cc99a95f1b1fd8029f2 100644 --- a/agents/prompt_agent.py +++ b/agents/prompt_agent.py @@ -16,9 +16,9 @@ class DummyPromptAgent(object): Input 1 (test_data) [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 1 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 1 ] Model should return 50 responses for Turn 1 @@ -26,9 +26,9 @@ class DummyPromptAgent(object): ... Input 7 (test_data) [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 7 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 7 ] Model should return 50 responses for Turn 7 diff --git a/agents/user_config.py b/agents/user_config.py index f67f8323c33ec1a9ce0025a4543b50af0c47b186..6a00cc611827a58a9a4dae07ffea671bd5db0c36 100644 --- a/agents/user_config.py +++ b/agents/user_config.py @@ -3,5 +3,5 @@ from agents.prompt_agent import DummyPromptAgent from agents.bart_agent import BARTResponseAgent # UserAgent = DummyResponseAgent -UserAgent = DummyPromptAgent -# UserAgent = BARTResponseAgent \ No newline at end of file +# UserAgent = DummyPromptAgent +UserAgent = BARTResponseAgent \ No newline at end of file diff --git a/local_evaluation.py b/local_evaluation.py index e5b8834cb544dd56164de8b614d61f5ddb88871f..97c99c2d50807a97e120757517b377598803cee4 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -26,7 +26,7 @@ def load_json_data(file_path: str, keys: List[str]) -> List[Dict]: def load_data(file_path: str) -> List[Dict]: # NOTE to participants: Gold reference will not available during actual evaluations - keys = ["persona A", "persona B", "dialogue", "gold_reference"] + keys = ["persona B", "dialogue", "gold_reference"] return load_json_data(file_path, keys) diff --git a/local_evaluation_with_api.py b/local_evaluation_with_api.py index b1de213e001f02e1791092b6e39377f93fafa675..a339e3c25b3ee994301b072083245406fac39fa6 100644 --- a/local_evaluation_with_api.py +++ b/local_evaluation_with_api.py @@ -28,7 +28,7 @@ def load_json_data(file_path: str, keys: List[str]) -> List[Dict]: def load_data(file_path: str) -> List[Dict]: # NOTE to participants: Gold reference will not available during actual evaluations - keys = ["persona A", "persona B", "dialogue", "gold_reference"] + keys = ["persona B", "dialogue", "gold_reference"] return load_json_data(file_path, keys) class LLM_API: