diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 133d0ae5429ec21f65ae9bbdb73a66429600c538..294ffab233458f1f3b98c18be50743ba65bd2d73 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -356,7 +356,7 @@ class RailEnv(Environment):
 
         # Perform step on all agents
         for i_agent in range(self.get_num_agents()):
-            self._step_agent(i_agent, action_dict_)
+            self._step_agent(i_agent, action_dict_.get(i_agent))
 
         # Check for end of episode + set global reward to all rewards!
         if np.all([np.array_equal(agent.position, agent.target) for agent in self.agents]):
@@ -384,7 +384,7 @@ class RailEnv(Environment):
 
         return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
-    def _step_agent(self, i_agent, action_dict_: Dict[int, RailEnvActions]):
+    def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):
         """
         Performs a step and step, start and stop penalty on a single agent in the following sub steps:
         - malfunction
@@ -416,10 +416,8 @@ class RailEnv(Environment):
         # Is the agent at the beginning of the cell? Then, it can take an action.
         if agent.speed_data['position_fraction'] == 0.0:
             # No action has been supplied for this agent -> set DO_NOTHING as default
-            if i_agent not in action_dict_:
+            if action is None:
                 action = RailEnvActions.DO_NOTHING
-            else:
-                action = action_dict_[i_agent]
 
             if action < 0 or action > len(RailEnvActions):
                 print('ERROR: illegal action=', action,