From 3f3030e2babe7f5ae4352726464132f517614710 Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipamc77@gmail.com>
Date: Thu, 7 Dec 2023 23:42:19 +0530
Subject: [PATCH] remove persona A

---
 agents/README.md             |  8 ++++----
 agents/bart_agent.py         | 19 +++++++++----------
 agents/dummy_agent.py        |  8 ++++----
 agents/prompt_agent.py       |  8 ++++----
 agents/user_config.py        |  4 ++--
 local_evaluation.py          |  2 +-
 local_evaluation_with_api.py |  2 +-
 7 files changed, 25 insertions(+), 26 deletions(-)

diff --git a/agents/README.md b/agents/README.md
index c087b33..58debb7 100644
--- a/agents/README.md
+++ b/agents/README.md
@@ -23,9 +23,9 @@ Each batch will be of the following format:
 ```python
 Input 1
 [
-    {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 1
+    {"persona B": ... "dialogue": ... }, # conversation 1  Turn 1
     ...
-    {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 1
+    {"persona B": ... "dialogue": ... }  # conversation 50 Turn 1
 ]
 
 Model should return 50 responses for Turn 1
@@ -33,9 +33,9 @@ Model should return 50 responses for Turn 1
 ...
 Input 7
 [
-    {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 7
+    {"persona B": ... "dialogue": ... }, # conversation 1  Turn 7
     ...
-    {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 7
+    {"persona B": ... "dialogue": ... }  # conversation 50 Turn 7
 ]
 ```
 
diff --git a/agents/bart_agent.py b/agents/bart_agent.py
index 4711b8d..cd702e3 100644
--- a/agents/bart_agent.py
+++ b/agents/bart_agent.py
@@ -48,8 +48,7 @@ class BARTResponseAgent(object):
             )
        
         persona = [tokenize(line.strip()) for line in conversation['persona B']]
-        # partner = [tokenize(line.strip()) for line in conversation['persona A']]
-        partner = [] # Baseline not trained with the partner personaj
+        partner = []
         history = [tokenize(line['text'].strip()) for line in conversation['dialogue']]
         return persona, partner, history
    
@@ -58,12 +57,12 @@ class BARTResponseAgent(object):
         input_ids, attention_mask, _, _ = create_encoder_input(persona,
             partner,
             history,
-            self.query_id, 
-            self.res_id, 
+            self.query_id,
+            self.res_id,
             self.latent_id,
-            self.persona_id, 
+            self.persona_id,
             self.partner_id,
-            self.sep_id, 
+            self.sep_id,
             self.eos_id
         )
         tensor_input_ids = torch.tensor(input_ids, device=self.device)[-self.max_input_tokens:].unsqueeze(0)
@@ -77,9 +76,9 @@ class BARTResponseAgent(object):
 
         Input 1
         [
-            {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 1
+            {"persona B": ... "dialogue": ... }, # conversation 1  Turn 1
             ...
-            {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 1
+            {"persona B": ... "dialogue": ... }  # conversation 50 Turn 1
         ]
 
         Model should return 50 responses for Turn 1
@@ -87,9 +86,9 @@ class BARTResponseAgent(object):
         ...
         Input 7
         [
-            {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 7
+            {"persona B": ... "dialogue": ... }, # conversation 1  Turn 7
             ...
-            {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 7
+            {"persona B": ... "dialogue": ... }  # conversation 50 Turn 7
         ]
         Model should return 50 responses for Turn 7
 
diff --git a/agents/dummy_agent.py b/agents/dummy_agent.py
index abdd9f2..027fb1b 100644
--- a/agents/dummy_agent.py
+++ b/agents/dummy_agent.py
@@ -11,9 +11,9 @@ class DummyResponseAgent(object):
         
         Input 1
         [
-            {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 1
+            {"persona B": ... "dialogue": ... }, # conversation 1  Turn 1
             ...
-            {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 1
+            {"persona B": ... "dialogue": ... }  # conversation 50 Turn 1
         ]
 
         Model should return 50 responses for Turn 1
@@ -21,9 +21,9 @@ class DummyResponseAgent(object):
         ...
         Input 7
         [
-            {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 7
+            {"persona B": ... "dialogue": ... }, # conversation 1  Turn 7
             ...
-            {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 7
+            {"persona B": ... "dialogue": ... }  # conversation 50 Turn 7
         ]
         Model should return 50 responses for Turn 7
 
diff --git a/agents/prompt_agent.py b/agents/prompt_agent.py
index dd79368..d40e4ee 100644
--- a/agents/prompt_agent.py
+++ b/agents/prompt_agent.py
@@ -16,9 +16,9 @@ class DummyPromptAgent(object):
         
         Input 1 (test_data)
         [
-            {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 1
+            {"persona B": ... "dialogue": ... }, # conversation 1  Turn 1
             ...
-            {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 1
+            {"persona B": ... "dialogue": ... }  # conversation 50 Turn 1
         ]
 
         Model should return 50 responses for Turn 1
@@ -26,9 +26,9 @@ class DummyPromptAgent(object):
         ...
         Input 7 (test_data)
         [
-            {"persona A": ..., "persona B": ... "dialogue": ... }, # conversation 1  Turn 7
+            {"persona B": ... "dialogue": ... }, # conversation 1  Turn 7
             ...
-            {"persona A": ..., "persona B": ... "dialogue": ... }  # conversation 50 Turn 7
+            {"persona B": ... "dialogue": ... }  # conversation 50 Turn 7
         ]
         Model should return 50 responses for Turn 7
 
diff --git a/agents/user_config.py b/agents/user_config.py
index f67f832..6a00cc6 100644
--- a/agents/user_config.py
+++ b/agents/user_config.py
@@ -3,5 +3,5 @@ from agents.prompt_agent import DummyPromptAgent
 from agents.bart_agent import BARTResponseAgent
 
 # UserAgent = DummyResponseAgent
-UserAgent = DummyPromptAgent
-# UserAgent = BARTResponseAgent
\ No newline at end of file
+# UserAgent = DummyPromptAgent
+UserAgent = BARTResponseAgent
\ No newline at end of file
diff --git a/local_evaluation.py b/local_evaluation.py
index e5b8834..97c99c2 100644
--- a/local_evaluation.py
+++ b/local_evaluation.py
@@ -26,7 +26,7 @@ def load_json_data(file_path: str, keys: List[str]) -> List[Dict]:
 
 def load_data(file_path: str) -> List[Dict]:
     # NOTE to participants: Gold reference will not available during actual evaluations
-    keys = ["persona A", "persona B", "dialogue", "gold_reference"] 
+    keys = ["persona B", "dialogue", "gold_reference"] 
     return load_json_data(file_path, keys)
 
 
diff --git a/local_evaluation_with_api.py b/local_evaluation_with_api.py
index b1de213..a339e3c 100644
--- a/local_evaluation_with_api.py
+++ b/local_evaluation_with_api.py
@@ -28,7 +28,7 @@ def load_json_data(file_path: str, keys: List[str]) -> List[Dict]:
 
 def load_data(file_path: str) -> List[Dict]:
     # NOTE to participants: Gold reference will not available during actual evaluations
-    keys = ["persona A", "persona B", "dialogue", "gold_reference"]
+    keys = ["persona B", "dialogue", "gold_reference"]
     return load_json_data(file_path, keys)
 
 class LLM_API:
-- 
GitLab