Skip to content
Snippets Groups Projects
Commit 2b40ef62 authored by spiglerg's avatar spiglerg
Browse files

fixed issue #60

parent 739df52f
No related branches found
No related tags found
No related merge requests found
...@@ -7,8 +7,8 @@ from flatland.envs.observations import TreeObsForRailEnv ...@@ -7,8 +7,8 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
random.seed(10) random.seed(1)
np.random.seed(10) np.random.seed(1)
env = RailEnv(width=7, env = RailEnv(width=7,
height=7, height=7,
...@@ -19,7 +19,7 @@ env = RailEnv(width=7, ...@@ -19,7 +19,7 @@ env = RailEnv(width=7,
# Print the observation vector for agent 0 # Print the observation vector for agent 0
obs, all_rewards, done, _ = env.step({0: 0}) obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()): for i in range(env.get_num_agents()):
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7)
env_renderer = RenderTool(env, gl="PIL") env_renderer = RenderTool(env, gl="PIL")
env_renderer.renderEnv(show=True, frames=True) env_renderer.renderEnv(show=True, frames=True)
......
...@@ -31,19 +31,32 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -31,19 +31,32 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent = {} self.location_has_agent = {}
self.location_has_agent_direction = {} self.location_has_agent_direction = {}
self.agents_previous_reset = None
def reset(self): def reset(self):
agents = self.env.agents agents = self.env.agents
nAgents = len(agents) nAgents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nAgents)
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations compute_distance_map = True
self.location_has_target = {tuple(agent.target): 1 for agent in agents} if self.agents_previous_reset is not None:
if nAgents == len(self.agents_previous_reset):
compute_distance_map = False
for i in range(nAgents):
if agents[i].target != self.agents_previous_reset[i].target:
compute_distance_map = True
self.agents_previous_reset = agents
if compute_distance_map:
self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nAgents)
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
def _distance_map_walker(self, position, target_nr): def _distance_map_walker(self, position, target_nr):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment