diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 23f0dd577833a50a6ce587cfee81c57fd9f3fcb5..1c2df4eaa68c17900a480ab3906084eac8a0b08e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -307,8 +307,19 @@ class RailEnv(Environment): self.obs_builder.reset() self.distance_map.reset(self.agents, self.rail) + info_dict = { + 'action_required': { + i: (agent.status == RailAgentStatus.READY_TO_DEPART or ( + agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0)) + for i, agent in enumerate(self.agents)}, + 'malfunction': { + i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) + }, + 'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} + } # Return the new observation vectors for each agent - return self._get_observations() + return self._get_observations(), info_dict def _agent_malfunction(self, i_agent) -> bool: """