diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cc85604e9b44a740929ba05536af866fb293e014..0ca62f9310c0c7a70946fdafb495e79c6e208a70 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -346,29 +346,28 @@ class RailEnv(Environment):
             # Check if agent breaks at this step
             malfunction = self._agent_malfunction(i_agent, action)
 
-            # if we're at the beginning of the cell, store the action
-            # As long as we're broken down at the beginning of the cell, we can choose other actions!
-            # This is a design choice made by Erik and Christian.
-
-            # TODO refactor!!!
-            # If the agent can make an action
+            # Is the agent at the beginning of the cell? Then, it can take an action
+            # Design choice (Erik+Christian):
+            #  as long as we're broken down at the beginning of the cell, we can choose other actions!
             if agent.speed_data['position_fraction'] == 0.0:
                 if action == RailEnvActions.DO_NOTHING and agent.moving:
                     # Keep moving
                     action = RailEnvActions.MOVE_FORWARD
 
-                if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.0:
+                if action == RailEnvActions.STOP_MOVING and agent.moving:
                     # Only allow halting an agent on entering new cells.
                     agent.moving = False
                     self.rewards_dict[i_agent] += self.stop_penalty
 
-                if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
+                if not agent.moving and not (
+                    action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
                     # Allow agent to start with any forward or direction action
                     agent.moving = True
                     self.rewards_dict[i_agent] += self.start_penalty
 
-                if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
-                    cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
+                # Store the action
+                if agent.moving and action not in [RailEnvActions.DO_NOTHING, RailEnvActions.STOP_MOVING]:
+                    _, new_cell_valid, new_direction, new_position, transition_valid = \
                         self._check_action_on_agent(action, agent)
 
                     if all([new_cell_valid, transition_valid]):
@@ -377,7 +376,7 @@ class RailEnv(Environment):
                         # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                         # try to keep moving forward!
                         if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
-                            cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
+                            _, new_cell_valid, new_direction, new_position, transition_valid = \
                                 self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
 
                             if all([new_cell_valid, transition_valid]):
@@ -388,7 +387,6 @@ class RailEnv(Environment):
                                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
                                 self.rewards_dict[i_agent] += self.stop_penalty
                                 agent.moving = False
-                                action = RailEnvActions.DO_NOTHING
 
                         else:
                             # If the agent cannot move due to an invalid transition, we set its state to not moving
@@ -396,10 +394,10 @@ class RailEnv(Environment):
                             self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
                             self.rewards_dict[i_agent] += self.stop_penalty
                             agent.moving = False
-                            action = RailEnvActions.DO_NOTHING
                 else:
                     agent.speed_data['transition_action_on_cellexit'] = action
 
+            # if we're broken, nothing else to do
             if malfunction:
                 continue
 
@@ -422,16 +420,10 @@ class RailEnv(Environment):
                     # Nothing left to do with broken agent
                     continue
 
-
             # Now perform a movement.
-            # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
-            #   store the desired action in `transition_action_on_cellexit' (only if the desired transition is
-            #   allowed! otherwise DO_NOTHING!)
-            # Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the
-            #   position_fraction by the speed of the agent   (regardless of action taken, as long as no
-            #   STOP_MOVING, but that makes agent.moving=False)
+            # If agent.moving, increment the position_fraction by the speed of the agent
             # If the new position fraction is >= 1, reset to 0, and perform the stored
-            #   transition_action_on_cellexit
+            #   transition_action_on_cellexit if the cell is free.
 
             if agent.moving:
 
@@ -445,9 +437,11 @@ class RailEnv(Environment):
                                                                              RailEnvActions.STOP_MOVING]:
                         agent.speed_data['position_fraction'] = 0.0
                     else:
+                        # cell and transition validity was checked when we stored transition_action_on_cellexit!
                         cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
                             agent.speed_data['transition_action_on_cellexit'], agent)
-                        assert cell_free == all([cell_free, new_cell_valid, transition_valid])
+                        if not cell_free == all([cell_free, new_cell_valid, transition_valid]):
+                            warnings.warn("Inconsistent state: cell or transition not valid although checked when we stored transition_action_on_cellexit!")
                         if cell_free:
                             agent.position = new_position
                             agent.direction = new_direction
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 8de36c81e4a13c0b7e7e5e556ad79234503ad31a..b8b1afaf433d7155e48d17a2c277a579c72798ce 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -1,8 +1,17 @@
-import numpy as np
+from typing import List
 
-from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_generators import complex_rail_generator
-from flatland.envs.schedule_generators import complex_schedule_generator
+import numpy as np
+from attr import attrib, attrs
+
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map
+from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
+from flatland.utils.rendertools import RenderTool
+from flatland.utils.simple_rail import make_simple_rail
 
 np.random.seed(1)
 
@@ -86,3 +95,149 @@ def test_multi_speed_init():
             if (step + 1) % (i_agent + 1) == 0:
                 print(step, i_agent, env.agents[i_agent].position)
                 old_pos[i_agent] = env.agents[i_agent].position
+
+
+# TODO test malfunction
+# TODO test other agent blocking
+def test_multispeed_actions_no_malfunction(rendering=True):
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    # initialize agents_static
+    env.reset()
+
+    @attrs
+    class Replay(object):
+        position = attrib()
+        direction = attrib()
+        action = attrib(type=RailEnvActions)
+
+    @attrs
+    class TestConfig(object):
+        replay = attrib(type=List[Replay])
+        target = attrib()
+        speed = attrib(type=float)
+
+    # reset to set agents from agents_static
+    env.reset(False, False)
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+
+    test_configs = [
+        TestConfig(
+            replay=[
+                Replay(
+                    position=(3, 9),  # east dead-end
+                    direction=Grid4TransitionsEnum.EAST,
+                    action=RailEnvActions.MOVE_FORWARD
+                ),
+                Replay(
+                    position=(3, 9),
+                    direction=Grid4TransitionsEnum.EAST,
+                    action=None
+                ),
+                Replay(
+                    position=(3, 8),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=RailEnvActions.MOVE_FORWARD
+                ),
+                Replay(
+                    position=(3, 8),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=None
+                ),
+                Replay(
+                    position=(3, 7),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=RailEnvActions.MOVE_FORWARD
+                ),
+                Replay(
+                    position=(3, 7),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=None
+                ),
+                Replay(
+                    position=(3, 6),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=RailEnvActions.MOVE_LEFT
+                ),
+                Replay(
+                    position=(3, 6),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=None
+                ),
+                Replay(
+                    position=(4, 6),
+                    direction=Grid4TransitionsEnum.SOUTH,
+                    action=RailEnvActions.STOP_MOVING
+                ),
+                Replay(
+                    position=(4, 6),
+                    direction=Grid4TransitionsEnum.SOUTH,
+                    action=RailEnvActions.STOP_MOVING
+                ),
+                Replay(
+                    position=(4, 6),
+                    direction=Grid4TransitionsEnum.SOUTH,
+                    action=RailEnvActions.MOVE_FORWARD
+                ),
+                Replay(
+                    position=(4, 6),
+                    direction=Grid4TransitionsEnum.SOUTH,
+                    action=None
+                ),
+                Replay(
+                    position=(5, 6),
+                    direction=Grid4TransitionsEnum.SOUTH,
+                    action=RailEnvActions.MOVE_FORWARD
+                ),
+
+            ],
+            target=(3, 0),  # west dead-end
+            speed=0.5
+        )
+    ]
+
+    # TODO test penalties!
+    agentStatic: EnvAgentStatic = env.agents_static[0]
+    for test_config in test_configs:
+        info_dict = {
+            'action_required': [True]
+        }
+        for i, replay in enumerate(test_config.replay):
+            if i == 0:
+                # set the initial position
+                agentStatic.position = replay.position
+                agentStatic.direction = replay.direction
+                agentStatic.target = test_config.target
+                agentStatic.moving = True
+                agentStatic.speed_data['speed'] = test_config.speed
+
+                # reset to set agents from agents_static
+                env.reset(False, False)
+
+            def _assert(actual, expected, msg):
+                assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
+
+            agent: EnvAgent = env.agents[0]
+
+            _assert(agent.position, replay.position, 'position')
+            _assert(agent.direction, replay.direction, 'direction')
+
+            if replay.action:
+                assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
+                _, _, _, info_dict = env.step({0: replay.action})
+
+            else:
+                assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
+                _, _, _, info_dict = env.step({})
+
+            if rendering:
+                renderer.render_env(show=True, show_observations=True)