diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 1520e598de6a351f87367ccba2fc853a86ab64e7..2dc6efa467dcd3a8d1856a968d43b5b4f4dc4b09 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -34,7 +34,7 @@ env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) num_intersections=0, # Number of intersections (no start / target) - num_trainstations=50, # Number of possible start/targets on map + num_trainstations=10, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center num_neighb=4, # Number of connections to other cities/intersections diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index e4a7e292b3c4358149e04ddfc1ed2733c8b68b3c..937b8c88feb086627a0fa14d4c73d98e7bac8d9f 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -667,8 +667,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if distance_on_rail(tmp_out_connection_point, neighb_connection_point) < center_distance: i += 1 connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point) - boarder_connections.add(tmp_out_connection_point) - boarder_connections.add(neighb_connection_point) + boarder_connections.add((tmp_out_connection_point, current_node)) + boarder_connections.add((neighb_connection_point, neighb)) node_stack.pop(0) @@ -722,10 +722,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 connection_points[trainstation_node][corner_node_idx], (station_x, station_y)) if len(connection) != 0: - if connection_points[trainstation_node][corner_node_idx] in boarder_connections: - boarder_connections.remove(connection_points[trainstation_node][corner_node_idx]) + if (connection_points[trainstation_node][corner_node_idx], + trainstation_node) in boarder_connections: + boarder_connections.remove( + (connection_points[trainstation_node][corner_node_idx], trainstation_node)) - grid_map.fix_transitions((station_x, station_y)) # Check if connection was made if len(connection) == 0: if len(train_stations[trainstation_node]) > 0: @@ -737,8 +738,30 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 num_agents = built_num_trainstation warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") + # Connect all disjunct parts of the network + + if len(boarder_connections) > 0: + to_be_deleted = [] + for disjunct_node in boarder_connections: + print(disjunct_node) + conn = connect_nodes(rail_trans, grid_map, + disjunct_node[0], + train_stations[disjunct_node[1]][0]) + if len(conn) > 0: + to_be_deleted.append(disjunct_node) + + for tbd in to_be_deleted: + boarder_connections.remove(tbd) + print(boarder_connections) # Fix all nodes with illegal transition maps + flat_trainstation_list = [item for sublist in train_stations for item in sublist] + for cell_to_fix in flat_trainstation_list: + grid_map.fix_transitions(cell_to_fix) + + grid_map.fix_transitions((station_x, station_y)) + flat_list = [item for sublist in connection_points for item in sublist] + for cell_to_fix in flat_list: grid_map.fix_transitions(cell_to_fix)