Skip to content
Snippets Groups Projects
Commit 22c39889 authored by spmohanty's avatar spmohanty
Browse files

Addresses #229 - Refactor internal lookup table creation to not use dict comprehensions

parent 9239f752
No related branches found
No related tags found
No related merge requests found
...@@ -164,25 +164,25 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -164,25 +164,25 @@ class TreeObsForRailEnv(ObservationBuilder):
# Update local lookup table for all agents' positions # Update local lookup table for all agents' positions
# ignore other agents not in the grid (only status active and done) # ignore other agents not in the grid (only status active and done)
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.location_has_agent_speed = {}
self.location_has_agent_malfunction = {}
self.location_has_agent_ready_to_depart = {} self.location_has_agent_ready_to_depart = {}
for agent in self.env.agents:
if agent.status == RailAgentStatus.READY_TO_DEPART: for _agent in self.env.agents:
self.location_has_agent_ready_to_depart[tuple(agent.initial_position)] = \ if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]:
self.location_has_agent_ready_to_depart.get(tuple(agent.initial_position), 0) + 1 self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction = { self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
tuple(agent.position): agent.direction self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data['malfunction']
}
self.location_has_agent_speed = { if _agent.status in [RailAgentStatus.READY_TO_DEPART]:
tuple(agent.position): agent.speed_data['speed'] self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
}
self.location_has_agent_malfunction = {
tuple(agent.position): agent.malfunction_data['malfunction']
for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
}
if handle > len(self.env.agents): if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment