From a85b1df0c609f425a19434708ee68852b3ab9662 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Fri, 23 Aug 2019 17:34:50 -0400
Subject: [PATCH] added stability measures for trainstation setting and filling

---
 flatland/envs/generators.py | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index e8b4f1fe..f79c5600 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1190,6 +1190,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             start_node = np.random.choice(avail_start_nodes)
             target_node = np.random.choice(avail_target_nodes)
             tries = 0
+            found_agent_pair = True
             while target_node == start_node:
                 target_node = np.random.choice(avail_target_nodes)
                 tries += 1
@@ -1197,13 +1198,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                 if (tries + 1) % 10 == 0:
                     start_node = np.random.choice(avail_start_nodes)
                 if tries > 100:
-                    warnings.warn("Could not set trainstations, please change initial parameters!!!!")
+                    warnings.warn("Could not set trainstations, removing agent!")
+                    found_agent_pair = False
                     break
-
-            node_available_start[start_node] -= 1
-            node_available_target[target_node] -= 1
-
-            agent_start_targets_nodes.append((start_node, target_node))
+            if found_agent_pair:
+                node_available_start[start_node] -= 1
+                node_available_target[target_node] -= 1
+                agent_start_targets_nodes.append((start_node, target_node))
+            else:
+                num_agents -= 1
 
         # Place agents and targets within available train stations
         agents_position = []
@@ -1211,7 +1214,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         agents_direction = []
 
         for agent_idx in range(num_agents):
-
             # Set target for agent
             current_target_node = agent_start_targets_nodes[agent_idx][1]
             target_station_idx = np.random.randint(len(train_stations[current_target_node]))
@@ -1222,7 +1224,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                 target = train_stations[current_target_node][target_station_idx]
                 tries += 1
                 if tries > 100:
-                    warnings.warn("Could not set target position, please change initial parameters!!!!")
+                    warnings.warn("Could not set target position, removing an agent")
                     break
             agents_target.append((target[0], target[1]))
 
-- 
GitLab