From c7d695c4701d0230335cf0d1e2ff31bf62dbf2f2 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Fri, 5 Jul 2019 13:13:43 +0200 Subject: [PATCH] unit test for conflicts of multiple agents --- examples/simple_example_3.py | 2 +- flatland/core/grid/grid4.py | 7 ++ flatland/envs/observations.py | 122 +++++++------------- flatland/envs/rail_env.py | 12 +- tests/simple_rail.py | 48 ++++++++ tests/test_flatland_envs_observations.py | 49 +------- tests/test_flatland_envs_predictions.py | 138 +++++++++++++---------- 7 files changed, 188 insertions(+), 190 deletions(-) create mode 100644 tests/simple_rail.py diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 9055dd4c..853d5f5e 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -19,7 +19,7 @@ env = RailEnv(width=7, # Print the observation vector for agent 0 obs, all_rewards, done, _ = env.step({0: 0}) for i in range(env.get_num_agents()): - env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7) + env.obs_builder.util_print_obs_subtree(tree=obs[i]) env_renderer = RenderTool(env) env_renderer.renderEnv(show=True, frames=True) diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py index 5c09f0ac..714123ed 100644 --- a/flatland/core/grid/grid4.py +++ b/flatland/core/grid/grid4.py @@ -11,6 +11,13 @@ class Grid4TransitionsEnum(IntEnum): SOUTH = 2 WEST = 3 + @staticmethod + def to_char(int: int): + return {0: 'N', + 1: 'E', + 2: 'S', + 3: 'W'}[int] + class Grid4Transitions(Transitions): """ diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8ed455ca..e9833b0d 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -1,6 +1,7 @@ """ Collection of environment-specific ObservationBuilder. """ +import pprint from collections import deque import numpy as np @@ -34,6 +35,8 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent_direction = {} self.predictor = predictor self.agents_previous_reset = None + self.tree_explored_actions = [1, 2, 3, 0] + self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] def reset(self): agents = self.env.agents @@ -126,19 +129,6 @@ class TreeObsForRailEnv(ObservationBuilder): desired_movement_from_new_cell = (neigh_direction + 2) % 4 - """ - # Is the next cell a dead-end? - isNextCellDeadEnd = False - nbits = 0 - tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - # Dead-end! - isNextCellDeadEnd = True - """ - # Check all possible transitions in new_cell for agent_orientation in range(4): # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? @@ -213,7 +203,7 @@ class TreeObsForRailEnv(ObservationBuilder): [... from 'right] + [... from 'back'] - Finally, each node information is composed of 5 floating point values: + Finally, each node information is composed of 8 floating point values: #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored. @@ -268,7 +258,6 @@ class TreeObsForRailEnv(ObservationBuilder): # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # If only one transition is possible, the tree is oriented with this transition as the forward branch. - # TODO: Test if this works as desired! orientation = agent.direction if num_transitions == 1: @@ -282,15 +271,20 @@ class TreeObsForRailEnv(ObservationBuilder): observation = observation + branch_observation visited = visited.union(branch_visited) else: - num_cells_to_fill_in = 0 - pow4 = 1 - for i in range(self.max_depth): - num_cells_to_fill_in += pow4 - pow4 *= 4 - observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in + # add cells filled with infinity if no transition is possible + observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth) self.env.dev_obs_dict[handle] = visited return observation + def _num_cells_to_fill_in(self, remaining_depth): + """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim.""" + num_observations = 0 + pow4 = 1 + for i in range(remaining_depth): + num_observations += pow4 + pow4 *= 4 + return num_observations * self.observation_dim + def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth): """ Utility function to compute tree-based observations. @@ -334,7 +328,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Cummulate the number of agents on branch with other direction other_agent_opposite_direction += 1 - # Register possible conflict + # Register possible future conflict if self.predictor and num_steps < self.max_prediction_depth: int_position = coordinate_to_position(self.env.width, [position]) if tot_dist < self.max_prediction_depth: @@ -422,42 +416,6 @@ class TreeObsForRailEnv(ObservationBuilder): # ############################# # ############################# # Modify here to append new / different features for each visited cell! - """ - other_agent_same_direction = \ - 1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0 - other_agent_opposite_direction = \ - 1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0 - - if last_isTarget: - observation = [0, - other_target_encountered, - other_agent_encountered, - root_observation[3] + num_steps, - 0, - other_agent_same_direction, - other_agent_opposite_direction - ] - - elif last_isTerminal: - observation = [0, - other_target_encountered, - other_agent_encountered, - np.inf, - np.inf, - other_agent_same_direction, - other_agent_opposite_direction - ] - else: - observation = [0, - other_target_encountered, - other_agent_encountered, - root_observation[3] + num_steps, - self.distance_map[handle, position[0], position[1], direction], - other_agent_same_direction, - other_agent_opposite_direction - ] - """ - if last_isTarget: observation = [own_target_encountered, other_target_encountered, @@ -522,41 +480,47 @@ class TreeObsForRailEnv(ObservationBuilder): if len(branch_visited) != 0: visited = visited.union(branch_visited) else: - num_cells_to_fill_in = 0 - pow4 = 1 - for i in range(self.max_depth - depth): - num_cells_to_fill_in += pow4 - pow4 *= 4 - observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in + # no exploring possible, add just cells with infinity + observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth) return observation, visited - def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0): + def util_print_obs_subtree(self, tree): """ Utility function to pretty-print tree observations returned by this object. """ - if len(tree) < num_features_per_node: + pp = pprint.PrettyPrinter(indent=4) + pp.pprint(self.unfold_observation_tree(tree)) + + def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True): + """ + Utility function to pretty-print tree observations returned by this object. + """ + if len(tree) < self.observation_dim: return depth = 0 - tmp = len(tree) / num_features_per_node - 1 + tmp = len(tree) / self.observation_dim - 1 pow4 = 4 while tmp > 0: tmp -= pow4 depth += 1 pow4 *= 4 - prompt_ = ['L:', 'F:', 'R:', 'B:'] - - print(" " * current_depth + prompt, tree[0:num_features_per_node]) - child_size = (len(tree) - num_features_per_node) // 4 - for children in range(4): - child_tree = tree[(num_features_per_node + children * child_size): - (num_features_per_node + (children + 1) * child_size)] - self.util_print_obs_subtree(child_tree, - num_features_per_node, - prompt=prompt_[children], - current_depth=current_depth + 1) + unfolded = {} + unfolded[''] = tree[0:self.observation_dim] + child_size = (len(tree) - self.observation_dim) // 4 + for child in range(4): + child_tree = tree[(self.observation_dim + child * child_size): + (self.observation_dim + (child + 1) * child_size)] + observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1) + if observation_tree is not None: + if actions_for_display: + label = self.tree_explorted_actions_char[child] + else: + label = self.tree_explored_actions[child] + unfolded[label] = observation_tree + return unfolded def _set_env(self, env): self.env = env @@ -725,8 +689,6 @@ class LocalObsForRailEnv(ObservationBuilder): bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist) - # self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array( - # list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) def get(self, handle): agents = self.env.agents diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 26bccf6a..b4a56a8d 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -19,12 +19,22 @@ from flatland.envs.observations import TreeObsForRailEnv class RailEnvActions(IntEnum): - DO_NOTHING = 0 + DO_NOTHING = 0 # implies change of direction in a dead-end! MOVE_LEFT = 1 MOVE_FORWARD = 2 MOVE_RIGHT = 3 STOP_MOVING = 4 + @staticmethod + def to_char(a: int): + return { + 0: 'B', + 1: 'L', + 2: 'F', + 3: 'R', + 4: 'S', + }[a] + class RailEnv(Environment): """ diff --git a/tests/simple_rail.py b/tests/simple_rail.py new file mode 100644 index 00000000..894864ac --- /dev/null +++ b/tests/simple_rail.py @@ -0,0 +1,48 @@ +import numpy as np + +from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.transition_map import GridTransitionMap + + +def make_simple_rail(): + # We instantiate a very simple rail network on a 7x10 grid: + # | + # | + # | + # _ _ _ /_\ _ _ _ _ _ _ + # \ / + # | + # | + # | + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + transitions = Grid4Transitions([]) + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + double_switch_south_horizontal_straight = horizontal_straight + cells[6] + double_switch_north_horizontal_straight = transitions.rotate_transition( + double_switch_south_horizontal_straight, 180) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [double_switch_north_horizontal_straight] + + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index ce224736..5ee5b4a6 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -3,62 +3,17 @@ import numpy as np -from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv +from simple_rail import make_simple_rail """Tests for `flatland` package.""" def test_global_obs(): - # We instantiate a very simple rail network on a 7x10 grid: - # | - # | - # | - # _ _ _ /_\ _ _ _ _ _ _ - # \ / - # | - # | - # | + rail, rail_map = make_simple_rail() - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - - transitions = Grid4Transitions([]) - empty = cells[0] - - dead_end_from_south = cells[7] - dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) - dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) - dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) - - vertical_straight = cells[1] - horizontal_straight = transitions.rotate_transition(vertical_straight, 90) - - double_switch_south_horizontal_straight = horizontal_straight + cells[6] - double_switch_north_horizontal_straight = transitions.rotate_transition( - double_switch_south_horizontal_straight, 180) - - rail_map = np.array( - [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + - [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + - [[dead_end_from_east] + [horizontal_straight] * 2 + - [double_switch_north_horizontal_straight] + - [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + - [horizontal_straight] * 2 + [dead_end_from_west]] + - [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + - [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) - - rail = GridTransitionMap(width=rail_map.shape[1], - height=rail_map.shape[0], transitions=transitions) - rail.grid = rail_map env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_GridTransitionMap_generator(rail), diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 16850672..1bf564ed 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -1,64 +1,21 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import pprint import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnvActions from flatland.utils.rendertools import RenderTool +from simple_rail import make_simple_rail """Test predictions for `flatland` package.""" -def make_simple_rail(): - # We instantiate a very simple rail network on a 7x10 grid: - # | - # | - # | - # _ _ _ /_\ _ _ _ _ _ _ - # \ / - # | - # | - # | - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - transitions = Grid4Transitions([]) - empty = cells[0] - dead_end_from_south = cells[7] - dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) - dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) - dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) - vertical_straight = cells[1] - horizontal_straight = transitions.rotate_transition(vertical_straight, 90) - double_switch_south_horizontal_straight = horizontal_straight + cells[6] - double_switch_north_horizontal_straight = transitions.rotate_transition( - double_switch_south_horizontal_straight, 180) - rail_map = np.array( - [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + - [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + - [[dead_end_from_east] + [horizontal_straight] * 2 + - [double_switch_north_horizontal_straight] + - [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + - [horizontal_straight] * 2 + [dead_end_from_west]] + - [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + - [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) - rail = GridTransitionMap(width=rail_map.shape[1], - height=rail_map.shape[0], transitions=transitions) - rail.grid = rail_map - return rail, rail_map - - def test_dummy_predictor(rendering=False): rail, rail_map = make_simple_rail() @@ -68,12 +25,16 @@ def test_dummy_predictor(rendering=False): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) + # reset to initialize agents_static env.reset() # set initial position and direction for testing... - env.agents[0].position = (5, 6) - env.agents[0].direction = 0 - env.agents[0].target = (3, 0) + env.agents_static[0].position = (5, 6) + env.agents_static[0].direction = 0 + env.agents_static[0].target = (3, 0) + + # reset to set agents from agents_static + env.reset(False, False) if rendering: renderer = RenderTool(env, gl="PILSVG") @@ -154,41 +115,39 @@ def test_shortest_path_predictor(rendering=False): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) + + # reset to initialize agents_static env.reset() - agent = env.agents[0] + # set the initial position + agent = env.agents_static[0] agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.target = (3, 9) # east dead-end - agent.moving = True + # reset to set agents from agents_static + env.reset(False, False) + if rendering: renderer = RenderTool(env, gl="PILSVG") renderer.renderEnv(show=True, show_observations=False) input("Continue?") - agent = env.agents[0] - assert agent.position == (5, 6) - assert agent.direction == 0 - assert agent.target == (3, 9) - assert agent.moving - - env.obs_builder._compute_distance_map() - + # compute the observations and predictions distance_map = env.obs_builder.distance_map - assert distance_map[agent.handle, agent.position[0], agent.position[ + assert distance_map[0, agent.position[0], agent.position[ 1], agent.direction] == 5.0, "found {} instead of {}".format( distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0) - # test assertions - env.obs_builder.get_many() + # extract the data predictions = env.obs_builder.predictions positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0]))) + # test if data meets expectations expected_positions = [ [5, 6], [4, 6], @@ -292,3 +251,60 @@ def test_shortest_path_predictor(rendering=False): "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets) assert np.array_equal(actions, expected_actions), \ "actions {}, expected {}".format(actions, expected_actions) + + +def test_shortest_path_predictor_conflicts(rendering=False): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + # initialize agents_static + env.reset() + + # set the initial position + agent = env.agents_static[0] + agent.position = (5, 6) # south dead-end + agent.direction = 0 # north + agent.target = (3, 9) # east dead-end + agent.moving = True + + agent = env.agents_static[1] + agent.position = (3, 8) # east dead-end + agent.direction = 3 # west + agent.target = (6, 6) # south dead-end + agent.moving = True + + # reset to set agents from agents_static + observations = env.reset(False, False) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.renderEnv(show=True, show_observations=False) + input("Continue?") + + # get the trees to test + obs_builder: TreeObsForRailEnv = env.obs_builder + pp = pprint.PrettyPrinter(indent=4) + tree_0 = obs_builder.unfold_observation_tree(observations[0]) + tree_1 = obs_builder.unfold_observation_tree(observations[1]) + pp.pprint(tree_0) + + # check the expectations + # TODO check with Erik, this should be symmetric, should it not? + expected_conflicts_0 = [('F', 'R'), ('F', 'L')] + expected_conflicts_1 = [('F'), ('F', 'L')] + _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ") + _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ") + + +def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''): + assert (tree_0[''][7] > 0) == (() in expected_conflicts), "{}[]".format(prompt) + for a_1 in obs_builder.tree_explorted_actions_char: + conflict = tree_0[a_1][''][7] + assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) + for a_2 in obs_builder.tree_explorted_actions_char: + conflict = tree_0[a_1][a_2][''][7] + assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2) -- GitLab