diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 2aa148999d2398ac6ab0696388f06bf4e93b8fab..27508c222868448d53755012b3ef2ce0da501232 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -838,7 +838,10 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation np.random.seed(seed + num_resets) # Generate a set of nodes for the sparse network + # Try to connect cities to nodes first node_positions = [] + city_positions = [] + intersection_positions = [] for node_idx in range(num_cities + num_intersections): to_close = True tries = 0 @@ -851,21 +854,34 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation to_close = True if not to_close: node_positions.append((x_tmp, y_tmp)) + if node_idx < num_cities: + city_positions.append((x_tmp, y_tmp)) + else: + intersection_positions.append((x_tmp, y_tmp)) tries += 1 if tries > 100: warnings.warn("Could not set nodes, please change initial parameters!!!!") break # Chose node connection - available_nodes = np.arange(num_cities + num_intersections) + available_nodes_full = np.arange(num_cities + num_intersections) + available_cities = np.arange(num_cities) + available_intersections = np.arange(num_cities, num_cities + num_intersections) current_node = 0 node_stack = [current_node] while len(node_stack) > 0: current_node = node_stack[0] - delete_idx = np.where(available_nodes == current_node) - available_nodes = np.delete(available_nodes, delete_idx, 0) - + delete_idx = np.where(available_nodes_full == current_node) + available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) + if current_node < num_cities and len(available_intersections) > 0: + available_nodes = available_intersections + available_cities = np.delete(available_cities, delete_idx, 0) + elif len(available_intersections) > 0: + available_nodes = available_cities + available_intersections = np.delete(available_intersections, delete_idx, 0) + else: + available_nodes = available_nodes_full # Sort available neighbors according to their distance. node_dist = [] for av_node in available_nodes: @@ -885,30 +901,36 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation node_stack.append(neighb) connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) node_stack.pop(0) + # Place train stations close to the node # We currently place them uniformly distirbuted among all cities - train_stations = [[] for i in range(num_cities)] - - for station in range(num_trainstations): - trainstation_node = int(station / num_trainstations * num_cities) - - station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), 0, - height - 1) - station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), 0, - width - 1) - while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[ - trainstation_node] or \ - rail_array[(station_x, station_y)] != 0: + if num_cities > 1: + train_stations = [[] for i in range(num_cities)] + + for station in range(num_trainstations): + trainstation_node = int(station / num_trainstations * num_cities) + station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), 0, height - 1) station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), 0, width - 1) - train_stations[trainstation_node].append((station_x, station_y)) - - # Connect train station to the correct node - connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y)) + while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[ + trainstation_node] or \ + rail_array[(station_x, station_y)] != 0: + station_x = np.clip( + node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), + 0, + height - 1) + station_y = np.clip( + node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), + 0, + width - 1) + train_stations[trainstation_node].append((station_x, station_y)) + + # Connect train station to the correct node + connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y)) # Fix all nodes with illegal transition maps for current_node in node_positions: diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index c4430e6e4d962d9011e2c973e8f342f09f647c38..b20754be9fdc76ee57c25da4fe211ecdceb55d04 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -1,3 +1,5 @@ +import time + import numpy as np from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator @@ -23,17 +25,17 @@ def test_realistic_rail_generator(): def test_sparse_rail_generator(): env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map - num_intersections=3, # Number of interesections in map + rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map + num_intersections=2, # Number of interesections in map num_trainstations=10, # Number of possible start/targets on map min_node_dist=10, # Minimal distance of nodes node_radius=2, # Proximity of stations to city center - num_neighb=4, # Number of connections to other cities + num_neighb=2, # Number of connections to other cities seed=15, # Random seed ), - number_of_agents=1, + number_of_agents=0, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) - + time.sleep(5)