From dae7c424b94380d78195fafbc50a14f7f5c6165d Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 31 Aug 2019 19:54:30 -0400 Subject: [PATCH] minor bugfixes --- examples/flatland_2_0_example.py | 8 ++++---- flatland/envs/rail_generators.py | 15 ++++++++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index caa6e80c..a99009a5 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -30,18 +30,18 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=30, # Number of cities in map (where train stations are) + rail_generator=sparse_rail_generator(num_cities=20, # Number of cities in map (where train stations are) num_intersections=5, # Number of intersections (no start / target) - num_trainstations=20, # Number of possible start/targets on map + num_trainstations=15, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=2, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities/intersections + num_neighb=4, # Number of connections to other cities/intersections seed=15, # Random seed realistic_mode=True, enhance_intersection=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), - number_of_agents=40, + number_of_agents=10, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 17db869e..d915fdc7 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -573,14 +573,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) - fraction = 0 - city_fraction = num_cities / tot_num_node - step = np.gcd(num_intersections, num_cities) / tot_num_node + fraction = 0 + city_fraction = num_cities / tot_num_node + step = np.gcd(num_intersections, num_cities) / tot_num_node + for node_idx in range(num_cities + num_intersections): to_close = True tries = 0 - fraction = (fraction + step) % 1. + if not realistic_mode: while to_close: x_tmp = node_radius + np.random.randint(height - node_radius) @@ -608,13 +609,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 warnings.warn("Could not set nodes, please change initial parameters!!!!") break else: + fraction = (fraction + step) % 1. x_tmp = x_positions[node_idx % nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row] if len(city_positions) < num_cities and fraction < city_fraction: city_positions.append((x_tmp, y_tmp)) else: intersection_positions.append((x_tmp, y_tmp)) - print(len(city_positions)) node_positions = city_positions + intersection_positions # Chose node connection @@ -634,13 +635,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) # Priority city to intersection connections - if current_node < num_cities and len(available_intersections) > 0: + if False and current_node < num_cities and len(available_intersections) > 0: available_nodes = available_intersections delete_idx = np.where(available_cities == current_node) available_cities = np.delete(available_cities, delete_idx, 0) # Priority intersection to city connections - elif current_node >= num_cities and len(available_cities) > 0: + elif False and current_node >= num_cities and len(available_cities) > 0: available_nodes = available_cities delete_idx = np.where(available_intersections == current_node) available_intersections = np.delete(available_intersections, delete_idx, 0) -- GitLab