diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 9acbeaf806258491d33297473469bb43bd0f4128..c181b8faf837159ea870f7b1afd7c41d7f4a74e0 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -174,13 +174,15 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent_ready_to_depart = {} for _agent in self.env.agents: - if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]: + if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ + _agent.position: self.location_has_agent[tuple(_agent.position)] = 1 self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data['malfunction'] - if _agent.status in [RailAgentStatus.READY_TO_DEPART]: + if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ + _agent.initial_position: self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1