test_flatland_malfunction.py 27.6 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
13
14
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
15
16


17
class SingleAgentNavigationObs(ObservationBuilder):
18
    """
19
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
20
21
22
23
24
25
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
26
        super().__init__()
27
28

    def reset(self):
29
        pass
30

31
    def get(self, handle: int = 0) -> List[int]:
32
33
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
34
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
35
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
36
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
37
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
38
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
39
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
40
41
42
        else:
            return None

u214892's avatar
u214892 committed
43
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
44
45
46
47
48
49
50
51
52
53
54
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
55
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
56
57
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
58
59
60
61
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
62
            observation[np.argmin(min_distances)] = 1
63
64
65
66
67

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
68
    # Set fixed malfunction duration for this test
69
    stochastic_data = {'malfunction_rate': 1,
70
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
71
                       'max_duration': 3}
72
73
74
75
76
77
78
79
80
81
82
83

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
84
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
85

86
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
87
88
    total_down_time = 0
    agent_old_position = env.agents[0].position
89
90
91

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
92
93
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
94

95
96
97
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

98
99
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
100
101
102
103
104
105
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
106
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
107
108
109
110
111
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
112
    # Check that the appropriate number of malfunctions is achieved
113
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 30, "Actual {}".format(
u214892's avatar
u214892 committed
114
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
115
116
117

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
118
119
120
121
122
123


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
124
125
126
                       'malfunction_rate': 5,
                       'min_duration': 5,
                       'max_duration': 5}
u214892's avatar
u214892 committed
127

128
129
130
131
132
133
    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
134
                  number_of_agents=10,
135
136
137
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
138

139
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
140
    env.reset(True, True, False, random_seed=10)
141

Erik Nygren's avatar
Erik Nygren committed
142
    env.agents[0].target = (0, 0)
143
    # Next line only for test generation
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    # agent_malfunction_list = [[] for i in range(20)]
    agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0],
                              [0, 0, 0, 0, 4, 4, 0, 0, 0, 0],
                              [0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5],
                              [0, 0, 0, 0, 1, 1, 5, 0, 0, 4],
                              [0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 0, 0, 3, 4, 0, 2],
                              [4, 5, 0, 3, 5, 5, 2, 3, 5, 1],
                              [3, 4, 0, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0],
                              [1, 2, 4, 0, 2, 2, 0, 0, 2, 0],
                              [0, 1, 3, 0, 1, 1, 5, 0, 1, 0], [0, 0, 2, 0, 0, 0, 4, 0, 0, 0],
                              [5, 0, 1, 0, 0, 0, 3, 5, 0, 5],
                              [4, 0, 0, 0, 5, 0, 2, 4, 0, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3],
                              [2, 0, 0, 0, 3, 0, 0, 2, 4, 2],
                              [1, 0, 5, 5, 2, 0, 0, 1, 3, 1], [0, 5, 4, 4, 1, 0, 5, 0, 2, 0]]
158

Erik Nygren's avatar
Erik Nygren committed
159
    for step in range(20):
160
        action_dict: Dict[int, RailEnvActions] = {}
161
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
162
            # We randomly select an action
163
164
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
Erik Nygren's avatar
Erik Nygren committed
165
            # agent_malfunction_list[step].append(env.agents[agent_idx].malfunction_data['malfunction'])
166
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[step][agent_idx]
u214892's avatar
u214892 committed
167
        env.step(action_dict)
168
    # For generating test onlz
169
    # print(agent_malfunction_list)
170

u214892's avatar
u214892 committed
171

172
def test_malfunction_before_entry():
173
    """Tests that malfunctions are working properlz for agents before entering the environment!"""
174
175
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
176
                       'malfunction_rate': 5,
177
178
179
180
181
182
183
184
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
Erik Nygren's avatar
Erik Nygren committed
185
186
187
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=10,
                  random_seed=1,
188
189
190
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
191
    env.reset(False, False, False, random_seed=10)
192

193
194
195
196
197
198
199
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
    assert env.agents[0].malfunction_data['malfunction'] == 0
    assert env.agents[1].malfunction_data['malfunction'] == 0
    assert env.agents[2].malfunction_data['malfunction'] == 0
    assert env.agents[3].malfunction_data['malfunction'] == 0
200
201
    assert env.agents[4].malfunction_data['malfunction'] == 10
    assert env.agents[5].malfunction_data['malfunction'] == 10
202
203
204
205
    assert env.agents[6].malfunction_data['malfunction'] == 0
    assert env.agents[7].malfunction_data['malfunction'] == 0
    assert env.agents[8].malfunction_data['malfunction'] == 0
    assert env.agents[9].malfunction_data['malfunction'] == 0
206

207

208
209
210
211
212
def test_next_malfunction_counter():
    """
    Test that the next malfunction occurs when desired
    Returns
    -------
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    """
    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()
    action_dict: Dict[int, RailEnvActions] = {}

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=1,
                  random_seed=1,
                  )
    # reset to initialize agents_static
    env.reset(False, False, activate_agents=True, random_seed=10)
    env.agents[0].malfunction_data['next_malfunction'] = 5
    env.agents[0].malfunction_data['malfunction_rate'] = 5
    env.agents[0].malfunction_data['malfunction'] = 0
232
    env.agents[0].target = (0, 0),  # Move the target out of range
233
234
235
236
    print(env.agents[0].position, env.agents[0].malfunction_data['next_malfunction'])

    for time_step in range(1, 6):
        # Move in the env
237
        env.step(action_dict)
238

239
240
241
        # Check that next_step decreases as expected
        assert env.agents[0].malfunction_data['next_malfunction'] == 5 - time_step

242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def test_malfunction_values_and_behavior():
    """
    Test that the next malfunction occurs when desired.
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()
    action_dict: Dict[int, RailEnvActions] = {}
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 5,
                       'min_duration': 10,
                       'max_duration': 10}
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  stochastic_data=stochastic_data,
                  number_of_agents=1,
                  random_seed=1,
                  )
    # reset to initialize agents_static
    env.reset(False, False, activate_agents=True, random_seed=10)
    env.agents[0].malfunction_data['next_malfunction'] = 5
    env.agents[0].malfunction_data['malfunction_rate'] = 50
    env.agents[0].malfunction_data['malfunction'] = 0
    env.agents[0].target = (0, 0),  # Move the target out of range
    print(env.agents[0].position, env.agents[0].malfunction_data['next_malfunction'])

    for time_step in range(1, 16):
        # Move in the env
        env.step(action_dict)
        print(time_step)
        # Check that next_step decreases as expected
        if env.agents[0].malfunction_data['malfunction'] < 1:
            assert env.agents[0].malfunction_data['next_malfunction'] == np.clip(5 - time_step, 0, 100)
        else:
            assert env.agents[0].malfunction_data['malfunction'] == np.clip(10 - (time_step - 6), 0, 100)

284

285
def test_initial_malfunction():
u214892's avatar
u214892 committed
286
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
287
                       'malfunction_rate': 100,  # Rate of malfunction occurence
u214892's avatar
u214892 committed
288
289
290
291
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

292
293
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
294
295
    env = RailEnv(width=25,
                  height=30,
296
                  rail_generator=rail_from_grid_transition_map(rail),
297
                  schedule_generator=random_schedule_generator(seed=10),
u214892's avatar
u214892 committed
298
299
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
300
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
301
                  )
302
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
303
    env.reset(False, False, True, random_seed=10)
304
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
305
    env.agents[0].target = (0, 5)
306
    set_penalties_for_replay(env)
307
308
309
    replay_config = ReplayConfig(
        replay=[
            Replay(
310
                position=(3, 2),
311
312
313
314
315
316
317
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
318
                position=(3, 2),
319
320
321
322
323
324
325
326
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
327
                position=(3, 2),
328
329
330
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
331
                reward=env.step_penalty * 1.0
332

333
            ),  # malfunctioning ends: starting and running at speed 1.0
334
            Replay(
335
                position=(3, 2),
336
                direction=Grid4TransitionsEnum.EAST,
337
338
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
339
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
340
341
            ),
            Replay(
342
                position=(3, 3),
343
                direction=Grid4TransitionsEnum.EAST,
344
345
346
347
348
349
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
350
        target=env.agents[0].target,
351
        initial_position=(3, 2),
u214892's avatar
u214892 committed
352
        initial_direction=Grid4TransitionsEnum.EAST,
353
    )
354
    run_replay_config(env, [replay_config])
355
356
357


def test_initial_malfunction_stop_moving():
358
359
360
361
362
363
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

364
    rail, rail_map = make_simple_rail2()
365
366
367

    env = RailEnv(width=25,
                  height=30,
368
369
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
370
371
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
372
                  obs_builder_object=SingleAgentNavigationObs()
373
                  )
374
    env.reset()
375
376
377

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

378
    set_penalties_for_replay(env)
379
380
381
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
382
                position=None,
383
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
384
                action=RailEnvActions.MOVE_FORWARD,
385
386
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
387
388
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
389
390
            ),
            Replay(
391
                position=(3, 2),
392
393
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
394
                malfunction=3,
u214892's avatar
u214892 committed
395
396
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
397
398
399
400
401
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
402
                position=(3, 2),
403
404
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
405
                malfunction=2,
u214892's avatar
u214892 committed
406
407
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
408
409
410
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
411
                position=(3, 2),
412
413
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
414
                malfunction=1,
u214892's avatar
u214892 committed
415
416
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
417
418
419
            ),
            # we start to move forward --> should go to next cell now
            Replay(
420
                position=(3, 2),
421
422
423
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
424
425
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
426
427
            ),
            Replay(
428
                position=(3, 3),
429
                direction=Grid4TransitionsEnum.EAST,
430
431
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
432
433
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
434
435
436
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
437
        target=env.agents[0].target,
438
        initial_position=(3, 2),
u214892's avatar
u214892 committed
439
        initial_direction=Grid4TransitionsEnum.EAST,
440
    )
441
442

    run_replay_config(env, [replay_config], activate_agents=False)
443
444


445
def test_initial_malfunction_do_nothing():
446
447
448
449
450
451
452
453
454
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

455
456
    rail, rail_map = make_simple_rail2()

457
458
    env = RailEnv(width=25,
                  height=30,
459
460
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
461
462
463
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
464
465
    # reset to initialize agents_static
    env.reset()
466
    set_penalties_for_replay(env)
467
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
468
469
470
471
472
473
474
475
476
477
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
478
            Replay(
479
                position=(3, 2),
480
481
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
482
                malfunction=3,
u214892's avatar
u214892 committed
483
484
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
485
486
487
488
489
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
490
                position=(3, 2),
491
492
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
493
                malfunction=2,
u214892's avatar
u214892 committed
494
495
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
496
497
498
            ),
            # we haven't started moving yet --> stay here
            Replay(
499
                position=(3, 2),
500
501
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
502
                malfunction=1,
u214892's avatar
u214892 committed
503
504
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
505
            ),
506

507
            Replay(
508
                position=(3, 2),
509
510
511
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
512
513
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
514
            ),  # we start to move forward --> should go to next cell now
515
            Replay(
516
                position=(3, 3),
517
                direction=Grid4TransitionsEnum.EAST,
518
519
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
520
521
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
522
523
524
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
525
        target=env.agents[0].target,
526
        initial_position=(3, 2),
u214892's avatar
u214892 committed
527
        initial_direction=Grid4TransitionsEnum.EAST,
528
    )
529
    run_replay_config(env, [replay_config], activate_agents=False)
530
531
532
533
534
535
536


def test_initial_nextmalfunction_not_below_zero():
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
537
538
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
539
540
541
                       'max_duration': 5  # Max duration of malfunction
                       }

542
    rail, rail_map = make_simple_rail2()
543
544
545

    env = RailEnv(width=25,
                  height=30,
546
547
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
548
549
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
550
                  obs_builder_object=SingleAgentNavigationObs()
551
                  )
552
553
    # reset to initialize agents_static
    env.reset()
554
555
556
557
558
    agent = env.agents[0]
    env.step({})
    # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
    assert agent.malfunction_data['next_malfunction'] >= 0, \
        "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction'])
Erik Nygren's avatar
Erik Nygren committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578


def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 1,
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=1,
                  random_seed=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
579
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
    env.agents[0].initial_position = (3, 0)
    env.agents[0].target = (3, 9)
    env.reset(False, False, False)
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

        _, reward, _, _ = env.step(action_dict)
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
595
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

    rail, rail_map = make_simple_rail2()
    random.seed(47)
    np.random.seed(1234)
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=1,
                  random_seed=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
612
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
    env.agents[0].initial_position = (3, 0)
    env.agents[0].target = (3, 9)
    env.reset(False, False, False)

    # Print for test generation
    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
628
629
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
630
631
632
633

        _, reward, _, _ = env.step(action_dict)
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 5,
                       'min_duration': 4,
                       'max_duration': 4}

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=1,
                  random_seed=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
    env.reset()
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
    env.agents_static[0].target = (0, 0)

    env.reset(False, False, True)
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []

    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)

        # Check if the agent is still allowed to move in this step
676
        if env.agents[0].malfunction_data['malfunction'] > 0 or env.agents[0].malfunction_data['next_malfunction'] < 1:
677
678
679
680
681
682
683
684
685
686
687
688
689
690
            agent_can_move = False
        else:
            agent_can_move = True

        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
        post_position = env.agents[0].speed_data['position_fraction']

        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position