From 718646755e32dea55f645360f093478daa37469f Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 19 Sep 2019 09:30:27 +0200
Subject: [PATCH] #178 pass action instead of action_dict to step_agent

---
 flatland/envs/rail_env.py | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 133d0ae5..294ffab2 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -356,7 +356,7 @@ class RailEnv(Environment):
 
         # Perform step on all agents
         for i_agent in range(self.get_num_agents()):
-            self._step_agent(i_agent, action_dict_)
+            self._step_agent(i_agent, action_dict_.get(i_agent))
 
         # Check for end of episode + set global reward to all rewards!
         if np.all([np.array_equal(agent.position, agent.target) for agent in self.agents]):
@@ -384,7 +384,7 @@ class RailEnv(Environment):
 
         return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
-    def _step_agent(self, i_agent, action_dict_: Dict[int, RailEnvActions]):
+    def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):
         """
         Performs a step and step, start and stop penalty on a single agent in the following sub steps:
         - malfunction
@@ -416,10 +416,8 @@ class RailEnv(Environment):
         # Is the agent at the beginning of the cell? Then, it can take an action.
         if agent.speed_data['position_fraction'] == 0.0:
             # No action has been supplied for this agent -> set DO_NOTHING as default
-            if i_agent not in action_dict_:
+            if action is None:
                 action = RailEnvActions.DO_NOTHING
-            else:
-                action = action_dict_[i_agent]
 
             if action < 0 or action > len(RailEnvActions):
                 print('ERROR: illegal action=', action,
-- 
GitLab