diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 45e823b78374d8a0b4ac130a8c707e874177ddda..1520e598de6a351f87367ccba2fc853a86ab64e7 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -34,10 +34,10 @@ env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) num_intersections=0, # Number of intersections (no start / target) - num_trainstations=100, # Number of possible start/targets on map + num_trainstations=50, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities/intersections + num_neighb=4, # Number of connections to other cities/intersections seed=15, # Random seed grid_mode=True, enhance_intersection=False diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index a780cfc9b260672e34d5e9c19d4b5cd57bd4db5e..e4a7e292b3c4358149e04ddfc1ed2733c8b68b3c 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -606,6 +606,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 allowed_connections = len(connection_points[current_node]) first_node = True i = 0 + boarder_connections = set() while len(node_stack) > 0: current_node = node_stack[0] delete_idx = np.where(available_nodes_full == current_node) @@ -666,6 +667,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if distance_on_rail(tmp_out_connection_point, neighb_connection_point) < center_distance: i += 1 connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point) + boarder_connections.add(tmp_out_connection_point) + boarder_connections.add(neighb_connection_point) node_stack.pop(0) @@ -712,11 +715,16 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Connect train station to random nodes - rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 3, replace=False) + rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, replace=False) + for corner_node_idx in rand_corner_nodes: connection = connect_nodes(rail_trans, grid_map, connection_points[trainstation_node][corner_node_idx], (station_x, station_y)) + if len(connection) != 0: + if connection_points[trainstation_node][corner_node_idx] in boarder_connections: + boarder_connections.remove(connection_points[trainstation_node][corner_node_idx]) + grid_map.fix_transitions((station_x, station_y)) # Check if connection was made if len(connection) == 0: @@ -724,7 +732,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 train_stations[trainstation_node].pop(-1) else: built_num_trainstation += 1 - # Adjust the number of agents if you could not build enough trainstations if num_agents > built_num_trainstation: num_agents = built_num_trainstation