diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 2a6d90b02de79ca40c09bdcd0c80888386b4c8f8..3bc31626fd017e0e8ad0b33165c3cc90cdd6b8d0 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -37,7 +37,7 @@ env = RailEnv(width=50, max_tracks_in_city=4, ), schedule_generator=sparse_schedule_generator(), - number_of_agents=50, + number_of_agents=15, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 08982fc6e9786d980d95e820a22d472149d3b632..0bdf5165ac8ccecbd3b6f373f07fc1bef5895442 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -4,6 +4,7 @@ import warnings from typing import Callable, Tuple, Optional, Dict, List, Any import msgpack +import networkx as nx import numpy as np from flatland.core.grid.grid4_utils import get_direction, mirror @@ -544,6 +545,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, :param seed: Random seed to initiate rail :return: generator """ + G = nx.Graph() def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct: @@ -552,10 +554,19 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, rail_array = grid_map.grid rail_array.fill(0) np.random.seed(seed + num_resets) + + # Graph to be able to create correct start/end pairs for schedule + + + node_radius = int(np.ceil((max_tracks_in_city + 2) / 2.0)) + 1 + if 3 > max_tracks_in_city: + rail_in_city = 3 + else: + rail_in_city = 3 max_inter_city_rails_allowed = max_inter_city_rails - if max_inter_city_rails_allowed > max_tracks_in_city: - max_inter_city_rails_allowed = max_tracks_in_city + if max_inter_city_rails_allowed > rail_in_city: + max_inter_city_rails_allowed = rail_in_city # Generate a set of nodes for the sparse network # Try to connect cities to nodes first city_positions = [] @@ -577,7 +588,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, node_connection_time = time.time() inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_node_connection_points( node_positions, node_radius, max_inter_city_rails_allowed, - max_tracks_in_city) + rail_in_city) print("Connection points", time.time() - node_connection_time) # Connect the cities through the connection points @@ -651,7 +662,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, len(node_positions), tries, nb_nodes)) break - + G.add_node(node_idx) return node_positions, city_cells def _generate_node_positions_grid_mode(nb_nodes, node_radius, height, width): @@ -667,6 +678,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, y_tmp = y_positions[node_idx // nodes_per_row] node_positions.append((x_tmp, y_tmp)) city_cells.extend(_city_cells(node_positions[-1], node_radius)) + G.add_node(node_idx) return node_positions, city_cells def _generate_node_connection_points(node_positions, node_size, max_inter_city_rails_allowed, tracks_in_city=2): @@ -691,7 +703,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, city_orientations.append(current_closest_direction) # set the number of tracks within a city, at least 2 tracks per city connections_per_direction = np.zeros(4, dtype=int) - nr_of_connection_points = np.random.randint(2, tracks_in_city + 1) + nr_of_connection_points = np.random.randint(3, tracks_in_city + 1) for idx in connection_sides_idx: connections_per_direction[idx] = nr_of_connection_points connection_points_coordinates_inner = [[] for i in range(4)] @@ -735,6 +747,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, :return: """ all_paths = [] + for current_node in np.arange(len(node_positions)): direction = 0 connected_to_city = [] @@ -767,6 +780,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, neighb_connection_point = tmp_in_connection_point new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, city_cells) + G.add_edge(current_node, neighb_idx) all_paths.extend(new_line) direction += 1 @@ -899,6 +913,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, node_available_start[start_node] -= 1 node_available_target[target_node] -= 1 agent_start_targets_nodes.append((start_node, target_node)) + print(agent_idx, "has connection", nx.astar_path(G, start_node, target_node)) else: num_agents -= 1 return agent_start_targets_nodes, num_agents diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index eb56cc56c434a6b53aff08b98a780a5195806319..86a41af0ec5ef83ec7d306b25a420d2cac32f31a 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -106,10 +106,14 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> # Orient the agent correctly if current_track_nbr % 2 != 0: - agents_direction.append(city_orientations[current_start_node]) + current_orientation = city_orientations[current_start_node] + agents_direction.append(current_orientation) else: - agents_direction.append((city_orientations[current_start_node] + 2) % 2) + current_orientation = (city_orientations[current_start_node] + 2) % 2 + agents_direction.append(current_orientation) + if not rail.check_path_exists(start, current_orientation, target): + print("No path") if speed_ratio_map: speeds = speed_initialization_helper(num_agents, speed_ratio_map)