Skip to content
Snippets Groups Projects
Commit 00cf76e3 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated tree obs to not do computations for each agent every time again

parent 9ee8a299
No related branches found
No related tags found
No related merge requests found
...@@ -78,6 +78,30 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -78,6 +78,30 @@ class TreeObsForRailEnv(ObservationBuilder):
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
self.predicted_dir.update({t: dir_list}) self.predicted_dir.update({t: dir_list})
self.max_prediction_depth = len(self.predicted_pos) self.max_prediction_depth = len(self.predicted_pos)
# Update local lookup table for all agents' positions
# ignore other agents not in the grid (only status active and done)
# self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
# agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.location_has_agent_speed = {}
self.location_has_agent_malfunction = {}
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
'malfunction']
if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
_agent.initial_position:
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
observations = super().get_many(handles) observations = super().get_many(handles)
...@@ -162,30 +186,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -162,30 +186,6 @@ class TreeObsForRailEnv(ObservationBuilder):
In case the target node is reached, the values are [0, 0, 0, 0, 0]. In case the target node is reached, the values are [0, 0, 0, 0, 0].
""" """
# Update local lookup table for all agents' positions
# ignore other agents not in the grid (only status active and done)
# self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
# agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.location_has_agent_speed = {}
self.location_has_agent_malfunction = {}
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data['malfunction']
if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
_agent.initial_position:
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
if handle > len(self.env.agents): if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index agent = self.env.agents[handle] # TODO: handle being treated as index
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment