From 00cf76e3dd55b3cad198944ed955701d8220a4bf Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Fri, 25 Oct 2019 11:50:30 -0400 Subject: [PATCH] updated tree obs to not do computations for each agent every time again --- flatland/envs/observations.py | 48 +++++++++++++++++------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8ab96d3a..85909221 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -78,6 +78,30 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) self.max_prediction_depth = len(self.predicted_pos) + # Update local lookup table for all agents' positions + # ignore other agents not in the grid (only status active and done) + # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if + # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} + + self.location_has_agent = {} + self.location_has_agent_direction = {} + self.location_has_agent_speed = {} + self.location_has_agent_malfunction = {} + self.location_has_agent_ready_to_depart = {} + + for _agent in self.env.agents: + if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ + _agent.position: + self.location_has_agent[tuple(_agent.position)] = 1 + self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction + self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] + self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[ + 'malfunction'] + + if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ + _agent.initial_position: + self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ + self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 observations = super().get_many(handles) @@ -162,30 +186,6 @@ class TreeObsForRailEnv(ObservationBuilder): In case the target node is reached, the values are [0, 0, 0, 0, 0]. """ - # Update local lookup table for all agents' positions - # ignore other agents not in the grid (only status active and done) - # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if - # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} - - self.location_has_agent = {} - self.location_has_agent_direction = {} - self.location_has_agent_speed = {} - self.location_has_agent_malfunction = {} - self.location_has_agent_ready_to_depart = {} - - for _agent in self.env.agents: - if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ - _agent.position: - self.location_has_agent[tuple(_agent.position)] = 1 - self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction - self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] - self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data['malfunction'] - - if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ - _agent.initial_position: - self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ - self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 - if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] # TODO: handle being treated as index -- GitLab