diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 91e3c6ce7b189e1ba0e1626b819974cae8bf7a79..5c92fd1cde5024b330bc10050203d74cd0d74af5 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -31,7 +31,7 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=20, # Number of cities in map (where train stations are) - seed=0, # Random seed + seed=1, # Random seed grid_mode=False, max_inter_city_rails=2, max_tracks_in_city=4, diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 166094aa1f886b606e2faf91917903494f48f7c3..d0e75e27418cfb179ae907f4a3036d6d124618f4 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -90,7 +90,7 @@ def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, def connect_cities(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, forbidden_cells, + start: IntVector2D, end: IntVector2D, forbidden_cells=None, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function, forbidden_cells) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index c7d97803d0262e5de3872b49578501187bbc2e3f..d6e14459215fbcf2ea8086b3ef82833384e45b4a 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -783,23 +783,33 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # This part only works if we have keep same number of connection points for both directions # Also only works with two connection direction at each city - for boarder in range(4): - opposite_boarder = (boarder + 2) % 4 - for track_id in range(len(inner_connection_points[current_city][boarder])): - if track_id % 2 == 0: - source = inner_connection_points[current_city][boarder][track_id] - for target in inner_connection_points[current_city][opposite_boarder]: - current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) - if target in all_outer_connection_points and source in \ - all_outer_connection_points and len(through_path_cells[current_city]) < 1: - through_path_cells[current_city].extend(current_track) - else: - source = inner_connection_points[current_city][opposite_boarder][track_id] - for target in inner_connection_points[current_city][boarder]: - current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) - if target in all_outer_connection_points and source in \ - all_outer_connection_points and len(through_path_cells[current_city]) < 1: - through_path_cells[current_city].extend(current_track) + for i in range(4): + if len(inner_connection_points[current_city][i]) > 0: + boarder = i + break + + opposite_boarder = (boarder + 2) % 4 + boarder_one = inner_connection_points[current_city][boarder] + boarder_two = inner_connection_points[current_city][opposite_boarder] + connect_cities(rail_trans, grid_map, boarder_one[0], boarder_one[-1]) + connect_cities(rail_trans, grid_map, boarder_two[0], boarder_two[-1]) + + for track_id in range(len(inner_connection_points[current_city][boarder])): + if track_id % 2 == 0: + source = inner_connection_points[current_city][boarder][track_id] + target = inner_connection_points[current_city][opposite_boarder][track_id] + current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) + if target in all_outer_connection_points and source in \ + all_outer_connection_points and len(through_path_cells[current_city]) < 1: + through_path_cells[current_city].extend(current_track) + else: + source = inner_connection_points[current_city][opposite_boarder][track_id] + target = inner_connection_points[current_city][boarder][track_id] + + current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) + if target in all_outer_connection_points and source in \ + all_outer_connection_points and len(through_path_cells[current_city]) < 1: + through_path_cells[current_city].extend(current_track) return through_path_cells