From 8e3d61a971fb55289666ca106375c135ebc13be4 Mon Sep 17 00:00:00 2001
From: spiglerg <spiglerg@gmail.com>
Date: Tue, 7 May 2019 12:00:48 +0200
Subject: [PATCH] random generator with turns

---
 examples/temporary_example.py | 6 +++++-
 flatland/envs/generators.py   | 7 +++++--
 2 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 28a94db3..c2720a93 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -17,7 +17,11 @@ transition_probability = [1.0,  # empty cell - Case 0
                           0.5,  # Case 4 - single slip
                           0.5,  # Case 5 - double slip
                           0.2,  # Case 6 - symmetrical
-                          0.0]  # Case 7 - dead end
+                          0.0,  # Case 7 - dead end
+                          0.2,  # Case 8 - turn left
+                          0.2,  # Case 9 - turn right
+                          1.0]  # Case 10 - mirrored switch
+
 """
 # Example generate a random rail
 env = RailEnv(width=20,
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 2c26076d..de75b70b 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -227,7 +227,7 @@ def generate_rail_from_list_of_manual_specifications(list_of_specifications)
 """
 
 
-def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
+def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
     """
     Dummy random level generator:
     - fill in cells at random in [width-2, height-2]
@@ -266,7 +266,10 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
 
         transitions_templates_ = []
         transition_probabilities = []
-        for i in range(len(t_utils.transitions) - 4):  # don't include dead-ends
+        for i in range(len(t_utils.transitions)):  # don't include dead-ends
+            if t_utils.transitions[i] == int('0010000000000000', 2):
+                continue
+
             all_transitions = 0
             for dir_ in range(4):
                 trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
-- 
GitLab