From 36d8f9e02ab4d045f7d6d4308c15ceffb2e675bb Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Mon, 19 Aug 2019 08:36:54 -0400 Subject: [PATCH] added realistic_mode for less random levels --- flatland/envs/generators.py | 58 +++++++++++++------ ...test_flatland_env_sparse_rail_generator.py | 14 ++--- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 2826ee51..dad07a60 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -813,8 +813,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstations=2, min_node_dist=20, node_radius=2, - num_neighb=4, - seed=0): + num_neighb=4, realistic_mode=False, seed=0): ''' :param nr_train_stations: @@ -843,26 +842,46 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation city_positions = [] intersection_positions = [] + # Evenly distribute cities and intersections + if realistic_mode: + tot_num_node = num_intersections + num_cities + nodes_ratio = height / width + nodes_per_row = int(np.ceil(np.sqrt(tot_num_node * nodes_ratio))) + nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row)) + x_positions = np.linspace(2, height - 2, nodes_per_row, dtype=int) + y_positions = np.linspace(2, width - 2, nodes_per_col, dtype=int) for node_idx in range(num_cities + num_intersections): to_close = True tries = 0 - while to_close: - x_tmp = 1 + np.random.randint(height - 2) - y_tmp = 1 + np.random.randint(width - 2) - to_close = False - for node_pos in node_positions: - if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: - to_close = True - if not to_close: - node_positions.append((x_tmp, y_tmp)) - if node_idx < num_cities: - city_positions.append((x_tmp, y_tmp)) - else: - intersection_positions.append((x_tmp, y_tmp)) - tries += 1 - if tries > 100: - warnings.warn("Could not set nodes, please change initial parameters!!!!") - break + if not realistic_mode: + while to_close: + x_tmp = 1 + np.random.randint(height - 2) + y_tmp = 1 + np.random.randint(width - 2) + to_close = False + for node_pos in node_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + if not to_close: + node_positions.append((x_tmp, y_tmp)) + if node_idx < num_cities: + city_positions.append((x_tmp, y_tmp)) + else: + intersection_positions.append((x_tmp, y_tmp)) + tries += 1 + if tries > 100: + warnings.warn("Could not set nodes, please change initial parameters!!!!") + break + else: + x_tmp = x_positions[node_idx % nodes_per_row] + y_tmp = y_positions[node_idx // nodes_per_row] + if len(city_positions) < num_cities and (node_idx % (tot_num_node // num_cities)) == 0: + city_positions.append((x_tmp, y_tmp)) + else: + intersection_positions.append((x_tmp, y_tmp)) + + if realistic_mode: + node_positions = city_positions + intersection_positions + # Chose node connection available_nodes_full = np.arange(num_cities + num_intersections) available_cities = np.arange(num_cities) @@ -886,6 +905,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation available_intersections = np.delete(available_intersections, delete_idx, 0) else: available_nodes = available_nodes_full + # Sort available neighbors according to their distance. node_dist = [] for av_node in available_nodes: diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index 92744080..74513b72 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -23,17 +23,17 @@ def test_realistic_rail_generator(): def test_sparse_rail_generator(): - env = RailEnv(width=20, - height=20, - rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map - num_intersections=2, # Number of interesections in map - num_trainstations=20, # Number of possible start/targets on map + env = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + num_intersections=10, # Number of interesections in map + num_trainstations=50, # Number of possible start/targets on map min_node_dist=6, # Minimal distance of nodes node_radius=3, # Proximity of stations to city center - num_neighb=2, # Number of connections to other cities + num_neighb=4, # Number of connections to other cities seed=5, # Random seed ), - number_of_agents=1, + number_of_agents=45, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) -- GitLab