test_flatland_malfunction.py 21.1 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
5
import numpy as np

6
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
7
from flatland.core.grid.grid4 import Grid4TransitionsEnum
8
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
9
from flatland.envs.agent_utils import RailAgentStatus
u214892's avatar
u214892 committed
10
11
12
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
u214892's avatar
u214892 committed
13
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
14
15


16
class SingleAgentNavigationObs(ObservationBuilder):
17
    """
18
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
19
20
21
22
23
24
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
25
        super().__init__()
26
27

    def reset(self):
28
        pass
29

30
    def get(self, handle: int = 0) -> List[int]:
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        agent = self.env.agents[handle]

        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
45
                    new_position = get_new_position(agent.position, direction)
u214892's avatar
u214892 committed
46
47
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
48
49
50
51
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
52
            observation[np.argmin(min_distances)[0]] = 1
53
54
55
56
57

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
58
    # Set fixed malfunction duration for this test
59
    stochastic_data = {'prop_malfunction': 1.,
60
                       'malfunction_rate': 1000,
61
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
62
                       'max_duration': 3}
63
64
    np.random.seed(5)

65
66
    env = RailEnv(width=20,
                  height=20,
67
68
                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                        seed=0),
69
                  schedule_generator=complex_schedule_generator(),
70
71
72
73
74
                  number_of_agents=2,
                  obs_builder_object=SingleAgentNavigationObs(),
                  stochastic_data=stochastic_data)

    obs = env.reset()
Erik Nygren's avatar
Erik Nygren committed
75
76
77
78

    # Check that a initial duration for malfunction was assigned
    assert env.agents[0].malfunction_data['next_malfunction'] > 0

79
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
80
81
    total_down_time = 0
    agent_old_position = env.agents[0].position
82
83
84
85
86
87
    for step in range(100):
        actions = {}
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

        if step % 5 == 0:
Erik Nygren's avatar
Erik Nygren committed
88
            # Stop the agent and set it to be malfunctioning
89
            env.agents[0].malfunction_data['malfunction'] = -1
Erik Nygren's avatar
Erik Nygren committed
90
            env.agents[0].malfunction_data['next_malfunction'] = 0
91
92
            agent_halts += 1

93
94
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
95
96
97
98
99
100
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
101
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
102
103
104
105
106
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
107
    # Check that the appropriate number of malfunctions is achieved
108
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21
Erik Nygren's avatar
Erik Nygren committed
109

Erik Nygren's avatar
Erik Nygren committed
110
    # Check that 20 stops where performed
Erik Nygren's avatar
Erik Nygren committed
111
    assert agent_halts == 20
112

Erik Nygren's avatar
Erik Nygren committed
113
114
    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
115
116
117
118
119
120
121
122
123
124


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 2,
                       'min_duration': 3,
                       'max_duration': 3}
    np.random.seed(5)
u214892's avatar
u214892 committed
125
    random.seed(0)
u214892's avatar
u214892 committed
126
127
128
129
130
131
132
133
134
135
136
137
138

    env = RailEnv(width=20,
                  height=20,
                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                        seed=0),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=2,
                  obs_builder_object=SingleAgentNavigationObs(),
                  stochastic_data=stochastic_data)

    env.reset()
    nb_malfunction = 0
    for step in range(100):
139
        action_dict: Dict[int, RailEnvActions] = {}
u214892's avatar
u214892 committed
140
141
142
143
        for agent in env.agents:
            if agent.malfunction_data['malfunction'] > 0:
                nb_malfunction += 1
            # We randomly select an action
144
            action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
u214892's avatar
u214892 committed
145
146
147
148

        env.step(action_dict)

    # check that generation of malfunctions works as expected
149
    assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction)
u214892's avatar
u214892 committed
150
151


152
def test_initial_malfunction():
u214892's avatar
u214892 committed
153
    random.seed(0)
u214892's avatar
u214892 committed
154
155
    np.random.seed(0)

u214892's avatar
u214892 committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    speed_ration_map = {1.: 1.,  # Fast passenger train
                        1. / 2.: 0.,  # Fast freight train
                        1. / 3.: 0.,  # Slow commuter train
                        1. / 4.: 0.}  # Slow freight train

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=sparse_rail_generator(num_cities=5,
                                                       # Number of cities in map (where train stations are)
                                                       num_intersections=4,
                                                       # Number of intersections (no start / target)
                                                       num_trainstations=25,  # Number of possible start/targets on map
                                                       min_node_dist=6,  # Minimal distance of nodes
                                                       node_radius=3,  # Proximity of stations to city center
                                                       num_neighb=3,
                                                       # Number of connections to other cities/intersections
                                                       seed=215545,  # Random seed
                                                       grid_mode=True,
                                                       enhance_intersection=False
                                                       ),
                  schedule_generator=sparse_schedule_generator(speed_ration_map),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
186
    set_penalties_for_replay(env)
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    replay_config = ReplayConfig(
        replay=[
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
                reward=env.start_penalty + env.step_penalty * 1.0
                # malfunctioning ends: starting and running at speed 1.0
            ),
            Replay(
                position=(28, 4),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            ),
            Replay(
                position=(27, 4),
                direction=Grid4TransitionsEnum.NORTH,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            )
        ],
        speed=env.agents[0].speed_data['speed'],
        target=env.agents[0].target
    )
    run_replay_config(env, [replay_config])


def test_initial_malfunction_stop_moving():
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    speed_ration_map = {1.: 1.,  # Fast passenger train
                        1. / 2.: 0.,  # Fast freight train
                        1. / 3.: 0.,  # Slow commuter train
                        1. / 4.: 0.}  # Slow freight train

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=sparse_rail_generator(num_cities=5,
                                                       # Number of cities in map (where train stations are)
                                                       num_intersections=4,
                                                       # Number of intersections (no start / target)
                                                       num_trainstations=25,  # Number of possible start/targets on map
                                                       min_node_dist=6,  # Minimal distance of nodes
                                                       node_radius=3,  # Proximity of stations to city center
                                                       num_neighb=3,
                                                       # Number of connections to other cities/intersections
                                                       seed=215545,  # Random seed
                                                       grid_mode=True,
                                                       enhance_intersection=False
                                                       ),
                  schedule_generator=sparse_schedule_generator(speed_ration_map),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
269
    set_penalties_for_replay(env)
270
271
272
273
274
    replay_config = ReplayConfig(
        replay=[
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
275
                action=RailEnvActions.MOVE_FORWARD,
276
277
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
278
279
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
280
281
282
283
284
285
            ),
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
286
287
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
288
289
290
291
292
293
294
295
296
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
                malfunction=1,
u214892's avatar
u214892 committed
297
298
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
299
300
301
302
303
304
305
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
306
307
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
308
309
310
311
312
313
314
            ),
            # we start to move forward --> should go to next cell now
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
315
316
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
317
318
319
320
321
322
            ),
            Replay(
                position=(28, 4),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
323
324
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
325
326
327
328
329
330
331
            )
        ],
        speed=env.agents[0].speed_data['speed'],
        target=env.agents[0].target
    )

    run_replay_config(env, [replay_config])
332
333


334
def test_initial_malfunction_do_nothing():
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    speed_ration_map = {1.: 1.,  # Fast passenger train
                        1. / 2.: 0.,  # Fast freight train
                        1. / 3.: 0.,  # Slow commuter train
                        1. / 4.: 0.}  # Slow freight train

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=sparse_rail_generator(num_cities=5,
                                                       # Number of cities in map (where train stations are)
                                                       num_intersections=4,
                                                       # Number of intersections (no start / target)
                                                       num_trainstations=25,  # Number of possible start/targets on map
                                                       min_node_dist=6,  # Minimal distance of nodes
                                                       node_radius=3,  # Proximity of stations to city center
                                                       num_neighb=3,
                                                       # Number of connections to other cities/intersections
                                                       seed=215545,  # Random seed
                                                       grid_mode=True,
                                                       enhance_intersection=False
                                                       ),
                  schedule_generator=sparse_schedule_generator(speed_ration_map),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
368
    set_penalties_for_replay(env)
369
370
    replay_config = ReplayConfig(
        replay=[Replay(
371
372
            position=(28, 5),
            direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
373
            action=RailEnvActions.MOVE_FORWARD,
374
375
            set_malfunction=3,
            malfunction=3,
u214892's avatar
u214892 committed
376
377
            reward=env.step_penalty,  # full step penalty while malfunctioning
            status=RailAgentStatus.READY_TO_DEPART
378
        ),
379
380
381
382
383
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
384
385
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
386
387
388
389
390
391
392
393
394
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=1,
u214892's avatar
u214892 committed
395
396
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
397
398
399
400
401
402
403
            ),
            # we haven't started moving yet --> stay here
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
404
405
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
406
407
408
409
410
411
412
            ),
            # we start to move forward --> should go to next cell now
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
413
414
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
415
416
417
418
419
420
            ),
            Replay(
                position=(28, 4),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
421
422
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
423
424
425
426
427
428
429
            )
        ],
        speed=env.agents[0].speed_data['speed'],
        target=env.agents[0].target
    )

    run_replay_config(env, [replay_config])
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470


def test_initial_nextmalfunction_not_below_zero():
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 0.5,  # Rate of malfunction occurence
                       'min_duration': 5,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    speed_ration_map = {1.: 1.,  # Fast passenger train
                        1. / 2.: 0.,  # Fast freight train
                        1. / 3.: 0.,  # Slow commuter train
                        1. / 4.: 0.}  # Slow freight train

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=sparse_rail_generator(num_cities=5,
                                                       # Number of cities in map (where train stations are)
                                                       num_intersections=4,
                                                       # Number of intersections (no start / target)
                                                       num_trainstations=25,  # Number of possible start/targets on map
                                                       min_node_dist=6,  # Minimal distance of nodes
                                                       node_radius=3,  # Proximity of stations to city center
                                                       num_neighb=3,
                                                       # Number of connections to other cities/intersections
                                                       seed=215545,  # Random seed
                                                       grid_mode=True,
                                                       enhance_intersection=False
                                                       ),
                  schedule_generator=sparse_schedule_generator(speed_ration_map),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
    agent = env.agents[0]
    env.step({})
    # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
    assert agent.malfunction_data['next_malfunction'] >= 0, \
        "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction'])