From 336348635a7213bcae3389fa22e5127086e4dce2 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Wed, 25 Sep 2019 08:59:53 -0400 Subject: [PATCH] connection updated for disjunct nodes --- examples/flatland_2_0_example.py | 2 +- flatland/envs/rail_generators.py | 33 +++++++++++++++++++++++++++----- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 1520e598..2dc6efa4 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -34,7 +34,7 @@ env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) num_intersections=0, # Number of intersections (no start / target) - num_trainstations=50, # Number of possible start/targets on map + num_trainstations=10, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center num_neighb=4, # Number of connections to other cities/intersections diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index e4a7e292..937b8c88 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -667,8 +667,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if distance_on_rail(tmp_out_connection_point, neighb_connection_point) < center_distance: i += 1 connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point) - boarder_connections.add(tmp_out_connection_point) - boarder_connections.add(neighb_connection_point) + boarder_connections.add((tmp_out_connection_point, current_node)) + boarder_connections.add((neighb_connection_point, neighb)) node_stack.pop(0) @@ -722,10 +722,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 connection_points[trainstation_node][corner_node_idx], (station_x, station_y)) if len(connection) != 0: - if connection_points[trainstation_node][corner_node_idx] in boarder_connections: - boarder_connections.remove(connection_points[trainstation_node][corner_node_idx]) + if (connection_points[trainstation_node][corner_node_idx], + trainstation_node) in boarder_connections: + boarder_connections.remove( + (connection_points[trainstation_node][corner_node_idx], trainstation_node)) - grid_map.fix_transitions((station_x, station_y)) # Check if connection was made if len(connection) == 0: if len(train_stations[trainstation_node]) > 0: @@ -737,8 +738,30 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 num_agents = built_num_trainstation warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") + # Connect all disjunct parts of the network + + if len(boarder_connections) > 0: + to_be_deleted = [] + for disjunct_node in boarder_connections: + print(disjunct_node) + conn = connect_nodes(rail_trans, grid_map, + disjunct_node[0], + train_stations[disjunct_node[1]][0]) + if len(conn) > 0: + to_be_deleted.append(disjunct_node) + + for tbd in to_be_deleted: + boarder_connections.remove(tbd) + print(boarder_connections) # Fix all nodes with illegal transition maps + flat_trainstation_list = [item for sublist in train_stations for item in sublist] + for cell_to_fix in flat_trainstation_list: + grid_map.fix_transitions(cell_to_fix) + + grid_map.fix_transitions((station_x, station_y)) + flat_list = [item for sublist in connection_points for item in sublist] + for cell_to_fix in flat_list: grid_map.fix_transitions(cell_to_fix) -- GitLab