From 1588496816810b7563aba21669688445efaa866c Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Mon, 17 Jun 2019 14:42:59 +0200 Subject: [PATCH] 66 shortest-path-predictor: cleanup and unit test --- flatland/core/env_prediction_builder.py | 8 +- flatland/core/transitions.py | 28 ++++ flatland/envs/observations.py | 28 ++-- flatland/envs/predictions.py | 59 ++++---- tests/test_env_observation_builder.py | 8 -- tests/test_env_prediction_builder.py | 179 +++++++++++++++++++++--- tests/test_environments.py | 5 - tests/test_player.py | 4 - 8 files changed, 244 insertions(+), 75 deletions(-) diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py index 060dbfc3..5ce69a81 100644 --- a/flatland/core/env_prediction_builder.py +++ b/flatland/core/env_prediction_builder.py @@ -13,6 +13,7 @@ case of multi-agent environments. class PredictionBuilder: """ PredictionBuilder base class. + """ def __init__(self, max_depth: int = 20): @@ -27,12 +28,15 @@ class PredictionBuilder: """ pass - def get(self, handle=0): + def get(self, custom_args=None, handle=0): """ - Called whenever predict is called on the environment. + Called whenever get_many in the observation build is called. Parameters ------- + custom_args: dict + Implementation-dependent custom arguments, see the sub-classes. + handle : int (optional) Handle of the agent for which to compute the observation vector. diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index a6d1bb07..6c38a39c 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -3,6 +3,7 @@ The transitions module defines the base Transitions class and a derived GridTransitions class, which allows for the specification of possible transitions over a 2D grid. """ +from enum import IntEnum import numpy as np @@ -129,6 +130,16 @@ class Transitions: """ raise NotImplementedError() + def get_direction_enum(self) -> IntEnum: + raise NotImplementedError() + + +class Grid4TransitionsEnum(IntEnum): + NORTH = 0 + EAST = 1 + SOUTH = 2 + WEST = 3 + class Grid4Transitions(Transitions): """ @@ -323,6 +334,20 @@ class Grid4Transitions(Transitions): cell_transition = value return cell_transition + def get_direction_enum(self) -> IntEnum: + return Grid4TransitionsEnum + + +class Grid8TransitionsEnum(IntEnum): + NORTH = 0 + NORTH_EAST = 1 + EAST = 2 + SOUTH_EAST = 3 + SOUTH = 4 + SOUTH_WEST = 5 + WEST = 6 + NORTH_WEST = 7 + class Grid8Transitions(Transitions): """ @@ -504,6 +529,9 @@ class Grid8Transitions(Transitions): return cell_transition + def get_direction_enum(self) -> IntEnum: + return Grid8TransitionsEnum + class RailEnvTransitions(Grid4Transitions): """ diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 4b0049f6..d7fdcee7 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -48,16 +48,19 @@ class TreeObsForRailEnv(ObservationBuilder): self.agents_previous_reset = agents if compute_distance_map: - self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, - self.env.height, - self.env.width, - 4)) - self.max_dist = np.zeros(nAgents) + self._compute_distance_map() - self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] - - # Update local lookup table for all agents' target locations - self.location_has_target = {tuple(agent.target): 1 for agent in agents} + def _compute_distance_map(self): + agents = self.env.agents + nAgents = len(agents) + self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, + self.env.height, + self.env.width, + 4)) + self.max_dist = np.zeros(nAgents) + self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] + # Update local lookup table for all agents' target locations + self.location_has_target = {tuple(agent.target): 1 for agent in agents} def _distance_map_walker(self, position, target_nr): """ @@ -177,7 +180,7 @@ class TreeObsForRailEnv(ObservationBuilder): if self.predictor: self.predicted_pos = {} self.predicted_dir = {} - self.predictions = self.predictor.get(self.distance_map) + self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map}) for t in range(len(self.predictions[0])): pos_list = [] dir_list = [] @@ -796,8 +799,3 @@ class LocalObsForRailEnv(ObservationBuilder): direction = self._get_one_hot_for_agent_direction(agent) return local_rail_obs, obs_map_state, obs_other_agents_state, direction - -# class LocalObsForRailEnvImproved(ObservationBuilder): -# """ -# Returns a local observation around the given agent -# """ diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index b6fe8631..3910fa1b 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -16,24 +16,28 @@ class DummyPredictorForRailEnv(PredictionBuilder): The prediction acts as if no other agent is in the environment and always takes the forward action. """ - def get(self, distancemap, handle=None): + def get(self, custom_args=None, handle=None): """ - Called whenever predict is called on the environment. + Called whenever get_many in the observation build is called. Parameters ------- + custom_args: dict + Not used in this dummy implementation. handle : int (optional) Handle of the agent for which to compute the observation vector. Returns ------- - function - Returns a dictionary index by the agent handle and for each agent a vector of 5 elements: + np.array + Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1) x 5 elements: - time_offset - position axis 0 - position axis 1 - direction - action taken to come here + The prediction at 0 is the current position, direction etc. + """ agents = self.env.agents if handle: @@ -46,12 +50,12 @@ class DummyPredictorForRailEnv(PredictionBuilder): _agent_initial_position = agent.position _agent_initial_direction = agent.direction prediction = np.zeros(shape=(self.max_depth + 1, 5)) - prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] + prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] for index in range(1, self.max_depth + 1): action_done = False # if we're at the target, stop moving... if agent.position == agent.target: - prediction[index] = [index, agent.target[0], agent.target[1], agent.direction, + prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] continue @@ -63,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): # performed agent.position = new_position agent.direction = new_direction - prediction[index] = [index, new_position[0], new_position[1], new_direction, action] + prediction[index] = [index, *new_position, new_direction, action] action_done = True break if not action_done: @@ -76,51 +80,55 @@ class DummyPredictorForRailEnv(PredictionBuilder): class ShortestPathPredictorForRailEnv(PredictionBuilder): """ - DummyPredictorForRailEnv object. + ShortestPathPredictorForRailEnv object. - This object returns predictions for agents in the RailEnv environment. + This object returns shortest-path predictions for agents in the RailEnv environment. The prediction acts as if no other agent is in the environment and always takes the forward action. """ - def get(self, distancemap, handle=None): + def get(self, custom_args=None, handle=None): """ - Called whenever predict is called on the environment. + Called whenever get_many in the observation build is called. + Requires distance_map to extract the shortest path. Parameters ------- + custom_args: dict + - distance_map : dict handle : int (optional) Handle of the agent for which to compute the observation vector. Returns ------- - function - Returns a dictionary index by the agent handle and for each agent a vector of 5 elements: + np.array + Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1) x 5 elements: - time_offset - position axis 0 - position axis 1 - direction - action taken to come here + The prediction at 0 is the current position, direction etc. """ agents = self.env.agents if handle: agents = [self.env.agents[handle]] + assert custom_args + distance_map = custom_args.get('distance_map') + assert distance_map is not None prediction_dict = {} - agent_idx = 0 for agent in agents: _agent_initial_position = agent.position _agent_initial_direction = agent.direction prediction = np.zeros(shape=(self.max_depth + 1, 5)) - prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] + prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] for index in range(1, self.max_depth + 1): # if we're at the target, stop moving... if agent.position == agent.target: - prediction[index] = [index, agent.target[0], agent.target[1], agent.direction, - RailEnvActions.STOP_MOVING] + prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] continue if not agent.moving: - prediction[index] = [index, agent.position[0], agent.position[1], agent.direction, - RailEnvActions.STOP_MOVING] + prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING] continue # Take shortest possible path cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) @@ -130,24 +138,25 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): new_position = self._new_position(agent.position, new_direction) elif np.sum(cell_transitions) > 1: min_dist = np.inf - for direct in range(4): - if cell_transitions[direct] == 1: - target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct] + for direction in range(4): + if cell_transitions[direction] == 1: + target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction] if target_dist < min_dist: min_dist = target_dist - new_direction = direct + new_direction = direction new_position = self._new_position(agent.position, new_direction) agent.position = new_position agent.direction = new_direction - prediction[index] = [index, new_position[0], new_position[1], new_direction, 0] + prediction[index] = [index, *new_position, new_direction, RailEnvActions.MOVE_FORWARD] action_done = True if not action_done: raise Exception("Cannot move further. Something is wrong") prediction_dict[agent.handle] = prediction + + # cleanup: reset initial position agent.position = _agent_initial_position agent.direction = _agent_initial_direction - agent_idx += 1 return prediction_dict diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py index 2e86477b..ce224736 100644 --- a/tests/test_env_observation_builder.py +++ b/tests/test_env_observation_builder.py @@ -80,11 +80,3 @@ def test_global_obs(): # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0) - - -def main(): - test_global_obs() - - -if __name__ == "__main__": - main() diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py index 5f5cea35..4d4078c3 100644 --- a/tests/test_env_prediction_builder.py +++ b/tests/test_env_prediction_builder.py @@ -4,15 +4,18 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap, Grid4Transitions +from flatland.core.transitions import Grid4TransitionsEnum from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import DummyPredictorForRailEnv +from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env import RailEnvActions +from flatland.utils.rendertools import RenderTool """Test predictions for `flatland` package.""" -def test_predictions(): +def make_simple_rail(): # We instantiate a very simple rail network on a 7x10 grid: # | # | @@ -22,7 +25,6 @@ def test_predictions(): # | # | # | - cells = [int('0000000000000000', 2), # empty cell - Case 0 int('1000000000100000', 2), # Case 1 - straight int('1001001000100000', 2), # Case 2 - simple switch @@ -31,22 +33,17 @@ def test_predictions(): int('1100110000110011', 2), # Case 5 - double slip switch int('0101001000000010', 2), # Case 6 - symmetrical switch int('0010000000000000', 2)] # Case 7 - dead end - transitions = Grid4Transitions([]) empty = cells[0] - dead_end_from_south = cells[7] dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) - vertical_straight = cells[1] horizontal_straight = transitions.rotate_transition(vertical_straight, 90) - double_switch_south_horizontal_straight = horizontal_straight + cells[6] double_switch_north_horizontal_straight = transitions.rotate_transition( double_switch_south_horizontal_straight, 180) - rail_map = np.array( [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + @@ -56,26 +53,36 @@ def test_predictions(): [horizontal_straight] * 2 + [dead_end_from_west]] + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) - rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map + return rail, rail_map + + +def test_dummy_predictor(rendering=False): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_GridTransitionMap_generator(rail), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) - env.reset() # set initial position and direction for testing... env.agents[0].position = (5, 6) env.agents[0].direction = 0 - env.agents[0].target = (3., 0.) + env.agents[0].target = (3, 0) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.renderEnv(show=True, show_observations=False) + input("Continue?") + # test assertions predictions = env.obs_builder.predictor.get(None) - positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) + positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0]))) @@ -139,9 +146,149 @@ def test_predictions(): assert np.array_equal(actions, expected_actions) -def main(): - test_predictions() +def test_shortest_path_predictor(rendering=False): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + env.reset() + + agent = env.agents[0] + agent.position = (5, 6) # south dead-end + agent.direction = 0 # north + agent.target = (3, 9) # east dead-end + + agent.moving = True + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.renderEnv(show=True, show_observations=False) + # input("Continue?") + + agent = env.agents[0] + assert agent.position == (5, 6) + assert agent.direction == 0 + assert agent.target == (3, 9) + assert agent.moving + + env.obs_builder._compute_distance_map() + + distance_map = env.obs_builder.distance_map + assert distance_map[agent.handle, agent.position[0], agent.position[ + 1], agent.direction] == 5.0, "found {} instead of {}".format( + distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0) + + # test assertions + env.obs_builder.get_many() + predictions = env.obs_builder.predictions + positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0]))) + directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) + time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) + actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0]))) + + expected_positions = [ + [5, 6], + [4, 6], + [3, 6], + [3, 7], + [3, 8], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + [3, 9], + ] + expected_directions = [ + [Grid4TransitionsEnum.NORTH], # next is [5,6] heading north + [Grid4TransitionsEnum.NORTH], # next is [4,6] heading north + [Grid4TransitionsEnum.NORTH], # next is [3,6] heading north + [Grid4TransitionsEnum.EAST], # next is [3,7] heading east + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + [Grid4TransitionsEnum.EAST], + ] + + expected_time_offsets = np.array([ + [0.], + [1.], + [2.], + [3.], + [4.], + [5.], + [6.], + [7.], + [8.], + [9.], + [10.], + [11.], + [12.], + [13.], + [14.], + [15.], + [16.], + [17.], + [18.], + [19.], + [20.], + ]) + expected_actions = np.array([ + [RailEnvActions.DO_NOTHING], # next [5,6] + [RailEnvActions.MOVE_FORWARD], # next [4,6] + [RailEnvActions.MOVE_FORWARD], # next [3,6] + [RailEnvActions.MOVE_RIGHT], # next [3,7] + [RailEnvActions.MOVE_FORWARD], # next [3,8] + [RailEnvActions.MOVE_FORWARD], # next [3,9] + [RailEnvActions.STOP_MOVING], # at [3,9] == target + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + [RailEnvActions.STOP_MOVING], + ]) -if __name__ == "__main__": - main() + assert np.array_equal(positions, expected_positions), \ + "positions {}, expected {}".format(positions, expected_positions) + assert np.array_equal(directions, expected_directions), \ + "directions {}, expected {}".format(directions, expected_directions) + assert np.array_equal(time_offsets, expected_time_offsets), \ + "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets) + assert np.array_equal(actions, expected_actions), \ + "actions {}, expected {}".format(actions, expected_actions) diff --git a/tests/test_environments.py b/tests/test_environments.py index 2131e08b..11f0acba 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -204,8 +204,3 @@ def test_dead_end(): rail_env.reset() rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)] - - -if __name__ == "__main__": - test_rail_environment_single_agent() - test_dead_end() diff --git a/tests/test_player.py b/tests/test_player.py index 21ff62c3..757fc90d 100644 --- a/tests/test_player.py +++ b/tests/test_player.py @@ -4,7 +4,3 @@ from examples.play_model import main def test_main(): main(render=True, n_steps=20, n_trials=2, sGL="PIL") main(render=True, n_steps=20, n_trials=2, sGL="PILSVG") - - -if __name__ == "__main__": - test_main() -- GitLab