test_flatland_malfunction.py 23.6 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
13
14
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
15
16


17
class SingleAgentNavigationObs(ObservationBuilder):
18
    """
19
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
20
21
22
23
24
25
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
26
        super().__init__()
27
28

    def reset(self):
29
        pass
30

31
    def get(self, handle: int = 0) -> List[int]:
32
33
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
34
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
35
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
36
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
37
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
38
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
39
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
40
41
42
        else:
            return None

u214892's avatar
u214892 committed
43
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
44
45
46
47
48
49
50
51
52
53
54
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
55
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
56
57
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
58
59
60
61
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
62
            observation[np.argmin(min_distances)] = 1
63
64
65
66
67

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
68
    # Set fixed malfunction duration for this test
69
    stochastic_data = {'malfunction_rate': 1,
70
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
71
                       'max_duration': 3}
72
73
74
75
76
77
78
79
80
81
82
83

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
84
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
85

86
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
87
88
    total_down_time = 0
    agent_old_position = env.agents[0].position
89
90
91

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
92
93
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
94

95
96
97
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

98
99
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
100
101
102
103
104
105
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
106
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
107
108
109
110
111
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
112
    # Check that the appropriate number of malfunctions is achieved
113
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 28, "Actual {}".format(
u214892's avatar
u214892 committed
114
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
115
116
117

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
118
119
120
121
122


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
123
    stochastic_data = {'malfunction_rate': 5,
124
125
                       'min_duration': 5,
                       'max_duration': 5}
u214892's avatar
u214892 committed
126

127
128
129
130
131
132
    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
133
                  number_of_agents=10,
134
135
136
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
137

138
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
139
    env.reset(True, True, False, random_seed=10)
140

Erik Nygren's avatar
Erik Nygren committed
141
    env.agents[0].target = (0, 0)
142
    # Next line only for test generation
Erik Nygren's avatar
Erik Nygren committed
143
144
145
146
147
148
149
150
151
152
153
    #agent_malfunction_list = [[] for i in range(20)]
    agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0],
     [0, 0, 0, 0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 5, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [], [], [], [], [], [], [], [], [], []]
154

Erik Nygren's avatar
Erik Nygren committed
155
    for step in range(20):
156
        action_dict: Dict[int, RailEnvActions] = {}
157
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
158
            # We randomly select an action
159
160
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
Erik Nygren's avatar
Erik Nygren committed
161
            #agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
162
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
u214892's avatar
u214892 committed
163
        env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
164
    #print(agent_malfunction_list)
165

u214892's avatar
u214892 committed
166

167
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
168
    """Tests that malfunctions are working properly for agents before entering the environment!"""
169
    # Set fixed malfunction duration for this test
170
    stochastic_data = {'malfunction_rate': 0.0001,
171
172
173
174
175
176
177
178
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
179
                  schedule_generator=random_schedule_generator(seed=1),  # seed 12
Erik Nygren's avatar
Erik Nygren committed
180
181
                  number_of_agents=10,
                  random_seed=1,
182
183
184
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
185
    env.reset(False, False, False, random_seed=10)
186
    env.agents[0].target = (0, 0)
187

188
189
190
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
191
192
193
194
195
196

    for a in range(10):

        print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))


197

198

199
200
201
202
203
204
205
206
207
208
209
def test_malfunction_values_and_behavior():
    """
    Test that the next malfunction occurs when desired.
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()
    action_dict: Dict[int, RailEnvActions] = {}
210
    stochastic_data = {'malfunction_rate': 0.01,
211
212
213
214
215
216
217
218
219
220
                       'min_duration': 10,
                       'max_duration': 10}
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  stochastic_data=stochastic_data,
                  number_of_agents=1,
                  random_seed=1,
                  )
Erik Nygren's avatar
Erik Nygren committed
221

222
223
224
    # reset to initialize agents_static
    env.reset(False, False, activate_agents=True, random_seed=10)

Erik Nygren's avatar
Erik Nygren committed
225
    # Assertions
226
    assert_list = [8, 7, 6, 5, 4, 3, 2, 1, 0, 9, 8, 7, 6, 5, 4]
Erik Nygren's avatar
Erik Nygren committed
227
228
    print("[")
    for time_step in range(15):
229
230
231
        # Move in the env
        env.step(action_dict)
        # Check that next_step decreases as expected
Erik Nygren's avatar
Erik Nygren committed
232
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
233

234

235
def test_initial_malfunction():
u214892's avatar
u214892 committed
236
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
237
                       'malfunction_rate': 100,  # Rate of malfunction occurence
u214892's avatar
u214892 committed
238
239
240
241
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

242
243
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
244
245
    env = RailEnv(width=25,
                  height=30,
246
                  rail_generator=rail_from_grid_transition_map(rail),
247
                  schedule_generator=random_schedule_generator(seed=10),
u214892's avatar
u214892 committed
248
249
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
250
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
251
                  )
252
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
253
    env.reset(False, False, True, random_seed=10)
254
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
255
    env.agents[0].target = (0, 5)
256
    set_penalties_for_replay(env)
257
258
259
    replay_config = ReplayConfig(
        replay=[
            Replay(
260
                position=(3, 2),
261
262
263
264
265
266
267
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
268
                position=(3, 2),
269
270
271
272
273
274
275
276
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
277
                position=(3, 2),
278
279
280
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
281
                reward=env.step_penalty * 1.0
282

283
            ),  # malfunctioning ends: starting and running at speed 1.0
284
            Replay(
285
                position=(3, 2),
286
                direction=Grid4TransitionsEnum.EAST,
287
288
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
289
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
290
291
            ),
            Replay(
292
                position=(3, 3),
293
                direction=Grid4TransitionsEnum.EAST,
294
295
296
297
298
299
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
300
        target=env.agents[0].target,
301
        initial_position=(3, 2),
u214892's avatar
u214892 committed
302
        initial_direction=Grid4TransitionsEnum.EAST,
303
    )
304
    run_replay_config(env, [replay_config])
305
306
307


def test_initial_malfunction_stop_moving():
308
309
310
311
312
313
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

314
    rail, rail_map = make_simple_rail2()
315
316
317

    env = RailEnv(width=25,
                  height=30,
318
319
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
320
321
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
322
                  obs_builder_object=SingleAgentNavigationObs()
323
                  )
324
    env.reset()
325
326
327

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

328
    set_penalties_for_replay(env)
329
330
331
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
332
                position=None,
333
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
334
                action=RailEnvActions.MOVE_FORWARD,
335
336
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
337
338
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
339
340
            ),
            Replay(
341
                position=(3, 2),
342
343
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
344
                malfunction=3,
u214892's avatar
u214892 committed
345
346
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
347
348
349
350
351
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
352
                position=(3, 2),
353
354
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
355
                malfunction=2,
u214892's avatar
u214892 committed
356
357
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
358
359
360
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
361
                position=(3, 2),
362
363
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
364
                malfunction=1,
u214892's avatar
u214892 committed
365
366
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
367
368
369
            ),
            # we start to move forward --> should go to next cell now
            Replay(
370
                position=(3, 2),
371
372
373
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
374
375
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
376
377
            ),
            Replay(
378
                position=(3, 3),
379
                direction=Grid4TransitionsEnum.EAST,
380
381
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
382
383
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
384
385
386
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
387
        target=env.agents[0].target,
388
        initial_position=(3, 2),
u214892's avatar
u214892 committed
389
        initial_direction=Grid4TransitionsEnum.EAST,
390
    )
391
392

    run_replay_config(env, [replay_config], activate_agents=False)
393
394


395
def test_initial_malfunction_do_nothing():
396
397
398
399
400
401
402
403
404
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

405
406
    rail, rail_map = make_simple_rail2()

407
408
    env = RailEnv(width=25,
                  height=30,
409
410
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
411
412
413
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
414
415
    # reset to initialize agents_static
    env.reset()
416
    set_penalties_for_replay(env)
417
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
418
419
420
421
422
423
424
425
426
427
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
428
            Replay(
429
                position=(3, 2),
430
431
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
432
                malfunction=3,
u214892's avatar
u214892 committed
433
434
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
435
436
437
438
439
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
440
                position=(3, 2),
441
442
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
443
                malfunction=2,
u214892's avatar
u214892 committed
444
445
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
446
447
448
            ),
            # we haven't started moving yet --> stay here
            Replay(
449
                position=(3, 2),
450
451
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
452
                malfunction=1,
u214892's avatar
u214892 committed
453
454
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
455
            ),
456

457
            Replay(
458
                position=(3, 2),
459
460
461
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
462
463
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
464
            ),  # we start to move forward --> should go to next cell now
465
            Replay(
466
                position=(3, 3),
467
                direction=Grid4TransitionsEnum.EAST,
468
469
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
470
471
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
472
473
474
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
475
        target=env.agents[0].target,
476
        initial_position=(3, 2),
u214892's avatar
u214892 committed
477
        initial_direction=Grid4TransitionsEnum.EAST,
478
    )
479
    run_replay_config(env, [replay_config], activate_agents=False)
480
481


Erik Nygren's avatar
Erik Nygren committed
482
483
484
def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
485
    stochastic_data = {'malfunction_rate': 1,
Erik Nygren's avatar
Erik Nygren committed
486
487
488
489
490
491
492
493
494
495
496
497
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=1,
                  random_seed=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
498
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
499
500
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
501
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
502
503
504
505
506
507
508
509
510
511
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

        _, reward, _, _ = env.step(action_dict)
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
512
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

    rail, rail_map = make_simple_rail2()
    random.seed(47)
    np.random.seed(1234)
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=1,
                  random_seed=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
529
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
530
531
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
532
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
533
534
535
536
537
538
539
540
541

    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
542
543
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
544
545
546
547

        _, reward, _, _ = env.step(action_dict)
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
548
549
550
551
552
553
554
555
556


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test
557
    stochastic_data = {'malfunction_rate': 5,
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                       'min_duration': 4,
                       'max_duration': 4}

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=1,
                  random_seed=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
    env.reset()
    # reset to initialize agents_static
573
    env.agents[0].speed_data['speed'] = 1. / 3.
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    env.agents_static[0].target = (0, 0)

    env.reset(False, False, True)
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)


588
589
        if env.agents[0].malfunction_data['malfunction'] < 1:
            agent_can_move = True
590
591
592
        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
593
        # Check if the agent is still allowed to move in this step
594

595
596
597
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_can_move = False
        post_position = env.agents[0].speed_data['position_fraction']
598
599
600
601
602
        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position