diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 0bbac92e50b0224f765c77935669328150ab1348..f3bfef2874f8170af2386dc20c9a8990840dd84d 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -36,7 +36,7 @@ env = RailEnv(width=50, # Number of cities in map (where train stations are) seed=1, # Random seed grid_mode=False, - max_rails_between_cities=2, + max_rails_between_cities=3, max_rails_in_city=6, ), schedule_generator=sparse_schedule_generator(speed_ration_map), diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 2a3db9df346faea5a8734e3cb4a78194133cb319..d65663d1a9ac63169cae10c35cb5e909fe32d338 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -561,11 +561,11 @@ class GridTransitionMap(TransitionMap): if number_of_incoming == 3: self.set_transitions(rcPos, 0) hole = np.argwhere(incoming_connections < 1)[0][0] - if direction > 0: + if direction >= 0: switch_type_idx = (direction - hole + 3) % 4 - if switch_type_idx == 2: - transition = simple_switch_west_south if switch_type_idx == 0: + transition = simple_switch_west_south + elif switch_type_idx == 2: transition = simple_switch_east_south else: transition = np.random.choice(three_way_transitions, 1) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 04903d3ad02ec41e9615c2bb488f24c01e163a1d..37fdcf5d5055e9b17310a06ddd3be13f14b8d1c7 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -9,7 +9,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_astar import a_star -from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point +from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point, get_new_position from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions @@ -126,3 +126,33 @@ def connect_straight_line_in_grid_map(grid_map: GridTransitionMap, start: IntVec grid_map.grid[cell] = transition return path + + +def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, rail_trans: RailEnvTransitions): + """ + Fix inner city nodes + :param grid_map: + :param start: + :param rail_trans: + :return: + """ + corner_directions = [] + for direction in range(4): + tmp_pos = get_new_position(inner_node_pos, direction) + if grid_map.grid[tmp_pos] > 0: + corner_directions.append(direction) + if len(corner_directions) == 2: + transition = 0 + transition = rail_trans.set_transition(transition, mirror(corner_directions[0]), corner_directions[1], 1) + transition = rail_trans.set_transition(transition, mirror(corner_directions[1]), corner_directions[0], 1) + grid_map.grid[inner_node_pos] = transition + tmp_pos = get_new_position(inner_node_pos, corner_directions[0]) + transition = grid_map.grid[tmp_pos] + transition = rail_trans.set_transition(transition, corner_directions[0], mirror(corner_directions[0]), 1) + grid_map.grid[tmp_pos] = transition + tmp_pos = get_new_position(inner_node_pos, corner_directions[1]) + transition = grid_map.grid[tmp_pos] + transition = rail_trans.set_transition(transition, corner_directions[1], mirror(corner_directions[1]), + 1) + grid_map.grid[tmp_pos] = transition + return diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index d32d1521c9fa98918a1744f6e5874c6ae88bbbcf..324e21b633b272e72c67e0a4b2b88aa024dcdaa4 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -12,7 +12,8 @@ from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, In Vec2dOperations from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map +from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \ + fix_inner_nodes RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] @@ -698,22 +699,35 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ start_idx = int((nr_of_connection_points - number_of_out_rails) / 2) for direction in range(4): connection_slots = np.arange(nr_of_connection_points) - start_idx + inner_point_offset = np.abs(connection_slots) + np.clip(connection_slots, 0, 1) for connection_idx in range(connections_per_direction[direction]): if direction == 0: tmp_coordinates = ( + city_position[0] - city_radius + inner_point_offset[connection_idx], + city_position[1] + connection_slots[connection_idx]) + out_tmp_coordinates = ( city_position[0] - city_radius, city_position[1] + connection_slots[connection_idx]) if direction == 1: tmp_coordinates = ( + city_position[0] + connection_slots[connection_idx], + city_position[1] + city_radius - inner_point_offset[connection_idx]) + out_tmp_coordinates = ( city_position[0] + connection_slots[connection_idx], city_position[1] + city_radius) if direction == 2: tmp_coordinates = ( + city_position[0] + city_radius - inner_point_offset[connection_idx], + city_position[1] + connection_slots[connection_idx]) + out_tmp_coordinates = ( city_position[0] + city_radius, city_position[1] + connection_slots[connection_idx]) if direction == 3: tmp_coordinates = ( + city_position[0] + connection_slots[connection_idx], + city_position[1] - city_radius + inner_point_offset[connection_idx]) + out_tmp_coordinates = ( city_position[0] + connection_slots[connection_idx], city_position[1] - city_radius) connection_points_coordinates_inner[direction].append(tmp_coordinates) if connection_idx in range(start_idx, start_idx + number_of_out_rails): - connection_points_coordinates_outer[direction].append(tmp_coordinates) + connection_points_coordinates_outer[direction].append(out_tmp_coordinates) inner_connection_points.append(connection_points_coordinates_inner) outer_connection_points.append(connection_points_coordinates_outer) @@ -789,11 +803,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ :param grid_map: :return: Returns the cells of the through path which cannot be occupied by trainstations """ - through_path_cells: List[IntVector2DArray] = [[] for i in range(len(city_positions))] free_rails: List[List[List[IntVector2D]]] = [[] for i in range(len(city_positions))] for current_city in range(len(city_positions)): - all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in - sublist] + # This part only works if we have keep same number of connection points for both directions # Also only works with two connection direction at each city for i in range(4): @@ -802,18 +814,29 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ break opposite_boarder = (boarder + 2) % 4 - boarder_one = inner_connection_points[current_city][boarder] - boarder_two = inner_connection_points[current_city][opposite_boarder] - - # Connect the ends of the tracks - connect_straight_line_in_grid_map(grid_map, boarder_one[0], boarder_one[-1], rail_trans) - connect_straight_line_in_grid_map(grid_map, boarder_two[0], boarder_two[-1], rail_trans) - + nr_of_connection_points = len(inner_connection_points[current_city][boarder]) + number_of_out_rails = len(outer_connection_points[current_city][boarder]) + start_idx = int((nr_of_connection_points - number_of_out_rails) / 2) # Connect parallel tracks - for track_id in range(len(inner_connection_points[current_city][boarder])): + for track_id in range(nr_of_connection_points): source = inner_connection_points[current_city][boarder][track_id] target = inner_connection_points[current_city][opposite_boarder][track_id] current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans) + + for track_id in range(nr_of_connection_points): + source = inner_connection_points[current_city][boarder][track_id] + target = inner_connection_points[current_city][opposite_boarder][track_id] + fix_inner_nodes( + grid_map, source, rail_trans) + fix_inner_nodes( + grid_map, target, rail_trans) + if start_idx <= track_id < start_idx + number_of_out_rails: + source_outer = outer_connection_points[current_city][boarder][track_id - start_idx] + target_outer = outer_connection_points[current_city][opposite_boarder][track_id - start_idx] + connect_straight_line_in_grid_map(grid_map, source, source_outer, rail_trans) + connect_straight_line_in_grid_map(grid_map, target, target_outer, rail_trans) + + free_rails[current_city].append(current_track) return free_rails @@ -870,18 +893,20 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ cells_to_fix = city_cells + inter_city_lines for cell in cells_to_fix: cell_valid = grid_map.cell_neighbours_valid(cell, True) - if grid_map.grid[cell] == int('1000010000100001', 2): - grid_map.fix_transitions(cell) + # cell_valid = grid_map.transitions.is_valid(cell) + # if grid_map.grid[cell] == int('1000010000100001', 2): + # grid_map.fix_transitions(cell) + # if bin(grid_map.grid[cell]).count("1") == 4: + # cell_valid = False + # print("fixing cell", cell, vector_field[cell]) if not cell_valid: rails_to_fix[3 * rails_to_fix_cnt] = cell[0] rails_to_fix[3 * rails_to_fix_cnt + 1] = cell[1] - rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[(cell[0], cell[1])] + rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[cell] rails_to_fix_cnt += 1 - # Fix all other cells for cell in range(rails_to_fix_cnt): - grid_map.fix_transitions((rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]), - rails_to_fix[3 * rails_to_fix_cnt + 2]) + grid_map.fix_transitions((rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]), rails_to_fix[3 * cell + 2]) def _closest_neighbour_in_grid4_directions(current_city_idx: int, city_positions: IntVector2DArray) -> List[int]: """ diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 699f1e91cce54d6486b5736a4884a4a5d804f1f6..08b8bdc9789df1a178b997860521630c468373d5 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -570,8 +570,8 @@ def test_sparse_rail_generator(): for a in range(env.get_num_agents()): s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0)) s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0)) - assert s0 == 61, "actual={}".format(s0) - assert s1 == 42, "actual={}".format(s1) + assert s0 == 39, "actual={}".format(s0) + assert s1 == 27, "actual={}".format(s1) def test_sparse_rail_generator_deterministic(): diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index ace603422006d7f6229e854e24b6f244b7fe183c..7cf6812654a8d26fd19dd174f07bc0901785463b 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -158,7 +158,7 @@ def test_malfunction_process_statistically(): env.step(action_dict) # check that generation of malfunctions works as expected - assert nb_malfunction == 128, "nb_malfunction={}".format(nb_malfunction) + assert nb_malfunction == 152, "nb_malfunction={}".format(nb_malfunction) def test_initial_malfunction():