Skip to content
Snippets Groups Projects
Commit 7ecdafb9 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

added possibility of multiple agents on same location to tree observation

parent d925dfb7
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
...@@ -39,7 +39,7 @@ env = RailEnv(width=100, ...@@ -39,7 +39,7 @@ env = RailEnv(width=100,
schedule_generator=sparse_schedule_generator(speed_ration_map), schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=20, number_of_agents=20,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=GlobalObsForRailEnv(), obs_builder_object=TreeObservation,
remove_agents_at_target=True remove_agents_at_target=True
) )
......
...@@ -160,7 +160,12 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -160,7 +160,12 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
# Update local lookup table for all agents' positions # Update local lookup table for all agents' positions
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} self.location_has_agent = dict()
for agent in self.env.agents:
if tuple(agent.position) in self.location_has_agent:
self.location_has_agent[tuple(agent.position)] = self.location_has_agent[tuple(agent.position)] + 1
else:
self.location_has_agent[tuple(agent.position)] = 1
# TODO: Update this to handle number of agents at same location # TODO: Update this to handle number of agents at same location
self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents} self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents} self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment