diff --git a/changelog.md b/changelog.md index 91e1ab6181ad1c7cd5b28e9ee06711ecd837324a..2b6123322f06ac13d5a8c503de7fdb5010714397 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,12 @@ Changelog Changes since Flatland 0.3 -------------------------- +### Changes in `Environment` +- moving of member variable `distance_map_computed` to new class `DistanceMap` + +### Changes in rail generator and `RailEnv` +- renaming of `distance_maps` into `distance_map` + ### Changes in stock predictors The stock `ShortestPathPredictorForRailEnv` now respects the different agent speeds and updates their prediction accordingly. @@ -25,12 +31,12 @@ The stock `ShortestPathPredictorForRailEnv` now respects the different agent spe ### Changes in level generation -- Separation of `schedule_generator` from `rail_generator`: +- Separation of `schedule_generator` from `rail_generator`: - Renaming of `flatland/envs/generators.py` to `flatland/envs/rail_generators.py` - `rail_generator` now only returns the grid and optionally hints (a python dictionary); the hints are currently use for distance_map and communication of start and goal position in complex rail generator. - `schedule_generator` takes a `GridTransitionMap` and the number of agents and optionally the `agents_hints` field of the hints dictionary. - - Inrodcution of types hints: -``` + - Inrodcution of types hints: +``` RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] AgentPosition = Tuple[int, int] diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index 278202067d822037861f7fea41603b476c127e78..d155042e7c045780872974c86c0655d59178b930 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -12,28 +12,44 @@ class DistanceMap: def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int): self.env_height = env_height self.env_width = env_width - self.distance_map = np.inf * np.ones(shape=(len(agents), - self.env_height, - self.env_width, - 4)) + self.distance_map = None self.distance_map_computed = False + self.agents_previous_reset = None """ Set the distance map """ - def set(self, distance_map: np.array): + def set(self, distance_map: np.ndarray): self.distance_map = distance_map """ Get the distance map """ - def get(self) -> np.array: + def get(self) -> np.ndarray: return self.distance_map """ Compute the distance map """ def compute(self, agents: List[EnvAgent], rail: GridTransitionMap): + + nb_agents = len(agents) + compute_distance_map = True + if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): + compute_distance_map = False + for i in range(nb_agents): + if agents[i].target != self.agents_previous_reset[i].target: + compute_distance_map = True + # Don't compute the distance map if it was loaded + if self.agents_previous_reset is None and self.distance_map is not None: + compute_distance_map = False + + if compute_distance_map: + self._compute(agents, rail) + + self.agents_previous_reset = agents + + def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap): self.distance_map_computed = True self.distance_map = np.inf * np.ones(shape=(len(agents), self.env_height, diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index ce0ce9b1985c026539f45ed29331bd4b69cb37f7..a833fc01949d4184d5ca2442c6bb429d697318f3 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -36,28 +36,12 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent = {} self.location_has_agent_direction = {} self.predictor = predictor - self.agents_previous_reset = None + self.location_has_target = None self.tree_explored_actions = [1, 2, 3, 0] self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] def reset(self): - agents = self.env.agents - nb_agents = len(agents) - compute_distance_map = True - if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): - compute_distance_map = False - for i in range(nb_agents): - if agents[i].target != self.agents_previous_reset[i].target: - compute_distance_map = True - # Don't compute the distance map if it was loaded - if self.agents_previous_reset is None and self.env.distance_map.get() is not None: - self.location_has_target = {tuple(agent.target): 1 for agent in agents} - compute_distance_map = False - - if compute_distance_map: - self.env.compute_distance_map() - - self.agents_previous_reset = agents + self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents} def get_many(self, handles=None): """ diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 0dfd4535a91899835affb4f00e2444b42a3e77af..2d358a29fa322539ed408c0a5c84cc3c6da986ac 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -276,6 +276,7 @@ class RailEnv(Environment): # Reset the state of the observation builder with the new environment self.obs_builder.reset() self.observation_space = self.obs_builder.observation_space # <-- change on reset? + self.distance_map.compute(self.agents, self.rail) # Return the new observation vectors for each agent return self._get_observations() @@ -625,9 +626,3 @@ class RailEnv(Environment): from importlib_resources import read_binary load_data = read_binary(package, resource) self.set_full_state_msg(load_data) - - def compute_distance_map(self): - self.distance_map.compute(self.agents, self.rail) - # Update local lookup table for all agents' target locations - self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in self.agents} -