From 9635e71cce55085d823cfd2a5809b9dbddb94f7c Mon Sep 17 00:00:00 2001
From: Mattias Ljungstrom <mattias.ljungstrom@gmail.com>
Date: Sun, 28 Apr 2019 17:57:59 +0200
Subject: [PATCH] level gen: enforce min/max distances

---
 flatland/envs/rail_env.py | 36 +++++++++++++++++++++++++++++++-----
 1 file changed, 31 insertions(+), 5 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 86582f5..8132d93 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -235,7 +235,11 @@ def connect_rail(rail_trans, rail_array, start, end):
         current_dir = new_dir
 
 
-def complex_rail_generator(nr_start_goal=1, min_dist=0, max_dist=99999, seed=0):
+def distance_on_rail(pos1, pos2):
+    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
+
+
+def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
     """
     Parameters
     -------
@@ -295,19 +299,41 @@ def complex_rail_generator(nr_start_goal=1, min_dist=0, max_dist=99999, seed=0):
         # - return transition map + list of [start, goal] points
         #
 
+        # step 1:
         start_goal = []
         for _ in range(nr_start_goal):
-            start = (np.random.randint(0, width), np.random.randint(0, height))
-            goal = (np.random.randint(0, height), np.random.randint(0, height))
-            # TODO: validate closeness with existing points
-            # TODO: make sure min/max distance condition is met
+            sanity_max = 9000
+            for _ in range(sanity_max):
+                start = (np.random.randint(0, width), np.random.randint(0, height))
+                goal = (np.random.randint(0, height), np.random.randint(0, height))
+                # check min/max distance
+                dist_sg = distance_on_rail(start, goal)
+                if dist_sg < min_dist:
+                    continue
+                if dist_sg > max_dist:
+                    continue
+                # check distance to existing points
+                sg_new = [start, goal]
+                def check_all_dist(sg_new):
+                    for sg in start_goal:
+                        for i in range(2):
+                            for j in range(2):
+                                dist = distance_on_rail(sg_new[i], sg[j])
+                                if dist < 2:
+                                    # print("too close:", dist, sg_new[i], sg[j])
+                                    return False
+                    return True
+                if check_all_dist(sg_new):
+                    break
             start_goal.append([start, goal])
 
+        # step 3:
         for sg in start_goal:
             connect_rail(rail_trans, rail_array, sg[0], sg[1])
 
         return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         return_rail.grid = rail_array
+        # TODO: return start_goal
         return return_rail
 
     return generator
-- 
GitLab