diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 621f2c797be255cf8530d5d079932ed5daee0db4..92542c9919c1578793a92860bb8f9ccc8b6c1e56 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -30,9 +30,9 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=8, # Number of cities in map (where train stations are) + rail_generator=sparse_rail_generator(num_cities=12, # Number of cities in map (where train stations are) seed=1, # Random seed - grid_mode=False, + grid_mode=True, max_inter_city_rails=2, max_tracks_in_city=4, ), diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 337b301a1c9005adaff311dd442310386d287b21..94e6ddb6ec4f1575c9ab5a5a09e94576d4cef38d 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -692,7 +692,10 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # Store the directions to these neighbours and orient city to face closest neighbour connection_sides_idx = [] idx = 1 - current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]]) + if grid_mode: + current_closest_direction = np.random.randint(4) + else: + current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]]) connection_sides_idx.append(current_closest_direction) connection_sides_idx.append((current_closest_direction + 2) % 4) city_orientations.append(current_closest_direction)