diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 2d393fd43db84042760b2b1bf9f0455b9d53d7b2..961e36a88a0f992b3f1997628c7741e94bd597b3 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -28,16 +28,16 @@ speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train -env = RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator(num_cities=12, # Number of cities in map (where train stations are) +env = RailEnv(width=100, + height=100, + rail_generator=sparse_rail_generator(num_cities=20, # Number of cities in map (where train stations are) seed=0, # Random seed grid_mode=False, max_inter_city_rails=2, max_tracks_in_city=4, ), schedule_generator=sparse_schedule_generator(), - number_of_agents=10, + number_of_agents=50, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 83be49b140720258b2a08400928fd24cbc95d0da..7ea9371527952eb28fdd2656f641cc593d62f36d 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -551,7 +551,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, rail_array = grid_map.grid rail_array.fill(0) np.random.seed(seed + num_resets) - node_radius = int(np.ceil(max_tracks_in_city / 2)) + 2 + node_radius = int(np.ceil((max_tracks_in_city + 2) / 2.0)) + 2 max_inter_city_rails_allowed = max_inter_city_rails if max_inter_city_rails_allowed > max_tracks_in_city: max_inter_city_rails_allowed = max_tracks_in_city @@ -768,13 +768,15 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, for current_city in range(len(node_positions)): all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in sublist] + city_boarder = _city_boarder(node_positions[current_city], node_radius) - for boarder in range(4): + random_boarders = np.random.choice(np.arange(4), 4, False) + # TODO: Only look at the relevant boarders (Only two at the moment) + for boarder in random_boarders: for source in inner_connection_points[current_city][boarder]: - for other_boarder in range(4): + for other_boarder in random_boarders: if boarder != other_boarder and len(inner_connection_points[current_city][other_boarder]) > 0: for target in inner_connection_points[current_city][other_boarder]: - city_boarder = _city_boarder(node_positions[current_city], node_radius) current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) if target in all_outer_connection_points and source in \ all_outer_connection_points and len(through_path_cells[current_city]) < 1: