From d4efd7847331df416b5c53bba85181015b154cd8 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Mon, 18 Jan 2021 14:59:39 +0100
Subject: [PATCH] action_size (full|reduced) and log writting

---
 .../multi_agent_training.py                   | 15 ++++++---
 utils/agent_action_config.py                  | 31 +++++++++++++++----
 utils/dead_lock_avoidance_agent.py            |  9 ++++--
 3 files changed, 42 insertions(+), 13 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 2d513ed..2e74d68 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -22,7 +22,8 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
 from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent
 from reinforcement_learning.multi_decision_agent import MultiDecisionAgent
 from reinforcement_learning.ppo_agent import PPOPolicy
-from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action
+from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action, \
+    set_action_size_reduced, set_action_size_full, map_action_policy
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 
 base_dir = Path(__file__).resolve().parent.parent
@@ -169,6 +170,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     scores_window = deque(maxlen=checkpoint_interval)  # todo smooth when rendering instead
     completion_window = deque(maxlen=checkpoint_interval)
 
+    if train_params.action_size == "reduced":
+        set_action_size_reduced()
+    else:
+        set_action_size_full()
+
     # Double Dueling DQN policy
     if train_params.policy == "DDDQN":
         policy = DDDQNPolicy(state_size, get_action_size(), train_params)
@@ -212,7 +218,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                 hdd.free / (2 ** 30)))
 
     # TensorBoard writer
-    writer = SummaryWriter()
+    writer = SummaryWriter(comment="_" + train_params.policy + "_" + train_params.action_size)
 
     training_timer = Timer()
     training_timer.start()
@@ -313,7 +319,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                     learn_timer.start()
                     policy.step(agent_handle,
                                 agent_prev_obs[agent_handle],
-                                agent_prev_action[agent_handle] - 1,
+                                map_action_policy(agent_prev_action[agent_handle]),
                                 all_rewards[agent_handle],
                                 agent_obs[agent_handle],
                                 done[agent_handle])
@@ -540,7 +546,8 @@ if __name__ == "__main__":
     parser.add_argument("--max_depth", help="max depth", default=2, type=int)
     parser.add_argument("--policy",
                         help="policy name [DDDQN, PPO, DeadLockAvoidance, DeadLockAvoidanceWithDecision, MultiDecision]",
-                        default="ppo")
+                        default="DeadLockAvoidance")
+    parser.add_argument("--action_size", help="define the action size [reduced,full]", default="full", type=str)
 
     training_params = parser.parse_args()
     env_params = [
diff --git a/utils/agent_action_config.py b/utils/agent_action_config.py
index 4c1f83f..9c2af58 100644
--- a/utils/agent_action_config.py
+++ b/utils/agent_action_config.py
@@ -1,26 +1,45 @@
 from flatland.envs.rail_env import RailEnvActions
 
+# global action size
+global _agent_action_config_action_size
+_agent_action_config_action_size = 5
+
 
 def get_flatland_full_action_size():
     # The action space of flatland is 5 discrete actions
     return 5
 
 
+def set_action_size_full():
+    # The agents (DDDQN, PPO, ... ) have this actions space
+    _agent_action_config_action_size = 5
+
+
+def set_action_size_reduced():
+    # The agents (DDDQN, PPO, ... ) have this actions space
+    _agent_action_config_action_size = 4
+
+
 def get_action_size():
     # The agents (DDDQN, PPO, ... ) have this actions space
-    return 4
+    return _agent_action_config_action_size
 
 
 def map_actions(actions):
     # Map the
-    if get_action_size() == get_flatland_full_action_size():
-        return actions
-    for key in actions:
-        value = actions.get(key, 0)
-        actions.update({key: map_action(value)})
+    if get_action_size() != get_flatland_full_action_size():
+        for key in actions:
+            value = actions.get(key, 0)
+            actions.update({key: map_action(value)})
     return actions
 
 
+def map_action_policy(action):
+    if get_action_size() != get_flatland_full_action_size():
+        return action - 1
+    return action
+
+
 def map_action(action):
     if get_action_size() == get_flatland_full_action_size():
         return action
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index ac7fd0c..4c4c903 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -67,7 +67,8 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
         self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
 
 class DeadLockAvoidanceAgent(HeuristicPolicy):
-    def __init__(self, env: RailEnv, action_size, show_debug_plot=False):
+    def __init__(self, env: RailEnv, action_size, enable_eps=False, show_debug_plot=False):
+        print(">> DeadLockAvoidance")
         self.env = env
         self.memory = DummyMemory()
         self.loss = 0
@@ -76,14 +77,16 @@ class DeadLockAvoidanceAgent(HeuristicPolicy):
         self.agent_can_move_value = {}
         self.switches = {}
         self.show_debug_plot = show_debug_plot
+        self.enable_eps = enable_eps
 
     def step(self, handle, state, action, reward, next_state, done):
         pass
 
     def act(self, handle, state, eps=0.):
         # Epsilon-greedy action selection
-        if np.random.random() < eps:
-            return np.random.choice(np.arange(self.action_size))
+        if self.enable_eps:
+            if np.random.random() < eps:
+                return np.random.choice(np.arange(self.action_size))
 
         # agent = self.env.agents[state[0]]
         check = self.agent_can_move.get(handle, None)
-- 
GitLab