From 5a13cff6e2a9e50cb48270b085ab678a051e2c9f Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Tue, 24 Sep 2019 18:32:39 -0400
Subject: [PATCH] updated inner city connections

---
 flatland/envs/rail_generators.py | 23 +++++++++++------------
 1 file changed, 11 insertions(+), 12 deletions(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index a62f0184..d1acc464 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -9,7 +9,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import distance_on_rail
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes
+from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes
 
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
@@ -691,9 +691,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                     0,
                     width - 1)
                 tries = 0
-                while (station_x, station_y) in train_stations[trainstation_node] \
-                    or (station_x, station_y) == node_positions[trainstation_node] \
-                    or rail_array[(station_x, station_y)] != 0:  # noqa: E125
+                while (station_x, station_y) in train_stations[trainstation_node]:
 
                     station_x = np.clip(
                         node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius,
@@ -717,14 +715,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                 # Connect train station to random nodes
 
                 rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, replace=False)
-                connection_1 = connect_from_nodes(rail_trans, grid_map,
-                                                  connection_points[trainstation_node][rand_corner_nodes[0]],
-                                                (station_x, station_y))
-                connection_2 = connect_from_nodes(rail_trans, grid_map,
-                                                  connection_points[trainstation_node][rand_corner_nodes[1]],
-                                                  (station_x, station_y))
+                for corner_node_idx in rand_corner_nodes:
+                    connection = connect_nodes(rail_trans, grid_map,
+                                               connection_points[trainstation_node][corner_node_idx],
+                                               (station_x, station_y))
+                grid_map.fix_transitions((station_x, station_y))
                 # Check if connection was made
-                if len(connection_1) == 0 and len(connection_2) == 0:
+                if len(connection) == 0:
                     if len(train_stations[trainstation_node]) > 0:
                         train_stations[trainstation_node].pop(-1)
                 else:
@@ -777,7 +774,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                 num_agents -= 1
 
         return grid_map, {'agents_hints': {
-            'num_agents': num_agents
+            'num_agents': num_agents,
+            'agent_start_targets_nodes': agent_start_targets_nodes,
+            'train_stations': train_stations_slots
         }}
 
     def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes,
-- 
GitLab