From 123806b141b59bcf80a8c1e73e36d421a4bba6bf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mattias=20Ljungstr=C3=B6m?= <ml@mljx.io>
Date: Mon, 29 Apr 2019 10:36:42 +0200
Subject: [PATCH] level gen: optimized validation

---
 flatland/core/transitions.py | 28 +++++++++++++++-------------
 flatland/envs/rail_env.py    |  9 +++++----
 2 files changed, 20 insertions(+), 17 deletions(-)

diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index 0afec98..cdc657c 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -532,14 +532,25 @@ class RailEnvTransitions(Grid4Transitions):
                        int('1100110000110011', 2),  # Case 5 - double slip
                        int('0101001000000010', 2),  # Case 6 - symmetrical
                        int('0010000000000000', 2),  # Case 7 - dead end
-                       int('0100000000000010', 2),  # Case 1b - simple turn right
-                       int('0001001000000000', 2),  # Case 1c - simple turn left
-                       int('1100000000100010', 2)]  # Case 2b - simple switch mirrored
+                       int('0100000000000010', 2),  # Case 1b (8)  - simple turn right
+                       int('0001001000000000', 2),  # Case 1c (9)  - simple turn left
+                       int('1100000000100010', 2)]  # Case 2b (10) - simple switch mirrored
 
     def __init__(self):
         super(RailEnvTransitions, self).__init__(
             transitions=self.transition_list
         )
+        # create this to make validation faster
+        self.transitions_all = []
+        for index, trans in enumerate(self.transitions):
+            self.transitions_all.append(trans)
+            if index in (2, 4, 6, 7, 8, 9, 10):
+                for _ in range(3):
+                    trans = self.rotate_transition(trans, rotation=90)
+                    self.transitions_all.append(trans)
+            elif index in (1, 5):
+                trans = self.rotate_transition(trans, rotation=90)
+                self.transitions_all.append(trans)
 
     def print(self, cell_transition):
         print("  NESW")
@@ -562,17 +573,8 @@ class RailEnvTransitions(Grid4Transitions):
         Boolean
             True or False
         """
-        # i = 0
-        for trans in self.transitions:
-            # print(">", i)
-            # i += 1
-            # self.print(trans)
+        for trans in self.transitions_all:
             if cell_transition == trans:
                 return True
-            for _ in range(3):
-                trans = self.rotate_transition(trans, rotation=90)
-                # self.print(trans)
-                if cell_transition == trans:
-                    return True
         return False
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index dc209a6..de952d3 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -73,7 +73,7 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
         else:
             # check if matches existing layout
             new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
-            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+            # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
             # rail_trans.print(new_trans)
     else:
         # set the forward path
@@ -89,7 +89,7 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
         else:
             # check if matches existing layout
             new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
-            new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
+            # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
             # print("end:", end_pos, current_pos)
             # rail_trans.print(new_trans_e)
 
@@ -244,7 +244,8 @@ def connect_rail(rail_trans, rail_array, start, end):
             else:
                 # into existing rail
                 new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
-                new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+                # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+                pass
         else:
             # set the forward path
             new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
@@ -261,7 +262,7 @@ def connect_rail(rail_trans, rail_array, start, end):
             else:
                 # into existing rail
                 new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
-                new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
+                # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
             rail_array[end_pos] = new_trans_e
 
         current_dir = new_dir
-- 
GitLab