From 5a13cff6e2a9e50cb48270b085ab678a051e2c9f Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Tue, 24 Sep 2019 18:32:39 -0400 Subject: [PATCH] updated inner city connections --- flatland/envs/rail_generators.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index a62f0184..d1acc464 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -9,7 +9,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes +from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] @@ -691,9 +691,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 0, width - 1) tries = 0 - while (station_x, station_y) in train_stations[trainstation_node] \ - or (station_x, station_y) == node_positions[trainstation_node] \ - or rail_array[(station_x, station_y)] != 0: # noqa: E125 + while (station_x, station_y) in train_stations[trainstation_node]: station_x = np.clip( node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius, @@ -717,14 +715,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Connect train station to random nodes rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, replace=False) - connection_1 = connect_from_nodes(rail_trans, grid_map, - connection_points[trainstation_node][rand_corner_nodes[0]], - (station_x, station_y)) - connection_2 = connect_from_nodes(rail_trans, grid_map, - connection_points[trainstation_node][rand_corner_nodes[1]], - (station_x, station_y)) + for corner_node_idx in rand_corner_nodes: + connection = connect_nodes(rail_trans, grid_map, + connection_points[trainstation_node][corner_node_idx], + (station_x, station_y)) + grid_map.fix_transitions((station_x, station_y)) # Check if connection was made - if len(connection_1) == 0 and len(connection_2) == 0: + if len(connection) == 0: if len(train_stations[trainstation_node]) > 0: train_stations[trainstation_node].pop(-1) else: @@ -777,7 +774,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 num_agents -= 1 return grid_map, {'agents_hints': { - 'num_agents': num_agents + 'num_agents': num_agents, + 'agent_start_targets_nodes': agent_start_targets_nodes, + 'train_stations': train_stations_slots }} def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes, -- GitLab