diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index d155042e7c045780872974c86c0655d59178b930..940193198a18a63f669d6a50a04cbd8f67740a32 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -1,5 +1,5 @@ from collections import deque -from typing import List +from typing import List, Optional import numpy as np @@ -13,8 +13,10 @@ class DistanceMap: self.env_height = env_height self.env_width = env_width self.distance_map = None - self.distance_map_computed = False - self.agents_previous_reset = None + self.agents_previous_computation = None + self.reset_was_called = False + self.agents: List[EnvAgent] = agents + self.rail: Optional[GridTransitionMap] = None """ Set the distance map @@ -26,31 +28,39 @@ class DistanceMap: Get the distance map """ def get(self) -> np.ndarray: - return self.distance_map - """ - Compute the distance map - """ - def compute(self, agents: List[EnvAgent], rail: GridTransitionMap): + if self.reset_was_called: + self.reset_was_called = False + + nb_agents = len(self.agents) + compute_distance_map = True + if self.agents_previous_computation is not None and nb_agents == len(self.agents_previous_computation): + compute_distance_map = False + for i in range(nb_agents): + if self.agents[i].target != self.agents_previous_computation[i].target: + compute_distance_map = True + # Don't compute the distance map if it was loaded + if self.agents_previous_computation is None and self.distance_map is not None: + compute_distance_map = False - nb_agents = len(agents) - compute_distance_map = True - if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): - compute_distance_map = False - for i in range(nb_agents): - if agents[i].target != self.agents_previous_reset[i].target: - compute_distance_map = True - # Don't compute the distance map if it was loaded - if self.agents_previous_reset is None and self.distance_map is not None: - compute_distance_map = False + if compute_distance_map: + self._compute(self.agents, self.rail) - if compute_distance_map: - self._compute(agents, rail) + elif self.distance_map is None: + self._compute(self.agents, self.rail) - self.agents_previous_reset = agents + return self.distance_map + + """ + Reset the distance map + """ + def reset(self, agents: List[EnvAgent], rail: GridTransitionMap): + self.reset_was_called = True + self.agents = agents + self.rail = rail def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap): - self.distance_map_computed = True + self.agents_previous_computation = self.agents self.distance_map = np.inf * np.ones(shape=(len(agents), self.env_height, self.env_width, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 2d358a29fa322539ed408c0a5c84cc3c6da986ac..f5881836072ec09a2fe4ee9a70482566867362c4 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -276,7 +276,7 @@ class RailEnv(Environment): # Reset the state of the observation builder with the new environment self.obs_builder.reset() self.observation_space = self.obs_builder.observation_space # <-- change on reset? - self.distance_map.compute(self.agents, self.rail) + self.distance_map.reset(self.agents, self.rail) # Return the new observation vectors for each agent return self._get_observations() diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 8918a585e64e073a1af3efe6b1cc45ea2818d102..4c925789e6560077d637e2a594c736df8850d00a 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -148,7 +148,6 @@ def tests_rail_from_file(): assert agents_initial == agents_loaded # Check that distance map was not recomputed - assert env.distance_map.distance_map_computed is False assert np.shape(env.distance_map.get()) == dist_map_shape assert env.distance_map.get() is not None @@ -222,6 +221,5 @@ def tests_rail_from_file(): assert agents_initial_2 == agents_loaded_4 # Check that distance map was generated with correct shape - assert env4.distance_map.distance_map_computed is True assert env4.distance_map.get() is not None assert np.shape(env4.distance_map.get()) == dist_map_shape