From 65844b60aabe2040b267a50848e2a91e3a62bdcc Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Sat, 28 Sep 2019 16:11:25 -0400 Subject: [PATCH] extended size of rails to get more possible train station positions --- examples/flatland_2_0_example.py | 2 +- flatland/envs/rail_generators.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 4a614b2f..2d393fd4 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -34,7 +34,7 @@ env = RailEnv(width=50, seed=0, # Random seed grid_mode=False, max_inter_city_rails=2, - max_tracks_in_city=8, + max_tracks_in_city=4, ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 74d17b83..725943a4 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -551,7 +551,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, rail_array = grid_map.grid rail_array.fill(0) np.random.seed(seed + num_resets) - node_radius = int(np.ceil(max_tracks_in_city / 2)) + node_radius = int(np.ceil(max_tracks_in_city / 2)) + 2 max_inter_city_rails_allowed = max_inter_city_rails if max_inter_city_rails_allowed > max_tracks_in_city: max_inter_city_rails_allowed = max_tracks_in_city @@ -616,8 +616,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, tries = 0 while to_close: - x_tmp = node_radius + 1 + np.random.randint(height - 2 * node_radius) - y_tmp = node_radius + 1 + np.random.randint(width - 2 * node_radius) + x_tmp = node_radius + 1 + np.random.randint(height - 2 * node_radius - 1) + y_tmp = node_radius + 1 + np.random.randint(width - 2 * node_radius - 1) to_close = False # Check distance to nodes for node_pos in node_positions: @@ -806,6 +806,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, if 1 <= nbits <= 2: built_num_trainstations += 1 train_stations[current_city].append(possible_location) + print(built_num_trainstations) return train_stations, built_num_trainstations def _generate_start_target_pairs(num_agents, nb_nodes, train_stations): -- GitLab