diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 9550b71348c177adbaa30b4bd3e5307ba2e855ce..7a673bcf9ba46c574db0983e3c52257e4a07358e 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -414,13 +414,13 @@ class GridTransitionMap(TransitionMap):
         # loop over available outbound directions (indices) for rcPos
         self.set_transitions(rcPos, 0)
 
-        incomping_connections = np.zeros(4)
+        incoming_connections = np.zeros(4)
         for iDirOut in np.arange(4):
             gdRC = gDir2dRC[iDirOut]  # row,col increment
             gPos2 = grcPos + gdRC  # next cell in that direction
 
             # Check the adjacent cell is within bounds
-            # if not, then this transition is invalid!
+            # if not, then ignore it for the count of incoming connections
             if np.any(gPos2 < 0):
                 continue
             if np.any(gPos2 >= grcMax):
@@ -432,23 +432,23 @@ class GridTransitionMap(TransitionMap):
             for orientation in range(4):
                 connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
             if connected > 0:
-                incomping_connections[iDirOut] = 1
+                incoming_connections[iDirOut] = 1
 
-        number_of_incoming = np.sum(incomping_connections)
+        number_of_incoming = np.sum(incoming_connections)
         # Only one incoming direction --> Straight line
         if number_of_incoming == 1:
             for direction in range(4):
-                if incomping_connections[direction] > 0:
+                if incoming_connections[direction] > 0:
                     self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
         # Connect all incoming connections
         if number_of_incoming == 2:
-            connect_directions = np.argwhere(incomping_connections > 0)
+            connect_directions = np.argwhere(incoming_connections > 0)
             self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
             self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
 
         # Find feasible connection fro three entries
         if number_of_incoming == 3:
-            hole = np.argwhere(incomping_connections < 1)[0][0]
+            hole = np.argwhere(incoming_connections < 1)[0][0]
             connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4]
             self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
             self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index deaabd02fc1f842fae8c36d9056db632791e67a8..6e6665af88c9e8b31e1a689815edb7aaada342f9 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -2,7 +2,7 @@
 Definition of the RailEnv environment.
 """
 # TODO:  _ this is a global method --> utils or remove later
-
+import warnings
 from enum import IntEnum
 
 import msgpack
@@ -228,7 +228,7 @@ class RailEnv(Environment):
                     rcPos = (r, c)
                     check = self.rail.cell_neighbours_valid(rcPos, True)
                     if not check:
-                        print("WARNING: Invalid grid at {} -> {}".format(rcPos, check))
+                        warnings.warn("Invalid grid at {} -> {}".format(rcPos, check))
 
         if replace_agents:
             self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py
index 28978ca3e5b958a51c578f6be5e8c87b77baaa97..c5fe4860783f242f21c97c55a9119d8918454a96 100644
--- a/flatland/utils/simple_rail.py
+++ b/flatland/utils/simple_rail.py
@@ -2,11 +2,87 @@ from typing import Tuple
 
 import numpy as np
 
-from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 
 
 def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    # Note that that cells have invalid RailEnvTransitions!
+    #        |
+    #        |
+    #        |
+    # _ _ _ _\ _ _  _  _ _ _
+    #                /
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_left = cells[2]
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [simple_switch_east_west_north] +
+         [horizontal_straight] * 2 + [simple_switch_east_west_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+
+
+def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    #        |
+    #        |
+    #        |
+    # _ _ _ _\ _ _  _  _ _ _
+    #               \
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [simple_switch_east_west_north] +
+         [horizontal_straight] * 2 + [simple_switch_west_east_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+
+
+def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     # We instantiate a very simple rail network on a 7x10 grid:
     #        |
     #        |
@@ -16,15 +92,9 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     #                |
     #                |
     #                |
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
-    transitions = Grid4Transitions([])
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+
     empty = cells[0]
     dead_end_from_south = cells[7]
     dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
index 12e0c092a37a475ab6e7dde21c665778e06f5e59..62df397d34b06755465ca0c9f664b9117c87243f 100644
--- a/tests/test_distance_map.py
+++ b/tests/test_distance_map.py
@@ -1,6 +1,6 @@
 import numpy as np
 
-from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import TreeObsForRailEnv
@@ -11,15 +11,8 @@ from flatland.envs.rail_env import RailEnv
 def test_walker():
     # _ _ _
 
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
-    transitions = Grid4Transitions([])
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
     dead_end_from_south = cells[7]
     dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
     dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 7acd58ed0337745f645db6dcc24a70ecb0b64305..98c276f894b51685ce0edf43f6bd1b1137d46eb0 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -10,13 +10,13 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
-from flatland.utils.simple_rail import make_simple_rail
+from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
 
 """Test predictions for `flatland` package."""
 
 
 def test_dummy_predictor(rendering=False):
-    rail, rail_map = make_simple_rail()
+    rail, rail_map = make_simple_rail2()
 
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
@@ -89,7 +89,7 @@ def test_dummy_predictor(rendering=False):
     expected_actions = np.array([[0.],
                                  [2.],
                                  [2.],
-                                 [1.],
+                                 [2.],
                                  [2.],
                                  [2.],
                                  [2.],
@@ -226,7 +226,7 @@ def test_shortest_path_predictor(rendering=False):
 
 
 def test_shortest_path_predictor_conflicts(rendering=False):
-    rail, rail_map = make_simple_rail()
+    rail, rail_map = make_invalid_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index 71dc87ceddde986be763491d28dd2b70673632f4..7ebbbb1461e24aae4c2319f51a9bb4abb2d3b25c 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -2,7 +2,6 @@
 # -*- coding: utf-8 -*-
 import numpy as np
 
-from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgent
@@ -49,15 +48,6 @@ def test_save_load():
 
 
 def test_rail_environment_single_agent():
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
-
     # We instantiate the following map on a 3x3 grid
     #  _  _
     # / \/ \
@@ -65,6 +55,7 @@ def test_rail_environment_single_agent():
     # \_/\_/
 
     transitions = RailEnvTransitions()
+    cells = transitions.transition_list
     vertical_line = cells[1]
     south_symmetrical_switch = cells[6]
     north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
@@ -139,7 +130,7 @@ test_rail_environment_single_agent()
 
 
 def test_dead_end():
-    transitions = Grid4Transitions([])
+    transitions = RailEnvTransitions()
 
     straight_vertical = int('1000000000100000', 2)  # Case 1 - straight
     straight_horizontal = transitions.rotate_transition(straight_vertical,