Skip to content
Snippets Groups Projects
Commit cff2a8de authored by u229589's avatar u229589
Browse files

Refactoring: compute distance_map directly from RailEnv.reset() and and not via obs_builder.reset()

parent 8137be64
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,12 @@ Changelog
Changes since Flatland 0.3
--------------------------
### Changes in `Environment`
- moving of member variable `distance_map_computed` to new class `DistanceMap`
### Changes in rail generator and `RailEnv`
- renaming of `distance_maps` into `distance_map`
### Changes in stock predictors
The stock `ShortestPathPredictorForRailEnv` now respects the different agent speeds and updates their prediction accordingly.
......@@ -25,12 +31,12 @@ The stock `ShortestPathPredictorForRailEnv` now respects the different agent spe
### Changes in level generation
- Separation of `schedule_generator` from `rail_generator`:
- Separation of `schedule_generator` from `rail_generator`:
- Renaming of `flatland/envs/generators.py` to `flatland/envs/rail_generators.py`
- `rail_generator` now only returns the grid and optionally hints (a python dictionary); the hints are currently use for distance_map and communication of start and goal position in complex rail generator.
- `schedule_generator` takes a `GridTransitionMap` and the number of agents and optionally the `agents_hints` field of the hints dictionary.
- Inrodcution of types hints:
```
- Inrodcution of types hints:
```
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
AgentPosition = Tuple[int, int]
......
......@@ -12,28 +12,44 @@ class DistanceMap:
def __init__(self, agents: List[EnvAgent], env_height: int, env_width: int):
self.env_height = env_height
self.env_width = env_width
self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height,
self.env_width,
4))
self.distance_map = None
self.distance_map_computed = False
self.agents_previous_reset = None
"""
Set the distance map
"""
def set(self, distance_map: np.array):
def set(self, distance_map: np.ndarray):
self.distance_map = distance_map
"""
Get the distance map
"""
def get(self) -> np.array:
def get(self) -> np.ndarray:
return self.distance_map
"""
Compute the distance map
"""
def compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
nb_agents = len(agents)
compute_distance_map = True
if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
compute_distance_map = False
for i in range(nb_agents):
if agents[i].target != self.agents_previous_reset[i].target:
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_reset is None and self.distance_map is not None:
compute_distance_map = False
if compute_distance_map:
self._compute(agents, rail)
self.agents_previous_reset = agents
def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
self.distance_map_computed = True
self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height,
......
......@@ -36,28 +36,12 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.predictor = predictor
self.agents_previous_reset = None
self.location_has_target = None
self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
def reset(self):
agents = self.env.agents
nb_agents = len(agents)
compute_distance_map = True
if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
compute_distance_map = False
for i in range(nb_agents):
if agents[i].target != self.agents_previous_reset[i].target:
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_reset is None and self.env.distance_map.get() is not None:
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
compute_distance_map = False
if compute_distance_map:
self.env.compute_distance_map()
self.agents_previous_reset = agents
self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
def get_many(self, handles=None):
"""
......
......@@ -276,6 +276,7 @@ class RailEnv(Environment):
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
self.observation_space = self.obs_builder.observation_space # <-- change on reset?
self.distance_map.compute(self.agents, self.rail)
# Return the new observation vectors for each agent
return self._get_observations()
......@@ -625,9 +626,3 @@ class RailEnv(Environment):
from importlib_resources import read_binary
load_data = read_binary(package, resource)
self.set_full_state_msg(load_data)
def compute_distance_map(self):
self.distance_map.compute(self.agents, self.rail)
# Update local lookup table for all agents' target locations
self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in self.agents}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment