diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py
index a049ae260860e946ac8b27f9e83ab11bf4ed2920..3a75aa81193d2355f71a05d8825bc64da4547f6f 100644
--- a/flatland/core/grid/grid4_astar.py
+++ b/flatland/core/grid/grid4_astar.py
@@ -46,8 +46,6 @@ def a_star(grid_map: GridTransitionMap,
     """
     rail_shape = grid_map.grid.shape
 
-    tmp = np.zeros(rail_shape) - 10
-
     start_node = AStarNode(start, None)
     end_node = AStarNode(end, None)
     open_nodes = OrderedSet()
@@ -114,8 +112,6 @@ def a_star(grid_map: GridTransitionMap,
             child.h = a_star_distance_function(child.pos, end_node.pos)
             child.f = child.g + child.h
 
-            tmp[child.pos[0]][child.pos[1]] = child.f
-
             # already in the open list?
             if child in open_nodes:
                 continue
diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py
index 98652459d7a7ac7f1694ac53fe1d0a12880ab8b2..75cef7b4d3aea783140a5c08c3498a0bc321fb62 100644
--- a/flatland/core/grid/grid4_utils.py
+++ b/flatland/core/grid/grid4_utils.py
@@ -1,8 +1,8 @@
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.core.grid.grid_utils import IntVector2DArray
+from flatland.core.grid.grid_utils import IntVector2D
 
 
-def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4TransitionsEnum:
+def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
     """
     Assumes pos1 and pos2 are adjacent location on grid.
     Returns direction (int) that can be used with transitions.
@@ -10,13 +10,13 @@ def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4Transi
     diff_0 = pos2[0] - pos1[0]
     diff_1 = pos2[1] - pos1[1]
     if diff_0 < 0:
-        return 0
+        return Grid4TransitionsEnum.NORTH
     if diff_0 > 0:
-        return 2
+        return Grid4TransitionsEnum.SOUTH
     if diff_1 > 0:
-        return 1
+        return Grid4TransitionsEnum.EAST
     if diff_1 < 0:
-        return 3
+        return Grid4TransitionsEnum.WEST
     raise Exception("Could not determine direction {}->{}".format(pos1, pos2))
 
 
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index d6f47abfd85cfa1cc7e72e27aeb4f7ededa975dd..fce3ffdf320a3c38d7f0551151ffdc8debe6ab5d 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -7,22 +7,25 @@ a GridTransitionMap object.
 
 from flatland.core.grid.grid4_astar import a_star
 from flatland.core.grid.grid4_utils import get_direction, mirror
-from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
+from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
 
 
-def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
-                            start: IntVector2D,
-                            end: IntVector2D,
-                            flip_start_node_trans=False,
-                            flip_end_node_trans=False,
-                            a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+def connect_basic_operation(
+    rail_trans: RailEnvTransitions,
+    grid_map: GridTransitionMap,
+    start: IntVector2D,
+    end: IntVector2D,
+    flip_start_node_trans=False,
+    flip_end_node_trans=False,
+    a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     """
-    Creates a new path [start,end] in grid_map, based on rail_trans.
+    Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
+    returns the path created as a list of positions.
     """
     # in the worst case we will need to do a A* search, so we might as well set that up
-    path = a_star(grid_map, start, end, a_star_distance_function)
+    path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function)
     if len(path) < 2:
         return []
     current_dir = get_direction(path[0], path[1])
@@ -71,23 +74,24 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
 
 def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                  start: IntVector2D, end: IntVector2D,
-                 a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+                 a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function)
 
 
 def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                   start: IntVector2D, end: IntVector2D,
-                  a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+                  a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function)
 
 
 def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                        start: IntVector2D, end: IntVector2D,
-                       a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+                       a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance
+                       ) -> IntVector2DArray:
     return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function)
 
 
 def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                      start: IntVector2D, end: IntVector2D,
-                     a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+                     a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index c81ef9dc82df0817f3f3fc42798392d7ffdbcf5e..862774319ce58c2625b227b89b77940c12016e89 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -237,7 +237,8 @@ class RailEnv(Environment):
             Relies on the rail_generator returning agent_static lists (pos, dir, target)
         """
 
-        # TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
+        # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172
+        #  can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
         rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
         if optionals and 'distance_map' in optionals:
@@ -257,6 +258,9 @@ class RailEnv(Environment):
             agents_hints = None
             if optionals and 'agents_hints' in optionals:
                 agents_hints = optionals['agents_hints']
+
+            # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185
+            #  why do we need static agents? could we it more elegantly?
             self.agents_static = EnvAgentStatic.from_lists(
                 *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints))
         self.restart_agents()
@@ -408,13 +412,14 @@ class RailEnv(Environment):
         # is the agent malfunctioning?
         malfunction = self._agent_malfunction(i_agent)
 
-        # if agent is broken, actions are ignored and agent does not move,
-        # the agent is not penalized in this step!
+        # if agent is broken, actions are ignored and agent does not move.
+        # full step penalty in this case
         if malfunction:
             self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
             return
 
         # Is the agent at the beginning of the cell? Then, it can take an action.
+        # As long as the agent is malfunctioning or stopped at the beginning of the cell, different actions may be taken!
         if agent.speed_data['position_fraction'] == 0.0:
             # No action has been supplied for this agent -> set DO_NOTHING as default
             if action is None:
@@ -463,9 +468,9 @@ class RailEnv(Environment):
                             _action_stored = True
 
                 if not _action_stored:
+
                     # If the agent cannot move due to an invalid transition, we set its state to not moving
                     self.rewards_dict[i_agent] += self.invalid_action_penalty
-                    self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
                     self.rewards_dict[i_agent] += self.stop_penalty
                     agent.moving = False
 
@@ -498,6 +503,9 @@ class RailEnv(Environment):
                 agent.moving = False
             else:
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+        else:
+            # step penalty if not moving (stopped now or before)
+            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
 
     def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
         """
diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py
index b95922cf67febdaa0aad396459bc446bc31adfea..cf5c8592708eef237bcf29308032df49753860bd 100644
--- a/tests/test_flatland_envs_env_utils.py
+++ b/tests/test_flatland_envs_env_utils.py
@@ -2,8 +2,8 @@ import numpy as np
 import pytest
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position
 from flatland.core.grid.grid4_utils import get_direction
+from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position
 
 depth_to_test = 5
 positions_to_test = [0, 5, 1, 6, 20, 30]
@@ -31,4 +31,4 @@ def test_get_direction():
     assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
     assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
     with pytest.raises(Exception, match="Could not determine direction"):
-        get_direction((0, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
+        get_direction((0, 0), (0, 0))
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 55d3526757123230fb351dbf67dbfc269e58b6ac..1b3c6adead4d0d82fd676efcc051fc66b4486ef8 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -1,16 +1,15 @@
 import random
+from typing import Dict
 
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
-from flatland.utils.rendertools import RenderTool
-from test_utils import Replay
+from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
 
 class SingleAgentNavigationObs(TreeObsForRailEnv):
@@ -54,7 +53,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
                     min_distances.append(np.inf)
 
             observation = [0, 0, 0]
-            observation[np.argmin(min_distances)] = 1
+            observation[np.argmin(min_distances)[0]] = 1
 
         return observation
 
@@ -83,7 +82,6 @@ def test_malfunction_process():
 
     agent_halts = 0
     total_down_time = 0
-    agent_malfunctioning = False
     agent_old_position = env.agents[0].position
     for step in range(100):
         actions = {}
@@ -142,12 +140,12 @@ def test_malfunction_process_statistically():
     env.reset()
     nb_malfunction = 0
     for step in range(100):
-        action_dict = {}
+        action_dict: Dict[int, RailEnvActions] = {}
         for agent in env.agents:
             if agent.malfunction_data['malfunction'] > 0:
                 nb_malfunction += 1
             # We randomly select an action
-            action_dict[agent.handle] = np.random.randint(4)
+            action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
 
         env.step(action_dict)
 
@@ -155,7 +153,7 @@ def test_malfunction_process_statistically():
     assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction)
 
 
-def test_initial_malfunction(rendering=True):
+def test_initial_malfunction():
     random.seed(0)
     np.random.seed(0)
 
@@ -189,75 +187,56 @@ def test_initial_malfunction(rendering=True):
                   number_of_agents=1,
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   )
-
-    if rendering:
-        renderer = RenderTool(env)
-        renderer.render_env(show=True, frames=False, show_observations=False)
-    _action = dict()
-
-    replay_steps = [
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=3
-        ),
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=2
-        ),
-        # malfunction stops in the next step and we're still at the beginning of the cell
-        # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=1
-        ),
-        Replay(
-            position=(28, 4),
-            direction=Grid4TransitionsEnum.WEST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        ),
-        Replay(
-            position=(27, 4),
-            direction=Grid4TransitionsEnum.NORTH,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        )
-    ]
-
-    info_dict = {
-        'action_required': [True]
-    }
-
-    for i, replay in enumerate(replay_steps):
-
-        def _assert(actual, expected, msg):
-            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
-
-        agent: EnvAgent = env.agents[0]
-
-        _assert(agent.position, replay.position, 'position')
-        _assert(agent.direction, replay.direction, 'direction')
-        _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
-
-        if replay.action is not None:
-            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-            _, _, _, info_dict = env.step({0: replay.action})
-
-        else:
-            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-            _, _, _, info_dict = env.step({})
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
-
-
-def test_initial_malfunction_stop_moving(rendering=True):
+    set_penalties_for_replay(env)
+    replay_config = ReplayConfig(
+        replay=[
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                set_malfunction=3,
+                malfunction=3,
+                reward=env.step_penalty  # full step penalty when malfunctioning
+            ),
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=2,
+                reward=env.step_penalty  # full step penalty when malfunctioning
+            ),
+            # malfunction stops in the next step and we're still at the beginning of the cell
+            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=1,
+                reward=env.start_penalty + env.step_penalty * 1.0
+                # malfunctioning ends: starting and running at speed 1.0
+            ),
+            Replay(
+                position=(28, 4),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.step_penalty * 1.0  # running at speed 1.0
+            ),
+            Replay(
+                position=(27, 4),
+                direction=Grid4TransitionsEnum.NORTH,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.step_penalty * 1.0  # running at speed 1.0
+            )
+        ],
+        speed=env.agents[0].speed_data['speed'],
+        target=env.agents[0].target
+    )
+    run_replay_config(env, [replay_config])
+
+
+def test_initial_malfunction_stop_moving():
     random.seed(0)
     np.random.seed(0)
 
@@ -291,84 +270,66 @@ def test_initial_malfunction_stop_moving(rendering=True):
                   number_of_agents=1,
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   )
-
-    if rendering:
-        renderer = RenderTool(env)
-        renderer.render_env(show=True, frames=False, show_observations=False)
-    _action = dict()
-
-    replay_steps = [
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=3
-        ),
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=2
-        ),
-        # malfunction stops in the next step and we're still at the beginning of the cell
-        # --> if we take action DO_NOTHING, agent should restart without moving
-        #
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.STOP_MOVING,
-            malfunction=1
-        ),
-        # we have stopped and do nothing --> should stand still
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=0
-        ),
-        # we start to move forward --> should go to next cell now
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        ),
-        Replay(
-            position=(28, 4),
-            direction=Grid4TransitionsEnum.WEST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        )
-    ]
-
-    info_dict = {
-        'action_required': [True]
-    }
-
-    for i, replay in enumerate(replay_steps):
-
-        def _assert(actual, expected, msg):
-            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
-
-        agent: EnvAgent = env.agents[0]
-
-        _assert(agent.position, replay.position, 'position')
-        _assert(agent.direction, replay.direction, 'direction')
-        _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
-
-        if replay.action is not None:
-            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-            _, _, _, info_dict = env.step({0: replay.action})
-
-        else:
-            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-            _, _, _, info_dict = env.step({})
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
-
-
-def test_initial_malfunction_do_nothing(rendering=True):
+    set_penalties_for_replay(env)
+    replay_config = ReplayConfig(
+        replay=[
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                set_malfunction=3,
+                malfunction=3,
+                reward=env.step_penalty  # full step penalty when stopped
+            ),
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                malfunction=2,
+                reward=env.step_penalty  # full step penalty when stopped
+            ),
+            # malfunction stops in the next step and we're still at the beginning of the cell
+            # --> if we take action STOP_MOVING, agent should restart without moving
+            #
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.STOP_MOVING,
+                malfunction=1,
+                reward=env.step_penalty  # full step penalty while stopped
+            ),
+            # we have stopped and do nothing --> should stand still
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                malfunction=0,
+                reward=env.step_penalty  # full step penalty while stopped
+            ),
+            # we start to move forward --> should go to next cell now
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.start_penalty + env.step_penalty * 1.0  # full step penalty while stopped
+            ),
+            Replay(
+                position=(28, 4),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.step_penalty * 1.0  # full step penalty while stopped
+            )
+        ],
+        speed=env.agents[0].speed_data['speed'],
+        target=env.agents[0].target
+    )
+
+    run_replay_config(env, [replay_config])
+
+
+def test_initial_malfunction_do_nothing():
     random.seed(0)
     np.random.seed(0)
 
@@ -402,78 +363,59 @@ def test_initial_malfunction_do_nothing(rendering=True):
                   number_of_agents=1,
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   )
-
-    if rendering:
-        renderer = RenderTool(env)
-        renderer.render_env(show=True, frames=False, show_observations=False)
-    _action = dict()
-
-    replay_steps = [
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=3
-        ),
-        Replay(
+    set_penalties_for_replay(env)
+    replay_config = ReplayConfig(
+        replay=[Replay(
             position=(28, 5),
             direction=Grid4TransitionsEnum.EAST,
             action=RailEnvActions.DO_NOTHING,
-            malfunction=2
+            set_malfunction=3,
+            malfunction=3,
+            reward=env.step_penalty  # full step penalty while malfunctioning
         ),
-        # malfunction stops in the next step and we're still at the beginning of the cell
-        # --> if we take action DO_NOTHING, agent should restart without moving
-        #
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=1
-        ),
-        # we haven't started moving yet --> stay here
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=0
-        ),
-        # we start to move forward --> should go to next cell now
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        ),
-        Replay(
-            position=(28, 4),
-            direction=Grid4TransitionsEnum.WEST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        )
-    ]
-
-    info_dict = {
-        'action_required': [True]
-    }
-
-    for i, replay in enumerate(replay_steps):
-
-        def _assert(actual, expected, msg):
-            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
-
-        agent: EnvAgent = env.agents[0]
-
-        _assert(agent.position, replay.position, 'position')
-        _assert(agent.direction, replay.direction, 'direction')
-        _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
-
-        if replay.action is not None:
-            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-            _, _, _, info_dict = env.step({0: replay.action})
-
-        else:
-            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-            _, _, _, info_dict = env.step({})
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                malfunction=2,
+                reward=env.step_penalty  # full step penalty while malfunctioning
+            ),
+            # malfunction stops in the next step and we're still at the beginning of the cell
+            # --> if we take action DO_NOTHING, agent should restart without moving
+            #
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                malfunction=1,
+                reward=env.step_penalty  # full step penalty while stopped
+            ),
+            # we haven't started moving yet --> stay here
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                malfunction=0,
+                reward=env.step_penalty  # full step penalty while stopped
+            ),
+            # we start to move forward --> should go to next cell now
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.start_penalty + env.step_penalty * 1.0  # start penalty + step penalty for speed 1.0
+            ),
+            Replay(
+                position=(28, 4),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.step_penalty * 1.0  # step penalty for speed 1.0
+            )
+        ],
+        speed=env.agents[0].speed_data['speed'],
+        target=env.agents[0].target
+    )
+
+    run_replay_config(env, [replay_config])
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 1cf0c325ac48e9e3d5ac04fb51b5f8462c867726..b0f274ba4c4b5453140fcc50bc6137e39e8e4f04 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -1,15 +1,13 @@
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map
 from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
-from flatland.utils.rendertools import RenderTool
 from flatland.utils.simple_rail import make_simple_rail
-from test_utils import ReplayConfig, Replay
+from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
 
 np.random.seed(1)
 
@@ -95,9 +93,7 @@ def test_multi_speed_init():
                 old_pos[i_agent] = env.agents[i_agent].position
 
 
-# TODO test penalties!
-# TODO test invalid actions!
-def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
+def test_multispeed_actions_no_malfunction_no_blocking():
     """Test that actions are correctly performed on cell exit for a single agent."""
     rail, rail_map = make_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
@@ -108,126 +104,97 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
 
-    # initialize agents_static
-    env.reset()
-
-    # reset to set agents from agents_static
-    env.reset(False, False)
-
-    if rendering:
-        renderer = RenderTool(env, gl="PILSVG")
-
+    set_penalties_for_replay(env)
     test_config = ReplayConfig(
         replay=[
             Replay(
                 position=(3, 9),  # east dead-end
                 direction=Grid4TransitionsEnum.EAST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
             ),
             Replay(
                 position=(3, 9),
                 direction=Grid4TransitionsEnum.EAST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_LEFT
+                action=RailEnvActions.MOVE_LEFT,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.STOP_MOVING
+                action=RailEnvActions.STOP_MOVING,
+                reward=env.stop_penalty + env.step_penalty * 0.5  # stopping and step penalty
             ),
             #
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.STOP_MOVING
+                action=RailEnvActions.STOP_MOVING,
+                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when stopped
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.start_penalty + env.step_penalty * 0.5  # starting + running at speed 0.5
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(5, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
-
         ],
         target=(3, 0),  # west dead-end
         speed=0.5
     )
 
-    agentStatic: EnvAgentStatic = env.agents_static[0]
-    info_dict = {
-        'action_required': [True]
-    }
-    for i, replay in enumerate(test_config.replay):
-        if i == 0:
-            # set the initial position
-            agentStatic.position = replay.position
-            agentStatic.direction = replay.direction
-            agentStatic.target = test_config.target
-            agentStatic.moving = True
-            agentStatic.speed_data['speed'] = test_config.speed
-
-            # reset to set agents from agents_static
-            env.reset(False, False)
-
-        def _assert(actual, expected, msg):
-            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
-
-        agent: EnvAgent = env.agents[0]
-
-        _assert(agent.position, replay.position, 'position')
-        _assert(agent.direction, replay.direction, 'direction')
-
-        if replay.action is not None:
-            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-            _, _, _, info_dict = env.step({0: replay.action})
+    run_replay_config(env, [test_config])
 
-        else:
-            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-            _, _, _, info_dict = env.step({})
 
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
-
-
-def test_multispeed_actions_no_malfunction_blocking(rendering=True):
+def test_multispeed_actions_no_malfunction_blocking():
     """The second agent blocks the first because it is slower."""
     rail, rail_map = make_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
@@ -237,81 +204,84 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
-
-    # initialize agents_static
-    env.reset()
-
-    # reset to set agents from agents_static
-    env.reset(False, False)
-
-    if rendering:
-        renderer = RenderTool(env, gl="PILSVG")
-
+    set_penalties_for_replay(env)
     test_configs = [
         ReplayConfig(
             replay=[
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.start_penalty + env.step_penalty * 1.0 / 3.0  # starting and running at speed 1/3
                 ),
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
 
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
 
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
 
                 Replay(
                     position=(3, 5),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 5),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 5),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 )
             ],
             target=(3, 0),  # west dead-end
@@ -321,69 +291,81 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
                 Replay(
                     position=(3, 9),  # east dead-end
                     direction=Grid4TransitionsEnum.EAST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
                 ),
                 Replay(
                     position=(3, 9),
                     direction=Grid4TransitionsEnum.EAST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 # blocked although fraction >= 1.0
                 Replay(
                     position=(3, 9),
                     direction=Grid4TransitionsEnum.EAST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
 
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 # blocked although fraction >= 1.0
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
 
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 # blocked although fraction >= 1.0
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
 
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_LEFT
+                    action=RailEnvActions.MOVE_LEFT,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 # not blocked, action required!
                 Replay(
                     position=(4, 6),
                     direction=Grid4TransitionsEnum.SOUTH,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
             ],
             target=(3, 0),  # west dead-end
@@ -391,52 +373,10 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
         )
 
     ]
+    run_replay_config(env, test_configs)
+
 
-    # TODO test penalties!
-    info_dict = {
-        'action_required': [True for _ in test_configs]
-    }
-    for step in range(len(test_configs[0].replay)):
-        if step == 0:
-            for a, test_config in enumerate(test_configs):
-                agentStatic: EnvAgentStatic = env.agents_static[a]
-                replay = test_config.replay[0]
-                # set the initial position
-                agentStatic.position = replay.position
-                agentStatic.direction = replay.direction
-                agentStatic.target = test_config.target
-                agentStatic.moving = True
-                agentStatic.speed_data['speed'] = test_config.speed
-
-            # reset to set agents from agents_static
-            env.reset(False, False)
-
-        def _assert(a, actual, expected, msg):
-            assert actual == expected, "[{}] {} {}:  actual={}, expected={}".format(step, a, msg, actual, expected)
-
-        action_dict = {}
-
-        for a, test_config in enumerate(test_configs):
-            agent: EnvAgent = env.agents[a]
-            replay = test_config.replay[step]
-
-            _assert(a, agent.position, replay.position, 'position')
-            _assert(a, agent.direction, replay.direction, 'direction')
-
-            if replay.action is not None:
-                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
-                    step, a, True)
-                action_dict[a] = replay.action
-            else:
-                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
-                    step, a, False)
-        _, _, _, info_dict = env.step(action_dict)
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
-
-
-def test_multispeed_actions_malfunction_no_blocking(rendering=True):
+def test_multispeed_actions_malfunction_no_blocking():
     """Test on a single agent whether action on cell exit work correctly despite malfunction."""
     rail, rail_map = make_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
@@ -447,107 +387,202 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
 
-    # initialize agents_static
-    env.reset()
-
-    # reset to set agents from agents_static
-    env.reset(False, False)
-
-    if rendering:
-        renderer = RenderTool(env, gl="PILSVG")
-
+    set_penalties_for_replay(env)
     test_config = ReplayConfig(
         replay=[
             Replay(
                 position=(3, 9),  # east dead-end
                 direction=Grid4TransitionsEnum.EAST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
             ),
             Replay(
                 position=(3, 9),
                 direction=Grid4TransitionsEnum.EAST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             # add additional step in the cell
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
                 action=None,
-                malfunction=2  # recovers in two steps from now!
+                set_malfunction=2,  # recovers in two steps from now!,
+                malfunction=2,
+                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when malfunctioning
             ),
             # agent recovers in this step
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                malfunction=1,
+                reward=env.step_penalty * 0.5  # recovered: running at speed 0.5
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_FORWARD,
-                malfunction=2  # recovers in two steps from now!
+                set_malfunction=2,  # recovers in two steps from now!
+                malfunction=2,
+                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when malfunctioning
             ),
             # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_LEFT,
+                malfunction=1,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.STOP_MOVING
+                action=RailEnvActions.STOP_MOVING,
+                reward=env.stop_penalty + env.step_penalty * 0.5  # stopping and step penalty for speed 0.5
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.STOP_MOVING
+                action=RailEnvActions.STOP_MOVING,
+                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 while stopped
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             # DO_NOTHING keeps moving!
             Replay(
                 position=(5, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.DO_NOTHING
+                action=RailEnvActions.DO_NOTHING,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(5, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(6, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
+            ),
+
+        ],
+        target=(3, 0),  # west dead-end
+        speed=0.5
+    )
+    run_replay_config(env, [test_config])
+
+
+# TODO invalid action penalty seems only given when forward is not possible - is this the intended behaviour?
+def test_multispeed_actions_no_malfunction_invalid_actions():
+    """Test that actions are correctly performed on cell exit for a single agent."""
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    set_penalties_for_replay(env)
+    test_config = ReplayConfig(
+        replay=[
+            Replay(
+                position=(3, 9),  # east dead-end
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_LEFT,
+                reward=env.start_penalty + env.step_penalty * 0.5  # auto-correction left to forward without penalty!
+            ),
+            Replay(
+                position=(3, 9),
+                direction=Grid4TransitionsEnum.EAST,
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 7),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 7),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_RIGHT,
+                reward=env.step_penalty * 0.5  # wrong action is corrected to forward without penalty!
+            ),
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 5),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_RIGHT,
+                reward=env.step_penalty * 0.5  # wrong action is corrected to forward without penalty!
+            ), Replay(
+                position=(3, 5),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
 
         ],
@@ -555,42 +590,4 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
         speed=0.5
     )
 
-    # TODO test penalties!
-    agentStatic: EnvAgentStatic = env.agents_static[0]
-    info_dict = {
-        'action_required': [True]
-    }
-    for i, replay in enumerate(test_config.replay):
-        if i == 0:
-            # set the initial position
-            agentStatic.position = replay.position
-            agentStatic.direction = replay.direction
-            agentStatic.target = test_config.target
-            agentStatic.moving = True
-            agentStatic.speed_data['speed'] = test_config.speed
-
-            # reset to set agents from agents_static
-            env.reset(False, False)
-
-        def _assert(actual, expected, msg):
-            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
-
-        agent: EnvAgent = env.agents[0]
-
-        _assert(agent.position, replay.position, 'position')
-        _assert(agent.direction, replay.direction, 'direction')
-
-        if replay.malfunction > 0:
-            agent.malfunction_data['malfunction'] = replay.malfunction
-            agent.malfunction_data['moving_before_malfunction'] = agent.moving
-
-        if replay.action is not None:
-            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-            _, _, _, info_dict = env.step({0: replay.action})
-
-        else:
-            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-            _, _, _, info_dict = env.step({})
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
+    run_replay_config(env, [test_config])
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 6347bd0f5048350c099ba2568dac7caba74baf2d..903120d868aa65833e7c2393ddfcc821c26da4f6 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,10 +1,13 @@
 """Test Utils."""
-from typing import List, Tuple
+from typing import List, Tuple, Optional
 
+import numpy as np
 from attr import attrs, attrib
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.rail_env import RailEnvActions
+from flatland.envs.agent_utils import EnvAgent
+from flatland.envs.rail_env import RailEnvActions, RailEnv
+from flatland.utils.rendertools import RenderTool
 
 
 @attrs
@@ -13,6 +16,8 @@ class Replay(object):
     direction = attrib(type=Grid4TransitionsEnum)
     action = attrib(type=RailEnvActions)
     malfunction = attrib(default=0, type=int)
+    set_malfunction = attrib(default=None, type=Optional[int])
+    reward = attrib(default=None, type=Optional[float])
 
 
 @attrs
@@ -20,3 +25,89 @@ class ReplayConfig(object):
     replay = attrib(type=List[Replay])
     target = attrib(type=Tuple[int, int])
     speed = attrib(type=float)
+
+
+# ensure that env is working correctly with start/stop/invalidaction penalty different from 0
+def set_penalties_for_replay(env: RailEnv):
+    env.step_penalty = -7
+    env.start_penalty = -13
+    env.stop_penalty = -19
+    env.invalid_action_penalty = -29
+
+
+def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False):
+    """
+    Runs the replay configs and checks assertions.
+
+    *Initially*
+    - the position, direction, target and speed of the initial step are taken to initialize the agents
+
+    *Before each step*
+    - action must only be provided if action_required from previous step (initally all True)
+    - position, direction before step are verified
+    - optionally, set_malfunction is applied
+    - malfunction is verified
+
+    *After each step*
+    - reward is verified after step
+
+    Parameters
+    ----------
+    env
+    test_configs
+    rendering
+    """
+    if rendering:
+        renderer = RenderTool(env)
+        renderer.render_env(show=True, frames=False, show_observations=False)
+    info_dict = {
+        'action_required': [True for _ in test_configs]
+    }
+
+    for step in range(len(test_configs[0].replay)):
+        if step == 0:
+            for a, test_config in enumerate(test_configs):
+                agent: EnvAgent = env.agents[a]
+                replay = test_config.replay[0]
+                # set the initial position
+                agent.position = replay.position
+                agent.direction = replay.direction
+                agent.target = test_config.target
+                agent.speed_data['speed'] = test_config.speed
+
+        def _assert(a, actual, expected, msg):
+            assert np.allclose(actual, expected), "[{}] agent {} {}:  actual={}, expected={}".format(step, a, msg,
+                                                                                                    actual,
+                                                                                                    expected)
+
+        action_dict = {}
+
+        for a, test_config in enumerate(test_configs):
+            agent: EnvAgent = env.agents[a]
+            replay = test_config.replay[step]
+
+            _assert(a, agent.position, replay.position, 'position')
+            _assert(a, agent.direction, replay.direction, 'direction')
+
+            if replay.action is not None:
+                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
+                    step, a, True)
+                action_dict[a] = replay.action
+            else:
+                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
+                    step, a, False)
+
+            if replay.set_malfunction is not None:
+                agent.malfunction_data['malfunction'] = replay.set_malfunction
+                agent.malfunction_data['moving_before_malfunction'] = agent.moving
+            _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
+
+        _, rewards_dict, _, info_dict = env.step(action_dict)
+        if rendering:
+            renderer.render_env(show=True, show_observations=True)
+
+        for a, test_config in enumerate(test_configs):
+            replay = test_config.replay[step]
+            _assert(a, rewards_dict[a], replay.reward, 'reward')
+
+