From c63e78aebd33ec11607d8d9118a80f2a53a39510 Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Tue, 17 Sep 2019 16:12:28 +0200 Subject: [PATCH] Refactoring: move logic to compute new distance_map into get method --- flatland/envs/distance_map.py | 54 +++++++++++++++++++++-------------- flatland/envs/rail_env.py | 2 +- tests/tests_generators.py | 2 -- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index d155042e..94019319 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -1,5 +1,5 @@ from collections import deque -from typing import List +from typing import List, Optional import numpy as np @@ -13,8 +13,10 @@ class DistanceMap: self.env_height = env_height self.env_width = env_width self.distance_map = None - self.distance_map_computed = False - self.agents_previous_reset = None + self.agents_previous_computation = None + self.reset_was_called = False + self.agents: List[EnvAgent] = agents + self.rail: Optional[GridTransitionMap] = None """ Set the distance map @@ -26,31 +28,39 @@ class DistanceMap: Get the distance map """ def get(self) -> np.ndarray: - return self.distance_map - """ - Compute the distance map - """ - def compute(self, agents: List[EnvAgent], rail: GridTransitionMap): + if self.reset_was_called: + self.reset_was_called = False + + nb_agents = len(self.agents) + compute_distance_map = True + if self.agents_previous_computation is not None and nb_agents == len(self.agents_previous_computation): + compute_distance_map = False + for i in range(nb_agents): + if self.agents[i].target != self.agents_previous_computation[i].target: + compute_distance_map = True + # Don't compute the distance map if it was loaded + if self.agents_previous_computation is None and self.distance_map is not None: + compute_distance_map = False - nb_agents = len(agents) - compute_distance_map = True - if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): - compute_distance_map = False - for i in range(nb_agents): - if agents[i].target != self.agents_previous_reset[i].target: - compute_distance_map = True - # Don't compute the distance map if it was loaded - if self.agents_previous_reset is None and self.distance_map is not None: - compute_distance_map = False + if compute_distance_map: + self._compute(self.agents, self.rail) - if compute_distance_map: - self._compute(agents, rail) + elif self.distance_map is None: + self._compute(self.agents, self.rail) - self.agents_previous_reset = agents + return self.distance_map + + """ + Reset the distance map + """ + def reset(self, agents: List[EnvAgent], rail: GridTransitionMap): + self.reset_was_called = True + self.agents = agents + self.rail = rail def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap): - self.distance_map_computed = True + self.agents_previous_computation = self.agents self.distance_map = np.inf * np.ones(shape=(len(agents), self.env_height, self.env_width, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 2d358a29..f5881836 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -276,7 +276,7 @@ class RailEnv(Environment): # Reset the state of the observation builder with the new environment self.obs_builder.reset() self.observation_space = self.obs_builder.observation_space # <-- change on reset? - self.distance_map.compute(self.agents, self.rail) + self.distance_map.reset(self.agents, self.rail) # Return the new observation vectors for each agent return self._get_observations() diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 8918a585..4c925789 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -148,7 +148,6 @@ def tests_rail_from_file(): assert agents_initial == agents_loaded # Check that distance map was not recomputed - assert env.distance_map.distance_map_computed is False assert np.shape(env.distance_map.get()) == dist_map_shape assert env.distance_map.get() is not None @@ -222,6 +221,5 @@ def tests_rail_from_file(): assert agents_initial_2 == agents_loaded_4 # Check that distance map was generated with correct shape - assert env4.distance_map.distance_map_computed is True assert env4.distance_map.get() is not None assert np.shape(env4.distance_map.get()) == dist_map_shape -- GitLab