diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 133d0ae5429ec21f65ae9bbdb73a66429600c538..294ffab233458f1f3b98c18be50743ba65bd2d73 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -356,7 +356,7 @@ class RailEnv(Environment): # Perform step on all agents for i_agent in range(self.get_num_agents()): - self._step_agent(i_agent, action_dict_) + self._step_agent(i_agent, action_dict_.get(i_agent)) # Check for end of episode + set global reward to all rewards! if np.all([np.array_equal(agent.position, agent.target) for agent in self.agents]): @@ -384,7 +384,7 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, info_dict - def _step_agent(self, i_agent, action_dict_: Dict[int, RailEnvActions]): + def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None): """ Performs a step and step, start and stop penalty on a single agent in the following sub steps: - malfunction @@ -416,10 +416,8 @@ class RailEnv(Environment): # Is the agent at the beginning of the cell? Then, it can take an action. if agent.speed_data['position_fraction'] == 0.0: # No action has been supplied for this agent -> set DO_NOTHING as default - if i_agent not in action_dict_: + if action is None: action = RailEnvActions.DO_NOTHING - else: - action = action_dict_[i_agent] if action < 0 or action > len(RailEnvActions): print('ERROR: illegal action=', action,