diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 5af4a079b1b6d210a395896df5d577c1c7a16267..a16fb6018a6354665a44c1b44cafd6975bb4e680 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -572,7 +572,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row)) x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) - city_idx = np.random.choice(np.arange(nb_nodes), num_cities) + city_idx = np.random.choice(np.arange(nb_nodes), num_cities, False) node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes, @@ -666,7 +666,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 0, width - 1) tries = 0 - while (station_x, station_y) in train_stations \ + while (station_x, station_y) in train_stations[trainstation_node] \ or (station_x, station_y) == node_positions[trainstation_node] \ or rail_array[(station_x, station_y)] != 0: # noqa: E125