diff --git a/benchmarks/benchmark_all_examples.py b/benchmarks/benchmark_all_examples.py index 676bfe16ba94d53682ff6e52d7b714d8fa1c2d8b..e45537a4dd2dca52482f53f9b865f2302cb660c6 100644 --- a/benchmarks/benchmark_all_examples.py +++ b/benchmarks/benchmark_all_examples.py @@ -1,7 +1,6 @@ import runpy import sys from io import StringIO -from test.support import swap_attr from time import sleep import importlib_resources @@ -9,6 +8,8 @@ import pkg_resources from benchmarker import Benchmarker from importlib_resources import path +from benchmarks.benchmark_utils import swap_attr + for entry in [entry for entry in importlib_resources.contents('examples') if not pkg_resources.resource_isdir('examples', entry) and entry.endswith(".py") diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 9055dd4ccc5adf75e673e6d3685daec4304f3e9e..853d5f5e32514ed6d61d10daa89dfc694130f15f 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -19,7 +19,7 @@ env = RailEnv(width=7, # Print the observation vector for agent 0 obs, all_rewards, done, _ = env.step({0: 0}) for i in range(env.get_num_agents()): - env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7) + env.obs_builder.util_print_obs_subtree(tree=obs[i]) env_renderer = RenderTool(env) env_renderer.renderEnv(show=True, frames=True) diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py index 5c09f0ac8ba86ed7987aefcf92a541f2ea5d1de4..714123ed5c5d2dd704c6f1262df5cc18fba8ef10 100644 --- a/flatland/core/grid/grid4.py +++ b/flatland/core/grid/grid4.py @@ -11,6 +11,13 @@ class Grid4TransitionsEnum(IntEnum): SOUTH = 2 WEST = 3 + @staticmethod + def to_char(int: int): + return {0: 'N', + 1: 'E', + 2: 'S', + 3: 'W'}[int] + class Grid4Transitions(Transitions): """ diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a4903853bab30d185293892010c714d050f63439..a10c58e6f5eac8ebd04990505894a917c2212b3f 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -1,6 +1,7 @@ """ Collection of environment-specific ObservationBuilder. """ +import pprint from collections import deque import numpy as np @@ -19,6 +20,8 @@ class TreeObsForRailEnv(ObservationBuilder): network to simplify the representation of the state of the environment for each agent. """ + observation_dim = 9 + def __init__(self, max_depth, predictor=None): self.max_depth = max_depth @@ -28,12 +31,13 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 - self.observation_dim = 9 self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} self.predictor = predictor self.agents_previous_reset = None + self.tree_explored_actions = [1, 2, 3, 0] + self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] def reset(self): agents = self.env.agents @@ -126,19 +130,6 @@ class TreeObsForRailEnv(ObservationBuilder): desired_movement_from_new_cell = (neigh_direction + 2) % 4 - """ - # Is the next cell a dead-end? - isNextCellDeadEnd = False - nbits = 0 - tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - # Dead-end! - isNextCellDeadEnd = True - """ - # Check all possible transitions in new_cell for agent_orientation in range(4): # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? @@ -213,7 +204,7 @@ class TreeObsForRailEnv(ObservationBuilder): [... from 'right] + [... from 'back'] - Finally, each node information is composed of 5 floating point values: + Finally, each node information is composed of 8 floating point values: #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored. @@ -240,7 +231,7 @@ class TreeObsForRailEnv(ObservationBuilder): (possible future use: number of other agents in the same direction in this branch) 0 = no agent present same direction - #9: agent in the opposite drection + #9: agent in the opposite direction n = number of agents present other direction than myself (so conflict) (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself @@ -273,7 +264,6 @@ class TreeObsForRailEnv(ObservationBuilder): # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # If only one transition is possible, the tree is oriented with this transition as the forward branch. - # TODO: Test if this works as desired! orientation = agent.direction if num_transitions == 1: @@ -287,15 +277,20 @@ class TreeObsForRailEnv(ObservationBuilder): observation = observation + branch_observation visited = visited.union(branch_visited) else: - num_cells_to_fill_in = 0 - pow4 = 1 - for i in range(self.max_depth): - num_cells_to_fill_in += pow4 - pow4 *= 4 - observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in + # add cells filled with infinity if no transition is possible + observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth) self.env.dev_obs_dict[handle] = visited return observation + def _num_cells_to_fill_in(self, remaining_depth): + """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim.""" + num_observations = 0 + pow4 = 1 + for i in range(remaining_depth): + num_observations += pow4 + pow4 *= 4 + return num_observations * self.observation_dim + def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth): """ Utility function to compute tree-based observations. @@ -343,7 +338,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Cummulate the number of agents on branch with other direction other_agent_opposite_direction += 1 - # Register possible conflict + # Register possible future conflict if self.predictor and num_steps < self.max_prediction_depth: int_position = coordinate_to_position(self.env.width, [position]) if tot_dist < self.max_prediction_depth: @@ -505,41 +500,47 @@ class TreeObsForRailEnv(ObservationBuilder): if len(branch_visited) != 0: visited = visited.union(branch_visited) else: - num_cells_to_fill_in = 0 - pow4 = 1 - for i in range(self.max_depth - depth): - num_cells_to_fill_in += pow4 - pow4 *= 4 - observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in + # no exploring possible, add just cells with infinity + observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth) return observation, visited - def util_print_obs_subtree(self, tree, num_features_per_node=9, prompt='', current_depth=0): + def util_print_obs_subtree(self, tree): + """ + Utility function to pretty-print tree observations returned by this object. + """ + pp = pprint.PrettyPrinter(indent=4) + pp.pprint(self.unfold_observation_tree(tree)) + + def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True): """ Utility function to pretty-print tree observations returned by this object. """ - if len(tree) < num_features_per_node: + if len(tree) < self.observation_dim: return depth = 0 - tmp = len(tree) / num_features_per_node - 1 + tmp = len(tree) / self.observation_dim - 1 pow4 = 4 while tmp > 0: tmp -= pow4 depth += 1 pow4 *= 4 - prompt_ = ['L:', 'F:', 'R:', 'B:'] - - print(" " * current_depth + prompt, tree[0:num_features_per_node]) - child_size = (len(tree) - num_features_per_node) // 4 - for children in range(4): - child_tree = tree[(num_features_per_node + children * child_size): - (num_features_per_node + (children + 1) * child_size)] - self.util_print_obs_subtree(child_tree, - num_features_per_node, - prompt=prompt_[children], - current_depth=current_depth + 1) + unfolded = {} + unfolded[''] = tree[0:self.observation_dim] + child_size = (len(tree) - self.observation_dim) // 4 + for child in range(4): + child_tree = tree[(self.observation_dim + child * child_size): + (self.observation_dim + (child + 1) * child_size)] + observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1) + if observation_tree is not None: + if actions_for_display: + label = self.tree_explorted_actions_char[child] + else: + label = self.tree_explored_actions[child] + unfolded[label] = observation_tree + return unfolded def _set_env(self, env): self.env = env @@ -708,8 +709,6 @@ class LocalObsForRailEnv(ObservationBuilder): bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist) - # self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array( - # list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) def get(self, handle): agents = self.env.agents diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 26bccf6a9b0ec0a59e7161c865082634d017ce3e..b4a56a8d241a4a1d6706ff67bad6030ce619ebf0 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -19,12 +19,22 @@ from flatland.envs.observations import TreeObsForRailEnv class RailEnvActions(IntEnum): - DO_NOTHING = 0 + DO_NOTHING = 0 # implies change of direction in a dead-end! MOVE_LEFT = 1 MOVE_FORWARD = 2 MOVE_RIGHT = 3 STOP_MOVING = 4 + @staticmethod + def to_char(a: int): + return { + 0: 'B', + 1: 'L', + 2: 'F', + 3: 'R', + 4: 'S', + }[a] + class RailEnv(Environment): """ diff --git a/notebooks/simple_example_3_manual_control.ipynb b/notebooks/simple_example_3_manual_control.ipynb index 7f492e196d5cad28fb4f6251ba3a9b983980a5b4..951fa8953b732d00e993711259726596f95f7273 100644 --- a/notebooks/simple_example_3_manual_control.ipynb +++ b/notebooks/simple_example_3_manual_control.ipynb @@ -114,7 +114,7 @@ "# Print the observation vector for agent 0\n", "obs, all_rewards, done, _ = env.step({0: 0})\n", "for i in range(env.get_num_agents()):\n", - " env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7)\n", + " env.obs_builder.util_print_obs_subtree(tree=obs[i])\n", "\n", "env_renderer = RenderTool(env, gl=\"PIL\")\n", "# env_renderer = RenderTool(env, gl=\"PILSVG\")\n", diff --git a/tests/simple_rail.py b/tests/simple_rail.py new file mode 100644 index 0000000000000000000000000000000000000000..894864acc8d7435681636f2b0c60da238265194e --- /dev/null +++ b/tests/simple_rail.py @@ -0,0 +1,48 @@ +import numpy as np + +from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.transition_map import GridTransitionMap + + +def make_simple_rail(): + # We instantiate a very simple rail network on a 7x10 grid: + # | + # | + # | + # _ _ _ /_\ _ _ _ _ _ _ + # \ / + # | + # | + # | + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + transitions = Grid4Transitions([]) + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + double_switch_south_horizontal_straight = horizontal_straight + cells[6] + double_switch_north_horizontal_straight = transitions.rotate_transition( + double_switch_south_horizontal_straight, 180) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [double_switch_north_horizontal_straight] + + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index ce224736a6fa12f46a23c8eb98c6e02f1c3d294d..fb45e5fa162d1323533a62349922b569a415395f 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -3,62 +3,17 @@ import numpy as np -from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv +from tests.simple_rail import make_simple_rail """Tests for `flatland` package.""" def test_global_obs(): - # We instantiate a very simple rail network on a 7x10 grid: - # | - # | - # | - # _ _ _ /_\ _ _ _ _ _ _ - # \ / - # | - # | - # | + rail, rail_map = make_simple_rail() - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - - transitions = Grid4Transitions([]) - empty = cells[0] - - dead_end_from_south = cells[7] - dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) - dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) - dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) - - vertical_straight = cells[1] - horizontal_straight = transitions.rotate_transition(vertical_straight, 90) - - double_switch_south_horizontal_straight = horizontal_straight + cells[6] - double_switch_north_horizontal_straight = transitions.rotate_transition( - double_switch_south_horizontal_straight, 180) - - rail_map = np.array( - [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + - [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + - [[dead_end_from_east] + [horizontal_straight] * 2 + - [double_switch_north_horizontal_straight] + - [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + - [horizontal_straight] * 2 + [dead_end_from_west]] + - [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + - [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) - - rail = GridTransitionMap(width=rail_map.shape[1], - height=rail_map.shape[0], transitions=transitions) - rail.grid = rail_map env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_GridTransitionMap_generator(rail), diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index c5514bd09b296796c93cb4586c5640eb05569164..c90f91a041b16cee2dc55a58562d34a0b9100560 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -1,63 +1,20 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import pprint import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool +from tests.simple_rail import make_simple_rail """Test predictions for `flatland` package.""" -def make_simple_rail(): - # We instantiate a very simple rail network on a 7x10 grid: - # | - # | - # | - # _ _ _ /_\ _ _ _ _ _ _ - # \ / - # | - # | - # | - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - transitions = Grid4Transitions([]) - empty = cells[0] - dead_end_from_south = cells[7] - dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) - dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) - dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) - vertical_straight = cells[1] - horizontal_straight = transitions.rotate_transition(vertical_straight, 90) - double_switch_south_horizontal_straight = horizontal_straight + cells[6] - double_switch_north_horizontal_straight = transitions.rotate_transition( - double_switch_south_horizontal_straight, 180) - rail_map = np.array( - [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + - [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + - [[dead_end_from_east] + [horizontal_straight] * 2 + - [double_switch_north_horizontal_straight] + - [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + - [horizontal_straight] * 2 + [dead_end_from_west]] + - [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + - [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) - rail = GridTransitionMap(width=rail_map.shape[1], - height=rail_map.shape[0], transitions=transitions) - rail.grid = rail_map - return rail, rail_map - - def test_dummy_predictor(rendering=False): rail, rail_map = make_simple_rail() @@ -67,12 +24,16 @@ def test_dummy_predictor(rendering=False): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) + # reset to initialize agents_static env.reset() # set initial position and direction for testing... - env.agents[0].position = (5, 6) - env.agents[0].direction = 0 - env.agents[0].target = (3, 0) + env.agents_static[0].position = (5, 6) + env.agents_static[0].direction = 0 + env.agents_static[0].target = (3, 0) + + # reset to set agents from agents_static + env.reset(False, False) if rendering: renderer = RenderTool(env, gl="PILSVG") @@ -153,40 +114,38 @@ def test_shortest_path_predictor(rendering=False): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) + + # reset to initialize agents_static env.reset() - agent = env.agents[0] + # set the initial position + agent = env.agents_static[0] agent.position = (5, 6) # south dead-end agent.direction = 0 # north agent.target = (3, 9) # east dead-end - agent.moving = True + # reset to set agents from agents_static + env.reset(False, False) + if rendering: renderer = RenderTool(env, gl="PILSVG") renderer.renderEnv(show=True, show_observations=False) input("Continue?") - agent = env.agents[0] - assert agent.position == (5, 6) - assert agent.direction == 0 - assert agent.target == (3, 9) - assert agent.moving - - env.obs_builder._compute_distance_map() - + # compute the observations and predictions distance_map = env.obs_builder.distance_map - assert distance_map[agent.handle, agent.position[0], agent.position[ + assert distance_map[0, agent.position[0], agent.position[ 1], agent.direction] == 5.0, "found {} instead of {}".format( distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0) - # test assertions - env.obs_builder.get_many() + # extract the data predictions = env.obs_builder.predictions positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) + # test if data meets expectations expected_positions = [ [5, 6], [4, 6], @@ -264,3 +223,59 @@ def test_shortest_path_predictor(rendering=False): "directions {}, expected {}".format(directions, expected_directions) assert np.array_equal(time_offsets, expected_time_offsets), \ "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets) + + +def test_shortest_path_predictor_conflicts(rendering=False): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + # initialize agents_static + env.reset() + + # set the initial position + agent = env.agents_static[0] + agent.position = (5, 6) # south dead-end + agent.direction = 0 # north + agent.target = (3, 9) # east dead-end + agent.moving = True + + agent = env.agents_static[1] + agent.position = (3, 8) # east dead-end + agent.direction = 3 # west + agent.target = (6, 6) # south dead-end + agent.moving = True + + # reset to set agents from agents_static + observations = env.reset(False, False) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.renderEnv(show=True, show_observations=False) + input("Continue?") + + # get the trees to test + obs_builder: TreeObsForRailEnv = env.obs_builder + pp = pprint.PrettyPrinter(indent=4) + tree_0 = obs_builder.unfold_observation_tree(observations[0]) + tree_1 = obs_builder.unfold_observation_tree(observations[1]) + pp.pprint(tree_0) + + # check the expectations + expected_conflicts_0 = [('F', 'R')] + expected_conflicts_1 = [('F', 'L')] + _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ") + _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ") + + +def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''): + assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt) + for a_1 in obs_builder.tree_explorted_actions_char: + conflict = tree_0[a_1][''][8] + assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) + for a_2 in obs_builder.tree_explorted_actions_char: + conflict = tree_0[a_1][a_2][''][8] + assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)