diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 98ae3995b4acd57235db9d613c25c42f02a5f2ff..3938a01d1235921fab66c765edb227c027331b4e 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -684,8 +684,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ number_of_out_rails = np.random.randint(1, min(rails_between_cities, nr_of_connection_points) + 1) start_idx = int((nr_of_connection_points - number_of_out_rails) / 2) for direction in range(4): - connection_slots = np.arange(connections_per_direction[direction]) - int( - connections_per_direction[direction] / 2) + connection_slots = nr_of_connection_points - start_idx for connection_idx in range(connections_per_direction[direction]): if direction == 0: tmp_coordinates = ( @@ -700,7 +699,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ tmp_coordinates = ( city_position[0] + connection_slots[connection_idx], city_position[1] - city_radius) connection_points_coordinates_inner[direction].append(tmp_coordinates) - if connection_idx in range(start_idx, start_idx + number_of_out_rails + 1): + if connection_idx in range(start_idx, start_idx + number_of_out_rails): connection_points_coordinates_outer[direction].append(tmp_coordinates) inner_connection_points.append(connection_points_coordinates_inner)