diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 44a086b770fb3b276ab87e825f7d10802a86904b..53105ab15dbbbf911be688d05b0cad4fcc4a82ef 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -36,12 +36,12 @@ env = RailEnv(width=50, num_trainstations=100, # Number of possible start/targets on map min_node_dist=10, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center - num_neighb=2, # Number of connections to other cities/intersections + num_neighb=3, # Number of connections to other cities/intersections seed=15, # Random seed grid_mode=True, nr_parallel_tracks=2, - connectin_points_per_side=5, - max_nr_connection_directions=2, + connectin_points_per_side=2, + max_nr_connection_directions=3, ), schedule_generator=sparse_schedule_generator(), number_of_agents=50, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 33fc408d67e52557aa89b5f8c8526557dd608b2a..fdf6e2167fe1d7e4f698e5fdf5223d92ed26ec60 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -591,22 +591,154 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode nb_nodes = len(node_positions) - _num_cities = len(city_positions) - - # Chose node connection - # Set up list of available nodes to connect to - available_nodes = np.arange(nb_nodes) # Set up connection points for all cities - connection_points = _generate_node_connection_points(node_positions, node_radius, connectin_points_per_side, - max_nr_connection_directions) + connection_points, connection_info = _generate_node_connection_points(node_positions, node_radius, + connectin_points_per_side, + max_nr_connection_directions) + + # Connect the cities through the connection points + _connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map) + + # Build inner cities + train_stations, built_num_trainstation = _build_cities(node_positions, connection_points, rail_trans, grid_map) + + # Adjust the number of agents if you could not build enough trainstations + if num_agents > built_num_trainstation: + num_agents = built_num_trainstation + warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") + + # Fix all transition elements + _fix_transitions(grid_map) + + # Generate start target paris + agent_start_targets_nodes = _generate_start_target_pairs(num_agents, nb_nodes, train_stations) + + return grid_map, {'agents_hints': { + 'num_agents': num_agents, + 'agent_start_targets_nodes': agent_start_targets_nodes, + 'train_stations': train_stations + }} + + def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes, + width): + + node_positions = [] + for node_idx in range(nb_nodes): + to_close = True + tries = 0 + + while to_close: + x_tmp = node_radius + np.random.randint(height - 2 * node_radius - 1) + y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1) + to_close = False + + # Check distance to cities + for node_pos in city_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + + # Check distance to intersections + for node_pos in intersection_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + + if not to_close: + node_positions.append((x_tmp, y_tmp)) + if node_idx < num_cities: + city_positions.append((x_tmp, y_tmp)) + else: + intersection_positions.append((x_tmp, y_tmp)) + tries += 1 + if tries > 100: + warnings.warn( + "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format( + len(node_positions), + tries, nb_nodes)) + break + + node_positions = city_positions + intersection_positions + return node_positions + + def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes, + nodes_per_row, x_positions, y_positions): + + for node_idx in range(nb_nodes): + + x_tmp = x_positions[node_idx % nodes_per_row] + y_tmp = y_positions[node_idx // nodes_per_row] + if node_idx in city_idx: + city_positions.append((x_tmp, y_tmp)) + + else: + intersection_positions.append((x_tmp, y_tmp)) + node_positions = city_positions + intersection_positions + return node_positions + + def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2, + max_nr_connection_directions=2): + connection_points = [] + connection_info = [] + for node_position in node_positions: + + connection_sides_idx = np.sort( + np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False)) + + # Chose the directions where close cities are situated + neighb_dist = [] + for neighb_node in node_positions: + neighb_dist.append(distance_on_rail(node_position, neighb_node)) + closest_neighb_idx = argsort(neighb_dist) + + # Store the directions to these neighbours + connection_sides_idx = [] + for idx in range(1, max_nr_connection_directions + 1): + connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]])) + + # set the number of connection points for each direction + connections_per_direction = np.zeros(4, dtype=int) + + for idx in connection_sides_idx: + connections_per_direction[idx] = max_nr_connection_points + connection_points_coordinates = [] + + for direction in range(4): + connection_slots = np.arange(connections_per_direction[direction]) - int( + connections_per_direction[direction] / 2) + for connection_idx in range(connections_per_direction[direction]): + if direction == 0: + connection_points_coordinates.append( + (node_position[0] - node_size, node_position[1] + connection_slots[connection_idx])) + if direction == 1: + connection_points_coordinates.append( + (node_position[0] + connection_slots[connection_idx], node_position[1] + node_size)) + if direction == 2: + connection_points_coordinates.append( + (node_position[0] + node_size, node_position[1] + connection_slots[connection_idx])) + if direction == 3: + connection_points_coordinates.append( + (node_position[0] + connection_slots[connection_idx], node_position[1] - node_size)) + + connection_points.append(connection_points_coordinates) + connection_info.append(connections_per_direction) + return connection_points, connection_info + + def _connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map): + """ + Function to connect the different cities through their connection points + :param node_positions: Positions of city centers + :param connection_points: Boarder connection points of cities + :param connection_info: Number of connection points per direction NESW + :param rail_trans: Transitions + :param grid_map: Grid map + :return: + """ # Start at some node + available_nodes = np.arange(len(node_positions)) current_node = np.random.randint(len(available_nodes)) node_stack = [current_node] open_nodes = np.copy(available_nodes) - allowed_connections = num_neighb - i = 0 boarder_connections = set() while len(open_nodes) > 0: if len(node_stack) > 0: @@ -623,7 +755,9 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node])) available_nodes = available_nodes[np.argsort(node_dist)] - # Set number of neighboring nodes + # Set number of neighboring + allowed_connections = np.count_nonzero(connection_info[current_node]) + if len(available_nodes) >= allowed_connections: connected_neighb_idx = available_nodes[1:allowed_connections + 1] else: @@ -649,16 +783,18 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n if tmp_dist < min_connection_dist: min_connection_dist = tmp_dist neighb_connection_point = tmp_in_connection_point - i += 1 connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point) boarder_connections.add((tmp_out_connection_point, current_node)) boarder_connections.add((neighb_connection_point, neighb)) node_stack.pop(0) + def _build_cities(node_positions, connection_points, rail_trans, grid_map): # Place train stations close to the node # We currently place them uniformly distributed among all cities built_num_trainstation = 0 + nb_nodes = len(node_positions) + height, width = np.shape(grid_map.grid) train_stations = [[] for i in range(nb_nodes)] if nb_nodes > 1: @@ -705,12 +841,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n connection = connect_nodes(rail_trans, grid_map, connection_points[trainstation_node][corner_node_idx], (station_x, station_y)) - if len(connection) != 0: - if (connection_points[trainstation_node][corner_node_idx], - trainstation_node) in boarder_connections: - boarder_connections.remove( - (connection_points[trainstation_node][corner_node_idx], trainstation_node)) - # Check if connection was made if len(connection) == 0: if len(train_stations[trainstation_node]) > 0: @@ -718,40 +848,16 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n else: built_num_trainstation += 1 - # Adjust the number of agents if you could not build enough trainstations - if num_agents > built_num_trainstation: - num_agents = built_num_trainstation - warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") - - # Connect all disjunct parts of the network - - if len(boarder_connections) > 0: - to_be_deleted = [] - for disjunct_node in boarder_connections: - if len(train_stations[disjunct_node[1]]) > 0: - conn = connect_nodes(rail_trans, grid_map, - disjunct_node[0], - train_stations[disjunct_node[1]][-1]) - else: - conn = connect_nodes(rail_trans, grid_map, - disjunct_node[0], - node_positions[disjunct_node[1]]) - if len(conn) > 0: - to_be_deleted.append(disjunct_node) - else: - conn = connect_nodes(rail_trans, grid_map, - disjunct_node[0], - node_positions[disjunct_node[1]]) - if len(conn) > 0: - to_be_deleted.append(disjunct_node) - - for tbd in to_be_deleted: - boarder_connections.remove(tbd) - print(boarder_connections) + return train_stations, built_num_trainstation + def _fix_transitions(grid_map): + """ + Function to fix all transition elements in environment + """ # Fix all nodes with illegal transition maps empty_to_fix = [] rails_to_fix = [] + height, width = np.shape(grid_map.grid) for r in range(height): for c in range(width): rc_pos = (r, c) @@ -770,6 +876,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n for cell in rails_to_fix: grid_map.fix_transitions(cell) + def _generate_start_target_pairs(num_agents, nb_nodes, train_stations): + # Generate start and target node directory for all agents. # Assure that start and target are not in the same node agent_start_targets_nodes = [] @@ -805,116 +913,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n agent_start_targets_nodes.append((start_node, target_node)) else: num_agents -= 1 - - return grid_map, {'agents_hints': { - 'num_agents': num_agents, - 'agent_start_targets_nodes': agent_start_targets_nodes, - 'train_stations': train_stations - }} - - def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes, - width): - - node_positions = [] - for node_idx in range(nb_nodes): - to_close = True - tries = 0 - - while to_close: - x_tmp = node_radius + np.random.randint(height - 2 * node_radius - 1) - y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1) - to_close = False - - # Check distance to cities - for node_pos in city_positions: - if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: - to_close = True - - # Check distance to intersections - for node_pos in intersection_positions: - if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: - to_close = True - - if not to_close: - node_positions.append((x_tmp, y_tmp)) - if node_idx < num_cities: - city_positions.append((x_tmp, y_tmp)) - else: - intersection_positions.append((x_tmp, y_tmp)) - tries += 1 - if tries > 100: - warnings.warn( - "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format( - len(node_positions), - tries, nb_nodes)) - break - - node_positions = city_positions + intersection_positions - return node_positions - - def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes, - nodes_per_row, x_positions, y_positions): - - for node_idx in range(nb_nodes): - - x_tmp = x_positions[node_idx % nodes_per_row] - y_tmp = y_positions[node_idx // nodes_per_row] - if node_idx in city_idx: - city_positions.append((x_tmp, y_tmp)) - - else: - intersection_positions.append((x_tmp, y_tmp)) - node_positions = city_positions + intersection_positions - return node_positions - - def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2, - max_nr_connection_directions=2): - connection_points = [] - for node_position in node_positions: - - connection_sides_idx = np.sort( - np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False)) - - # Chose the directions where close cities are situated - neighb_dist = [] - for neighb_node in node_positions: - neighb_dist.append(distance_on_rail(node_position, neighb_node)) - closest_neighb_idx = argsort(neighb_dist) - connection_sides_idx = [] - for idx in range(1, max_nr_connection_directions + 1): - connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]])) - - connections_per_direction = np.zeros(4, dtype=int) - # set the number of connection points for each direction - for idx in connection_sides_idx: - connections_per_direction[idx] = max_nr_connection_points - connection_points_coordinates = [] - random_connection_slots = False - for direction in range(4): - if random_connection_slots: - connection_slots = np.random.choice(np.arange(-node_size, node_size), - size=connections_per_direction[direction], - replace=False) - else: - connection_slots = np.arange(connections_per_direction[direction]) - int( - connections_per_direction[direction] / 2) - for connection_idx in range(connections_per_direction[direction]): - if direction == 0: - connection_points_coordinates.append( - (node_position[0] - node_size, node_position[1] + connection_slots[connection_idx])) - if direction == 1: - connection_points_coordinates.append( - (node_position[0] + connection_slots[connection_idx], node_position[1] + node_size)) - if direction == 2: - connection_points_coordinates.append( - (node_position[0] + node_size, node_position[1] + connection_slots[connection_idx])) - if direction == 3: - connection_points_coordinates.append( - (node_position[0] + connection_slots[connection_idx], node_position[1] - node_size)) - - connection_points.append(connection_points_coordinates) - return connection_points - + return agent_start_targets_nodes def argsort(seq): # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python return sorted(range(len(seq)), key=seq.__getitem__)