Skip to content
Snippets Groups Projects
test_utils.py 3.66 KiB
Newer Older
"""Test Utils."""
u214892's avatar
u214892 committed
from typing import List, Tuple, Optional

from attr import attrs, attrib

u214892's avatar
u214892 committed
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.rail_env import RailEnvActions, RailEnv
from flatland.utils.rendertools import RenderTool


@attrs
class Replay(object):
u214892's avatar
u214892 committed
    position = attrib(type=Tuple[int, int])
    direction = attrib(type=Grid4TransitionsEnum)
    action = attrib(type=RailEnvActions)
    malfunction = attrib(default=0, type=int)
    set_malfunction = attrib(default=None, type=Optional[int])
    reward = attrib(default=None, type=Optional[float])
class ReplayConfig(object):
    replay = attrib(type=List[Replay])
u214892's avatar
u214892 committed
    target = attrib(type=Tuple[int, int])
    speed = attrib(type=float)


def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False):
    """
    Runs the replay configs and checks assertions.

    *Initially*
    - the position, direction, target and speed of the initial step are taken to initialize the agents

    *Before each step*
    - action must only be provided if action_required from previous step (initally all True)
    - position, direction before step are verified
    - optionally, set_malfunction is applied
    - malfunction is verified

    *After each step*
    - reward is verified after step

    Parameters
    ----------
    env
    test_configs
    rendering
    """
    if rendering:
        renderer = RenderTool(env)
        renderer.render_env(show=True, frames=False, show_observations=False)
    info_dict = {
        'action_required': [True for _ in test_configs]
    }
    for step in range(len(test_configs[0].replay)):
        if step == 0:
            for a, test_config in enumerate(test_configs):
                agent: EnvAgent = env.agents[a]
                replay = test_config.replay[0]
                # set the initial position
                agent.position = replay.position
                agent.direction = replay.direction
                agent.target = test_config.target
                agent.speed_data['speed'] = test_config.speed

        def _assert(a, actual, expected, msg):
            assert actual == expected, "[{}] agent {} {}:  actual={}, expected={}".format(step, a, msg, actual,
                                                                                          expected)

        action_dict = {}

        for a, test_config in enumerate(test_configs):
            agent: EnvAgent = env.agents[a]
            replay = test_config.replay[step]

            _assert(a, agent.position, replay.position, 'position')
            _assert(a, agent.direction, replay.direction, 'direction')

            if replay.action is not None:
                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
                    step, a, True)
                action_dict[a] = replay.action
            else:
                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
                    step, a, False)

            if replay.set_malfunction is not None:
                agent.malfunction_data['malfunction'] = replay.set_malfunction
                agent.malfunction_data['moving_before_malfunction'] = agent.moving
            _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')

        _, rewards_dict, _, info_dict = env.step(action_dict)

        for a, test_config in enumerate(test_configs):
            replay = test_config.replay[step]
            _assert(a, rewards_dict[a], replay.reward, 'reward')

    if rendering:
        renderer.render_env(show=True, show_observations=True)