test_flatland_malfunction.py 21.3 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
5
import numpy as np

6
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
7
from flatland.core.grid.grid4 import Grid4TransitionsEnum
8
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
9
from flatland.envs.agent_utils import RailAgentStatus
u214892's avatar
u214892 committed
10
11
12
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
u214892's avatar
u214892 committed
13
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
14
15


16
class SingleAgentNavigationObs(ObservationBuilder):
17
    """
18
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
19
20
21
22
23
24
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
25
        super().__init__()
26
27

    def reset(self):
28
        pass
29

30
    def get(self, handle: int = 0) -> List[int]:
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        agent = self.env.agents[handle]

        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
45
                    new_position = get_new_position(agent.position, direction)
u214892's avatar
u214892 committed
46
47
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
48
49
50
51
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
52
            observation[np.argmin(min_distances)[0]] = 1
53
54
55
56
57

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
58
    # Set fixed malfunction duration for this test
59
    stochastic_data = {'prop_malfunction': 1.,
60
                       'malfunction_rate': 1000,
61
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
62
                       'max_duration': 3}
63
64
    np.random.seed(5)

65
66
    env = RailEnv(width=20,
                  height=20,
67
68
                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                        seed=0),
69
                  schedule_generator=complex_schedule_generator(),
70
71
72
73
74
                  number_of_agents=2,
                  obs_builder_object=SingleAgentNavigationObs(),
                  stochastic_data=stochastic_data)

    obs = env.reset()
Erik Nygren's avatar
Erik Nygren committed
75
76
77

    # Check that a initial duration for malfunction was assigned
    assert env.agents[0].malfunction_data['next_malfunction'] > 0
u214892's avatar
u214892 committed
78
79
    for agent in env.agents:
        agent.status = RailAgentStatus.ACTIVE
Erik Nygren's avatar
Erik Nygren committed
80

81
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
82
83
    total_down_time = 0
    agent_old_position = env.agents[0].position
84
85
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
86

87
88
89
90
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

        if step % 5 == 0:
Erik Nygren's avatar
Erik Nygren committed
91
            # Stop the agent and set it to be malfunctioning
92
            env.agents[0].malfunction_data['malfunction'] = -1
Erik Nygren's avatar
Erik Nygren committed
93
            env.agents[0].malfunction_data['next_malfunction'] = 0
94
95
            agent_halts += 1

96
97
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
98
99
100
101
102
103
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
104
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
105
106
107
108
109
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
110
    # Check that the appropriate number of malfunctions is achieved
u214892's avatar
u214892 committed
111
112
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
113

Erik Nygren's avatar
Erik Nygren committed
114
    # Check that 20 stops where performed
Erik Nygren's avatar
Erik Nygren committed
115
    assert agent_halts == 20
116

Erik Nygren's avatar
Erik Nygren committed
117
118
    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
119
120
121
122
123
124
125
126
127
128


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 2,
                       'min_duration': 3,
                       'max_duration': 3}
    np.random.seed(5)
u214892's avatar
u214892 committed
129
    random.seed(0)
u214892's avatar
u214892 committed
130
131
132
133
134
135
136
137
138
139
140
141
142

    env = RailEnv(width=20,
                  height=20,
                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                        seed=0),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=2,
                  obs_builder_object=SingleAgentNavigationObs(),
                  stochastic_data=stochastic_data)

    env.reset()
    nb_malfunction = 0
    for step in range(100):
143
        action_dict: Dict[int, RailEnvActions] = {}
u214892's avatar
u214892 committed
144
145
146
147
        for agent in env.agents:
            if agent.malfunction_data['malfunction'] > 0:
                nb_malfunction += 1
            # We randomly select an action
148
            action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
u214892's avatar
u214892 committed
149
150
151
152

        env.step(action_dict)

    # check that generation of malfunctions works as expected
153
    assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction)
u214892's avatar
u214892 committed
154
155


156
def test_initial_malfunction():
u214892's avatar
u214892 committed
157
    random.seed(0)
u214892's avatar
u214892 committed
158
159
    np.random.seed(0)

u214892's avatar
u214892 committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    speed_ration_map = {1.: 1.,  # Fast passenger train
                        1. / 2.: 0.,  # Fast freight train
                        1. / 3.: 0.,  # Slow commuter train
                        1. / 4.: 0.}  # Slow freight train

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=sparse_rail_generator(num_cities=5,
                                                       # Number of cities in map (where train stations are)
                                                       num_intersections=4,
                                                       # Number of intersections (no start / target)
                                                       num_trainstations=25,  # Number of possible start/targets on map
                                                       min_node_dist=6,  # Minimal distance of nodes
                                                       node_radius=3,  # Proximity of stations to city center
                                                       num_neighb=3,
                                                       # Number of connections to other cities/intersections
                                                       seed=215545,  # Random seed
                                                       grid_mode=True,
                                                       enhance_intersection=False
                                                       ),
                  schedule_generator=sparse_schedule_generator(speed_ration_map),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
190
    set_penalties_for_replay(env)
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    replay_config = ReplayConfig(
        replay=[
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
                reward=env.start_penalty + env.step_penalty * 1.0
                # malfunctioning ends: starting and running at speed 1.0
            ),
            Replay(
                position=(28, 4),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            ),
            Replay(
                position=(27, 4),
                direction=Grid4TransitionsEnum.NORTH,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            )
        ],
        speed=env.agents[0].speed_data['speed'],
        target=env.agents[0].target
    )
    run_replay_config(env, [replay_config])


def test_initial_malfunction_stop_moving():
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    speed_ration_map = {1.: 1.,  # Fast passenger train
                        1. / 2.: 0.,  # Fast freight train
                        1. / 3.: 0.,  # Slow commuter train
                        1. / 4.: 0.}  # Slow freight train

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=sparse_rail_generator(num_cities=5,
                                                       # Number of cities in map (where train stations are)
                                                       num_intersections=4,
                                                       # Number of intersections (no start / target)
                                                       num_trainstations=25,  # Number of possible start/targets on map
                                                       min_node_dist=6,  # Minimal distance of nodes
                                                       node_radius=3,  # Proximity of stations to city center
                                                       num_neighb=3,
                                                       # Number of connections to other cities/intersections
                                                       seed=215545,  # Random seed
                                                       grid_mode=True,
                                                       enhance_intersection=False
                                                       ),
                  schedule_generator=sparse_schedule_generator(speed_ration_map),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
273
    set_penalties_for_replay(env)
274
275
276
277
278
    replay_config = ReplayConfig(
        replay=[
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
279
                action=RailEnvActions.MOVE_FORWARD,
280
281
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
282
283
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
284
285
286
287
288
289
            ),
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
290
291
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
292
293
294
295
296
297
298
299
300
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
                malfunction=1,
u214892's avatar
u214892 committed
301
302
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
303
304
305
306
307
308
309
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
310
311
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
312
313
314
315
316
317
318
            ),
            # we start to move forward --> should go to next cell now
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
319
320
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
321
322
323
324
325
326
            ),
            Replay(
                position=(28, 4),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
327
328
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
329
330
331
332
333
334
335
            )
        ],
        speed=env.agents[0].speed_data['speed'],
        target=env.agents[0].target
    )

    run_replay_config(env, [replay_config])
336
337


338
def test_initial_malfunction_do_nothing():
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    speed_ration_map = {1.: 1.,  # Fast passenger train
                        1. / 2.: 0.,  # Fast freight train
                        1. / 3.: 0.,  # Slow commuter train
                        1. / 4.: 0.}  # Slow freight train

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=sparse_rail_generator(num_cities=5,
                                                       # Number of cities in map (where train stations are)
                                                       num_intersections=4,
                                                       # Number of intersections (no start / target)
                                                       num_trainstations=25,  # Number of possible start/targets on map
                                                       min_node_dist=6,  # Minimal distance of nodes
                                                       node_radius=3,  # Proximity of stations to city center
                                                       num_neighb=3,
                                                       # Number of connections to other cities/intersections
                                                       seed=215545,  # Random seed
                                                       grid_mode=True,
                                                       enhance_intersection=False
                                                       ),
                  schedule_generator=sparse_schedule_generator(speed_ration_map),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
372
    set_penalties_for_replay(env)
373
374
    replay_config = ReplayConfig(
        replay=[Replay(
375
376
            position=(28, 5),
            direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
377
            action=RailEnvActions.MOVE_FORWARD,
378
379
            set_malfunction=3,
            malfunction=3,
u214892's avatar
u214892 committed
380
381
            reward=env.step_penalty,  # full step penalty while malfunctioning
            status=RailAgentStatus.READY_TO_DEPART
382
        ),
383
384
385
386
387
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
388
389
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
390
391
392
393
394
395
396
397
398
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=1,
u214892's avatar
u214892 committed
399
400
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
401
402
403
404
405
406
407
            ),
            # we haven't started moving yet --> stay here
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
408
409
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
410
411
412
413
414
415
416
            ),
            # we start to move forward --> should go to next cell now
            Replay(
                position=(28, 5),
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
417
418
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
419
420
421
422
423
424
            ),
            Replay(
                position=(28, 4),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
425
426
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
427
428
429
430
431
432
433
            )
        ],
        speed=env.agents[0].speed_data['speed'],
        target=env.agents[0].target
    )

    run_replay_config(env, [replay_config])
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474


def test_initial_nextmalfunction_not_below_zero():
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 0.5,  # Rate of malfunction occurence
                       'min_duration': 5,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    speed_ration_map = {1.: 1.,  # Fast passenger train
                        1. / 2.: 0.,  # Fast freight train
                        1. / 3.: 0.,  # Slow commuter train
                        1. / 4.: 0.}  # Slow freight train

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=sparse_rail_generator(num_cities=5,
                                                       # Number of cities in map (where train stations are)
                                                       num_intersections=4,
                                                       # Number of intersections (no start / target)
                                                       num_trainstations=25,  # Number of possible start/targets on map
                                                       min_node_dist=6,  # Minimal distance of nodes
                                                       node_radius=3,  # Proximity of stations to city center
                                                       num_neighb=3,
                                                       # Number of connections to other cities/intersections
                                                       seed=215545,  # Random seed
                                                       grid_mode=True,
                                                       enhance_intersection=False
                                                       ),
                  schedule_generator=sparse_schedule_generator(speed_ration_map),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
    agent = env.agents[0]
    env.step({})
    # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
    assert agent.malfunction_data['next_malfunction'] >= 0, \
        "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction'])