diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cec0542c520277903d0b0f6b7746a365fc9862e7..0c1e3f844b8cf543f94d52a7b01b1e3af6964ce9 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -28,6 +28,7 @@ from flatland.envs.observations import GlobalObsForRailEnv
 
 from flatland.envs.timetable_generators import timetable_generator
 from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
+from flatland.envs.step_utils.transition_utils import check_valid_action
 from flatland.envs.step_utils import action_preprocessing
 from flatland.envs.step_utils import env_utils
 
@@ -437,6 +438,11 @@ class RailEnv(Environment):
             current_position, current_direction = agent.initial_position, agent.initial_direction
         
         action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
+
+        # Check transitions, bounts for executing the action in the given position and directon
+        if not check_valid_action(action, self.rail, current_position, current_direction):
+            action = RailEnvActions.STOP_MOVING
+
         return action
     
     def clear_rewards_dict(self):
@@ -579,14 +585,15 @@ class RailEnv(Environment):
             movement_allowed = movement_allowed and agent.state != TrainState.DONE
 
             # Agent is being added to map
-            if agent.state.is_on_map_state() and agent.state_machine.previous_state.is_off_map_state():
-                agent.position = agent.initial_position
-                agent.direction = agent.initial_direction
+            if agent.state.is_on_map_state():
+                if agent.state_machine.previous_state.is_off_map_state():
+                    agent.position = agent.initial_position
+                    agent.direction = agent.initial_direction
             # Speed counter completes
-            elif movement_allowed and (agent.speed_counter.is_cell_exit):
-                agent.position = agent_transition_data.position
-                agent.direction = agent_transition_data.direction
-                agent.state_machine.update_if_reached(agent.position, agent.target)
+                elif movement_allowed and (agent.speed_counter.is_cell_exit):
+                    agent.position = agent_transition_data.position
+                    agent.direction = agent_transition_data.direction
+                    agent.state_machine.update_if_reached(agent.position, agent.target)
 
             # Off map or on map state and position should match
             env_utils.state_position_sync_check(agent.state, agent.position, agent.handle)
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 628298f0ac99341016037092d7ab797a89a2c14f..a23bf4c6df135dfbf82d4e999e39e1ab68884c90 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -86,7 +86,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
                             expected_next_position[agent.handle] = neighbour
                             print("   improved (action) -> {}".format(actions[agent.handle]))
     _, rewards, dones, _ = env.step(actions)
-    return rewards
+    return rewards, dones
 
 
 def test_reward_function_conflict(rendering=False):
@@ -162,8 +162,9 @@ def test_reward_function_conflict(rendering=False):
         },
     }
     while iteration < 5:
-        rewards = _step_along_shortest_path(env, obs_builder, rail)
-
+        rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
+        if dones["__all__"]:
+            break
         for agent in env.agents:
             # assert rewards[agent.handle] == 0
             expected_position = expected_positions[iteration + 1][agent.handle]
@@ -289,7 +290,9 @@ def test_reward_function_waiting(rendering=False):
     }
     while iteration < 7:
 
-        rewards = _step_along_shortest_path(env, obs_builder, rail)
+        rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
+        if dones["__all__"]:
+            break
 
         if rendering:
             renderer.render_env(show=True, show_observations=True)