From 2b40ef62209d3ca0bc94ac85bb41a58a19c5779a Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Tue, 11 Jun 2019 11:24:01 +0200 Subject: [PATCH] fixed issue #60 --- examples/simple_example_3.py | 6 +++--- flatland/envs/observations.py | 31 ++++++++++++++++++++++--------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 41fd5a31..1661ef65 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -7,8 +7,8 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -random.seed(10) -np.random.seed(10) +random.seed(1) +np.random.seed(1) env = RailEnv(width=7, height=7, @@ -19,7 +19,7 @@ env = RailEnv(width=7, # Print the observation vector for agent 0 obs, all_rewards, done, _ = env.step({0: 0}) for i in range(env.get_num_agents()): - env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) + env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7) env_renderer = RenderTool(env, gl="PIL") env_renderer.renderEnv(show=True, frames=True) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 5cc3f26d..2214544a 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -31,19 +31,32 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent = {} self.location_has_agent_direction = {} + self.agents_previous_reset = None + def reset(self): agents = self.env.agents nAgents = len(agents) - self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, - self.env.height, - self.env.width, - 4)) - self.max_dist = np.zeros(nAgents) - - self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] - # Update local lookup table for all agents' target locations - self.location_has_target = {tuple(agent.target): 1 for agent in agents} + compute_distance_map = True + if self.agents_previous_reset is not None: + if nAgents == len(self.agents_previous_reset): + compute_distance_map = False + for i in range(nAgents): + if agents[i].target != self.agents_previous_reset[i].target: + compute_distance_map = True + self.agents_previous_reset = agents + + if compute_distance_map: + self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, + self.env.height, + self.env.width, + 4)) + self.max_dist = np.zeros(nAgents) + + self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] + + # Update local lookup table for all agents' target locations + self.location_has_target = {tuple(agent.target): 1 for agent in agents} def _distance_map_walker(self, position, target_nr): """ -- GitLab