From 1d832003ea008c94855ff2e53d8dbbcee342d226 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Wed, 25 Sep 2019 09:15:29 -0400 Subject: [PATCH] connection updated for disjunct nodes --- examples/flatland_2_0_example.py | 2 +- flatland/envs/rail_generators.py | 20 ++++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 2dc6efa4..0d86561f 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -34,7 +34,7 @@ env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) num_intersections=0, # Number of intersections (no start / target) - num_trainstations=10, # Number of possible start/targets on map + num_trainstations=15, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center num_neighb=4, # Number of connections to other cities/intersections diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 937b8c88..62aa8fb7 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -743,12 +743,22 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if len(boarder_connections) > 0: to_be_deleted = [] for disjunct_node in boarder_connections: - print(disjunct_node) - conn = connect_nodes(rail_trans, grid_map, - disjunct_node[0], - train_stations[disjunct_node[1]][0]) + if len(train_stations[disjunct_node[1]]) > 0: + conn = connect_nodes(rail_trans, grid_map, + disjunct_node[0], + train_stations[disjunct_node[1]][-1]) + else: + conn = connect_nodes(rail_trans, grid_map, + disjunct_node[0], + node_positions[disjunct_node[1]]) if len(conn) > 0: to_be_deleted.append(disjunct_node) + else: + conn = connect_nodes(rail_trans, grid_map, + disjunct_node[0], + node_positions[disjunct_node[1]]) + if len(conn) > 0: + to_be_deleted.append(disjunct_node) for tbd in to_be_deleted: boarder_connections.remove(tbd) @@ -758,8 +768,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 for cell_to_fix in flat_trainstation_list: grid_map.fix_transitions(cell_to_fix) - grid_map.fix_transitions((station_x, station_y)) - flat_list = [item for sublist in connection_points for item in sublist] for cell_to_fix in flat_list: -- GitLab