diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index 04da66904fda1a58847a4acc510d7fc4e4e86887..6162b918734eb311752675e75e203d90e5558c1c 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -43,6 +43,7 @@ def custom_schedule_generator() -> ScheduleGenerator: env = RailEnv(width=6, height=4, rail_generator=custom_rail_generator(), + schedule_generator=custom_schedule_generator(), number_of_agents=1) env.reset() diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 7a673bcf9ba46c574db0983e3c52257e4a07358e..bb954998688772a7ce69e5228cff3e16d037f2af 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -7,6 +7,7 @@ from importlib_resources import path from numpy import array from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transitions import Transitions @@ -298,6 +299,76 @@ class GridTransitionMap(TransitionMap): self.height = new_height self.grid = new_grid + def is_dead_end(self, rcPos): + """ + Check if the cell is a dead-end. + + Parameters + ---------- + rcPos: Tuple[int,int] + tuple(row, column) with grid coordinate + Returns + ------- + boolean + True if and only if the cell is a dead-end. + """ + nbits = 0 + tmp = self.get_full_transitions(rcPos[0], rcPos[1]) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + return nbits == 1 + + def is_simple_turn(self, rcPos): + """ + Check if the cell is a left/right simple turn + + Parameters + ---------- + rcPos: Tuple[int,int] + tuple(row, column) with grid coordinate + Returns + ------- + boolean + True if and only if the cell is a left/right simple turn. + """ + tmp = self.get_full_transitions(rcPos[0], rcPos[1]) + + def is_simple_turn(trans): + all_simple_turns = set() + for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right + int('0001001000000000', 2) # Case 1c (9) - simple turn left]: + ]: + for _ in range(3): + trans = self.transitions.rotate_transition(trans, rotation=90) + all_simple_turns.add(trans) + return trans in all_simple_turns + + return is_simple_turn(tmp) + + def check_path_exists(self, start, direction, end): + # print("_path_exists({},{},{}".format(start, direction, end)) + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + node_position = node[0] + node_direction = node[1] + if node_position[0] == end[0] and node_position[1] == end[1]: + return True + if node not in visited: + visited.add(node) + + moves = self.get_transitions(node_position[0], node_position[1], node_direction) + for move_index in range(4): + if moves[move_index]: + stack.append((get_new_position(node_position, move_index), + move_index)) + + return False + def cell_neighbours_valid(self, rcPos, check_this_cell=False): """ Check validity of cell at rcPos = tuple(row, column) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 4158675cae63394ed768bfc36aaef9cd5f44da7e..c4fed2e07a97b00b786a7db8eb06af247d3ede8a 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -383,8 +383,9 @@ class TreeObsForRailEnv(ObservationBuilder): elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent[0]: - if direction != self.predicted_dir[pre_step][ca] and cell_transitions[self._reverse_dir( - self.predicted_dir[pre_step][ca])] == 1 and tot_dist < potential_conflict: + if direction != self.predicted_dir[pre_step][ca] \ + and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ + and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist @@ -394,7 +395,8 @@ class TreeObsForRailEnv(ObservationBuilder): conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir( - self.predicted_dir[post_step][ca])] == 1 and tot_dist < potential_conflict: + self.predicted_dir[post_step][ca])] == 1 \ + and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a61ef02207174d04489b5311dc042b7c06db1412..62efbdc5bc0781c4c7482412dafd98710ed9d14e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -271,8 +271,7 @@ class RailEnv(Environment): agent.malfunction_data['next_malfunction'] -= 1 # Only agents that have a positive rate for malfunctions and are not currently broken are considered - if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[ - 'malfunction']: + if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']: # If counter has come to zero --> Agent has malfunction # set next malfunction time and duration of current malfunction diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 8f2c5231232c3cc17f50de4b89b88f1c9fdc5d60..40ec2e0df89a3de48bf0a2a4430de3e30fa556e5 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -683,8 +683,10 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 0, width - 1) tries = 0 - while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[ - trainstation_node] or rail_array[(station_x, station_y)] != 0: + while (station_x, station_y) in train_stations \ + or (station_x, station_y) == node_positions[trainstation_node] \ + or rail_array[(station_x, station_y)] != 0: # noqa: E125 + station_x = np.clip( node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), 0, diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 2ef6dab85fc6145bbfb57b25994903bbd861f65f..4843e0040d80b79de54e8ed57674a37884ef6809 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -131,34 +131,6 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> """ def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: - def _path_exists(rail, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return 1 - if node not in visited: - visited.add(node) - moves = rail.get_transitions(node[0][0], node[0][1], node[1]) - for move_index in range(4): - if moves[move_index]: - stack.append((get_new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = rail.get_full_transitions(node[0][0], node[0][1]) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - stack.append((node[0], (node[1] + 2) % 4)) - - return 0 valid_positions = [] for r in range(rail.height): @@ -167,14 +139,35 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> valid_positions.append((r, c)) if len(valid_positions) == 0: return [], [], [], [] + + if len(valid_positions) < num_agents: + warnings.warn("schedule_generators: len(valid_positions) < num_agents") + return [], [], [], [] + + agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] + agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)] + agents_target_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] + agents_target = [valid_positions[agents_target_idx[i]] for i in range(num_agents)] + update_agents = np.zeros(num_agents) + re_generate = True + cnt = 0 while re_generate: - agents_position = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - agents_target = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] + cnt += 1 + if cnt > 1: + print("re_generate cnt={}".format(cnt)) + if cnt > 1000: + raise Exception("After 1000 re_generates still not success, giving up.") + # update position + for i in range(num_agents): + if update_agents[i] == 1: + x = np.setdiff1d(np.arange(len(valid_positions)), agents_position_idx) + agents_position_idx[i] = np.random.choice(x) + agents_position[i] = valid_positions[agents_position_idx[i]] + x = np.setdiff1d(np.arange(len(valid_positions)), agents_target_idx) + agents_target_idx[i] = np.random.choice(x) + agents_target[i] = valid_positions[agents_target_idx[i]] + update_agents = np.zeros(num_agents) # agents_direction must be a direction for which a solution is # guaranteed. @@ -192,12 +185,15 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> valid_starting_directions = [] for m in valid_movements: new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], - agents_target[i]): + if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[1], + agents_target[i]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: + update_agents[i] = 1 + warnings.warn("reset position for agent[{}]: {} -> {}".format(i, agents_position[i], agents_target[i])) re_generate = True + break else: agents_direction[i] = valid_starting_directions[ np.random.choice(len(valid_starting_directions), 1)[0]] diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index c5fe4860783f242f21c97c55a9119d8918454a96..67bd93dd35c8f53ef3cdef23dbae0f0d785b9a64 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -81,6 +81,43 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: rail.grid = rail_map return rail, rail_map +def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # Note that that cells have invalid RailEnvTransitions! + # | + # | + # | + # _ _ _ _ _ _ _ _ _ _ + # / + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_left = cells[2] + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] + + [[empty] * 3 + [dead_end_from_north] + [empty] * 6] + + [[dead_end_from_east] + [horizontal_straight] * 5 + [simple_switch_east_west_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]: # We instantiate a very simple rail network on a 7x10 grid: diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index a414231619cfa924c2d33776f9f140cf88280517..8812c847e61d81f6614f37d26489b4c17ea7fd14 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -1,6 +1,13 @@ from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum from flatland.core.transition_map import GridTransitionMap +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.rendertools import RenderTool +from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected def test_grid4_get_transitions(): @@ -43,4 +50,111 @@ def test_grid8_set_transitions(): grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0) assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0) -# TODO GridTransitionMap + +def check_path(env, rail, position, direction, target, expected, rendering=False): + agent = env.agents_static[0] + agent.position = position # south dead-end + agent.direction = direction # north + agent.target = target # east dead-end + agent.moving = True + # reset to set agents from agents_static + # env.reset(False, False) + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.render_env(show=True, show_observations=False) + input("Continue?") + assert rail.check_path_exists(agent.position, agent.direction, agent.target) == expected + + +def test_path_exists(rendering=False): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # reset to initialize agents_static + env.reset() + + check_path( + env, + rail, + (5, 6), # north of south dead-end + 0, # north + (3, 9), # east dead-end + True + ) + + check_path( + env, + rail, + (6, 6), # south dead-end + 2, # south + (3, 9), # east dead-end + True + ) + + check_path( + env, + rail, + (3, 0), # east dead-end + 3, # west + (0, 3), # north dead-end + True + ) + check_path( + env, + rail, + (5, 6), # east dead-end + 0, # west + (1, 3), # north dead-end + True) + + check_path( + env, + rail, + (1,3), # east dead-end + 2, # south + (3,3), # north dead-end + True + ) + + check_path( + env, + rail, + (1,3), # east dead-end + 0, # north + (3,3), # north dead-end + True + ) + + +def test_path_not_exists(rendering=False): + rail, rail_map = make_simple_rail_unconnected() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # reset to initialize agents_static + env.reset() + + check_path( + env, + rail, + (5, 6), # south dead-end + 0, # north + (0, 3), # north dead-end + False + ) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.render_env(show=True, show_observations=False) + input("Continue?")