diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py
index 2d7284dc74cf2411ee2f0466d0d653bffb2132dc..819fc8da9102b688138cb06f59833e718188a411 100644
--- a/flatland/core/grid/grid4.py
+++ b/flatland/core/grid/grid4.py
@@ -1,5 +1,5 @@
 from enum import IntEnum
-from typing import Type
+from typing import Type, List
 
 import numpy as np
 
@@ -238,5 +238,6 @@ class Grid4Transitions(Transitions):
         cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
         return cell_transition
 
-    def get_entry_directions(self, cell_transition):
+    @staticmethod
+    def get_entry_directions(cell_transition) -> List[int]:
         return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)]
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 995e7d67ca87f306be7c87dbba685f46cde240d2..cf77cb580f2af7e4657140517e58825902be47e8 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -14,7 +14,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transitions import Transitions
 from flatland.utils.ordered_set import OrderedSet
 
-
+# TODO are these general classes or for grid4 only?
 class TransitionMap:
     """
     Base TransitionMap class.
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 901f318c0add611c4d0ebb0846c7a4c0c30f39dd..887f97aae7c74af8d986454f0c50dcaace59f0dc 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -12,7 +12,7 @@ import numpy as np
 
 from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
@@ -593,6 +593,9 @@ class RailEnv(Environment):
         self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
         return self.obs_dict
 
+    def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
+        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col))
+
     def get_full_state_msg(self):
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index cc516de2174d262fdc5c5086b01921c183323031..0ec268dc1d3910967541ba0bcc357c8cc4638ec4 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -896,25 +896,20 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
         Function to fix all transition elements in environment
         """
         # Fix all nodes with illegal transition maps
-        empty_to_fix = []
-        rails_to_fix = []
+        rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int')
+        rails_to_fix_cnt = 0
         for cell in city_cells:
             check = grid_map.cell_neighbours_valid(cell, True)
             if grid_map.grid[cell] == int('1000010000100001', 2):
                 grid_map.fix_transitions(cell)
             if not check:
-                if grid_map.grid[cell] == 0:
-                    empty_to_fix.append(cell)
-                else:
-                    rails_to_fix.append(cell)
-
-        # Fix empty cells first to avoid cutting the network
-        for cell in empty_to_fix:
-            grid_map.fix_transitions(cell)
+                rails_to_fix[2 * rails_to_fix_cnt] = cell[0]
+                rails_to_fix[2 * rails_to_fix_cnt + 1] = cell[1]
+                rails_to_fix_cnt += 1
 
         # Fix all other cells
-        for cell in rails_to_fix:
-            grid_map.fix_transitions(cell)
+        for cell in range(rails_to_fix_cnt):
+            grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1]))
 
     def _closest_neigh_in_direction(current_node, node_positions):
         """
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index 930cc24fb4a9be817c14f2cca149747ac6ca370b..0913e45959d08230a815c33d98fb6de8eb99d956 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -174,7 +174,7 @@ def test_get_entry_directions():
     north_west_turn = transitions.rotate_transition(south_east_turn, 180)
 
     def _assert(transition, expected):
-        actual = transitions.get_entry_directions(transition)
+        actual = Grid4Transitions.get_entry_directions(transition)
         assert actual == expected, "Found {}, expected {}.".format(actual, expected)
 
     _assert(south_east_turn, [True, False, False, True])
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index 0114730a2ac1d0df618eea773dfcf1cd7175dee2..0fefd3e212ddb5f084c1e219f4063079e03dabdf 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -6,11 +6,13 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.agent_utils import EnvAgentStatic
-from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator
+from flatland.utils.simple_rail import make_simple_rail
 
 """Tests for `flatland` package."""
 
@@ -212,3 +214,36 @@ def test_dead_end():
 
     rail_env.reset()
     rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
+
+
+def test_get_entry_directions():
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    def _assert(position, expected):
+        actual = env.get_valid_directions_on_grid(*position)
+        assert actual == expected, "[{},{}] actual={}, expected={}".format(*position, actual, expected)
+
+    # north dead end
+    _assert((0, 3), [True, False, False, False])
+
+    # west dead end
+    _assert((3, 0), [False, False, False, True])
+
+    # switch
+    _assert((3, 3), [False, True, True, True])
+
+    # horizontal
+    _assert((3, 2), [False, True, False, True])
+
+    # vertical
+    _assert((2, 3), [True, False, True, False])
+
+    # nowhere
+    _assert((0, 0), [False, False, False, False])