From ef5ad8e95f540f5d92078715da9e699070ecc197 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Wed, 25 Sep 2019 18:17:26 -0400
Subject: [PATCH] updated tests to reflect changes to sparse_rail_generator

---
 flatland/envs/rail_generators.py                  | 13 ++++++-------
 tests/test_flatland_envs_sparse_rail_generator.py |  3 +--
 2 files changed, 7 insertions(+), 9 deletions(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 266d2bec..de780402 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -628,8 +628,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             else:
                 connected_neighb_idx = available_nodes
 
-            print(current_node, connected_neighb_idx)
-
             # Connect to the neighboring nodes
             for neighb in connected_neighb_idx:
                 if neighb not in node_stack and neighb in open_nodes:
@@ -660,13 +658,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         # Place train stations close to the node
         # We currently place them uniformly distributed among all cities
         built_num_trainstation = 0
-        train_stations = [[] for i in range(_num_cities)]
-        if _num_cities > 1:
+        train_stations = [[] for i in range(nb_nodes)]
+        if nb_nodes > 1:
 
             for station in range(num_trainstations):
                 spot_found = True
                 reduced_node_radius = node_radius - 1
-                trainstation_node = int(station / num_trainstations * _num_cities)
+                trainstation_node = int(station / num_trainstations * nb_nodes)
 
                 station_x = np.clip(
                     node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius, reduced_node_radius),
@@ -778,12 +776,12 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         # Slot availability in node
         node_available_start = []
         node_available_target = []
-        for node_idx in range(_num_cities):
+        for node_idx in range(nb_nodes):
             node_available_start.append(len(train_stations[node_idx]))
             node_available_target.append(len(train_stations[node_idx]))
 
         # Assign agents to slots
-        for agent_idx in range(num_agents):
+        for agent_idx in range(nb_nodes):
             avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
             avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
             start_node = np.random.choice(avail_start_nodes)
@@ -905,4 +903,5 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
     def argsort(seq):
         # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
         return sorted(range(len(seq)), key=seq.__getitem__)
+
     return generator
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index e1647524..fd6e7b88 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -743,8 +743,7 @@ def test_sparse_rail_generator_deterministic():
                                                        num_neighb=3,
                                                        # Number of connections to other cities/intersections
                                                        seed=215545,  # Random seed
-                                                       grid_mode=True,
-                                                       enhance_intersection=False
+                                                       grid_mode=True
                                                        ),
                   schedule_generator=sparse_schedule_generator(speed_ration_map),
                   number_of_agents=1,
-- 
GitLab