diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py
index 2d7284dc74cf2411ee2f0466d0d653bffb2132dc..819fc8da9102b688138cb06f59833e718188a411 100644
--- a/flatland/core/grid/grid4.py
+++ b/flatland/core/grid/grid4.py
@@ -1,5 +1,5 @@
 from enum import IntEnum
-from typing import Type
+from typing import Type, List
 
 import numpy as np
 
@@ -238,5 +238,6 @@ class Grid4Transitions(Transitions):
         cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
         return cell_transition
 
-    def get_entry_directions(self, cell_transition):
+    @staticmethod
+    def get_entry_directions(cell_transition) -> List[int]:
         return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)]
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 07678add5549c3ac13df876132ed3bcdbf5bec5e..9db7f3c7775a01824a849d8dc126fbbb3955d212 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -14,7 +14,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transitions import Transitions
 from flatland.utils.ordered_set import OrderedSet
 
-
+# TODO are these general classes or for grid4 only?
 class TransitionMap:
     """
     Base TransitionMap class.
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index b5bc44f2e698a4ffa23ca31d34ac14f613340d04..6e021183a77642c8547c5efc3a8c97764fa078d0 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -12,7 +12,7 @@ import numpy as np
 
 from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
@@ -592,6 +592,9 @@ class RailEnv(Environment):
         self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
         return self.obs_dict
 
+    def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
+        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col))
+
     def get_full_state_msg(self):
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index 930cc24fb4a9be817c14f2cca149747ac6ca370b..0913e45959d08230a815c33d98fb6de8eb99d956 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -174,7 +174,7 @@ def test_get_entry_directions():
     north_west_turn = transitions.rotate_transition(south_east_turn, 180)
 
     def _assert(transition, expected):
-        actual = transitions.get_entry_directions(transition)
+        actual = Grid4Transitions.get_entry_directions(transition)
         assert actual == expected, "Found {}, expected {}.".format(actual, expected)
 
     _assert(south_east_turn, [True, False, False, True])
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index 0114730a2ac1d0df618eea773dfcf1cd7175dee2..0fefd3e212ddb5f084c1e219f4063079e03dabdf 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -6,11 +6,13 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.agent_utils import EnvAgentStatic
-from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator
+from flatland.utils.simple_rail import make_simple_rail
 
 """Tests for `flatland` package."""
 
@@ -212,3 +214,36 @@ def test_dead_end():
 
     rail_env.reset()
     rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
+
+
+def test_get_entry_directions():
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    def _assert(position, expected):
+        actual = env.get_valid_directions_on_grid(*position)
+        assert actual == expected, "[{},{}] actual={}, expected={}".format(*position, actual, expected)
+
+    # north dead end
+    _assert((0, 3), [True, False, False, False])
+
+    # west dead end
+    _assert((3, 0), [False, False, False, True])
+
+    # switch
+    _assert((3, 3), [False, True, True, True])
+
+    # horizontal
+    _assert((3, 2), [False, True, False, True])
+
+    # vertical
+    _assert((2, 3), [True, False, True, False])
+
+    # nowhere
+    _assert((0, 0), [False, False, False, False])