diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py
index da721dd99fbc0ea46124dcd036425df3206aa339..2d7284dc74cf2411ee2f0466d0d653bffb2132dc 100644
--- a/flatland/core/grid/grid4.py
+++ b/flatland/core/grid/grid4.py
@@ -237,3 +237,6 @@ class Grid4Transitions(Transitions):
         """
         cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
         return cell_transition
+
+    def get_entry_directions(self, cell_transition):
+        return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)]
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index 1137b8816973c12601029543c221810c9acd157c..4da1da4d23cc98ec7032530ee51820f29b637c17 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -1,5 +1,6 @@
 from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum
 from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -158,3 +159,29 @@ def test_path_not_exists(rendering=False):
         renderer = RenderTool(env, gl="PILSVG")
         renderer.render_env(show=True, show_observations=False)
         input("Continue?")
+
+
+def test_get_entry_directions():
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    vertical_line = cells[1]
+    south_symmetrical_switch = cells[6]
+    north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
+
+    # Simple turn not in the base transitions ?
+    south_east_turn = int('0100000000000010', 2)
+    south_west_turn = transitions.rotate_transition(south_east_turn, 90)
+    north_east_turn = transitions.rotate_transition(south_east_turn, 270)
+    north_west_turn = transitions.rotate_transition(south_east_turn, 180)
+
+    def _assert(transition, expected):
+        actual = transitions.get_entry_directions(transition)
+        assert actual == expected, "Found {}, expected {}.".format(actual, expected)
+
+    _assert(south_east_turn, [True, False, False, True])
+    _assert(south_west_turn, [True, True, False, False])
+    _assert(north_east_turn, [False, False, True, True])
+    _assert(north_west_turn, [False, True, True, False])
+    _assert(vertical_line, [True, False, True, False])
+    _assert(south_symmetrical_switch, [True, True, False, True])
+    _assert(north_symmetrical_switch, [False, True, True, True])