diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 40b415915305df0384b31ad30037ec14ec0e985d..4e97ff11c71f3a9061696c4a5321c45abf02b821 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -30,11 +30,11 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=40, height=40, - rail_generator=sparse_rail_generator(num_cities=8, # Number of cities in map (where train stations are) + rail_generator=sparse_rail_generator(max_num_cities=8, # Number of cities in map (where train stations are) seed=1, # Random seed grid_mode=False, - max_inter_city_rails=2, - max_tracks_in_city=4, + max_rails_between_cities=2, + max_rails_in_city=4, ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=20, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index d98b9644411ebc00d3a765ed32d11d0f22fdcdd6..b132a5caeaaf1cd9c144b48eafeb016763d8bb42 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -4,7 +4,6 @@ import warnings from typing import Callable, Tuple, Optional, Dict, List, Any import msgpack -import networkx as nx import numpy as np from flatland.core.grid.grid4_utils import get_direction, mirror @@ -534,88 +533,74 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener return generator -def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, max_tracks_in_city=4, - seed=0) -> RailGenerator: +def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4, + max_rails_in_city: int = 4, seed: int = 0) -> RailGenerator: """ Generates railway networks with cities and inner city rails - :param num_cities: Number of city centers in the map - :param grid_mode: Arange cities in a grid or randomly - :param max_inter_city_rails: Maximum number of connecting rails going out from a city - :param max_tracks_in_city: maximum number of internal rails + :param max_num_cities: Number of city centers in the map + :param grid_mode: arrange cities in a grid or randomly + :param max_rails_between_cities: Maximum number of connecting rails going out from a city + :param max_rails_in_city: maximum number of internal rails :param seed: Random seed to initiate rail :return: generator """ - G = nx.DiGraph() DEBUG_PRINT_TIMING = False - def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct: + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: + np.random.seed(seed + num_resets) rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) - rail_array = grid_map.grid - rail_array.fill(0) - np.random.seed(seed + num_resets) - # Graph to be able to create correct start/end pairs for schedule + city_radius = int(np.ceil((max_rails_in_city + 2) / 2.0)) + 1 - node_radius = int(np.ceil((max_tracks_in_city + 2) / 2.0)) + 1 - if 3 > max_tracks_in_city: - rail_in_city = 3 - else: - rail_in_city = max_tracks_in_city - max_inter_city_rails_allowed = max_inter_city_rails - if max_inter_city_rails_allowed > rail_in_city: - max_inter_city_rails_allowed = rail_in_city - # Generate a set of nodes for the sparse network - # Try to connect cities to nodes first - city_positions = [] - intersection_positions = [] - - # Evenly distribute cities and intersections + min_nr_rails_in_city = 3 + rails_in_city = min_nr_rails_in_city if max_rails_in_city < min_nr_rails_in_city else max_rails_in_city + rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities + + # Evenly distribute cities node_time_start = time.time() - node_positions: List[Any] = None - nb_nodes = num_cities if grid_mode: - node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, node_radius, height, width) + city_positions, city_cells = _generate_evenly_distr_city_positions(max_num_cities, city_radius, width, height) else: - node_positions, city_cells = _generate_random_node_positions(nb_nodes, node_radius, height, width) + city_positions, city_cells = _generate_random_city_positions(max_num_cities, city_radius, width, height) - # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode - nb_nodes = len(node_positions) + # reduce num_cities, _num_cities, _num_intersections if less were generated in not_grid_mode + num_cities = len(city_positions) if DEBUG_PRINT_TIMING: print("City position time", time.time() - node_time_start, "Seconds") + # Set up connection points for all cities node_connection_time = time.time() inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_node_connection_points( - node_positions, node_radius, max_inter_city_rails_allowed, - rail_in_city) + city_positions, city_radius, rails_between_cities, + rails_in_city) if DEBUG_PRINT_TIMING: print("Connection points", time.time() - node_connection_time) # Connect the cities through the connection points city_connection_time = time.time() - inter_city_lines = _connect_cities(node_positions, outer_connection_points, connection_info, city_cells, + inter_city_lines = _connect_cities(city_positions, outer_connection_points, connection_info, city_cells, rail_trans, grid_map) if DEBUG_PRINT_TIMING: print("City connection time", time.time() - city_connection_time) # Build inner cities city_build_time = time.time() - through_tracks, free_tracks = _build_inner_cities(node_positions, inner_connection_points, + through_tracks, free_tracks = _build_inner_cities(city_positions, inner_connection_points, outer_connection_points, - node_radius, + city_radius, rail_trans, grid_map) if DEBUG_PRINT_TIMING: print("City build time", time.time() - city_build_time) # Populate cities train_station_time = time.time() - train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, node_radius, free_tracks, + train_stations, built_num_trainstation = _set_trainstation_positions(city_positions, city_radius, free_tracks, grid_map) if DEBUG_PRINT_TIMING: print("Trainstation placing time", time.time() - train_station_time) - # Fix all transition elements grid_fix_time = time.time() _fix_transitions(city_cells, inter_city_lines, grid_map) @@ -624,7 +609,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # Generate start target pairs schedule_time = time.time() - agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, nb_nodes, train_stations, + agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, num_cities, train_stations, city_orientations) if DEBUG_PRINT_TIMING: print("Schedule time", time.time() - schedule_time) @@ -636,52 +621,50 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, 'city_orientations': city_orientations }} - def _generate_random_node_positions(nb_nodes, node_radius, height, width): + def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]): node_positions = [] city_cells = [] - for node_idx in range(nb_nodes): + for node_idx in range(num_cities): to_close = True tries = 0 while to_close: - x_tmp = node_radius + 1 + np.random.randint(height - 2 * (node_radius + 1)) - y_tmp = node_radius + 1 + np.random.randint(width - 2 * (node_radius + 1)) + x_tmp = city_radius + 1 + np.random.randint(height - 2 * (city_radius + 1)) + y_tmp = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1)) to_close = False # Check distance to nodes for node_pos in node_positions: - if _city_overlap((x_tmp, y_tmp), node_pos, 2 * (node_radius + 1) + 1): + if _city_overlap((x_tmp, y_tmp), node_pos, 2 * (city_radius + 1) + 1): to_close = True if not to_close: node_positions.append((x_tmp, y_tmp)) - city_cells.extend(_city_cells(node_positions[-1], node_radius)) + city_cells.extend(_city_cells(node_positions[-1], city_radius)) tries += 1 if tries > 200: warnings.warn( "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format( len(node_positions), - tries, nb_nodes)) + tries, num_cities)) break - G.add_node(node_idx) return node_positions, city_cells - def _generate_node_positions_grid_mode(nb_nodes, node_radius, height, width): + def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]): nodes_ratio = height / width - nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio))) - nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row)) - x_positions = np.linspace(node_radius + 1, height - node_radius - 2, nodes_per_row, dtype=int) - y_positions = np.linspace(node_radius + 1, width - node_radius - 2, nodes_per_col, dtype=int) + nodes_per_row = int(np.ceil(np.sqrt(num_cities * nodes_ratio))) + nodes_per_col = int(np.ceil(num_cities / nodes_per_row)) + x_positions = np.linspace(city_radius + 1, height - city_radius - 2, nodes_per_row, dtype=int) + y_positions = np.linspace(city_radius + 1, width - city_radius - 2, nodes_per_col, dtype=int) node_positions = [] city_cells = [] - for node_idx in range(nb_nodes): + for node_idx in range(num_cities): x_tmp = x_positions[node_idx % nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row] node_positions.append((x_tmp, y_tmp)) - city_cells.extend(_city_cells(node_positions[-1], node_radius)) - G.add_node(node_idx) + city_cells.extend(_city_cells(node_positions[-1], city_radius)) return node_positions, city_cells def _generate_node_connection_points(node_positions, node_size, max_inter_city_rails_allowed, tracks_in_city=2): @@ -745,7 +728,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, rail_trans, grid_map): """ Function to connect the different cities through their connection points - :param node_positions: Positions of city centers + :param city_positions: Positions of city centers :param connection_points: Boarder connection points of cities :param connection_info: Number of connection points per direction NESW :param rail_trans: Transitions @@ -778,9 +761,6 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, city_cells) - G.add_edge(current_node, neighb_idx, direction=out_direction, length=len(new_line)) - G.add_edge(neighb_idx, current_node, direction=neighbour_direction, length=len(new_line)) - all_paths.extend(new_line) return all_paths @@ -789,7 +769,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, grid_map): """ Builds inner city tracks. This current version connects all incoming connections to all outgoing connections - :param node_positions: Positions of the cities + :param city_positions: Positions of the cities :param inner_connection_points: Points on city boarder that are used to generate inner city track :param outer_connection_points: Points where the city is connected to neighboring cities :param rail_trans: @@ -833,7 +813,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, def _set_trainstation_positions(node_positions, node_radius, free_tracks, grid_map): """ - :param node_positions: + :param city_positions: :param num_trainstations: :return: """ @@ -906,8 +886,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, def _closest_neigh_in_direction(current_node, node_positions): """ Returns indices of closest neighbours in every direction NESW - :param current_node: Index of node in node_positions list - :param node_positions: list of all points being considered + :param current_node: Index of node in city_positions list + :param city_positions: list of all points being considered :return: list of index of closest neighbours in all directions """ node_dist = [] diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index fd6e7b88562e88f13774ff07fdad8b663fcd5ab5..b94de82c4f7bdcda355316e65350c2b8b56a4ebe 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -13,7 +13,7 @@ from flatland.utils.rendertools import RenderTool def test_sparse_rail_generator(): env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map num_intersections=10, # Number of interesections in map num_trainstations=50, # Number of possible start/targets on map min_node_dist=6, # Minimal distance of nodes @@ -733,7 +733,7 @@ def test_sparse_rail_generator_deterministic(): env = RailEnv(width=25, height=30, - rail_generator=sparse_rail_generator(num_cities=5, + rail_generator=sparse_rail_generator(max_num_cities=5, # Number of cities in map (where train stations are) num_intersections=4, # Number of intersections (no start / target) @@ -1509,7 +1509,7 @@ def test_rail_env_action_required_info(): 1. / 4.: 0.25} # Slow freight train env_always_action = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map num_intersections=10, # Number of interesections in map num_trainstations=50, @@ -1528,7 +1528,7 @@ def test_rail_env_action_required_info(): np.random.seed(0) env_only_if_action_required = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map num_intersections=10, # Number of interesections in map num_trainstations=50, @@ -1592,7 +1592,7 @@ def test_rail_env_malfunction_speed_info(): } env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map num_intersections=10, # Number of interesections in map num_trainstations=50, @@ -1640,7 +1640,7 @@ def test_sparse_generator_with_too_man_cities_does_not_break_down(): RailEnv(width=50, height=50, rail_generator=sparse_rail_generator( - num_cities=100, # Number of cities in map + max_num_cities=100, # Number of cities in map num_intersections=10, # Number of interesections in map num_trainstations=50, # Number of possible start/targets on map min_node_dist=6, # Minimal distance of nodes diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index cc61150325ee86710ea3ee3820d8386c7b926da6..16f993a8373e8fb059c9b21af8e19fc06d53c517 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -165,7 +165,7 @@ def test_initial_malfunction(): env = RailEnv(width=25, height=30, - rail_generator=sparse_rail_generator(num_cities=5, + rail_generator=sparse_rail_generator(max_num_cities=5, # Number of cities in map (where train stations are) num_intersections=4, # Number of intersections (no start / target) @@ -247,7 +247,7 @@ def test_initial_malfunction_stop_moving(): env = RailEnv(width=25, height=30, - rail_generator=sparse_rail_generator(num_cities=5, + rail_generator=sparse_rail_generator(max_num_cities=5, # Number of cities in map (where train stations are) num_intersections=4, # Number of intersections (no start / target) @@ -339,7 +339,7 @@ def test_initial_malfunction_do_nothing(): env = RailEnv(width=25, height=30, - rail_generator=sparse_rail_generator(num_cities=5, + rail_generator=sparse_rail_generator(max_num_cities=5, # Number of cities in map (where train stations are) num_intersections=4, # Number of intersections (no start / target) @@ -430,7 +430,7 @@ def test_initial_nextmalfunction_not_below_zero(): env = RailEnv(width=25, height=30, - rail_generator=sparse_rail_generator(num_cities=5, + rail_generator=sparse_rail_generator(max_num_cities=5, # Number of cities in map (where train stations are) num_intersections=4, # Number of intersections (no start / target) diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index 7213560f9e9873ea4488b96d30223bab8128b37b..7f8f62c0f028a9a679ba100ce749841ea96ba1ce 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -23,7 +23,7 @@ def test_get_global_observation(): env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=25, + rail_generator=sparse_rail_generator(max_num_cities=25, # Number of cities in map (where train stations are) num_intersections=10, # Number of intersections (no start / target)