diff --git a/flatland/action_plan/action_plan_player.py b/flatland/action_plan/action_plan_player.py index a1608139eb634af9098a6d9a79dca7ff08ec2a93..adf3dbec6662edad048dde89d9cbdf49edbff25a 100644 --- a/flatland/action_plan/action_plan_player.py +++ b/flatland/action_plan/action_plan_player.py @@ -1,23 +1,26 @@ +from typing import Callable + from flatland.action_plan.action_plan import ControllerFromTrainruns from flatland.envs.rail_env import RailEnv from flatland.envs.rail_trainrun_data_structures import Waypoint -from flatland.utils.rendertools import RenderTool, AgentRenderVariant class ControllerFromTrainrunsReplayer(): """Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction.""" @staticmethod - def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv, rendering: bool): - """Replays this deterministic `ActionPlan` and verifies whether it is feasible.""" - if rendering: - renderer = RenderTool(env, gl="PILSVG", - agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX, - show_debug=True, - clear_debug_text=True, - screen_height=1000, - screen_width=1000) - renderer.render_env(show=True, show_observations=False, show_predictions=False) + def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv, + call_back: Callable[[RailEnv], None] = lambda *a, **k: None): + """Replays this deterministic `ActionPlan` and verifies whether it is feasible. + + Parameters + ---------- + ctl + env + call_back + Called before/after each step() call. The env is passed to it. + """ + call_back(env) i = 0 while not env.dones['__all__'] and i <= env._max_episode_steps: for agent_id, agent in enumerate(env.agents): @@ -30,7 +33,6 @@ class ControllerFromTrainrunsReplayer(): obs, all_rewards, done, _ = env.step(actions) - if rendering: - renderer.render_env(show=True, show_observations=False, show_predictions=False) + call_back(env) i += 1 diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 6caa9fc3cff4f4410ed1a32f043da810c40a92c7..9a03eb8a9f89e1edaac558e563a3c0544b4d6b5c 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -7,6 +7,7 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.rendertools import RenderTool, AgentRenderVariant from flatland.utils.simple_rail import make_simple_rail @@ -84,4 +85,16 @@ def test_action_plan(rendering: bool = False): deterministic_controller = ControllerFromTrainruns(env, chosen_path_dict) deterministic_controller.print_action_plan() ControllerFromTrainruns.assert_actions_plans_equal(expected_action_plan, deterministic_controller.action_plan) - ControllerFromTrainrunsReplayer.replay_verify(deterministic_controller, env, rendering) + if rendering: + renderer = RenderTool(env, gl="PILSVG", + agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX, + show_debug=True, + clear_debug_text=True, + screen_height=1000, + screen_width=1000) + + def render(*argv): + if rendering: + renderer.render_env(show=True, show_observations=False, show_predictions=False) + + ControllerFromTrainrunsReplayer.replay_verify(deterministic_controller, env, call_back=render)