From 44b55f209230d50fb04724add06cf5f6d6e08b3c Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 19 Sep 2019 15:17:05 +0200
Subject: [PATCH] #168 #163 multispeed and penalty testing

---
 flatland/envs/rail_env.py          |  17 +-
 tests/test_flatland_malfunction.py | 374 ++++++++++++-----------------
 tests/test_multi_speed.py          | 308 +++++++++---------------
 tests/test_utils.py                |  82 ++++++-
 4 files changed, 368 insertions(+), 413 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 294ffab2..0467fcd6 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -236,7 +236,8 @@ class RailEnv(Environment):
             Relies on the rail_generator returning agent_static lists (pos, dir, target)
         """
 
-        # TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
+        # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172
+        #  can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
         rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
         if optionals and 'distance_map' in optionals:
@@ -256,6 +257,9 @@ class RailEnv(Environment):
             agents_hints = None
             if optionals and 'agents_hints' in optionals:
                 agents_hints = optionals['agents_hints']
+
+            # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185
+            #  why do we need static agents? could we it more elegantly?
             self.agents_static = EnvAgentStatic.from_lists(
                 *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints))
         self.restart_agents()
@@ -356,6 +360,8 @@ class RailEnv(Environment):
 
         # Perform step on all agents
         for i_agent in range(self.get_num_agents()):
+            if self._elapsed_steps - 1 == 3 and i_agent == 0:
+                a = 5
             self._step_agent(i_agent, action_dict_.get(i_agent))
 
         # Check for end of episode + set global reward to all rewards!
@@ -407,13 +413,14 @@ class RailEnv(Environment):
         # is the agent malfunctioning?
         malfunction = self._agent_malfunction(i_agent)
 
-        # if agent is broken, actions are ignored and agent does not move,
-        # the agent is not penalized in this step!
+        # if agent is broken, actions are ignored and agent does not move.
+        # full step penalty in this case
         if malfunction:
             self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
             return
 
         # Is the agent at the beginning of the cell? Then, it can take an action.
+        # As long as the agent is malfunctioning or stopped at the beginning of the cell, different actions may be taken!
         if agent.speed_data['position_fraction'] == 0.0:
             # No action has been supplied for this agent -> set DO_NOTHING as default
             if action is None:
@@ -464,7 +471,6 @@ class RailEnv(Environment):
                 if not _action_stored:
                     # If the agent cannot move due to an invalid transition, we set its state to not moving
                     self.rewards_dict[i_agent] += self.invalid_action_penalty
-                    self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
                     self.rewards_dict[i_agent] += self.stop_penalty
                     agent.moving = False
 
@@ -497,6 +503,9 @@ class RailEnv(Environment):
                 agent.moving = False
             else:
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+        else:
+            # step penalty if not moving (stopped now or before)
+            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
 
     def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
         """
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 884a2a51..8bd023cf 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -4,13 +4,11 @@ import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
-from flatland.utils.rendertools import RenderTool
-from test_utils import Replay
+from test_utils import Replay, ReplayConfig, run_replay_config
 
 
 class SingleAgentNavigationObs(TreeObsForRailEnv):
@@ -155,7 +153,7 @@ def test_malfunction_process_statistically():
     assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction)
 
 
-def test_initial_malfunction(rendering=True):
+def test_initial_malfunction():
     random.seed(0)
     np.random.seed(0)
 
@@ -190,74 +188,55 @@ def test_initial_malfunction(rendering=True):
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   )
 
-    if rendering:
-        renderer = RenderTool(env)
-        renderer.render_env(show=True, frames=False, show_observations=False)
-    _action = dict()
-
-    replay_steps = [
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=3
-        ),
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=2
-        ),
-        # malfunction stops in the next step and we're still at the beginning of the cell
-        # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=1
-        ),
-        Replay(
-            position=(28, 4),
-            direction=Grid4TransitionsEnum.WEST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        ),
-        Replay(
-            position=(27, 4),
-            direction=Grid4TransitionsEnum.NORTH,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        )
-    ]
-
-    info_dict = {
-        'action_required': [True]
-    }
-
-    for i, replay in enumerate(replay_steps):
-
-        def _assert(actual, expected, msg):
-            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
-
-        agent: EnvAgent = env.agents[0]
-
-        _assert(agent.position, replay.position, 'position')
-        _assert(agent.direction, replay.direction, 'direction')
-        _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
-
-        if replay.action is not None:
-            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-            _, _, _, info_dict = env.step({0: replay.action})
-
-        else:
-            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-            _, _, _, info_dict = env.step({})
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
-
-
-def test_initial_malfunction_stop_moving(rendering=True):
+    replay_config = ReplayConfig(
+        replay=[
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                set_malfunction=3,
+                malfunction=3,
+                reward=env.step_penalty  # full step penalty when malfunctioning
+            ),
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=2,
+                reward=env.step_penalty  # full step penalty when malfunctioning
+            ),
+            # malfunction stops in the next step and we're still at the beginning of the cell
+            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=1,
+                reward=env.start_penalty + env.step_penalty * 1.0
+                # malfunctioning ends: starting and running at speed 1.0
+            ),
+            Replay(
+                position=(28, 4),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.step_penalty * 1.0  # running at speed 1.0
+            ),
+            Replay(
+                position=(27, 4),
+                direction=Grid4TransitionsEnum.NORTH,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.step_penalty * 1.0  # running at speed 1.0
+            )
+        ],
+        speed=env.agents[0].speed_data['speed'],
+        target=env.agents[0].target
+    )
+    run_replay_config(env, [replay_config])
+
+
+def test_initial_malfunction_stop_moving():
     random.seed(0)
     np.random.seed(0)
 
@@ -292,80 +271,62 @@ def test_initial_malfunction_stop_moving(rendering=True):
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   )
 
-    if rendering:
-        renderer = RenderTool(env)
-        renderer.render_env(show=True, frames=False, show_observations=False)
-    _action = dict()
-
-    replay_steps = [
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=3
-        ),
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=2
-        ),
-        # malfunction stops in the next step and we're still at the beginning of the cell
-        # --> if we take action DO_NOTHING, agent should restart without moving
-        #
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.STOP_MOVING,
-            malfunction=1
-        ),
-        # we have stopped and do nothing --> should stand still
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=0
-        ),
-        # we start to move forward --> should go to next cell now
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        ),
-        Replay(
-            position=(28, 4),
-            direction=Grid4TransitionsEnum.WEST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        )
-    ]
-
-    info_dict = {
-        'action_required': [True]
-    }
-
-    for i, replay in enumerate(replay_steps):
-
-        def _assert(actual, expected, msg):
-            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
-
-        agent: EnvAgent = env.agents[0]
-
-        _assert(agent.position, replay.position, 'position')
-        _assert(agent.direction, replay.direction, 'direction')
-        _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
-
-        if replay.action is not None:
-            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-            _, _, _, info_dict = env.step({0: replay.action})
-
-        else:
-            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-            _, _, _, info_dict = env.step({})
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
+    replay_config = ReplayConfig(
+        replay=[
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                set_malfunction=3,
+                malfunction=3,
+                reward=env.step_penalty  # full step penalty when stopped
+            ),
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                malfunction=2,
+                reward=env.step_penalty  # full step penalty when stopped
+            ),
+            # malfunction stops in the next step and we're still at the beginning of the cell
+            # --> if we take action STOP_MOVING, agent should restart without moving
+            #
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.STOP_MOVING,
+                malfunction=1,
+                reward=env.step_penalty  # full step penalty while stopped
+            ),
+            # we have stopped and do nothing --> should stand still
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                malfunction=0,
+                reward=env.step_penalty  # full step penalty while stopped
+            ),
+            # we start to move forward --> should go to next cell now
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.start_penalty + env.step_penalty * 1.0  # full step penalty while stopped
+            ),
+            Replay(
+                position=(28, 4),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.step_penalty * 1.0  # full step penalty while stopped
+            )
+        ],
+        speed=env.agents[0].speed_data['speed'],
+        target=env.agents[0].target
+    )
+
+    run_replay_config(env, [replay_config])
 
 
 def test_initial_malfunction_do_nothing(rendering=True):
@@ -403,77 +364,58 @@ def test_initial_malfunction_do_nothing(rendering=True):
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   )
 
-    if rendering:
-        renderer = RenderTool(env)
-        renderer.render_env(show=True, frames=False, show_observations=False)
-    _action = dict()
-
-    replay_steps = [
-        Replay(
+    replay_config = ReplayConfig(
+        replay=[Replay(
             position=(28, 5),
             direction=Grid4TransitionsEnum.EAST,
             action=RailEnvActions.DO_NOTHING,
-            malfunction=3
+            set_malfunction=3,
+            malfunction=3,
+            reward=env.step_penalty  # full step penalty while malfunctioning
         ),
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=2
-        ),
-        # malfunction stops in the next step and we're still at the beginning of the cell
-        # --> if we take action DO_NOTHING, agent should restart without moving
-        #
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=1
-        ),
-        # we haven't started moving yet --> stay here
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            malfunction=0
-        ),
-        # we start to move forward --> should go to next cell now
-        Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        ),
-        Replay(
-            position=(28, 4),
-            direction=Grid4TransitionsEnum.WEST,
-            action=RailEnvActions.MOVE_FORWARD,
-            malfunction=0
-        )
-    ]
-
-    info_dict = {
-        'action_required': [True]
-    }
-
-    for i, replay in enumerate(replay_steps):
-
-        def _assert(actual, expected, msg):
-            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
-
-        agent: EnvAgent = env.agents[0]
-
-        _assert(agent.position, replay.position, 'position')
-        _assert(agent.direction, replay.direction, 'direction')
-        _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
-
-        if replay.action is not None:
-            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-            _, _, _, info_dict = env.step({0: replay.action})
-
-        else:
-            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-            _, _, _, info_dict = env.step({})
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                malfunction=2,
+                reward=env.step_penalty  # full step penalty while malfunctioning
+            ),
+            # malfunction stops in the next step and we're still at the beginning of the cell
+            # --> if we take action DO_NOTHING, agent should restart without moving
+            #
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                malfunction=1,
+                reward=env.step_penalty  # full step penalty while stopped
+            ),
+            # we haven't started moving yet --> stay here
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.DO_NOTHING,
+                malfunction=0,
+                reward=env.step_penalty  # full step penalty while stopped
+            ),
+            # we start to move forward --> should go to next cell now
+            Replay(
+                position=(28, 5),
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.start_penalty + env.step_penalty * 1.0  # start penalty + step penalty for speed 1.0
+            ),
+            Replay(
+                position=(28, 4),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=0,
+                reward=env.step_penalty * 1.0  # step penalty for speed 1.0
+            )
+        ],
+        speed=env.agents[0].speed_data['speed'],
+        target=env.agents[0].target
+    )
+
+    run_replay_config(env, [replay_config])
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 1cf0c325..f2fd3613 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -1,7 +1,6 @@
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
@@ -9,7 +8,7 @@ from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid
 from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from flatland.utils.simple_rail import make_simple_rail
-from test_utils import ReplayConfig, Replay
+from test_utils import ReplayConfig, Replay, run_replay_config
 
 np.random.seed(1)
 
@@ -95,7 +94,6 @@ def test_multi_speed_init():
                 old_pos[i_agent] = env.agents[i_agent].position
 
 
-# TODO test penalties!
 # TODO test invalid actions!
 def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
     """Test that actions are correctly performed on cell exit for a single agent."""
@@ -108,123 +106,96 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
 
-    # initialize agents_static
-    env.reset()
-
-    # reset to set agents from agents_static
-    env.reset(False, False)
-
-    if rendering:
-        renderer = RenderTool(env, gl="PILSVG")
+    # env.start_penalty = 13
+    # env.stop_penalty = 19
 
     test_config = ReplayConfig(
         replay=[
             Replay(
                 position=(3, 9),  # east dead-end
                 direction=Grid4TransitionsEnum.EAST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
             ),
             Replay(
                 position=(3, 9),
                 direction=Grid4TransitionsEnum.EAST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_LEFT
+                action=RailEnvActions.MOVE_LEFT,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.STOP_MOVING
+                action=RailEnvActions.STOP_MOVING,
+                reward=env.stop_penalty + env.step_penalty * 0.5  # stopping and step penalty
             ),
             #
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.STOP_MOVING
+                action=RailEnvActions.STOP_MOVING,
+                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when stopped
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.start_penalty + env.step_penalty * 0.5  # starting + running at speed 0.5
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(5, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
-
         ],
         target=(3, 0),  # west dead-end
         speed=0.5
     )
 
-    agentStatic: EnvAgentStatic = env.agents_static[0]
-    info_dict = {
-        'action_required': [True]
-    }
-    for i, replay in enumerate(test_config.replay):
-        if i == 0:
-            # set the initial position
-            agentStatic.position = replay.position
-            agentStatic.direction = replay.direction
-            agentStatic.target = test_config.target
-            agentStatic.moving = True
-            agentStatic.speed_data['speed'] = test_config.speed
-
-            # reset to set agents from agents_static
-            env.reset(False, False)
-
-        def _assert(actual, expected, msg):
-            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
-
-        agent: EnvAgent = env.agents[0]
-
-        _assert(agent.position, replay.position, 'position')
-        _assert(agent.direction, replay.direction, 'direction')
-
-        if replay.action is not None:
-            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-            _, _, _, info_dict = env.step({0: replay.action})
-
-        else:
-            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-            _, _, _, info_dict = env.step({})
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
+    run_replay_config(env, [test_config])
 
 
 def test_multispeed_actions_no_malfunction_blocking(rendering=True):
@@ -238,80 +209,83 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
 
-    # initialize agents_static
-    env.reset()
-
-    # reset to set agents from agents_static
-    env.reset(False, False)
-
-    if rendering:
-        renderer = RenderTool(env, gl="PILSVG")
-
     test_configs = [
         ReplayConfig(
             replay=[
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.start_penalty + env.step_penalty * 1.0 / 3.0  # starting and running at speed 1/3
                 ),
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
 
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
 
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
 
                 Replay(
                     position=(3, 5),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 5),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 ),
                 Replay(
                     position=(3, 5),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
                 )
             ],
             target=(3, 0),  # west dead-end
@@ -321,69 +295,81 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
                 Replay(
                     position=(3, 9),  # east dead-end
                     direction=Grid4TransitionsEnum.EAST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
                 ),
                 Replay(
                     position=(3, 9),
                     direction=Grid4TransitionsEnum.EAST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 # blocked although fraction >= 1.0
                 Replay(
                     position=(3, 9),
                     direction=Grid4TransitionsEnum.EAST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
 
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 # blocked although fraction >= 1.0
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
 
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 # blocked although fraction >= 1.0
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
 
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_LEFT
+                    action=RailEnvActions.MOVE_LEFT,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=None
+                    action=None,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
                 # not blocked, action required!
                 Replay(
                     position=(4, 6),
                     direction=Grid4TransitionsEnum.SOUTH,
-                    action=RailEnvActions.MOVE_FORWARD
+                    action=RailEnvActions.MOVE_FORWARD,
+                    reward=env.step_penalty * 0.5  # running at speed 0.5
                 ),
             ],
             target=(3, 0),  # west dead-end
@@ -391,49 +377,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
         )
 
     ]
-
-    # TODO test penalties!
-    info_dict = {
-        'action_required': [True for _ in test_configs]
-    }
-    for step in range(len(test_configs[0].replay)):
-        if step == 0:
-            for a, test_config in enumerate(test_configs):
-                agentStatic: EnvAgentStatic = env.agents_static[a]
-                replay = test_config.replay[0]
-                # set the initial position
-                agentStatic.position = replay.position
-                agentStatic.direction = replay.direction
-                agentStatic.target = test_config.target
-                agentStatic.moving = True
-                agentStatic.speed_data['speed'] = test_config.speed
-
-            # reset to set agents from agents_static
-            env.reset(False, False)
-
-        def _assert(a, actual, expected, msg):
-            assert actual == expected, "[{}] {} {}:  actual={}, expected={}".format(step, a, msg, actual, expected)
-
-        action_dict = {}
-
-        for a, test_config in enumerate(test_configs):
-            agent: EnvAgent = env.agents[a]
-            replay = test_config.replay[step]
-
-            _assert(a, agent.position, replay.position, 'position')
-            _assert(a, agent.direction, replay.direction, 'direction')
-
-            if replay.action is not None:
-                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
-                    step, a, True)
-                action_dict[a] = replay.action
-            else:
-                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
-                    step, a, False)
-        _, _, _, info_dict = env.step(action_dict)
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
+    run_replay_config(env, test_configs)
 
 
 def test_multispeed_actions_malfunction_no_blocking(rendering=True):
@@ -461,136 +405,118 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
             Replay(
                 position=(3, 9),  # east dead-end
                 direction=Grid4TransitionsEnum.EAST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
             ),
             Replay(
                 position=(3, 9),
                 direction=Grid4TransitionsEnum.EAST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             # add additional step in the cell
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
                 action=None,
-                malfunction=2  # recovers in two steps from now!
+                set_malfunction=2,  # recovers in two steps from now!,
+                malfunction=2,
+                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when malfunctioning
             ),
             # agent recovers in this step
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                malfunction=1,
+                reward=env.step_penalty * 0.5  # recovered: running at speed 0.5
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_FORWARD,
-                malfunction=2  # recovers in two steps from now!
+                set_malfunction=2,  # recovers in two steps from now!
+                malfunction=2,
+                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when malfunctioning
             ),
             # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_LEFT,
+                malfunction=1,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.STOP_MOVING
+                action=RailEnvActions.STOP_MOVING,
+                reward=env.stop_penalty + env.step_penalty * 0.5  # stopping and step penalty for speed 0.5
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.STOP_MOVING
+                action=RailEnvActions.STOP_MOVING,
+                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 while stopped
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
             ),
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             # DO_NOTHING keeps moving!
             Replay(
                 position=(5, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.DO_NOTHING
+                action=RailEnvActions.DO_NOTHING,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(5, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=None
+                action=None,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
             Replay(
                 position=(6, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
-                action=RailEnvActions.MOVE_FORWARD
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5  # running at speed 0.5
             ),
 
         ],
         target=(3, 0),  # west dead-end
         speed=0.5
     )
-
-    # TODO test penalties!
-    agentStatic: EnvAgentStatic = env.agents_static[0]
-    info_dict = {
-        'action_required': [True]
-    }
-    for i, replay in enumerate(test_config.replay):
-        if i == 0:
-            # set the initial position
-            agentStatic.position = replay.position
-            agentStatic.direction = replay.direction
-            agentStatic.target = test_config.target
-            agentStatic.moving = True
-            agentStatic.speed_data['speed'] = test_config.speed
-
-            # reset to set agents from agents_static
-            env.reset(False, False)
-
-        def _assert(actual, expected, msg):
-            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
-
-        agent: EnvAgent = env.agents[0]
-
-        _assert(agent.position, replay.position, 'position')
-        _assert(agent.direction, replay.direction, 'direction')
-
-        if replay.malfunction > 0:
-            agent.malfunction_data['malfunction'] = replay.malfunction
-            agent.malfunction_data['moving_before_malfunction'] = agent.moving
-
-        if replay.action is not None:
-            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-            _, _, _, info_dict = env.step({0: replay.action})
-
-        else:
-            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-            _, _, _, info_dict = env.step({})
-
-        if rendering:
-            renderer.render_env(show=True, show_observations=True)
+    run_replay_config(env, [test_config])
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 96761583..88d669fa 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -4,7 +4,9 @@ from typing import List, Tuple, Optional
 from attr import attrs, attrib
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.rail_env import RailEnvActions
+from flatland.envs.agent_utils import EnvAgent
+from flatland.envs.rail_env import RailEnvActions, RailEnv
+from flatland.utils.rendertools import RenderTool
 
 
 @attrs
@@ -13,7 +15,8 @@ class Replay(object):
     direction = attrib(type=Grid4TransitionsEnum)
     action = attrib(type=RailEnvActions)
     malfunction = attrib(default=0, type=int)
-    penalty = attrib(default=None, type=Optional[float])
+    set_malfunction = attrib(default=None, type=Optional[int])
+    reward = attrib(default=None, type=Optional[float])
 
 
 @attrs
@@ -21,3 +24,78 @@ class ReplayConfig(object):
     replay = attrib(type=List[Replay])
     target = attrib(type=Tuple[int, int])
     speed = attrib(type=float)
+
+
+def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False):
+    """
+    Runs the replay configs and checks assertions.
+
+    *Initially*
+    - the position, direction, target and speed of the initial step are taken to initialize the agents
+
+    *Before each step*
+    - action must only be provided if action_required from previous step (initally all True)
+    - position, direction before step are verified
+    - optionally, set_malfunction is applied
+    - malfunction is verified
+
+    *After each step*
+    - reward is verified after step
+
+    Parameters
+    ----------
+    env
+    test_configs
+    rendering
+    """
+    if rendering:
+        renderer = RenderTool(env)
+        renderer.render_env(show=True, frames=False, show_observations=False)
+    info_dict = {
+        'action_required': [True for _ in test_configs]
+    }
+    for step in range(len(test_configs[0].replay)):
+        if step == 0:
+            for a, test_config in enumerate(test_configs):
+                agent: EnvAgent = env.agents[a]
+                replay = test_config.replay[0]
+                # set the initial position
+                agent.position = replay.position
+                agent.direction = replay.direction
+                agent.target = test_config.target
+                agent.speed_data['speed'] = test_config.speed
+
+        def _assert(a, actual, expected, msg):
+            assert actual == expected, "[{}] agent {} {}:  actual={}, expected={}".format(step, a, msg, actual,
+                                                                                          expected)
+
+        action_dict = {}
+
+        for a, test_config in enumerate(test_configs):
+            agent: EnvAgent = env.agents[a]
+            replay = test_config.replay[step]
+
+            _assert(a, agent.position, replay.position, 'position')
+            _assert(a, agent.direction, replay.direction, 'direction')
+
+            if replay.action is not None:
+                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
+                    step, a, True)
+                action_dict[a] = replay.action
+            else:
+                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
+                    step, a, False)
+
+            if replay.set_malfunction is not None:
+                agent.malfunction_data['malfunction'] = replay.set_malfunction
+                agent.malfunction_data['moving_before_malfunction'] = agent.moving
+            _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
+
+        _, rewards_dict, _, info_dict = env.step(action_dict)
+
+        for a, test_config in enumerate(test_configs):
+            replay = test_config.replay[step]
+            _assert(a, rewards_dict[a], replay.reward, 'reward')
+
+    if rendering:
+        renderer.render_env(show=True, show_observations=True)
-- 
GitLab