diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 1a064c992803a6e0a3fad509b0632187624d8f83..437978833627874f79e2f42107c4f91c065ff413 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -440,7 +440,7 @@ class RailEnv(Environment):
         action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
 
         # Check transitions, bounts for executing the action in the given position and directon
-        if not check_valid_action(action, self.rail, current_position, current_direction):
+        if action.is_moving_action() and not check_valid_action(action, self.rail, current_position, current_direction):
             action = RailEnvActions.STOP_MOVING
 
         return action
diff --git a/tests/test_flatland_states_edge_cases.py b/tests/test_flatland_states_edge_cases.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0d5383880dc1df3c2e1e02eca6c535aa3de0c00
--- /dev/null
+++ b/tests/test_flatland_states_edge_cases.py
@@ -0,0 +1,83 @@
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.line_generators import sparse_line_generator
+from flatland.utils.simple_rail import make_simple_rail
+from flatland.envs.step_utils.states import TrainState
+
+def test_return_to_ready_to_depart():
+    """
+    When going from ready to depart to malfunction off map, if do nothing is provided, should return to ready to depart
+    """
+    stochastic_data = MalfunctionParameters(malfunction_rate=0,  # Rate of malfunction occurence
+                                        min_duration=0,  # Minimal duration of malfunction
+                                        max_duration=0  # Max duration of malfunction
+                                        )
+
+    rail, _, optionals = make_simple_rail()
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
+                  line_generator=sparse_line_generator(seed=10),
+                  number_of_agents=1,
+                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                  )
+    
+    env.reset(False, False, random_seed=10)
+    env._max_episode_steps = 100
+
+    for _ in range(3):
+        env.step({0: RailEnvActions.DO_NOTHING})
+
+    env.agents[0].malfunction_handler._set_malfunction_down_counter(2)
+    env.step({0: RailEnvActions.DO_NOTHING})
+
+    assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP
+
+    for _ in range(2):
+        env.step({0: RailEnvActions.DO_NOTHING})
+
+    
+    assert env.agents[0].state == TrainState.READY_TO_DEPART
+
+def test_ready_to_depart_to_stopped():
+    """
+    When going from ready to depart to malfunction off map, if stopped is provided, should go to stopped
+    """
+    stochastic_data = MalfunctionParameters(malfunction_rate=0,  # Rate of malfunction occurence
+                                        min_duration=0,  # Minimal duration of malfunction
+                                        max_duration=0  # Max duration of malfunction
+                                        )
+
+    rail, _, optionals = make_simple_rail()
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
+                  line_generator=sparse_line_generator(seed=10),
+                  number_of_agents=1,
+                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                  )
+    
+    env.reset(False, False, random_seed=10)
+    env._max_episode_steps = 100
+
+    for _ in range(3):
+        env.step({0: RailEnvActions.STOP_MOVING})
+
+    assert env.agents[0].state == TrainState.READY_TO_DEPART
+
+    env.agents[0].malfunction_handler._set_malfunction_down_counter(2)
+    env.step({0: RailEnvActions.STOP_MOVING})
+
+    assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP
+
+    for _ in range(2):
+        env.step({0: RailEnvActions.STOP_MOVING})
+
+    
+    assert env.agents[0].state == TrainState.STOPPED
\ No newline at end of file