From 3f3030e2babe7f5ae4352726464132f517614710 Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipamc77@gmail.com> Date: Thu, 7 Dec 2023 23:42:19 +0530 Subject: [PATCH] remove persona A --- agents/README.md | 8 ++++---- agents/bart_agent.py | 19 +++++++++---------- agents/dummy_agent.py | 8 ++++---- agents/prompt_agent.py | 8 ++++---- agents/user_config.py | 4 ++-- local_evaluation.py | 2 +- local_evaluation_with_api.py | 2 +- 7 files changed, 25 insertions(+), 26 deletions(-) diff --git a/agents/README.md b/agents/README.md index c087b33..58debb7 100644 --- a/agents/README.md +++ b/agents/README.md @@ -23,9 +23,9 @@ Each batch will be of the following format: ```python Input 1 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 1 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 1 ] Model should return 50 responses for Turn 1 @@ -33,9 +33,9 @@ Model should return 50 responses for Turn 1 ... Input 7 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 7 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 7 ] ``` diff --git a/agents/bart_agent.py b/agents/bart_agent.py index 4711b8d..cd702e3 100644 --- a/agents/bart_agent.py +++ b/agents/bart_agent.py @@ -48,8 +48,7 @@ class BARTResponseAgent(object): ) persona = [tokenize(line.strip()) for line in conversation['persona B']] - # partner = [tokenize(line.strip()) for line in conversation['persona A']] - partner = [] # Baseline not trained with the partner personaj + partner = [] history = [tokenize(line['text'].strip()) for line in conversation['dialogue']] return persona, partner, history @@ -58,12 +57,12 @@ class BARTResponseAgent(object): input_ids, attention_mask, _, _ = create_encoder_input(persona, partner, history, - self.query_id, - self.res_id, + self.query_id, + self.res_id, self.latent_id, - self.persona_id, + self.persona_id, self.partner_id, - self.sep_id, + self.sep_id, self.eos_id ) tensor_input_ids = torch.tensor(input_ids, device=self.device)[-self.max_input_tokens:].unsqueeze(0) @@ -77,9 +76,9 @@ class BARTResponseAgent(object): Input 1 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 1 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 1 ] Model should return 50 responses for Turn 1 @@ -87,9 +86,9 @@ class BARTResponseAgent(object): ... Input 7 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 7 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 7 ] Model should return 50 responses for Turn 7 diff --git a/agents/dummy_agent.py b/agents/dummy_agent.py index abdd9f2..027fb1b 100644 --- a/agents/dummy_agent.py +++ b/agents/dummy_agent.py @@ -11,9 +11,9 @@ class DummyResponseAgent(object): Input 1 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 1 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 1 ] Model should return 50 responses for Turn 1 @@ -21,9 +21,9 @@ class DummyResponseAgent(object): ... Input 7 [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 7 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 7 ] Model should return 50 responses for Turn 7 diff --git a/agents/prompt_agent.py b/agents/prompt_agent.py index dd79368..d40e4ee 100644 --- a/agents/prompt_agent.py +++ b/agents/prompt_agent.py @@ -16,9 +16,9 @@ class DummyPromptAgent(object): Input 1 (test_data) [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 1 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 1 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 1 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 1 ] Model should return 50 responses for Turn 1 @@ -26,9 +26,9 @@ class DummyPromptAgent(object): ... Input 7 (test_data) [ - {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1 Turn 7 + {"persona B": ... "dialogue": ... }, # conversation 1 Turn 7 ... - {"persona A": ..., "persona B": ... "dialogue": ... } # conversation 50 Turn 7 + {"persona B": ... "dialogue": ... } # conversation 50 Turn 7 ] Model should return 50 responses for Turn 7 diff --git a/agents/user_config.py b/agents/user_config.py index f67f832..6a00cc6 100644 --- a/agents/user_config.py +++ b/agents/user_config.py @@ -3,5 +3,5 @@ from agents.prompt_agent import DummyPromptAgent from agents.bart_agent import BARTResponseAgent # UserAgent = DummyResponseAgent -UserAgent = DummyPromptAgent -# UserAgent = BARTResponseAgent \ No newline at end of file +# UserAgent = DummyPromptAgent +UserAgent = BARTResponseAgent \ No newline at end of file diff --git a/local_evaluation.py b/local_evaluation.py index e5b8834..97c99c2 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -26,7 +26,7 @@ def load_json_data(file_path: str, keys: List[str]) -> List[Dict]: def load_data(file_path: str) -> List[Dict]: # NOTE to participants: Gold reference will not available during actual evaluations - keys = ["persona A", "persona B", "dialogue", "gold_reference"] + keys = ["persona B", "dialogue", "gold_reference"] return load_json_data(file_path, keys) diff --git a/local_evaluation_with_api.py b/local_evaluation_with_api.py index b1de213..a339e3c 100644 --- a/local_evaluation_with_api.py +++ b/local_evaluation_with_api.py @@ -28,7 +28,7 @@ def load_json_data(file_path: str, keys: List[str]) -> List[Dict]: def load_data(file_path: str) -> List[Dict]: # NOTE to participants: Gold reference will not available during actual evaluations - keys = ["persona A", "persona B", "dialogue", "gold_reference"] + keys = ["persona B", "dialogue", "gold_reference"] return load_json_data(file_path, keys) class LLM_API: -- GitLab