From 474d13e494379368936af55b6a1006670e9bfa65 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Tue, 20 Aug 2019 16:11:34 -0400
Subject: [PATCH] removed constraint on start location

---
 flatland/core/transition_map.py |  1 -
 flatland/envs/generators.py     | 22 +++++++++++++++++++---
 2 files changed, 19 insertions(+), 4 deletions(-)

diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 048593d9..2aff675d 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -437,7 +437,6 @@ class GridTransitionMap(TransitionMap):
         number_of_incoming = np.sum(incomping_connections)
         # Only one incoming direction --> Straight line
         if number_of_incoming == 1:
-
             for direction in range(4):
                 if incomping_connections[direction] > 0:
                     self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index bf9679f9..37779d22 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1032,6 +1032,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                 connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb])
             node_stack.pop(0)
 
+
         # Place train stations close to the node
         # We currently place them uniformly distirbuted among all cities
         if num_cities > 1:
@@ -1046,6 +1047,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                 station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
                                     0,
                                     width - 1)
+                tries = 0
                 while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[
                         trainstation_node] or rail_array[(station_x, station_y)] != 0:
                     station_x = np.clip(
@@ -1056,6 +1058,10 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                         node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
                         0,
                         width - 1)
+                    tries += 1
+                    if tries > 100:
+                        warnings.warn("Could not set trainstations, please change initial parameters!!!!")
+                        break
                 train_stations[trainstation_node].append((station_x, station_y))
 
                 # Connect train station to the correct node
@@ -1122,8 +1128,10 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
                 tries += 1
                 # Test again with new start node if no pair is found (This code needs to be improved)
                 if tries > 10:
-                    break
                     start_node = np.random.choice(avail_start_nodes)
+                if tries > 100:
+                    warnings.warn("Could not set trainstations, please change initial parameters!!!!")
+                    break
 
             node_available_start[start_node] -= 1
             node_available_target[target_node] -= 1
@@ -1141,18 +1149,26 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
             current_target_node = agent_start_targets_nodes[agent_idx][1]
             target_station_idx = np.random.randint(len(train_stations[current_target_node]))
             target = train_stations[current_target_node][target_station_idx]
+            tries = 0
             while (target[0], target[1]) in agents_target:
                 target_station_idx = np.random.randint(len(train_stations[current_target_node]))
                 target = train_stations[current_target_node][target_station_idx]
-
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set target position, please change initial parameters!!!!")
+                    break
             agents_target.append((target[0], target[1]))
 
             # Set start for agent
             current_start_node = agent_start_targets_nodes[agent_idx][0]
             start_station_idx = np.random.randint(len(train_stations[current_start_node]))
             start = train_stations[current_start_node][start_station_idx]
-
+            tries = 0
             while (start[0], start[1]) in agents_position:
+                tries += 1
+                if tries > 100:
+                    warnings.warn("Could not set start position, please change initial parameters!!!!")
+                    break
                 start_station_idx = np.random.randint(len(train_stations[current_start_node]))
                 start = train_stations[current_start_node][start_station_idx]
 
-- 
GitLab