diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 898f6a9906e998e70ac41f0833e57a4df2fbcac1..ca9346fc71a8d5485d7b71098dc91558d09bd929 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -32,16 +32,16 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map (where train stations are) - num_intersections=4, # Number of intersections (no start / target) + rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) num_trainstations=100, # Number of possible start/targets on map min_node_dist=10, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center num_neighb=2, # Number of connections to other cities/intersections seed=15, # Random seed grid_mode=True, - nr_inter_connections=2, - max_nr_connection_points=12 + nr_parallel_tracks=2, + connectin_points_per_side=3, + max_nr_connection_directions=2, ), schedule_generator=sparse_schedule_generator(), number_of_agents=50, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 147b1bd14a5943f87f4d78f37e9ea7e1722c9087..117862a71593f8483928323574d8977898cfdc9c 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -528,8 +528,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener return generator -def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2, min_node_dist=20, node_radius=2, - num_neighb=3, nr_inter_connections=2, grid_mode=False, max_nr_connection_points=4, +def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, node_radius=2, + num_neighb=3, nr_parallel_tracks=2, grid_mode=False, connectin_points_per_side=4, + max_nr_connection_directions=2, seed=0) -> RailGenerator: """ This is a level generator which generates complex sparse rail configurations @@ -566,7 +567,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Evenly distribute cities and intersections node_positions: List[Any] = None - nb_nodes = num_cities + num_intersections + nb_nodes = num_cities if grid_mode: nodes_ratio = height / width nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio))) @@ -591,14 +592,14 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode nb_nodes = len(node_positions) _num_cities = len(city_positions) - _num_intersections = len(intersection_positions) # Chose node connection # Set up list of available nodes to connect to available_nodes = np.arange(nb_nodes) # Set up connection points for all cities - connection_points = _generate_node_connection_points(node_positions, node_radius, max_nr_connection_points) + connection_points = _generate_node_connection_points(node_positions, node_radius, connectin_points_per_side, + max_nr_connection_directions) # Start at some node current_node = np.random.randint(len(available_nodes)) @@ -639,8 +640,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb]) connection_distances.append(tmp_dist_to_node) possible_connection_points = argsort(connection_distances) - for sort_idx in possible_connection_points[:nr_inter_connections]: - # Find closes connection point + for sort_idx in possible_connection_points[:nr_parallel_tracks]: + # Find closest connection point tmp_out_connection_point = connection_points[current_node][sort_idx] min_connection_dist = np.inf for tmp_in_connection_point in connection_points[neighb]: @@ -705,7 +706,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 connection_points[trainstation_node][corner_node_idx], (station_x, station_y)) if len(connection) != 0: - if (connection_points[trainstation_node][corner_node_idx], trainstation_node) in boarder_connections: boarder_connections.remove( @@ -748,6 +748,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 for tbd in to_be_deleted: boarder_connections.remove(tbd) print(boarder_connections) + # Fix all nodes with illegal transition maps empty_to_fix = [] rails_to_fix = [] @@ -866,24 +867,28 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 node_positions = city_positions + intersection_positions return node_positions - def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2): + def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2, + max_nr_connection_directions=2): connection_points = [] for node_position in node_positions: - n_connection_points = max_nr_connection_points # np.random.randint(1, max_nr_connection_points) - connection_per_direction = n_connection_points // 4 - connection_point_vector = [connection_per_direction, connection_per_direction, connection_per_direction, - n_connection_points - 3 * connection_per_direction] + connection_sides_idx = np.sort( + np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False)) + connections_per_direction = np.zeros(4, dtype=int) + + # set the number of connection points for each direction + for idx in connection_sides_idx: + connections_per_direction[idx] = max_nr_connection_points connection_points_coordinates = [] random_connection_slots = False for direction in range(4): if random_connection_slots: connection_slots = np.random.choice(np.arange(-node_size, node_size), - size=connection_point_vector[direction], + size=connections_per_direction[direction], replace=False) else: - connection_slots = np.arange(connection_point_vector[direction]) - int( - connection_point_vector[direction] / 2) - for connection_idx in range(connection_point_vector[direction]): + connection_slots = np.arange(connections_per_direction[direction]) - int( + connections_per_direction[direction] / 2) + for connection_idx in range(connections_per_direction[direction]): if direction == 0: connection_points_coordinates.append( (node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]))