diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 3bbc1ab092a92d479f91934a871cbbe2aaa30583..3cb38c16c3ea1cf223f9b6c32a06e5faf96c0d52 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -436,11 +436,6 @@ class RailEnv(Environment): self._elapsed_steps += 1 - # Reset the step rewards - self.rewards_dict = dict() - for i_agent in range(self.get_num_agents()): - self.rewards_dict[i_agent] = 0 - # If we're done, set reward and info_dict and step() is done. if self.dones["__all__"]: self.rewards_dict = {} @@ -450,39 +445,50 @@ class RailEnv(Environment): "speed" : {}, "status" : {}, } - for i_agent in range(self.get_num_agents()): + for i_agent, agent in enumerate(self.agents): self.rewards_dict[i_agent] = self.global_reward info_dict["action_required"][i_agent] = False info_dict["malfunction"][i_agent] = 0 info_dict["speed"][i_agent] = 0 - info_dict["status"][i_agent] = self.agents[i_agent].status + info_dict["status"][i_agent] = agent.status return self._get_observations(), self.rewards_dict, self.dones, info_dict - # Perform step on all agents - for i_agent in range(self.get_num_agents()): + # Reset the step rewards + self.rewards_dict = dict() + info_dict = { + "action_required" : {}, + "malfunction" : {}, + "speed" : {}, + "status" : {}, + } + have_all_agents_ended = True # boolean flag to check if all agents are done + for i_agent, agent in enumerate(self.agents): + # Reset the step rewards + self.rewards_dict[i_agent] = 0 + + # Perform step on the agent self._step_agent(i_agent, action_dict_.get(i_agent)) + # manage the boolean flag to check if all agents are indeed done (or done_removed) + have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]) + + # Build info dict + info_dict["action_required"][i_agent] = \ + (agent.status == RailAgentStatus.READY_TO_DEPART or ( + agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, + rtol=1e-03))) + info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction'] + info_dict["speed"][i_agent] = agent.speed_data['speed'] + info_dict["status"][i_agent] = agent.status + # Check for end of episode + set global reward to all rewards! - if np.all([agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED] for agent in self.agents]): + if have_all_agents_ended: self.dones["__all__"] = True self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): self.dones["__all__"] = True - info_dict = { - 'action_required': { - i: (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, - rtol=1e-03))) - for i, agent in enumerate(self.agents)}, - 'malfunction': { - i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) - }, - 'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}, - 'status': {i: agent.status for i, agent in enumerate(self.agents)} - } - return self._get_observations(), self.rewards_dict, self.dones, info_dict def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):