test_utils.py 7.5 KB
Newer Older
u214892's avatar
u214892 committed
1
"""Test Utils."""
u214892's avatar
u214892 committed
2
from typing import List, Tuple, Optional
u214892's avatar
u214892 committed
3

4
import numpy as np
u214892's avatar
u214892 committed
5
6
from attr import attrs, attrib

u214892's avatar
u214892 committed
7
from flatland.core.grid.grid4 import Grid4TransitionsEnum
8
from flatland.envs.agent_utils import EnvAgent
Erik Nygren's avatar
Erik Nygren committed
9
from flatland.envs.malfunction_generators import MalfunctionParameters, malfunction_from_params
10
from flatland.envs.rail_env import RailEnvActions, RailEnv
Erik Nygren's avatar
Erik Nygren committed
11
from flatland.envs.rail_generators import RailGenerator
12
from flatland.envs.line_generators import LineGenerator
13
from flatland.utils.rendertools import RenderTool
14
from flatland.envs.persistence import RailEnvPersister
15
16
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.speed_counter import SpeedCounter
u214892's avatar
u214892 committed
17
18
19

@attrs
class Replay(object):
u214892's avatar
u214892 committed
20
21
    position = attrib(type=Tuple[int, int])
    direction = attrib(type=Grid4TransitionsEnum)
u214892's avatar
u214892 committed
22
23
    action = attrib(type=RailEnvActions)
    malfunction = attrib(default=0, type=int)
24
25
    set_malfunction = attrib(default=None, type=Optional[int])
    reward = attrib(default=None, type=Optional[float])
26
    state = attrib(default=None, type=Optional[TrainState])
u214892's avatar
u214892 committed
27
28
29


@attrs
u214892's avatar
u214892 committed
30
class ReplayConfig(object):
u214892's avatar
u214892 committed
31
    replay = attrib(type=List[Replay])
u214892's avatar
u214892 committed
32
    target = attrib(type=Tuple[int, int])
u214892's avatar
u214892 committed
33
    speed = attrib(type=float)
u214892's avatar
u214892 committed
34
35
    initial_position = attrib(type=Tuple[int, int])
    initial_direction = attrib(type=Grid4TransitionsEnum)
36
37


38
39
# ensure that env is working correctly with start/stop/invalidaction penalty different from 0
def set_penalties_for_replay(env: RailEnv):
40
41
42
43
    env.step_penalty = -7
    env.start_penalty = -13
    env.stop_penalty = -19
    env.invalid_action_penalty = -29
44
45


Dipam Chakraborty's avatar
Dipam Chakraborty committed
46
47
def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True, 
                      skip_reward_check=False, set_ready_to_depart=False, skip_action_required_check=False):
48
49
50
51
    """
    Runs the replay configs and checks assertions.

    *Initially*
u214892's avatar
u214892 committed
52
    - The `initial_position`, `initial_direction`, `target` and `speed` are taken from the `ReplayConfig` to initialize the agents.
53
54

    *Before each step*
u214892's avatar
u214892 committed
55
56
57
58
59
60
61
62
63
    - `position` is verfified
    - `direction` is verified
    - `status` is verified (optionally, only if not `None` in `Replay`)
    - `set_malfunction` is applied (optionally, only if not `None` in `Replay`)
    - `malfunction` is verified
    - `action` must only be provided if action_required from previous step (initally all True)

    *Step*
    - performed with the given `action`
64
65

    *After each step*
u214892's avatar
u214892 committed
66
    - `reward` is verified after step
67

u214892's avatar
u214892 committed
68

69
70
    Parameters
    ----------
u214892's avatar
u214892 committed
71
72
73
74
    activate_agents: should the agents directly be activated when the environment is initially setup by `reset()`?
    env: the environment; is `reset()` to set the agents' intial position, direction, target and speed
    test_configs: the `ReplayConfig`s, one for each agent
    rendering: should be rendered during replay?
75
76
77
78
79
80
81
    """
    if rendering:
        renderer = RenderTool(env)
        renderer.render_env(show=True, frames=False, show_observations=False)
    info_dict = {
        'action_required': [True for _ in test_configs]
    }
82

83
84
85
    for step in range(len(test_configs[0].replay)):
        if step == 0:
            for a, test_config in enumerate(test_configs):
u229589's avatar
u229589 committed
86
                agent: EnvAgent = env.agents[a]
87
                # set the initial position
u214892's avatar
u214892 committed
88
                agent.initial_position = test_config.initial_position
u229589's avatar
u229589 committed
89
                agent.initial_direction = test_config.initial_direction
u214892's avatar
u214892 committed
90
                agent.direction = test_config.initial_direction
91
                agent.target = test_config.target
92
                agent.speed_counter = SpeedCounter(speed=test_config.speed)
93
            env.reset(False, False)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
94
95
96
97
98
99
100
101

            if set_ready_to_depart:
                # Set all agents to ready to depart
                for i_agent in range(len(env.agents)):
                    env.agents[i_agent].earliest_departure = 0
                    env.agents[i_agent]._set_state(TrainState.READY_TO_DEPART)

            elif activate_agents:
102
103
                for a_idx in range(len(env.agents)):
                    env.agents[a_idx].position =  env.agents[a_idx].initial_position
104
                    env.agents[a_idx]._set_state(TrainState.MOVING)
105
106

        def _assert(a, actual, expected, msg):
u214892's avatar
u214892 committed
107
108
109
110
111
            print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected))
            assert (actual == expected) or (
                np.allclose(actual, expected)), "[{}] agent {} {}:  actual={}, expected={}".format(step, a, msg,
                                                                                                   actual,
                                                                                                   expected)
112
113
114
115
116
117

        action_dict = {}

        for a, test_config in enumerate(test_configs):
            agent: EnvAgent = env.agents[a]
            replay = test_config.replay[step]
118
119
            _assert(a, agent.position, replay.position, 'position')
            _assert(a, agent.direction, replay.direction, 'direction')
120
121
            if replay.state is not None:
                _assert(a, agent.state, replay.state, 'state')
122
123

            if replay.action is not None:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
124
125
                if not skip_action_required_check:    
                    assert info_dict['action_required'][
126
                           a] == True or agent.state == TrainState.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format(
127
128
129
                    step, a, True)
                action_dict[a] = replay.action
            else:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
130
131
                if not skip_action_required_check:
                    assert info_dict['action_required'][
u214892's avatar
u214892 committed
132
133
                           a] == False, "[{}] agent {} expecting action_required={}, but found {}".format(
                    step, a, False, info_dict['action_required'][a])
134
135

            if replay.set_malfunction is not None:
136
137
138
                # As we force malfunctions on the agents we have to set a positive rate that the env
                # recognizes the agent as potentially malfuncitoning
                # We also set next malfunction to infitiy to avoid interference with our tests
139
140
                env.agents[a].malfunction_handler._set_malfunction_down_counter(replay.set_malfunction)
            _assert(a, agent.malfunction_handler.malfunction_down_counter, replay.malfunction, 'malfunction')
141
        print(step)
142
        _, rewards_dict, _, info_dict = env.step(action_dict)
143
144
        if rendering:
            renderer.render_env(show=True, show_observations=True)
145
146
147

        for a, test_config in enumerate(test_configs):
            replay = test_config.replay[step]
u214892's avatar
u214892 committed
148

Dipam Chakraborty's avatar
Dipam Chakraborty committed
149
150
            if not skip_reward_check:
                _assert(a, rewards_dict[a], replay.reward, 'reward')
151

152
def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator):
153
154
155
156
157
158
159
160
    stochastic_data = MalfunctionParameters(malfunction_rate=1000,  # Rate of malfunction occurence
                                            min_duration=15,  # Minimal duration of malfunction
                                            max_duration=50  # Max duration of malfunction
                                            )

    env = RailEnv(width=30,
                  height=30,
                  rail_generator=rail_generator,
161
                  line_generator=line_generator,
162
163
164
165
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  remove_agents_at_target=True)
    env.reset(True, True)
166
167
    #env.save(file_name)
    RailEnvPersister.save(env, file_name)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
168
    return env