From 0766f7ed1d5159d2504f45af5858fa3fed87bbf7 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 1 Sep 2019 16:35:55 -0400
Subject: [PATCH] removed bias in level generation where switches of certain
 orientations where more common

---
 flatland/core/transition_map.py | 34 +++++++++++++++++----------------
 1 file changed, 18 insertions(+), 16 deletions(-)

diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index bb954998..232d6fda 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -8,6 +8,7 @@ from numpy import array
 
 from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transitions import Transitions
 
 
@@ -482,6 +483,14 @@ class GridTransitionMap(TransitionMap):
         grcPos = array(rcPos)
         grcMax = self.grid.shape
 
+        # Transition elements
+        transitions = RailEnvTransitions()
+        cells = transitions.transition_list
+        simple_switch_east_south = transitions.rotate_transition(cells[10], 90)
+        simple_switch_west_south = transitions.rotate_transition(cells[2], 270)
+        symmetrical = cells[6]
+        double_slip = cells[5]
+        three_way_transitions = [simple_switch_east_south, simple_switch_west_south, symmetrical]
         # loop over available outbound directions (indices) for rcPos
         self.set_transitions(rcPos, 0)
 
@@ -517,25 +526,18 @@ class GridTransitionMap(TransitionMap):
             self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
             self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
 
-        # Find feasible connection fro three entries
+        # Find feasible connection for three entries
         if number_of_incoming == 3:
+            transition = np.random.choice(three_way_transitions, 1)
             hole = np.argwhere(incoming_connections < 1)[0][0]
-            connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4]
-            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
-            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1)
-            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
-            self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1)
-        # Make a cross
+            transition = transitions.rotate_transition(transition, int(hole * 90))
+            self.set_transitions((rcPos[0], rcPos[1]), transition)
+
+        # Make a double slip switch
         if number_of_incoming == 4:
-            connect_directions = np.arange(4)
-            self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[0], 1)
-            self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[1], 1)
-            self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[0], 1)
-            self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[1], 1)
-            self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[2], 1)
-            self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[3], 1)
-            self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[2], 1)
-            self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[3], 1)
+            rotation = np.random.randint(2)
+            transition = transitions.rotate_transition(double_slip, int(rotation * 90))
+            self.set_transitions((rcPos[0], rcPos[1]), transition)
         return True
 
 
-- 
GitLab