diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 2aff675dfb7353aa253dad5811a572f259e4ef54..9550b71348c177adbaa30b4bd3e5307ba2e855ce 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -422,9 +422,9 @@ class GridTransitionMap(TransitionMap):
             # Check the adjacent cell is within bounds
             # if not, then this transition is invalid!
             if np.any(gPos2 < 0):
-                return False
+                continue
             if np.any(gPos2 >= grcMax):
-                return False
+                continue
 
             # Get the transitions out of gPos2, using iDirOut as the inbound direction
             # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 2281282977d8c9d972f13526efa9d96abaf84a52..deaabd02fc1f842fae8c36d9056db632791e67a8 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -11,6 +11,7 @@ import numpy as np
 
 from flatland.core.env import Environment
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
 from flatland.envs.generators import random_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
@@ -132,7 +133,7 @@ class RailEnv(Environment):
         """
 
         self.rail_generator = rail_generator
-        self.rail = None
+        self.rail: GridTransitionMap = None
         self.width = width
         self.height = height
 
@@ -222,6 +223,12 @@ class RailEnv(Environment):
         if regen_rail or self.rail is None:
             self.rail = tRailAgents[0]
             self.height, self.width = self.rail.grid.shape
+            for r in range(self.height):
+                for c in range(self.width):
+                    rcPos = (r, c)
+                    check = self.rail.cell_neighbours_valid(rcPos, True)
+                    if not check:
+                        print("WARNING: Invalid grid at {} -> {}".format(rcPos, check))
 
         if replace_agents:
             self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])