From a481cdaa52987ae6e953e27de9fcd28463c36ba3 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Mon, 30 Sep 2019 14:48:50 -0400 Subject: [PATCH] updated inner city construciton algorithm --- examples/flatland_2_0_example.py | 2 +- flatland/envs/grid4_generators_utils.py | 2 +- flatland/envs/rail_generators.py | 44 +++++++++++++++---------- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 91e3c6ce..5c92fd1c 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -31,7 +31,7 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=20, # Number of cities in map (where train stations are) - seed=0, # Random seed + seed=1, # Random seed grid_mode=False, max_inter_city_rails=2, max_tracks_in_city=4, diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 166094aa..d0e75e27 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -90,7 +90,7 @@ def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, def connect_cities(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, forbidden_cells, + start: IntVector2D, end: IntVector2D, forbidden_cells=None, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function, forbidden_cells) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index c7d97803..d6e14459 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -783,23 +783,33 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # This part only works if we have keep same number of connection points for both directions # Also only works with two connection direction at each city - for boarder in range(4): - opposite_boarder = (boarder + 2) % 4 - for track_id in range(len(inner_connection_points[current_city][boarder])): - if track_id % 2 == 0: - source = inner_connection_points[current_city][boarder][track_id] - for target in inner_connection_points[current_city][opposite_boarder]: - current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) - if target in all_outer_connection_points and source in \ - all_outer_connection_points and len(through_path_cells[current_city]) < 1: - through_path_cells[current_city].extend(current_track) - else: - source = inner_connection_points[current_city][opposite_boarder][track_id] - for target in inner_connection_points[current_city][boarder]: - current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) - if target in all_outer_connection_points and source in \ - all_outer_connection_points and len(through_path_cells[current_city]) < 1: - through_path_cells[current_city].extend(current_track) + for i in range(4): + if len(inner_connection_points[current_city][i]) > 0: + boarder = i + break + + opposite_boarder = (boarder + 2) % 4 + boarder_one = inner_connection_points[current_city][boarder] + boarder_two = inner_connection_points[current_city][opposite_boarder] + connect_cities(rail_trans, grid_map, boarder_one[0], boarder_one[-1]) + connect_cities(rail_trans, grid_map, boarder_two[0], boarder_two[-1]) + + for track_id in range(len(inner_connection_points[current_city][boarder])): + if track_id % 2 == 0: + source = inner_connection_points[current_city][boarder][track_id] + target = inner_connection_points[current_city][opposite_boarder][track_id] + current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) + if target in all_outer_connection_points and source in \ + all_outer_connection_points and len(through_path_cells[current_city]) < 1: + through_path_cells[current_city].extend(current_track) + else: + source = inner_connection_points[current_city][opposite_boarder][track_id] + target = inner_connection_points[current_city][boarder][track_id] + + current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) + if target in all_outer_connection_points and source in \ + all_outer_connection_points and len(through_path_cells[current_city]) < 1: + through_path_cells[current_city].extend(current_track) return through_path_cells -- GitLab