From 71aea64e9bc5fb0b5db3d3ca48f3b54b0e1a07fd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mattias=20Ljungstr=C3=B6m?= <ml@mljx.io>
Date: Thu, 9 May 2019 12:23:50 +0200
Subject: [PATCH] complex rail gen: added extra connections

---
 examples/play_model.py      |  2 +-
 flatland/envs/generators.py | 24 ++++++++++++++++++++++--
 2 files changed, 23 insertions(+), 3 deletions(-)

diff --git a/examples/play_model.py b/examples/play_model.py
index e69b312..7f92cb3 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -97,7 +97,7 @@ def main(render=True, delay=0.0):
 
     # Example generate a random rail
     env = RailEnv(width=15, height=15,
-                  rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=12),
+                  rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
                   number_of_agents=5)
 
     if render:
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 4f356e1..7452d32 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -9,7 +9,7 @@ from flatland.envs.env_utils import distance_on_rail, connect_rail, get_directio
 from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail
 
 
-def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
+def complex_rail_generator(nr_start_goal=1, nr_extra=10, min_dist=2, max_dist=99999, seed=0):
     """
     Parameters
     -------
@@ -123,7 +123,27 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
                 # print("failed...")
                 created_sanity += 1
 
-        print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs")
+        # add extra connections between existing rail
+        created_sanity = 0
+        nr_created = 0
+        while nr_created < nr_extra and created_sanity < sanity_max:
+            all_ok = False
+            for _ in range(sanity_max):
+                start = (np.random.randint(0, width), np.random.randint(0, height))
+                goal = (np.random.randint(0, height), np.random.randint(0, height))
+                # check to make sure start,goal pos are not empty
+                if rail_array[goal] == 0 or rail_array[start] == 0:
+                    continue
+                else:
+                    all_ok = True
+                    break
+            if not all_ok:
+                break
+            new_path = connect_rail(rail_trans, rail_array, start, goal)
+            if len(new_path) >= 2:
+                nr_created += 1
+
+        print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections")
         # print(start_goal)
 
         agents_position = [sg[0] for sg in start_goal]
-- 
GitLab