diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 2dc6efa467dcd3a8d1856a968d43b5b4f4dc4b09..0d86561f0b2d3317f88dc25ef8d8fb824ade74d1 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -34,7 +34,7 @@ env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) num_intersections=0, # Number of intersections (no start / target) - num_trainstations=10, # Number of possible start/targets on map + num_trainstations=15, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center num_neighb=4, # Number of connections to other cities/intersections diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 937b8c88feb086627a0fa14d4c73d98e7bac8d9f..62aa8fb7aa276f15e69a749aec0a4159d0978740 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -743,12 +743,22 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if len(boarder_connections) > 0: to_be_deleted = [] for disjunct_node in boarder_connections: - print(disjunct_node) - conn = connect_nodes(rail_trans, grid_map, - disjunct_node[0], - train_stations[disjunct_node[1]][0]) + if len(train_stations[disjunct_node[1]]) > 0: + conn = connect_nodes(rail_trans, grid_map, + disjunct_node[0], + train_stations[disjunct_node[1]][-1]) + else: + conn = connect_nodes(rail_trans, grid_map, + disjunct_node[0], + node_positions[disjunct_node[1]]) if len(conn) > 0: to_be_deleted.append(disjunct_node) + else: + conn = connect_nodes(rail_trans, grid_map, + disjunct_node[0], + node_positions[disjunct_node[1]]) + if len(conn) > 0: + to_be_deleted.append(disjunct_node) for tbd in to_be_deleted: boarder_connections.remove(tbd) @@ -758,8 +768,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 for cell_to_fix in flat_trainstation_list: grid_map.fix_transitions(cell_to_fix) - grid_map.fix_transitions((station_x, station_y)) - flat_list = [item for sublist in connection_points for item in sublist] for cell_to_fix in flat_list: