test_flatland_malfunction.py 20.2 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
13
14
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
15
16


17
class SingleAgentNavigationObs(ObservationBuilder):
18
    """
19
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
20
21
22
23
24
25
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
26
        super().__init__()
27
28

    def reset(self):
29
        pass
30

31
    def get(self, handle: int = 0) -> List[int]:
32
33
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
34
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
35
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
36
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
37
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
38
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
39
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
40
41
42
        else:
            return None

u214892's avatar
u214892 committed
43
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
44
45
46
47
48
49
50
51
52
53
54
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
55
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
56
57
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
58
59
60
61
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
62
            observation[np.argmin(min_distances)] = 1
63
64
65
66
67

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
68
    # Set fixed malfunction duration for this test
69
    stochastic_data = {'prop_malfunction': 1.,
70
                       'malfunction_rate': 1000,
71
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
72
                       'max_duration': 3}
73
74
75
76
77
78
79
80
81
82
83
84

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
85
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
86

Erik Nygren's avatar
Erik Nygren committed
87
88
    # Check that a initial duration for malfunction was assigned
    assert env.agents[0].malfunction_data['next_malfunction'] > 0
u214892's avatar
u214892 committed
89
90
    for agent in env.agents:
        agent.status = RailAgentStatus.ACTIVE
Erik Nygren's avatar
Erik Nygren committed
91

92
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
93
94
    total_down_time = 0
    agent_old_position = env.agents[0].position
95
96
97

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
98
99
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
100

101
102
103
104
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

        if step % 5 == 0:
Erik Nygren's avatar
Erik Nygren committed
105
            # Stop the agent and set it to be malfunctioning
106
            env.agents[0].malfunction_data['malfunction'] = -1
Erik Nygren's avatar
Erik Nygren committed
107
            env.agents[0].malfunction_data['next_malfunction'] = 0
108
109
            agent_halts += 1

110
111
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
112
113
114
115
116
117
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
118
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
119
120
121
122
123
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
124
    # Check that the appropriate number of malfunctions is achieved
Erik Nygren's avatar
Erik Nygren committed
125
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
u214892's avatar
u214892 committed
126
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
127

Erik Nygren's avatar
Erik Nygren committed
128
    # Check that 20 stops where performed
Erik Nygren's avatar
Erik Nygren committed
129
    assert agent_halts == 20
130

Erik Nygren's avatar
Erik Nygren committed
131
132
    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
133
134
135
136
137
138
139
140
141
142


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 2,
                       'min_duration': 3,
                       'max_duration': 3}

143
144
145
146
147
148
149
150
151
152
153
    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
154
    env.reset(True, True, False, random_seed=10)
155

Erik Nygren's avatar
Erik Nygren committed
156
    env.agents[0].target = (0, 0)
u214892's avatar
u214892 committed
157
    nb_malfunction = 0
Erik Nygren's avatar
Erik Nygren committed
158
    for step in range(20):
159
        action_dict: Dict[int, RailEnvActions] = {}
u214892's avatar
u214892 committed
160
161
        for agent in env.agents:
            # We randomly select an action
162
            action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
u214892's avatar
u214892 committed
163
164
165

        env.step(action_dict)
    # check that generation of malfunctions works as expected
166
    assert env.agents[0].malfunction_data["nr_malfunctions"] == 4
u214892's avatar
u214892 committed
167
168


169
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
170
    """Tests that malfunctions are produced by stochastic_data!"""
171
172
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
173
                       'malfunction_rate': 1,
174
175
176
177
178
179
180
181
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
Erik Nygren's avatar
Erik Nygren committed
182
183
184
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=10,
                  random_seed=1,
185
186
187
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
188
    env.reset(False, False, False, random_seed=10)
189
    env.agents[0].target = (0, 0)
190
191
192
193
194
195
196
197
198
199
200

    assert env.agents[1].malfunction_data['malfunction'] == 11
    assert env.agents[2].malfunction_data['malfunction'] == 11
    assert env.agents[3].malfunction_data['malfunction'] == 11
    assert env.agents[4].malfunction_data['malfunction'] == 11
    assert env.agents[5].malfunction_data['malfunction'] == 11
    assert env.agents[6].malfunction_data['malfunction'] == 11
    assert env.agents[7].malfunction_data['malfunction'] == 11
    assert env.agents[8].malfunction_data['malfunction'] == 11
    assert env.agents[9].malfunction_data['malfunction'] == 11

Erik Nygren's avatar
Erik Nygren committed
201

202
203
204
205
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
Erik Nygren's avatar
Erik Nygren committed
206
            action_dict[agent.handle] = RailEnvActions(2)
207
208
209
210
            if step < 10:
                action_dict[agent.handle] = RailEnvActions(0)

        env.step(action_dict)
211

Erik Nygren's avatar
Erik Nygren committed
212
213
214
215
    assert env.agents[1].malfunction_data['malfunction'] == 1
    assert env.agents[2].malfunction_data['malfunction'] == 1
    assert env.agents[3].malfunction_data['malfunction'] == 1
    assert env.agents[4].malfunction_data['malfunction'] == 1
216
    assert env.agents[5].malfunction_data['malfunction'] == 1
Erik Nygren's avatar
Erik Nygren committed
217
218
219
    assert env.agents[6].malfunction_data['malfunction'] == 1
    assert env.agents[7].malfunction_data['malfunction'] == 1
    assert env.agents[8].malfunction_data['malfunction'] == 1
220
    assert env.agents[9].malfunction_data['malfunction'] == 1
Erik Nygren's avatar
Erik Nygren committed
221
222
223
224
225
    # Print for test generation
    # for a in range(env.get_num_agents()):
    #    print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,
    #                                                                               env.agents[a].malfunction_data[
    #                                                                                   'malfunction']))
226
227


228
def test_initial_malfunction():
229

u214892's avatar
u214892 committed
230
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
231
                       'malfunction_rate': 100,  # Rate of malfunction occurence
u214892's avatar
u214892 committed
232
233
234
235
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

236
237
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
238
239
    env = RailEnv(width=25,
                  height=30,
240
                  rail_generator=rail_from_grid_transition_map(rail),
241
                  schedule_generator=random_schedule_generator(seed=10),
u214892's avatar
u214892 committed
242
243
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
244
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
245
                  )
246
247

    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
248
    env.reset(False, False, True, random_seed=10)
249
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
250
    env.agents[0].target = (0, 5)
251
    set_penalties_for_replay(env)
252
253
254
    replay_config = ReplayConfig(
        replay=[
            Replay(
255
                position=(3, 2),
256
257
258
259
260
261
262
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
263
                position=(3, 2),
264
265
266
267
268
269
270
271
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
272
                position=(3, 2),
273
274
275
276
277
278
279
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
                reward=env.start_penalty + env.step_penalty * 1.0
                # malfunctioning ends: starting and running at speed 1.0
            ),
            Replay(
280
                position=(3, 3),
281
                direction=Grid4TransitionsEnum.EAST,
282
283
284
285
286
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            ),
            Replay(
287
288
                position=(3, 4),
                direction=Grid4TransitionsEnum.EAST,
289
290
291
292
293
294
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
295
        target=env.agents[0].target,
296
        initial_position=(3, 2),
u214892's avatar
u214892 committed
297
        initial_direction=Grid4TransitionsEnum.EAST,
298
    )
299
    run_replay_config(env, [replay_config])
300
301
302


def test_initial_malfunction_stop_moving():
303
304
305
306
307
308
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

309
    rail, rail_map = make_simple_rail2()
310
311
312

    env = RailEnv(width=25,
                  height=30,
313
314
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
315
316
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
317
                  obs_builder_object=SingleAgentNavigationObs()
318
                  )
319
320
321
322
    # reset to initialize agents_static

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

323
    set_penalties_for_replay(env)
324
325
326
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
327
                position=None,
328
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
329
                action=RailEnvActions.MOVE_FORWARD,
330
331
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
332
333
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
334
335
            ),
            Replay(
336
                position=(3, 2),
337
338
339
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
340
341
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
342
343
344
345
346
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
347
                position=(3, 2),
348
349
350
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
                malfunction=1,
u214892's avatar
u214892 committed
351
352
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
353
354
355
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
356
                position=(3, 2),
357
358
359
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
360
361
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
362
363
364
            ),
            # we start to move forward --> should go to next cell now
            Replay(
365
                position=(3, 2),
366
367
368
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
369
370
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
371
372
            ),
            Replay(
373
                position=(3, 3),
374
                direction=Grid4TransitionsEnum.EAST,
375
376
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
377
378
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
379
380
381
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
382
        target=env.agents[0].target,
383
        initial_position=(3, 2),
u214892's avatar
u214892 committed
384
        initial_direction=Grid4TransitionsEnum.EAST,
385
    )
386
387

    run_replay_config(env, [replay_config], activate_agents=False)
388
389


390
def test_initial_malfunction_do_nothing():
391
392
393
394
395
396
397
398
399
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

400
401
    rail, rail_map = make_simple_rail2()

402
403
404

    env = RailEnv(width=25,
                  height=30,
405
406
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
407
408
409
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
410
411
    # reset to initialize agents_static
    env.reset()
412
    set_penalties_for_replay(env)
413
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
414
415
416
417
418
419
420
421
422
423
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
424
            Replay(
425
                position=(3, 2),
426
427
428
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
429
430
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
431
432
433
434
435
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
436
                position=(3, 2),
437
438
439
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=1,
u214892's avatar
u214892 committed
440
441
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
442
443
444
            ),
            # we haven't started moving yet --> stay here
            Replay(
445
                position=(3, 2),
446
447
448
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
449
450
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
451
            ),
452

453
            Replay(
454
                position=(3, 2),
455
456
457
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
458
459
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
460
            ),  # we start to move forward --> should go to next cell now
461
            Replay(
462
                position=(3, 3),
463
                direction=Grid4TransitionsEnum.EAST,
464
465
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
466
467
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
468
469
470
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
471
        target=env.agents[0].target,
472
        initial_position=(3, 2),
u214892's avatar
u214892 committed
473
        initial_direction=Grid4TransitionsEnum.EAST,
474
    )
475
    run_replay_config(env, [replay_config], activate_agents=False)
476
477
478
479
480
481
482


def test_initial_nextmalfunction_not_below_zero():
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
483
484
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
485
486
487
                       'max_duration': 5  # Max duration of malfunction
                       }

488
    rail, rail_map = make_simple_rail2()
489
490
491

    env = RailEnv(width=25,
                  height=30,
492
493
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
494
495
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
496
                  obs_builder_object=SingleAgentNavigationObs()
497
                  )
498
499
    # reset to initialize agents_static
    env.reset()
500
501
502
503
504
    agent = env.agents[0]
    env.step({})
    # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
    assert agent.malfunction_data['next_malfunction'] >= 0, \
        "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction'])