diff --git a/docs/intro_observationbuilder.rst b/docs/intro_observationbuilder.rst index c94a2ce18e80c9b6fbdf34b824dfc2062374c82d..563dd113e9df748a99d41dacdef19193dc0f1c01 100644 --- a/docs/intro_observationbuilder.rst +++ b/docs/intro_observationbuilder.rst @@ -63,7 +63,7 @@ cells and orientations to the target of each agent, i.e. a distance map for each In this example we exploit these distance maps by implementing an observation builder that shows the current shortest path for each agent as a one-hot observation vector of length 3, whose components represent the possible directions an agent can take (LEFT, FORWARD, RIGHT). All values of the observation vector are set to :code:`0` except for the shortest direction where it is set to :code:`1`. -Using this observation with highly engineered features indicating the agent's shortest path, an agent can then learn to take the corresponding action at each time-step; or we could even hardcode the optimal policy. +Using this observation with highly engineered features indicating the agent's shortest path, an agent can then learn to take the corresponding action at each time-step; or we could even hardcode the optimal policy. Note that this simple strategy fails when multiple agents are present, as each agent would only attempt its greedy solution, which is not usually `Pareto-optimal <https://en.wikipedia.org/wiki/Pareto_efficiency>`_ in this context. .. _TreeObsForRailEnv: https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14 @@ -71,7 +71,7 @@ Note that this simple strategy fails when multiple agents are present, as each a .. code-block:: python from flatland.envs.observations import TreeObsForRailEnv - + class SingleAgentNavigationObs(TreeObsForRailEnv): """ We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute @@ -84,7 +84,7 @@ Note that this simple strategy fails when multiple agents are present, as each a """ def __init__(self): super().__init__(max_depth=0) - # We set max_depth=0 in because we only need to look at the current + # We set max_depth=0 in because we only need to look at the current # position of the agent to decide what direction is shortest. self.observation_space = [3] @@ -110,7 +110,7 @@ Note that this simple strategy fails when multiple agents are present, as each a for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: new_position = self._new_position(agent.position, direction) - min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction]) + min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) @@ -175,28 +175,28 @@ In contrast to the previous examples we also implement the :code:`def get_many(s .. _example: https://gitlab.aicrowd.com/flatland/flatland/blob/master/examples/custom_observation_example.py#L110 .. code-block:: python - + class ObservePredictions(TreeObsForRailEnv): """ We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation. - + We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute the minimum distances from each grid node to each agent's target. - + This is necessary so that we can pass the distance map to the ShortestPathPredictor - + Here we also want to highlight how you can visualize your observation """ - + def __init__(self, predictor): super().__init__(max_depth=0) self.observation_space = [10] self.predictor = predictor - + def reset(self): # Recompute the distance map, if the environment has changed. super().reset() - + def get_many(self, handles=None): ''' Because we do not want to call the predictor seperately for every agent we implement the get_many function @@ -204,9 +204,9 @@ In contrast to the previous examples we also implement the :code:`def get_many(s :param handles: :return: ''' - - self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map}) - + + self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map}) + self.predicted_pos = {} for t in range(len(self.predictions[0])): pos_list = [] @@ -215,47 +215,47 @@ In contrast to the previous examples we also implement the :code:`def get_many(s # We transform (x,y) coodrinates to a single integer number for simpler comparison self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) observations = {} - + # Collect all the different observation for all the agents for h in handles: observations[h] = self.get(h) return observations - + def get(self, handle): ''' Lets write a simple observation which just indicates whether or not the own predicted path overlaps with other predicted paths at any time. This is useless for the task of navigation but might help when looking for conflicts. A more complex implementation can be found in the TreeObsForRailEnv class - + Each agent recieves an observation of length 10, where each element represents a prediction step and its value is: - 0 if no overlap is happening - 1 where n i the number of other paths crossing the predicted cell - + :param handle: handeled as an index of an agent :return: Observation of handle ''' - + observation = np.zeros(10) - + # We are going to track what cells where considered while building the obervation and make them accesible # For rendering - + visited = set() for _idx in range(10): # Check if any of the other prediction overlap with agents own predictions x_coord = self.predictions[handle][_idx][1] y_coord = self.predictions[handle][_idx][2] - + # We add every observed cell to the observation rendering visited.add((x_coord, y_coord)) if self.predicted_pos[_idx][handle] in np.delete(self.predicted_pos[_idx], handle, 0): # We detect if another agent is predicting to pass through the same cell at the same predicted time observation[handle] = 1 - + # This variable will be access by the renderer to visualize the observation self.env.dev_obs_dict[handle] = visited - + return observation We can then use this new observation builder and the renderer to visualize the observation of each agent. @@ -265,23 +265,23 @@ We can then use this new observation builder and the renderer to visualize the o # Initiate the Predictor CustomPredictor = ShortestPathPredictorForRailEnv(10) - + # Pass the Predictor to the observation builder CustomObsBuilder = ObservePredictions(CustomPredictor) - + # Initiate Environment env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0), number_of_agents=3, obs_builder_object=CustomObsBuilder) - + obs = env.reset() env_renderer = RenderTool(env, gl="PILSVG") - + # We render the initial step and show the obsered cells as colored boxes env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False) - + action_dict = {} for step in range(100): for a in range(env.get_num_agents()): @@ -321,7 +321,7 @@ These two objects can be used for example to detect switches that are usable by cell_transitions = self.env.rail.get_transitions(*position, direction) transition_bit = bin(self.env.rail.get_full_transitions(*position)) - + total_transitions = transition_bit.count("1") num_transitions = np.count_nonzero(cell_transitions) @@ -357,7 +357,7 @@ Beyond the basic agent information we can also access more details about the age Similar to the speed data you can also access individual data about the malfunctions of an agent. All data is available through :code:`agent.malfunction_data` with: -- Indication how long the agent is still malfunctioning :code:`'malfunction'` by an integer counting down at each time step. 0 means the agent is ok and can move. +- Indication how long the agent is still malfunctioning :code:`'malfunction'` by an integer counting down at each time step. 0 means the agent is ok and can move. - Possion rate at which malfunctions happen for this agent :code:`'malfunction_rate'` - Number of steps untill next malfunction will occur :code:`'next_malfunction'` - Number of malfunctions an agent have occured for this agent so far :code:`nr_malfunctions'` diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 8b1de6aa4e303469d30983d30333fbfda89c1d1e..03b0ff17f23d4a1046237082bfdcb946de2f87fb 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -79,8 +79,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = self._new_position(agent.position, direction) - min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction]) + new_position = self.new_position(agent.position, direction) + min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) @@ -140,7 +140,7 @@ class ObservePredictions(TreeObsForRailEnv): :return: ''' - self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map}) + self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map}) self.predicted_pos = {} for t in range(len(self.predictions[0])): diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 50ea74b84ac9851e88e48bcd32b914e69bc7dd34..e3683d893f4feb9979d86f3ca3507100d86d1813 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -47,8 +47,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = self._new_position(agent.position, direction) - min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction]) + new_position = self.new_position(agent.position, direction) + min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/flatland/core/env.py b/flatland/core/env.py index 1bc5b6f3eba4ee4713bd3c8d6b88440006c215a5..f1f1b270820b5ccaa9e0644eede703a62a40aad9 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -45,6 +45,8 @@ class Environment: def __init__(self): self.action_space = () self.observation_space = () + self.distance_map_computed = False + self.distance_map = None pass def reset(self): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 706e9bdeaf04d60fe076af6f50cd07077548a32f..f479e07890fc9c53538cc4f95acac9559202354f 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -2,7 +2,6 @@ Collection of environment-specific ObservationBuilder. """ import pprint -from collections import deque import numpy as np @@ -39,8 +38,6 @@ class TreeObsForRailEnv(ObservationBuilder): self.agents_previous_reset = None self.tree_explored_actions = [1, 2, 3, 0] self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] - self.distance_map = None - self.distance_map_computed = False def reset(self): agents = self.env.agents @@ -52,110 +49,16 @@ class TreeObsForRailEnv(ObservationBuilder): if agents[i].target != self.agents_previous_reset[i].target: compute_distance_map = True # Don't compute the distance map if it was loaded - if self.agents_previous_reset is None and self.distance_map is not None: + if self.agents_previous_reset is None and self.env.distance_map is not None: self.location_has_target = {tuple(agent.target): 1 for agent in agents} compute_distance_map = False if compute_distance_map: - self._compute_distance_map() + self.env.compute_distance_map() self.agents_previous_reset = agents - def _compute_distance_map(self): - agents = self.env.agents - # For testing only --> To assert if a distance map need to be recomputed. - self.distance_map_computed = True - nb_agents = len(agents) - self.distance_map = np.inf * np.ones(shape=(nb_agents, - self.env.height, - self.env.width, - 4)) - self.max_dist = np.zeros(nb_agents) - self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] - # Update local lookup table for all agents' target locations - self.location_has_target = {tuple(agent.target): 1 for agent in agents} - - def _distance_map_walker(self, position, target_nr): - """ - Utility function to compute distance maps from each cell in the rail network (and each possible - orientation within it) to each agent's target cell. - """ - # Returns max distance to target, from the farthest away node, while filling in distance_map - self.distance_map[target_nr, position[0], position[1], :] = 0 - - # Fill in the (up to) 4 neighboring nodes - # direction is the direction of movement, meaning that at least a possible orientation of an agent - # in cell (row,col) allows a movement in direction `direction' - nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) - - # BFS from target `position' to all the reachable nodes in the grid - # Stop the search if the target position is re-visited, in any direction - visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), - (position[0], position[1], 3)} - - max_distance = 0 - - while nodes_queue: - node = nodes_queue.popleft() - - node_id = (node[0], node[1], node[2]) - - if node_id not in visited: - visited.add(node_id) - - # From the list of possible neighbors that have at least a path to the current node, only keep those - # whose new orientation in the current cell would allow a transition to direction node[2] - valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) - - for n in valid_neighbors: - nodes_queue.append(n) - - if len(valid_neighbors) > 0: - max_distance = max(max_distance, node[3] + 1) - - return max_distance - - def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): - """ - Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the - minimum distances from each target cell. - """ - neighbors = [] - - possible_directions = [0, 1, 2, 3] - if enforce_target_direction >= 0: - # The agent must land into the current cell with orientation `enforce_target_direction'. - # This is only possible if the agent has arrived from the cell in the opposite direction! - possible_directions = [(enforce_target_direction + 2) % 4] - - for neigh_direction in possible_directions: - new_cell = self._new_position(position, neigh_direction) - - if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: - - desired_movement_from_new_cell = (neigh_direction + 2) % 4 - - # Check all possible transitions in new_cell - for agent_orientation in range(4): - # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? - is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), - desired_movement_from_new_cell) - - if is_valid: - """ - # TODO: check that it works with deadends! -- still bugged! - movement = desired_movement_from_new_cell - if isNextCellDeadEnd: - movement = (desired_movement_from_new_cell+2) % 4 - """ - new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], - current_distance + 1) - neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) - self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance - - return neighbors - - def _new_position(self, position, movement): + def new_position(self, position, movement): """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ @@ -180,7 +83,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.max_prediction_depth = 0 self.predicted_pos = {} self.predicted_dir = {} - self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map}) + self.predictions = self.predictor.get(custom_args={'distance_map': self.env.distance_map}) if self.predictions: for t in range(len(self.predictions[0])): @@ -276,7 +179,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Root node - current position # Here information about the agent itself is stored - observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, + observation = [0, 0, 0, 0, 0, 0, self.env.distance_map[(handle, *agent.position, agent.direction)], 0, 0, agent.malfunction_data['malfunction'], agent.speed_data['speed']] visited = set() @@ -291,7 +194,7 @@ class TreeObsForRailEnv(ObservationBuilder): for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: - new_cell = self._new_position(agent.position, branch_direction) + new_cell = self.new_position(agent.position, branch_direction) branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) observation = observation + branch_observation @@ -464,7 +367,7 @@ class TreeObsForRailEnv(ObservationBuilder): exploring = True # convert one-hot encoding to 0,1,2,3 direction = np.argmax(cell_transitions) - position = self._new_position(position, direction) + position = self.new_position(position, direction) num_steps += 1 tot_dist += 1 elif num_transitions > 0: @@ -506,7 +409,7 @@ class TreeObsForRailEnv(ObservationBuilder): potential_conflict, unusable_switch, np.inf, - self.distance_map[handle, position[0], position[1], direction], + self.env.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, other_agent_opposite_direction, malfunctioning_agent, @@ -520,7 +423,7 @@ class TreeObsForRailEnv(ObservationBuilder): potential_conflict, unusable_switch, tot_dist, - self.distance_map[handle, position[0], position[1], direction], + self.env.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, other_agent_opposite_direction, malfunctioning_agent, @@ -537,7 +440,7 @@ class TreeObsForRailEnv(ObservationBuilder): (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back - new_cell = self._new_position(position, (branch_direction + 2) % 4) + new_cell = self.new_position(position, (branch_direction + 2) % 4) branch_observation, branch_visited = self._explore_branch(handle, new_cell, (branch_direction + 2) % 4, @@ -547,7 +450,7 @@ class TreeObsForRailEnv(ObservationBuilder): if len(branch_visited) != 0: visited = visited.union(branch_visited) elif last_is_switch and possible_transitions[branch_direction]: - new_cell = self._new_position(position, branch_direction) + new_cell = self.new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index e4d693065aec0307bee8eaf3acd1c9e9e9df0e93..10770210b073bf86e1c813895948be3354bb291a 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -5,6 +5,7 @@ Definition of the RailEnv environment. import warnings from enum import IntEnum from typing import List +from collections import deque import msgpack import msgpack_numpy as m @@ -143,6 +144,7 @@ class RailEnv(Environment): file_name: you can load a pickle file. from previously saved *.pkl file """ + super().__init__() self.rail_generator: RailGenerator = rail_generator self.schedule_generator: ScheduleGenerator = schedule_generator @@ -233,7 +235,7 @@ class RailEnv(Environment): rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) if optionals and 'distance_maps' in optionals: - self.obs_builder.distance_map = optionals['distance_maps'] + self.distance_map = optionals['distance_maps'] if regen_rail or self.rail is None: self.rail = rail @@ -573,8 +575,8 @@ class RailEnv(Environment): # agents are always reset as not moving self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] - if hasattr(self.obs_builder, 'distance_map') and "distance_maps" in data.keys(): - self.obs_builder.distance_map = data["distance_maps"] + if hasattr(self, 'distance_map') and "distance_maps" in data.keys(): + self.distance_map = data["distance_maps"] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -588,8 +590,8 @@ class RailEnv(Environment): msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True) msgpack.packb(agent_static_data, use_bin_type=True) - if hasattr(self.obs_builder, 'distance_map'): - distance_map_data = self.obs_builder.distance_map + if hasattr(self, 'distance_map'): + distance_map_data = self.distance_map msgpack.packb(distance_map_data, use_bin_type=True) msg_data = { "grid": grid_data, @@ -605,8 +607,8 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) def save(self, filename): - if hasattr(self.obs_builder, 'distance_map'): - if len(self.obs_builder.distance_map) > 0: + if hasattr(self, 'distance_map') and self.distance_map is not None: + if len(self.distance_map) > 0: with open(filename, "wb") as file_out: file_out.write(self.get_full_state_dist_msg()) else: @@ -617,7 +619,7 @@ class RailEnv(Environment): file_out.write(self.get_full_state_msg()) def load(self, filename): - if hasattr(self.obs_builder, 'distance_map'): + if hasattr(self, 'distance_map'): with open(filename, "rb") as file_in: load_data = file_in.read() self.set_full_state_dist_msg(load_data) @@ -633,3 +635,97 @@ class RailEnv(Environment): from importlib_resources import read_binary load_data = read_binary(package, resource) self.set_full_state_msg(load_data) + + def compute_distance_map(self): + agents = self.agents + # For testing only --> To assert if a distance map need to be recomputed. + self.distance_map_computed = True + nb_agents = len(agents) + self.distance_map = np.inf * np.ones(shape=(nb_agents, + self.height, + self.width, + 4)) + max_dist = np.zeros(nb_agents) + max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] + # Update local lookup table for all agents' target locations + self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in agents} + + def _distance_map_walker(self, position, target_nr): + """ + Utility function to compute distance maps from each cell in the rail network (and each possible + orientation within it) to each agent's target cell. + """ + # Returns max distance to target, from the farthest away node, while filling in distance_map + self.distance_map[target_nr, position[0], position[1], :] = 0 + + # Fill in the (up to) 4 neighboring nodes + # direction is the direction of movement, meaning that at least a possible orientation of an agent + # in cell (row,col) allows a movement in direction `direction' + nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) + + # BFS from target `position' to all the reachable nodes in the grid + # Stop the search if the target position is re-visited, in any direction + visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), + (position[0], position[1], 3)} + + max_distance = 0 + + while nodes_queue: + node = nodes_queue.popleft() + + node_id = (node[0], node[1], node[2]) + + if node_id not in visited: + visited.add(node_id) + + # From the list of possible neighbors that have at least a path to the current node, only keep those + # whose new orientation in the current cell would allow a transition to direction node[2] + valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) + + for n in valid_neighbors: + nodes_queue.append(n) + + if len(valid_neighbors) > 0: + max_distance = max(max_distance, node[3] + 1) + + return max_distance + + def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): + """ + Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the + minimum distances from each target cell. + """ + neighbors = [] + + possible_directions = [0, 1, 2, 3] + if enforce_target_direction >= 0: + # The agent must land into the current cell with orientation `enforce_target_direction'. + # This is only possible if the agent has arrived from the cell in the opposite direction! + possible_directions = [(enforce_target_direction + 2) % 4] + + for neigh_direction in possible_directions: + new_cell = self.obs_builder.new_position(position, neigh_direction) + + if new_cell[0] >= 0 and new_cell[0] < self.height and new_cell[1] >= 0 and new_cell[1] < self.width: + + desired_movement_from_new_cell = (neigh_direction + 2) % 4 + + # Check all possible transitions in new_cell + for agent_orientation in range(4): + # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? + is_valid = self.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), + desired_movement_from_new_cell) + + if is_valid: + """ + # TODO: check that it works with deadends! -- still bugged! + movement = desired_movement_from_new_cell + if isNextCellDeadEnd: + movement = (desired_movement_from_new_cell+2) % 4 + """ + new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], + current_distance + 1) + neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) + self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance + + return neighbors diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index e5e89f76428bb881d0f72aa60aada97ab02167a5..2653aaec9db483903988a09aa4f69314ba7967c4 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -43,9 +43,8 @@ def test_walker(): # reset to set agents from agents_static env.reset(False, False) - obs_builder: TreeObsForRailEnv = env.obs_builder - print(obs_builder.distance_map[(0, *[0, 1], 1)]) - assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3 - print(obs_builder.distance_map[(0, *[0, 2], 3)]) - assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2 + print(env.distance_map[(0, *[0, 1], 1)]) + assert env.distance_map[(0, *[0, 1], 1)] == 3 + print(env.distance_map[(0, *[0, 2], 3)]) + assert env.distance_map[(0, *[0, 2], 1)] == 2 diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index c96e8db00fe721f42667aed4833d034a47f19156..eb056012c59f3e136b583911f86b3952d4435eae 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -51,7 +51,7 @@ def _step_along_shortest_path(env, obs_builder, rail): shortest_distance = np.inf for exit_direction in range(4): - neighbour = obs_builder._new_position(agent.position, exit_direction) + neighbour = obs_builder.new_position(agent.position, exit_direction) if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width: desired_movement_from_new_cell = (exit_direction + 2) % 4 @@ -62,7 +62,7 @@ def _step_along_shortest_path(env, obs_builder, rail): is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation), desired_movement_from_new_cell) if is_valid: - distance_to_target = obs_builder.distance_map[ + distance_to_target = obs_builder.env.distance_map[ (agent.handle, *agent.position, exit_direction)] print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position, agent.direction, diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 09f7e5e67a15c55b5070ac8679e43ecc9a14b9da..0221bf6d31c9505607250e43ca49c4e34787ed70 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -137,7 +137,7 @@ def test_shortest_path_predictor(rendering=False): input("Continue?") # compute the observations and predictions - distance_map = env.obs_builder.distance_map + distance_map = env.distance_map assert distance_map[0, agent.position[0], agent.position[ 1], agent.direction] == 5.0, "found {} instead of {}".format( distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index a63e97229457d606027c88e5189d0cb680f25c9b..a0f97c3e7d3248b22cbc5228ddf49417b146bc67 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -40,8 +40,8 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = self._new_position(agent.position, direction) - min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction]) + new_position = self.new_position(agent.position, direction) + min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 8b0480c887a53ade155c28aa6199db3d32f19603..43e6e720b0b993b17e11877d57478b3dbdeee6a0 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -129,7 +129,7 @@ def tests_rail_from_file(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.save(file_name) - dist_map_shape = np.shape(env.obs_builder.distance_map) + dist_map_shape = np.shape(env.distance_map) # initialize agents_static rails_initial = env.rail.grid agents_initial = env.agents @@ -148,9 +148,9 @@ def tests_rail_from_file(): assert agents_initial == agents_loaded # Check that distance map was not recomputed - assert env.obs_builder.distance_map_computed is False - assert np.shape(env.obs_builder.distance_map) == dist_map_shape - assert env.obs_builder.distance_map is not None + assert env.distance_map_computed is False + assert np.shape(env.distance_map) == dist_map_shape + assert env.distance_map is not None # Test to save and load file without distance map. @@ -222,6 +222,6 @@ def tests_rail_from_file(): assert agents_initial_2 == agents_loaded_4 # Check that distance map was generated with correct shape - assert env4.obs_builder.distance_map_computed is True - assert env4.obs_builder.distance_map is not None - assert np.shape(env4.obs_builder.distance_map) == dist_map_shape + assert env4.distance_map_computed is True + assert env4.distance_map is not None + assert np.shape(env4.distance_map) == dist_map_shape