test_flatland_rail_agent_status.py 9.31 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
5
6
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
7
from flatland.envs.line_generators import sparse_line_generator
u214892's avatar
u214892 committed
8
from flatland.utils.simple_rail import make_simple_rail
9
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
u214892's avatar
u214892 committed
10
11
12
13


def test_initial_status():
    """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
14
15
16
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                  line_generator=sparse_line_generator(), number_of_agents=1,
u214892's avatar
u214892 committed
17
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
18
                  remove_agents_at_target=False)
19
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
20
21
22
23
24

    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents

u214892's avatar
u214892 committed
25
26
27
28
    set_penalties_for_replay(env)
    test_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
29
                position=None,  # not entered grid yet
u214892's avatar
u214892 committed
30
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
31
                status=RailAgentStatus.READY_TO_DEPART,
u214892's avatar
u214892 committed
32
                action=RailEnvActions.DO_NOTHING,
33
                reward=env.step_penalty * 0.5,
u214892's avatar
u214892 committed
34

u214892's avatar
u214892 committed
35
36
            ),
            Replay(
u214892's avatar
u214892 committed
37
                position=None,  # not entered grid yet before step
u214892's avatar
u214892 committed
38
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
39
                status=RailAgentStatus.READY_TO_DEPART,
u214892's avatar
u214892 committed
40
                action=RailEnvActions.MOVE_LEFT,
41
                reward=env.step_penalty * 0.5,  # auto-correction left to forward without penalty!
u214892's avatar
u214892 committed
42
43
44
45
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
46
                status=RailAgentStatus.ACTIVE,
47
48
                action=RailEnvActions.MOVE_LEFT,
                reward=env.start_penalty + env.step_penalty * 0.5,  # running at speed 0.5
u214892's avatar
u214892 committed
49
50
            ),
            Replay(
Erik Nygren's avatar
Erik Nygren committed
51
                position=(3, 9),
52
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
53
                status=RailAgentStatus.ACTIVE,
54
                action=None,
u214892's avatar
u214892 committed
55
56
57
58
59
                reward=env.step_penalty * 0.5,  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
u214892's avatar
u214892 committed
60
                status=RailAgentStatus.ACTIVE,
61
                action=RailEnvActions.MOVE_FORWARD,
u214892's avatar
u214892 committed
62
63
64
                reward=env.step_penalty * 0.5,  # running at speed 0.5
            ),
            Replay(
Erik Nygren's avatar
Erik Nygren committed
65
                position=(3, 8),
u214892's avatar
u214892 committed
66
                direction=Grid4TransitionsEnum.WEST,
u214892's avatar
u214892 committed
67
                status=RailAgentStatus.ACTIVE,
68
                action=None,
u214892's avatar
u214892 committed
69
                reward=env.step_penalty * 0.5,  # running at speed 0.5
u214892's avatar
u214892 committed
70

u214892's avatar
u214892 committed
71
72
73
74
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
75
                action=RailEnvActions.MOVE_FORWARD,
u214892's avatar
u214892 committed
76
77
78
79
                reward=env.step_penalty * 0.5,  # running at speed 0.5
                status=RailAgentStatus.ACTIVE
            ),
            Replay(
Erik Nygren's avatar
Erik Nygren committed
80
                position=(3, 7),
u214892's avatar
u214892 committed
81
                direction=Grid4TransitionsEnum.WEST,
82
                action=None,
u214892's avatar
u214892 committed
83
84
85
86
87
88
                reward=env.step_penalty * 0.5,  # wrong action is corrected to forward without penalty!
                status=RailAgentStatus.ACTIVE
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
89
90
                action=RailEnvActions.MOVE_RIGHT,
                reward=env.step_penalty * 0.5,  #
u214892's avatar
u214892 committed
91
92
93
                status=RailAgentStatus.ACTIVE
            ),
            Replay(
Erik Nygren's avatar
Erik Nygren committed
94
                position=(3, 6),
u214892's avatar
u214892 committed
95
                direction=Grid4TransitionsEnum.WEST,
u214892's avatar
u214892 committed
96
                action=None,
97
98
                reward=env.global_reward,  #
                status=RailAgentStatus.ACTIVE
u214892's avatar
u214892 committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            ),
            Replay(
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.global_reward,  # already done
                status=RailAgentStatus.DONE
            ),
            Replay(
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.global_reward,  # already done
                status=RailAgentStatus.DONE
            )

        ],
u214892's avatar
u214892 committed
116
117
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
118
119
120
121
        target=(3, 5),
        speed=0.5
    )

u214892's avatar
u214892 committed
122
    run_replay_config(env, [test_config], activate_agents=False)
u214892's avatar
u214892 committed
123

Erik Nygren's avatar
Erik Nygren committed
124

u214892's avatar
u214892 committed
125
126
def test_status_done_remove():
    """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
127
128
129
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                  line_generator=sparse_line_generator(), number_of_agents=1,
u214892's avatar
u214892 committed
130
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
131
                  remove_agents_at_target=True)
132
    env.reset()
u214892's avatar
u214892 committed
133

Dipam Chakraborty's avatar
Dipam Chakraborty committed
134
135
136
137
    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents

u214892's avatar
u214892 committed
138
139
140
141
142
143
144
145
    set_penalties_for_replay(env)
    test_config = ReplayConfig(
        replay=[
            Replay(
                position=None,  # not entered grid yet
                direction=Grid4TransitionsEnum.EAST,
                status=RailAgentStatus.READY_TO_DEPART,
                action=RailEnvActions.DO_NOTHING,
146
                reward=env.step_penalty * 0.5,
u214892's avatar
u214892 committed
147
148
149
150
151
152
153

            ),
            Replay(
                position=None,  # not entered grid yet before step
                direction=Grid4TransitionsEnum.EAST,
                status=RailAgentStatus.READY_TO_DEPART,
                action=RailEnvActions.MOVE_LEFT,
154
155
156
157
158
159
160
161
                reward=env.step_penalty * 0.5,  # auto-correction left to forward without penalty!
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
                status=RailAgentStatus.ACTIVE,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty + env.step_penalty * 0.5,  # running at speed 0.5
u214892's avatar
u214892 committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
                status=RailAgentStatus.ACTIVE,
                action=None,
                reward=env.step_penalty * 0.5,  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                status=RailAgentStatus.ACTIVE,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5,  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                status=RailAgentStatus.ACTIVE,
                action=None,
                reward=env.step_penalty * 0.5,  # running at speed 0.5

            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
188
                action=RailEnvActions.MOVE_RIGHT,
u214892's avatar
u214892 committed
189
190
191
192
                reward=env.step_penalty * 0.5,  # running at speed 0.5
                status=RailAgentStatus.ACTIVE
            ),
            Replay(
193
                position=(3, 7),
u214892's avatar
u214892 committed
194
                direction=Grid4TransitionsEnum.WEST,
195
                action=None,
u214892's avatar
u214892 committed
196
197
198
199
200
201
                reward=env.step_penalty * 0.5,  # wrong action is corrected to forward without penalty!
                status=RailAgentStatus.ACTIVE
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
202
203
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5,  # done
u214892's avatar
u214892 committed
204
205
206
                status=RailAgentStatus.ACTIVE
            ),
            Replay(
207
                position=(3, 6),
u214892's avatar
u214892 committed
208
209
210
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.global_reward,  # already done
211
                status=RailAgentStatus.ACTIVE
u214892's avatar
u214892 committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
            ),
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.global_reward,  # already done
                status=RailAgentStatus.DONE_REMOVED
            ),
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.global_reward,  # already done
                status=RailAgentStatus.DONE_REMOVED
            )

        ],
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
        target=(3, 5),
        speed=0.5
    )

    run_replay_config(env, [test_config], activate_agents=False)