diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8ab96d3a31ac169457adad9899a466592abec705..8590922131d57cd76ba4a638b10e3e7e169a984a 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -78,6 +78,30 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) self.max_prediction_depth = len(self.predicted_pos) + # Update local lookup table for all agents' positions + # ignore other agents not in the grid (only status active and done) + # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if + # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} + + self.location_has_agent = {} + self.location_has_agent_direction = {} + self.location_has_agent_speed = {} + self.location_has_agent_malfunction = {} + self.location_has_agent_ready_to_depart = {} + + for _agent in self.env.agents: + if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ + _agent.position: + self.location_has_agent[tuple(_agent.position)] = 1 + self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction + self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] + self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[ + 'malfunction'] + + if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ + _agent.initial_position: + self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ + self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 observations = super().get_many(handles) @@ -162,30 +186,6 @@ class TreeObsForRailEnv(ObservationBuilder): In case the target node is reached, the values are [0, 0, 0, 0, 0]. """ - # Update local lookup table for all agents' positions - # ignore other agents not in the grid (only status active and done) - # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if - # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} - - self.location_has_agent = {} - self.location_has_agent_direction = {} - self.location_has_agent_speed = {} - self.location_has_agent_malfunction = {} - self.location_has_agent_ready_to_depart = {} - - for _agent in self.env.agents: - if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ - _agent.position: - self.location_has_agent[tuple(_agent.position)] = 1 - self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction - self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] - self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data['malfunction'] - - if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ - _agent.initial_position: - self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ - self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 - if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] # TODO: handle being treated as index