From 3e902906f2d1bf5928cf8dd0d395ee2510b10d4b Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Wed, 14 Aug 2019 18:32:40 -0400
Subject: [PATCH] checking that targets don't end up on a connecting rail

---
 flatland/envs/generators.py                   | 44 ++++++++-----------
 ...test_flatland_env_sparse_rail_generator.py |  6 +--
 2 files changed, 22 insertions(+), 28 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 6d1205b3..81ee9a92 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -693,7 +693,7 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
     return generator
 
 
-def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2,
+def sparse_rail_generator(nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2,
                           seed=0):
     '''
 
@@ -706,10 +706,6 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, mi
     '''
 
     def generator(width, height, num_agents, num_resets=0):
-
-        if num_agents > nr_train_stations:
-            num_agents = nr_train_stations
-            warnings.warn("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         rail_array = grid_map.grid
@@ -780,26 +776,24 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, mi
                                height - 1)
             target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
                                width - 1)
-            if agent_idx == 0:
-                agents_position.append((start_x, start_y))
-                agents_target.append((target_x, target_y))
-            else:
-                # Make sure we don't put to starts or targets on same cell
-                while (start_x, start_y) in agents_position or (target_x, target_y) in agents_target:
-                    start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius),
-                                      0,
-                                      height - 1)
-                    start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius),
-                                      0,
-                                      width - 1)
-                    target_x = np.clip(
-                        node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
-                        height - 1)
-                    target_y = np.clip(
-                        node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
-                        width - 1)
-                agents_position.append((start_x, start_y))
-                agents_target.append((target_x, target_y))
+            # Make sure we don't put to starts or targets on same cell
+            while (start_x, start_y) in agents_position or (start_x, start_y) == node_positions[start_target[0]]:
+                start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius),
+                                  0,
+                                  height - 1)
+                start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius),
+                                  0,
+                                  width - 1)
+            while (target_x, target_y) in agents_target or (target_x, target_y) == node_positions[start_target[1]] or \
+                rail_array[(target_x, target_y)] != 0:
+                target_x = np.clip(
+                    node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
+                    height - 1)
+                target_y = np.clip(
+                    node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
+                    width - 1)
+            agents_position.append((start_x, start_y))
+            agents_target.append((target_x, target_y))
             new_path = connect_to_nodes(rail_trans, rail_array, agents_position[agent_idx],
                                         node_positions[start_target[0]])
             new_path = connect_from_nodes(rail_trans, rail_array, node_positions[start_target[1]],
diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py
index b64aaa64..54bee175 100644
--- a/tests/test_flatland_env_sparse_rail_generator.py
+++ b/tests/test_flatland_env_sparse_rail_generator.py
@@ -10,9 +10,9 @@ def test_sparse_rail_generator():
 
     env = RailEnv(width=20,
                   height=20,
-                  rail_generator=sparse_rail_generator(nr_train_stations=3, nr_nodes=2, min_node_dist=5,
-                                                       node_radius=4),
-                  number_of_agents=3,
+                  rail_generator=sparse_rail_generator(nr_nodes=5, min_node_dist=8,
+                                                       node_radius=3),
+                  number_of_agents=10,
                   obs_builder_object=GlobalObsForRailEnv())
     # reset to initialize agents_static
     env_renderer = RenderTool(env, gl="PILSVG", )
-- 
GitLab