From b66688da63765dadabb7d9027c4cb8f87ed11928 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Mon, 30 Sep 2019 14:20:25 -0400 Subject: [PATCH] updated limits for city center positioning --- examples/flatland_2_0_example.py | 4 ++-- flatland/envs/rail_generators.py | 33 +++++++++++++++++--------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 961e36a8..91e3c6ce 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -28,8 +28,8 @@ speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train -env = RailEnv(width=100, - height=100, +env = RailEnv(width=50, + height=50, rail_generator=sparse_rail_generator(num_cities=20, # Number of cities in map (where train stations are) seed=0, # Random seed grid_mode=False, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 7ea93715..31eef79a 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -616,8 +616,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, tries = 0 while to_close: - x_tmp = node_radius + 1 + np.random.randint(height - 2 * node_radius - 1) - y_tmp = node_radius + 1 + np.random.randint(width - 2 * node_radius - 1) + x_tmp = node_radius + 1 + np.random.randint(height - 2 * (node_radius - 1)) + y_tmp = node_radius + 1 + np.random.randint(width - 2 * (node_radius - 1)) to_close = False # Check distance to nodes for node_pos in node_positions: @@ -770,19 +770,22 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, sublist] city_boarder = _city_boarder(node_positions[current_city], node_radius) - random_boarders = np.random.choice(np.arange(4), 4, False) - # TODO: Only look at the relevant boarders (Only two at the moment) - for boarder in random_boarders: - for source in inner_connection_points[current_city][boarder]: - for other_boarder in random_boarders: - if boarder != other_boarder and len(inner_connection_points[current_city][other_boarder]) > 0: - for target in inner_connection_points[current_city][other_boarder]: - current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) - if target in all_outer_connection_points and source in \ - all_outer_connection_points and len(through_path_cells[current_city]) < 1: - through_path_cells[current_city].extend(current_track) - else: - continue + # This part only works if we have keep same number of connection points for both directions + # Also only works with two connection direction at each city + for boarder in range(4): + opposite_boarder = (boarder + 2) % 4 + for track_id in range(len(inner_connection_points[current_city][boarder])): + if track_id % 2 == 0: + source = inner_connection_points[current_city][boarder][track_id] + for target in inner_connection_points[current_city][opposite_boarder]: + current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) + if target in all_outer_connection_points and source in \ + all_outer_connection_points and len(through_path_cells[current_city]) < 1: + through_path_cells[current_city].extend(current_track) + else: + source = inner_connection_points[current_city][opposite_boarder][track_id] + for target in inner_connection_points[current_city][boarder]: + current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) return through_path_cells -- GitLab