action_plan_player.py 1.67 KB
Newer Older
1
2
from typing import Callable

3
4
5
6
from flatland.action_plan.action_plan import ControllerFromTrainruns
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_trainrun_data_structures import Waypoint

7
8
ControllerFromTrainrunsReplayerRenderCallback = Callable[[RailEnv], None]

9
10
11
12
13

class ControllerFromTrainrunsReplayer():
    """Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction."""

    @staticmethod
14
    def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv,
15
                      call_back: ControllerFromTrainrunsReplayerRenderCallback = lambda *a, **k: None):
16
17
18
19
20
21
22
23
24
25
        """Replays this deterministic `ActionPlan` and verifies whether it is feasible.

        Parameters
        ----------
        ctl
        env
        call_back
            Called before/after each step() call. The env is passed to it.
        """
        call_back(env)
26
27
28
29
30
31
32
        i = 0
        while not env.dones['__all__'] and i <= env._max_episode_steps:
            for agent_id, agent in enumerate(env.agents):
                waypoint: Waypoint = ctl.get_waypoint_before_or_at_step(agent_id, i)
                assert agent.position == waypoint.position, \
                    "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
                                                                    waypoint.position)
33
34
                if agent_id == 1:
                    print(env._elapsed_steps, agent.position, agent.state, agent.speed_counter)
35
36
37
38
39
            actions = ctl.act(i)
            print("actions for {}: {}".format(i, actions))

            obs, all_rewards, done, _ = env.step(actions)

40
            call_back(env)
41
42

            i += 1