test_utils.py 4.28 KB
Newer Older
u214892's avatar
u214892 committed
1
"""Test Utils."""
u214892's avatar
u214892 committed
2
from typing import List, Tuple, Optional
u214892's avatar
u214892 committed
3

4
import numpy as np
u214892's avatar
u214892 committed
5
6
from attr import attrs, attrib

u214892's avatar
u214892 committed
7
from flatland.core.grid.grid4 import Grid4TransitionsEnum
u214892's avatar
u214892 committed
8
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
9
10
from flatland.envs.rail_env import RailEnvActions, RailEnv
from flatland.utils.rendertools import RenderTool
u214892's avatar
u214892 committed
11
12
13
14


@attrs
class Replay(object):
u214892's avatar
u214892 committed
15
16
    position = attrib(type=Tuple[int, int])
    direction = attrib(type=Grid4TransitionsEnum)
u214892's avatar
u214892 committed
17
18
    action = attrib(type=RailEnvActions)
    malfunction = attrib(default=0, type=int)
19
20
    set_malfunction = attrib(default=None, type=Optional[int])
    reward = attrib(default=None, type=Optional[float])
u214892's avatar
u214892 committed
21
    status = attrib(default=None, type=Optional[RailAgentStatus])
u214892's avatar
u214892 committed
22
23
24


@attrs
u214892's avatar
u214892 committed
25
class ReplayConfig(object):
u214892's avatar
u214892 committed
26
    replay = attrib(type=List[Replay])
u214892's avatar
u214892 committed
27
    target = attrib(type=Tuple[int, int])
u214892's avatar
u214892 committed
28
    speed = attrib(type=float)
29
30


31
32
# ensure that env is working correctly with start/stop/invalidaction penalty different from 0
def set_penalties_for_replay(env: RailEnv):
33
34
35
36
    env.step_penalty = -7
    env.start_penalty = -13
    env.stop_penalty = -19
    env.invalid_action_penalty = -29
37
38


39
40
41
42
43
44
45
46
47
48
49
50
def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False):
    """
    Runs the replay configs and checks assertions.

    *Initially*
    - the position, direction, target and speed of the initial step are taken to initialize the agents

    *Before each step*
    - action must only be provided if action_required from previous step (initally all True)
    - position, direction before step are verified
    - optionally, set_malfunction is applied
    - malfunction is verified
u214892's avatar
u214892 committed
51
    - status is verified (optionally)
52
53
54
55

    *After each step*
    - reward is verified after step

u214892's avatar
u214892 committed
56

57
58
59
60
61
62
63
64
65
66
67
68
    Parameters
    ----------
    env
    test_configs
    rendering
    """
    if rendering:
        renderer = RenderTool(env)
        renderer.render_env(show=True, frames=False, show_observations=False)
    info_dict = {
        'action_required': [True for _ in test_configs]
    }
69

70
71
72
73
74
75
76
77
78
79
80
81
    for step in range(len(test_configs[0].replay)):
        if step == 0:
            for a, test_config in enumerate(test_configs):
                agent: EnvAgent = env.agents[a]
                replay = test_config.replay[0]
                # set the initial position
                agent.position = replay.position
                agent.direction = replay.direction
                agent.target = test_config.target
                agent.speed_data['speed'] = test_config.speed

        def _assert(a, actual, expected, msg):
82
            assert np.allclose(actual, expected), "[{}] agent {} {}:  actual={}, expected={}".format(step, a, msg,
u214892's avatar
u214892 committed
83
84
                                                                                                     actual,
                                                                                                     expected)
85
86
87
88
89
90
91
92
93

        action_dict = {}

        for a, test_config in enumerate(test_configs):
            agent: EnvAgent = env.agents[a]
            replay = test_config.replay[step]

            _assert(a, agent.position, replay.position, 'position')
            _assert(a, agent.direction, replay.direction, 'direction')
u214892's avatar
u214892 committed
94
95
            if replay.status is not None:
                _assert(a, agent.status, replay.status, 'status')
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

            if replay.action is not None:
                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
                    step, a, True)
                action_dict[a] = replay.action
            else:
                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
                    step, a, False)

            if replay.set_malfunction is not None:
                agent.malfunction_data['malfunction'] = replay.set_malfunction
                agent.malfunction_data['moving_before_malfunction'] = agent.moving
            _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')

        _, rewards_dict, _, info_dict = env.step(action_dict)
111
112
        if rendering:
            renderer.render_env(show=True, show_observations=True)
113
114
115
116

        for a, test_config in enumerate(test_configs):
            replay = test_config.replay[step]
            _assert(a, rewards_dict[a], replay.reward, 'reward')