From 1664efd1a4d9d60341358741db949257d5c4cadd Mon Sep 17 00:00:00 2001
From: Mattias Ljungstrom <mattias.ljungstrom@gmail.com>
Date: Sun, 28 Apr 2019 20:19:54 +0200
Subject: [PATCH] level gen: bug fixes

---
 examples/play_model.py       |  2 +-
 flatland/core/transitions.py |  3 +-
 flatland/envs/rail_env.py    | 59 ++++++++++++++++++++++++++++--------
 3 files changed, 49 insertions(+), 15 deletions(-)

diff --git a/examples/play_model.py b/examples/play_model.py
index 82458e3..1f654c1 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -103,7 +103,7 @@ def main(render=True, delay=0.0):
 
     # Example generate a random rail
     env = RailEnv(width=15, height=15,
-                  rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=5),
+                  rail_generator=complex_rail_generator(nr_start_goal=15, min_dist=5),
                   number_of_agents=1)
 
     if render:
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index 9262089..0afec98 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -533,7 +533,8 @@ class RailEnvTransitions(Grid4Transitions):
                        int('0101001000000010', 2),  # Case 6 - symmetrical
                        int('0010000000000000', 2),  # Case 7 - dead end
                        int('0100000000000010', 2),  # Case 1b - simple turn right
-                       int('0001001000000000', 2)]  # Case 1c - simple turn left
+                       int('0001001000000000', 2),  # Case 1c - simple turn left
+                       int('1100000000100010', 2)]  # Case 2b - simple switch mirrored
 
     def __init__(self):
         super(RailEnvTransitions, self).__init__(
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index a7af40e..edcd069 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -67,8 +67,14 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
     # create new transition that would go to child
     new_trans = rail_array[current_pos]
     if prev_pos is None:
-        # need to flip direction because of how end points are defined
-        new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
+        if new_trans == 0:
+            # need to flip direction because of how end points are defined
+            new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
+        else:
+            # check if matches existing layout
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+            # rail_trans.print(new_trans)
     else:
         # set the forward path
         new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
@@ -77,11 +83,24 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
     if new_pos == end_pos:
         # need to validate end pos setup as well
         new_trans_e = rail_array[end_pos]
-        new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+        if new_trans_e == 0:
+            # need to flip direction because of how end points are defined
+            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+        else:
+            # check if matches existing layout
+            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+            new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
+            # print("end:", end_pos, current_pos)
+            # rail_trans.print(new_trans_e)
+
         # print("========> end trans")
         # rail_trans.print(new_trans_e)
         if not rail_trans.is_valid(new_trans_e):
+            # print("end failed", end_pos, current_pos)
             return False
+        # else:
+        #    print("end ok!", end_pos, current_pos)
+
     # is transition is valid?
     # print("=======> trans")
     # rail_trans.print(new_trans)
@@ -196,6 +215,7 @@ def a_star(rail_trans, rail_array, start, end):
                 path.append(current.pos)
                 current = current.parent
             # return reversed path
+            print("partial:", start, end, path[::-1])
             return path[::-1]
 
 
@@ -217,8 +237,14 @@ def connect_rail(rail_trans, rail_array, start, end):
 
         new_trans = rail_array[current_pos]
         if index == 0:
-            # need to flip direction because of how end points are defined
-            new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
+            if new_trans == 0:
+                # end-point
+                # need to flip direction because of how end points are defined
+                new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
+            else:
+                # into existing rail
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+                new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
         else:
             # set the forward path
             new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
@@ -227,9 +253,15 @@ def connect_rail(rail_trans, rail_array, start, end):
         rail_array[current_pos] = new_trans
 
         if new_pos == end_pos:
-            # need to validate end pos setup as well
+            # setup end pos setup
             new_trans_e = rail_array[end_pos]
-            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+            if new_trans_e == 0:
+                # end-point
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+            else:
+                # into existing rail
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+                new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
             rail_array[end_pos] = new_trans_e
 
         current_dir = new_dir
@@ -299,13 +331,15 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
         # - return transition map + list of [start, goal] points
         #
 
-        # step 1:
         start_goal = []
         for _ in range(nr_start_goal):
             sanity_max = 9000
             for _ in range(sanity_max):
                 start = (np.random.randint(0, width), np.random.randint(0, height))
                 goal = (np.random.randint(0, height), np.random.randint(0, height))
+                # check to make sure start,goal pos is empty?
+                # if rail_array[goal] != 0: # or rail_array[start] != 0:
+                #     continue
                 # check min/max distance
                 dist_sg = distance_on_rail(start, goal)
                 if dist_sg < min_dist:
@@ -327,11 +361,10 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
                 if check_all_dist(sg_new):
                     break
             start_goal.append([start, goal])
-        print("Created #", len(start_goal), "pairs")
+            connect_rail(rail_trans, rail_array, start, goal)
 
-        # step 3:
-        for sg in start_goal:
-            connect_rail(rail_trans, rail_array, sg[0], sg[1])
+        print("Created #", len(start_goal), "pairs")
+        # print(start_goal)
 
         return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         return_rail.grid = rail_array
@@ -476,7 +509,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
 
         transitions_templates_ = []
         transition_probabilities = []
-        for i in range(len(t_utils.transitions) - 3):  # don't include dead-ends
+        for i in range(len(t_utils.transitions) - 4):  # don't include dead-ends
             all_transitions = 0
             for dir_ in range(4):
                 trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
-- 
GitLab