From 4570db2d89ef887c353c70162c9ef5553ac8df69 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 31 Aug 2019 19:24:57 -0400 Subject: [PATCH] dixing level generator city distribution --- examples/flatland_2_0_example.py | 10 +++++----- flatland/envs/rail_generators.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 082308e6..caa6e80c 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -30,18 +30,18 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map (where train stations are) - num_intersections=15, # Number of intersections (no start / target) - num_trainstations=50, # Number of possible start/targets on map + rail_generator=sparse_rail_generator(num_cities=30, # Number of cities in map (where train stations are) + num_intersections=5, # Number of intersections (no start / target) + num_trainstations=20, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center + node_radius=2, # Proximity of stations to city center num_neighb=3, # Number of connections to other cities/intersections seed=15, # Random seed realistic_mode=True, enhance_intersection=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), - number_of_agents=20, + number_of_agents=40, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 40ec2e0d..17db869e 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -573,9 +573,14 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) + fraction = 0 + city_fraction = num_cities / tot_num_node + step = np.gcd(num_intersections, num_cities) / tot_num_node for node_idx in range(num_cities + num_intersections): to_close = True tries = 0 + + fraction = (fraction + step) % 1. if not realistic_mode: while to_close: x_tmp = node_radius + np.random.randint(height - node_radius) @@ -587,7 +592,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: to_close = True - # CHeck distance to intersections + # Check distance to intersections for node_pos in intersection_positions: if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: to_close = True @@ -605,11 +610,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 else: x_tmp = x_positions[node_idx % nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row] - if len(city_positions) < num_cities and (node_idx % (tot_num_node // num_cities)) == 0: + if len(city_positions) < num_cities and fraction < city_fraction: city_positions.append((x_tmp, y_tmp)) else: intersection_positions.append((x_tmp, y_tmp)) - + print(len(city_positions)) node_positions = city_positions + intersection_positions # Chose node connection @@ -627,6 +632,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 current_node = node_stack[0] delete_idx = np.where(available_nodes_full == current_node) available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) + # Priority city to intersection connections if current_node < num_cities and len(available_intersections) > 0: available_nodes = available_intersections -- GitLab