From 6e367dd68c7e59abe25c4762be61503d8949e7ac Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 29 Aug 2019 13:44:10 +0200 Subject: [PATCH] bugfix #141: check_path_exists and tests --- flatland/core/transition_map.py | 29 +++++- flatland/envs/schedule_generators.py | 34 ++---- flatland/utils/simple_rail.py | 37 +++++++ tests/test_flatland_core_transition_map.py | 116 ++++++++++++++++++++- 4 files changed, 185 insertions(+), 31 deletions(-) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index e32a5623..1de8d68f 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -7,6 +7,7 @@ from importlib_resources import path from numpy import array from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transitions import Transitions @@ -298,8 +299,7 @@ class GridTransitionMap(TransitionMap): self.height = new_height self.grid = new_grid - - def is_dead_end(self,rcPos): + def is_dead_end(self, rcPos): """ Check if the cell is a dead-end :param rcPos: tuple(row, column) with grid coordinate @@ -310,7 +310,30 @@ class GridTransitionMap(TransitionMap): while tmp > 0: nbits += (tmp & 1) tmp = tmp >> 1 - return nbits==1 + return nbits == 1 + + def _path_exists(self, start, direction, end): + # print("_path_exists({},{},{}".format(start, direction, end)) + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + node_position = node[0] + node_direction = node[1] + if node_position[0] == end[0] and node_position[1] == end[1]: + return True + if node not in visited: + visited.add(node) + + moves = self.get_transitions(node_position[0], node_position[1], node_direction) + for move_index in range(4): + if moves[move_index]: + stack.append((get_new_position(node_position, move_index), + move_index)) + + return False def cell_neighbours_valid(self, rcPos, check_this_cell=False): """ diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index c9f97ea4..59bc3e5d 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -131,29 +131,6 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> """ def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: - def _path_exists(rail, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return True - if node not in visited: - visited.add(node) - moves = rail.get_transitions(node[0][0], node[0][1], node[1]) - for move_index in range(4): - if moves[move_index]: - stack.append((get_new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - if rail.is_dead_end(node[0]): - stack.append((node[0], (node[1] + 2) % 4)) - - return False valid_positions = [] for r in range(rail.height): @@ -194,6 +171,8 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> re_generate = False for i in range(num_agents): valid_movements = [] + if rail.is_dead_end(agents_position[i]): + print(" dead_end", agents_position[i]) for direction in range(4): position = agents_position[i] moves = rail.get_transitions(position[0], position[1], direction) @@ -204,14 +183,15 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> valid_starting_directions = [] for m in valid_movements: new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], - agents_target[i]): + if m[0] not in valid_starting_directions and rail._path_exists(new_position, m[0], + agents_target[i]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: - re_generate = True update_agents[i] = 1 - print("reset position for agents:",i, agents_position[i],agents_target[i]) + print("reset position for agents:", i, agents_position[i], agents_target[i]) + print(" dead_end", rail.is_dead_end(agents_position[i])) + re_generate = True break else: agents_direction[i] = valid_starting_directions[ diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index c5fe4860..67bd93dd 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -81,6 +81,43 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: rail.grid = rail_map return rail, rail_map +def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # Note that that cells have invalid RailEnvTransitions! + # | + # | + # | + # _ _ _ _ _ _ _ _ _ _ + # / + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_left = cells[2] + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] + + [[empty] * 3 + [dead_end_from_north] + [empty] * 6] + + [[dead_end_from_east] + [horizontal_straight] * 5 + [simple_switch_east_west_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]: # We instantiate a very simple rail network on a 7x10 grid: diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index a4142316..eb35856a 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -1,6 +1,13 @@ from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum from flatland.core.transition_map import GridTransitionMap +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.rendertools import RenderTool +from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected def test_grid4_get_transitions(): @@ -43,4 +50,111 @@ def test_grid8_set_transitions(): grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0) assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0) -# TODO GridTransitionMap + +def check_path(env, rail, position, direction, target, expected, rendering=False): + agent = env.agents_static[0] + agent.position = position # south dead-end + agent.direction = direction # north + agent.target = target # east dead-end + agent.moving = True + # reset to set agents from agents_static + # env.reset(False, False) + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.render_env(show=True, show_observations=False) + input("Continue?") + assert rail._path_exists(agent.position, agent.direction, agent.target) == expected + + +def test_path_exists(rendering=False): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # reset to initialize agents_static + env.reset() + + check_path( + env, + rail, + (5, 6), # north of south dead-end + 0, # north + (3, 9), # east dead-end + True + ) + + check_path( + env, + rail, + (6, 6), # south dead-end + 2, # south + (3, 9), # east dead-end + True + ) + + check_path( + env, + rail, + (3, 0), # east dead-end + 3, # west + (0, 3), # north dead-end + True + ) + check_path( + env, + rail, + (5, 6), # east dead-end + 0, # west + (1, 3), # north dead-end + True) + + check_path( + env, + rail, + (1,3), # east dead-end + 2, # south + (3,3), # north dead-end + True + ) + + check_path( + env, + rail, + (1,3), # east dead-end + 0, # north + (3,3), # north dead-end + True + ) + + +def test_path_not_exists(rendering=False): + rail, rail_map = make_simple_rail_unconnected() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # reset to initialize agents_static + env.reset() + + check_path( + env, + rail, + (5, 6), # south dead-end + 0, # north + (0, 3), # north dead-end + False + ) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.render_env(show=True, show_observations=False) + input("Continue?") -- GitLab