From ab40849fb7ab7af0ccae8baf6d0838e844efe822 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 20 Jun 2019 10:56:44 +0200
Subject: [PATCH] #62 increase unit test coverage

---
 tests/test_flatland_core_transitions.py | 81 ++++++++++++++++++++-----
 1 file changed, 65 insertions(+), 16 deletions(-)

diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py
index c32b0385..47def83c 100644
--- a/tests/test_flatland_core_transitions.py
+++ b/tests/test_flatland_core_transitions.py
@@ -10,27 +10,76 @@ from flatland.envs.env_utils import validate_new_transition
 
 def test_rotate_railenv_transition():
     rail_env_transitions = RailEnvTransitions()
+
+    # remove whitespace in string; keep whitespace below for easier reading
+    def rw(s):
+        return s.replace(" ", "")
+
+    # TODO test all cases
     transition_cycles = [
         # empty cell - Case 0
-        [int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2)],
-            # Case 1 - straight
-        [int('1000000000100000', 2), int('0000000100000100', 2)],
+        [int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2),
+         int('0000000000000000', 2)],
+        # Case 1 - straight
+        #     |
+        #     |
+        #     |
+        [int(rw('1000 0000 0010 0000'), 2), int(rw('0000 0100 0000 0001'), 2)],
+        # Case 1b (8)  - simple turn right
+        #      _
+        #     |
+        #     |
+        [
+            int(rw('0100 0000 0000 0010'), 2),
+            int(rw('0001 0010 0000 0000'), 2),
+            int(rw('0000 1000 0001 0000'), 2),
+            int(rw('0000 0000 0100 1000'), 2),
+        ],
+        # Case 1c (9)  - simple turn left
+        #    _
+        #     |
+        #     |
+
+        #                int('0001001000000000', 2),
+
+        # Case 2 - simple left switch
+        #  _ _|
+        #     |
+        #     |
+        [
+            int(rw('1001 0010 0010 0000'), 2),
+            int(rw('0000 1100 0001 0001'), 2),
+            int(rw('1000 0000 0110 1000'), 2),
+            int(rw('0100 0100 0000 0011'), 2),
+        ],
+        # Case 2b (10) - simple right switch
+        #     |
+        #     |
+        #     |
+        #                int('1100000000100010', 2)]
+        #                int('1000010000100001', 2),  # Case 3 - diamond drossing
+        #                int('1001011000100001', 2),  # Case 4 - single slip
+        #                int('1100110000110011', 2),  # Case 5 - double slip
+        #                int('0101001000000010', 2),  # Case 6 - symmetrical
+        #                int('0010000000000000', 2),  # Case 7 - dead end
+
     ]
 
-    for cycle in transition_cycles:
+    for index, cycle in enumerate(transition_cycles):
         for i in range(4):
-            assert rail_env_transitions.rotate_transition(cycle[0], i) == cycle[i % len(cycle)]
-
-    #
-    #                int('1001001000100000', 2),  # Case 2 - simple switch
-    #                int('1000010000100001', 2),  # Case 3 - diamond drossing
-    #                int('1001011000100001', 2),  # Case 4 - single slip
-    #                int('1100110000110011', 2),  # Case 5 - double slip
-    #                int('0101001000000010', 2),  # Case 6 - symmetrical
-    #                int('0010000000000000', 2),  # Case 7 - dead end
-    #                int('0100000000000010', 2),  # Case 1b (8)  - simple turn right
-    #                int('0001001000000000', 2),  # Case 1c (9)  - simple turn left
-    #                int('1100000000100010', 2)]  # Case 2b (10) - simple switch mirrored
+            actual_transition = rail_env_transitions.rotate_transition(cycle[0], i * 90)
+            expected_transition = cycle[i % len(cycle)]
+            try:
+                assert actual_transition == expected_transition, \
+                    "Case {}: rotate_transition({}, {}) should equal {} but was {}." \
+                        .format(i, cycle[0], i, expected_transition, actual_transition)
+            except Exception as e:
+                print("expected:")
+                rail_env_transitions.print(expected_transition)
+                print("actual:")
+                rail_env_transitions.print(actual_transition)
+
+                raise e
 
 
 def test_is_valid_railenv_transitions():
-- 
GitLab