diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 31eef79a195d42dbcc54279104eaa2e21fdf40bc..c7d97803d0262e5de3872b49578501187bbc2e3f 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -1,4 +1,5 @@ """Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" +import time import warnings from typing import Callable, Tuple, Optional, Dict, List, Any @@ -561,6 +562,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, intersection_positions = [] # Evenly distribute cities and intersections + node_time_start = time.time() node_positions: List[Any] = None nb_nodes = num_cities if grid_mode: @@ -570,24 +572,30 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode nb_nodes = len(node_positions) - + print("City position time", time.time() - node_time_start, "Seconds") # Set up connection points for all cities + node_connection_time = time.time() inner_connection_points, outer_connection_points, connection_info = _generate_node_connection_points( node_positions, node_radius, max_inter_city_rails_allowed, max_tracks_in_city) + print("Connection points", time.time() - node_connection_time) # Connect the cities through the connection points + city_connection_time = time.time() _connect_cities(node_positions, outer_connection_points, connection_info, city_cells, rail_trans, grid_map) - + print("City connection time", time.time() - city_connection_time) # Build inner cities + city_build_time = time.time() through_tracks = _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans, grid_map) - + print("City build time", time.time() - city_build_time) # Populate cities + train_station_time = time.time() train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, through_tracks, node_radius, grid_map) + print("Trainstation placing time", time.time() - train_station_time) # Adjust the number of agents if you could not build enough trainstations if num_agents > built_num_trainstation: @@ -595,11 +603,14 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") # Fix all transition elements + grid_fix_time = time.time() _fix_transitions(grid_map) + print("Grid fix time", time.time() - grid_fix_time) # Generate start target pairs + schedule_time = time.time() agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, nb_nodes, train_stations) - + print("Schedule time", time.time() - schedule_time) return grid_map, {'agents_hints': { 'num_agents': num_agents, 'agent_start_targets_nodes': agent_start_targets_nodes, @@ -616,8 +627,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, tries = 0 while to_close: - x_tmp = node_radius + 1 + np.random.randint(height - 2 * (node_radius - 1)) - y_tmp = node_radius + 1 + np.random.randint(width - 2 * (node_radius - 1)) + x_tmp = node_radius + 1 + np.random.randint(height - 2 * (node_radius + 1)) + y_tmp = node_radius + 1 + np.random.randint(width - 2 * (node_radius + 1)) to_close = False # Check distance to nodes for node_pos in node_positions: @@ -629,7 +640,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, city_cells.extend(_city_cells(node_positions[-1], node_radius)) tries += 1 - if tries > 1000: + if tries > 200: warnings.warn( "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format( len(node_positions), @@ -786,6 +797,9 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, source = inner_connection_points[current_city][opposite_boarder][track_id] for target in inner_connection_points[current_city][boarder]: current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) + if target in all_outer_connection_points and source in \ + all_outer_connection_points and len(through_path_cells[current_city]) < 1: + through_path_cells[current_city].extend(current_track) return through_path_cells