diff --git a/flatland/action_plan/action_plan_player.py b/flatland/action_plan/action_plan_player.py
index a1608139eb634af9098a6d9a79dca7ff08ec2a93..adf3dbec6662edad048dde89d9cbdf49edbff25a 100644
--- a/flatland/action_plan/action_plan_player.py
+++ b/flatland/action_plan/action_plan_player.py
@@ -1,23 +1,26 @@
+from typing import Callable
+
 from flatland.action_plan.action_plan import ControllerFromTrainruns
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_trainrun_data_structures import Waypoint
-from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 
 
 class ControllerFromTrainrunsReplayer():
     """Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction."""
 
     @staticmethod
-    def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv, rendering: bool):
-        """Replays this deterministic `ActionPlan` and verifies whether it is feasible."""
-        if rendering:
-            renderer = RenderTool(env, gl="PILSVG",
-                                  agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
-                                  show_debug=True,
-                                  clear_debug_text=True,
-                                  screen_height=1000,
-                                  screen_width=1000)
-            renderer.render_env(show=True, show_observations=False, show_predictions=False)
+    def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv,
+                      call_back: Callable[[RailEnv], None] = lambda *a, **k: None):
+        """Replays this deterministic `ActionPlan` and verifies whether it is feasible.
+
+        Parameters
+        ----------
+        ctl
+        env
+        call_back
+            Called before/after each step() call. The env is passed to it.
+        """
+        call_back(env)
         i = 0
         while not env.dones['__all__'] and i <= env._max_episode_steps:
             for agent_id, agent in enumerate(env.agents):
@@ -30,7 +33,6 @@ class ControllerFromTrainrunsReplayer():
 
             obs, all_rewards, done, _ = env.step(actions)
 
-            if rendering:
-                renderer.render_env(show=True, show_observations=False, show_predictions=False)
+            call_back(env)
 
             i += 1
diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py
index 6caa9fc3cff4f4410ed1a32f043da810c40a92c7..9a03eb8a9f89e1edaac558e563a3c0544b4d6b5c 100644
--- a/tests/test_action_plan.py
+++ b/tests/test_action_plan.py
@@ -7,6 +7,7 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.rail_trainrun_data_structures import Waypoint
 from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 from flatland.utils.simple_rail import make_simple_rail
 
 
@@ -84,4 +85,16 @@ def test_action_plan(rendering: bool = False):
     deterministic_controller = ControllerFromTrainruns(env, chosen_path_dict)
     deterministic_controller.print_action_plan()
     ControllerFromTrainruns.assert_actions_plans_equal(expected_action_plan, deterministic_controller.action_plan)
-    ControllerFromTrainrunsReplayer.replay_verify(deterministic_controller, env, rendering)
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG",
+                              agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
+                              show_debug=True,
+                              clear_debug_text=True,
+                              screen_height=1000,
+                              screen_width=1000)
+
+    def render(*argv):
+        if rendering:
+            renderer.render_env(show=True, show_observations=False, show_predictions=False)
+
+    ControllerFromTrainrunsReplayer.replay_verify(deterministic_controller, env, call_back=render)