From 62395962ef2c4342b41053fd55b30a2160678842 Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Mon, 8 Apr 2019 17:49:59 +0200
Subject: [PATCH] test valid environment initialization

---
 flatland/core/transitions.py | 17 ++++++------
 tests/test_environments.py   | 50 ++++++++++++++++++++++++++++++++----
 2 files changed, 54 insertions(+), 13 deletions(-)

diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index 467ad67..dc4ffe8 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -438,14 +438,15 @@ class RailEnvTransitions(GridTransitions):
     transitions available as a function of the agent's orientation
     (north, east, south, west)
     """
-    transition_list = [int('0000000000000000', 2),
-                       int('1000000000100000', 2),
-                       int('1001001000100000', 2),
-                       int('1000010000100001', 2),
-                       int('1001011000100001', 2),
-                       int('1100110000110011', 2),
-                       int('0101001000000010', 2),
-                       int('0000000000100000', 2)]
+
+    transition_list = [int('0000000000000000', 2),  # empty cell - Case 0
+             int('1000000000100000', 2),  # Case 1 - straight
+             int('1001001000000000', 2),  # Case 2 - simple switch
+             int('1000010000100001', 2),  # Case 3 - diamond drossing
+             int('1001011000100001', 2),  # Case 4 - single slip switch
+             int('1100110000110011', 2),  # Case 5 - double slip switch
+             int('0101001000000010', 2),  # Case 6 - symmetrical switch
+             int('0010000000000000', 2)]  # Case 7 - dead end
 
     def __init__(self):
         super(RailEnvTransitions, self).__init__(
diff --git a/tests/test_environments.py b/tests/test_environments.py
index ae4d524..e95c398 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -1,11 +1,51 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+from flatland.core.env import RailEnv
+from flatland.core.transitions import GridTransitions
+import numpy as np
+
 """Tests for `flatland` package."""
 
 
-def test_base_environment():
-    """Test example Transition."""
-    a = True
-    b = True
-    assert a == b
+
+def test_rail_environment():
+    cells = [int('0000000000000000', 2),  # empty cell - Case 0
+                       int('1000000000100000', 2),  # Case 1 - straight
+                       int('1001001000000000', 2),  # Case 2 - simple switch
+                       int('1000010000100001', 2),  # Case 3 - diamond drossing
+                       int('1001011000100001', 2),  # Case 4 - single slip switch
+                       int('1100110000110011', 2),  # Case 5 - double slip switch
+                       int('0101001000000010', 2),  # Case 6 - symmetrical switch
+                       int('0010000000000000', 2)]  # Case 7 - dead end
+
+    # We instantiate the following map on a 3x3 grid
+    #  _  _
+    # / \/ \
+    # | |  |
+    # \_/\_/
+
+    transitions = GridTransitions([], False)
+    vertical_line = cells[1]
+    south_symmetrical_switch = cells[6]
+    north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
+    south_east_turn = int('0100000000100000', 2)  # Simple turn not in the base transitions ?
+    south_west_turn = transitions.rotate_transition(south_east_turn, 90)
+    north_east_turn = transitions.rotate_transition(south_east_turn, 270)
+    north_west_turn = transitions.rotate_transition(south_east_turn, 180)
+
+    rail_map = np.array([[south_east_turn, south_symmetrical_switch, south_west_turn],
+                    [vertical_line, vertical_line, vertical_line],
+                    [north_east_turn, north_symmetrical_switch, north_west_turn]],
+                   dtype=np.uint16)
+
+    rail_env = RailEnv(rail_map, number_of_agents=1)
+
+    # Check that trains are always initialized at a consistent position / direction.
+    # They should always be able to go somewhere.
+    for _ in range(1000):
+        obs = rail_env.reset()
+        assert(transitions.get_transitions_from_orientation(
+            rail_map[rail_env.agents_position[0]],
+            rail_env.agents_direction[0]) != (0, 0, 0, 0))
+
-- 
GitLab