test_flatland_malfunction.py 24.3 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
13
14
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
15
16


17
class SingleAgentNavigationObs(ObservationBuilder):
18
    """
19
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
20
21
22
23
24
25
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
26
        super().__init__()
27
28

    def reset(self):
29
        pass
30

31
    def get(self, handle: int = 0) -> List[int]:
32
33
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
34
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
35
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
36
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
37
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
38
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
39
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
40
41
42
        else:
            return None

u214892's avatar
u214892 committed
43
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
44
45
46
47
48
49
50
51
52
53
54
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
55
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
56
57
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
58
59
60
61
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
62
            observation[np.argmin(min_distances)] = 1
63
64
65
66
67

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
68
    # Set fixed malfunction duration for this test
69
    stochastic_data = {'malfunction_rate': 1,
70
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
71
                       'max_duration': 3}
72
73
74
75
76
77
78
79
80
81
82
83

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
84
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
85

86
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
87
88
    total_down_time = 0
    agent_old_position = env.agents[0].position
89
90
91

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
92
93
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
94

95
96
97
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

98
99
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
100
101
102
103
104
105
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
106
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
107
108
109
110
111
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
112
    # Check that the appropriate number of malfunctions is achieved
113
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 30, "Actual {}".format(
u214892's avatar
u214892 committed
114
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
115
116
117

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
118
119
120
121
122


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
123
    stochastic_data = {'malfunction_rate': 5,
124
125
                       'min_duration': 5,
                       'max_duration': 5}
u214892's avatar
u214892 committed
126

127
128
129
130
131
132
    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
133
                  number_of_agents=10,
134
135
136
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
137

138
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
139
    env.reset(True, True, False, random_seed=10)
140

Erik Nygren's avatar
Erik Nygren committed
141
    env.agents[0].target = (0, 0)
142
    # Next line only for test generation
143
    # agent_malfunction_list = [[] for i in range(20)]
Erik Nygren's avatar
Erik Nygren committed
144
145
146
147
148
149
150
151
152
153
154
155
156
    agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 0, 0, 0, 0, 0, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0, 3],
                              [4, 0, 0, 0, 0, 0, 0, 0, 0, 2],
                              [3, 0, 0, 0, 0, 0, 0, 0, 0, 1], [2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                              [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 4, 0, 0, 0],
                              [0, 0, 0, 0, 0, 0, 3, 0, 0, 0],
                              [0, 0, 0, 0, 0, 0, 2, 4, 0, 0], [0, 0, 0, 0, 0, 0, 1, 3, 0, 0],
                              [0, 0, 0, 0, 0, 0, 0, 2, 0, 0],
                              [0, 0, 0, 0, 0, 0, 0, 1, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0, 3],
                              [0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
                              [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [4, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
157

Erik Nygren's avatar
Erik Nygren committed
158
    for step in range(20):
159
        action_dict: Dict[int, RailEnvActions] = {}
160
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
161
            # We randomly select an action
162
163
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
Erik Nygren's avatar
Erik Nygren committed
164
            # agent_malfunction_list[step].append(env.agents[agent_idx].malfunction_data['malfunction'])
165
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[step][agent_idx]
u214892's avatar
u214892 committed
166
        env.step(action_dict)
167
    # For generating test onlz
168
    # print(agent_malfunction_list)
169

u214892's avatar
u214892 committed
170

171
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
172
    """Tests that malfunctions are working properly for agents before entering the environment!"""
173
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
174
    stochastic_data = {'malfunction_rate': 1,
175
176
177
178
179
180
181
182
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
Erik Nygren's avatar
Erik Nygren committed
183
184
185
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=10,
                  random_seed=1,
186
187
188
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
189
    env.reset(False, False, False, random_seed=10)
190

191
192
193
194
195
196
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
    assert env.agents[1].malfunction_data['malfunction'] == 0
    assert env.agents[2].malfunction_data['malfunction'] == 0
    assert env.agents[3].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
197
198
    assert env.agents[4].malfunction_data['malfunction'] == 0
    assert env.agents[5].malfunction_data['malfunction'] == 0
199
200
201
    assert env.agents[6].malfunction_data['malfunction'] == 0
    assert env.agents[7].malfunction_data['malfunction'] == 0
    assert env.agents[8].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
202
    assert env.agents[9].malfunction_data['malfunction'] == 9
203

204

205
206
207
208
209
210
211
212
213
214
215
def test_malfunction_values_and_behavior():
    """
    Test that the next malfunction occurs when desired.
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()
    action_dict: Dict[int, RailEnvActions] = {}
Erik Nygren's avatar
Erik Nygren committed
216
    stochastic_data = {'malfunction_rate': 5,
217
218
219
220
221
222
223
224
225
226
                       'min_duration': 10,
                       'max_duration': 10}
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  stochastic_data=stochastic_data,
                  number_of_agents=1,
                  random_seed=1,
                  )
Erik Nygren's avatar
Erik Nygren committed
227

228
229
230
    # reset to initialize agents_static
    env.reset(False, False, activate_agents=True, random_seed=10)

Erik Nygren's avatar
Erik Nygren committed
231
232
233
234
    # Assertions
    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 0, 9, 8, 7, 6]
    print("[")
    for time_step in range(15):
235
236
237
        # Move in the env
        env.step(action_dict)
        # Check that next_step decreases as expected
Erik Nygren's avatar
Erik Nygren committed
238
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
239

240

241
def test_initial_malfunction():
u214892's avatar
u214892 committed
242
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
243
                       'malfunction_rate': 100,  # Rate of malfunction occurence
u214892's avatar
u214892 committed
244
245
246
247
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

248
249
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
250
251
    env = RailEnv(width=25,
                  height=30,
252
                  rail_generator=rail_from_grid_transition_map(rail),
253
                  schedule_generator=random_schedule_generator(seed=10),
u214892's avatar
u214892 committed
254
255
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
256
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
257
                  )
258
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
259
    env.reset(False, False, True, random_seed=10)
260
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
261
    env.agents[0].target = (0, 5)
262
    set_penalties_for_replay(env)
263
264
265
    replay_config = ReplayConfig(
        replay=[
            Replay(
266
                position=(3, 2),
267
268
269
270
271
272
273
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
274
                position=(3, 2),
275
276
277
278
279
280
281
282
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
283
                position=(3, 2),
284
285
286
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
287
                reward=env.step_penalty * 1.0
288

289
            ),  # malfunctioning ends: starting and running at speed 1.0
290
            Replay(
291
                position=(3, 2),
292
                direction=Grid4TransitionsEnum.EAST,
293
294
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
295
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
296
297
            ),
            Replay(
298
                position=(3, 3),
299
                direction=Grid4TransitionsEnum.EAST,
300
301
302
303
304
305
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
306
        target=env.agents[0].target,
307
        initial_position=(3, 2),
u214892's avatar
u214892 committed
308
        initial_direction=Grid4TransitionsEnum.EAST,
309
    )
310
    run_replay_config(env, [replay_config])
311
312
313


def test_initial_malfunction_stop_moving():
314
315
316
317
318
319
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

320
    rail, rail_map = make_simple_rail2()
321
322
323

    env = RailEnv(width=25,
                  height=30,
324
325
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
326
327
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
328
                  obs_builder_object=SingleAgentNavigationObs()
329
                  )
330
    env.reset()
331
332
333

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

334
    set_penalties_for_replay(env)
335
336
337
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
338
                position=None,
339
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
340
                action=RailEnvActions.MOVE_FORWARD,
341
342
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
343
344
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
345
346
            ),
            Replay(
347
                position=(3, 2),
348
349
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
350
                malfunction=3,
u214892's avatar
u214892 committed
351
352
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
353
354
355
356
357
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
358
                position=(3, 2),
359
360
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
361
                malfunction=2,
u214892's avatar
u214892 committed
362
363
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
364
365
366
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
367
                position=(3, 2),
368
369
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
370
                malfunction=1,
u214892's avatar
u214892 committed
371
372
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
373
374
375
            ),
            # we start to move forward --> should go to next cell now
            Replay(
376
                position=(3, 2),
377
378
379
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
380
381
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
382
383
            ),
            Replay(
384
                position=(3, 3),
385
                direction=Grid4TransitionsEnum.EAST,
386
387
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
388
389
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
390
391
392
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
393
        target=env.agents[0].target,
394
        initial_position=(3, 2),
u214892's avatar
u214892 committed
395
        initial_direction=Grid4TransitionsEnum.EAST,
396
    )
397
398

    run_replay_config(env, [replay_config], activate_agents=False)
399
400


401
def test_initial_malfunction_do_nothing():
402
403
404
405
406
407
408
409
410
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

411
412
    rail, rail_map = make_simple_rail2()

413
414
    env = RailEnv(width=25,
                  height=30,
415
416
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
417
418
419
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
420
421
    # reset to initialize agents_static
    env.reset()
422
    set_penalties_for_replay(env)
423
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
424
425
426
427
428
429
430
431
432
433
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
434
            Replay(
435
                position=(3, 2),
436
437
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
438
                malfunction=3,
u214892's avatar
u214892 committed
439
440
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
441
442
443
444
445
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
446
                position=(3, 2),
447
448
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
449
                malfunction=2,
u214892's avatar
u214892 committed
450
451
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
452
453
454
            ),
            # we haven't started moving yet --> stay here
            Replay(
455
                position=(3, 2),
456
457
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
458
                malfunction=1,
u214892's avatar
u214892 committed
459
460
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
461
            ),
462

463
            Replay(
464
                position=(3, 2),
465
466
467
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
468
469
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
470
            ),  # we start to move forward --> should go to next cell now
471
            Replay(
472
                position=(3, 3),
473
                direction=Grid4TransitionsEnum.EAST,
474
475
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
476
477
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
478
479
480
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
481
        target=env.agents[0].target,
482
        initial_position=(3, 2),
u214892's avatar
u214892 committed
483
        initial_direction=Grid4TransitionsEnum.EAST,
484
    )
485
    run_replay_config(env, [replay_config], activate_agents=False)
486
487


Erik Nygren's avatar
Erik Nygren committed
488
489
490
def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
491
    stochastic_data = {'malfunction_rate': 1,
Erik Nygren's avatar
Erik Nygren committed
492
493
494
495
496
497
498
499
500
501
502
503
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=1,
                  random_seed=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
504
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
505
506
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
507
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
508
509
510
511
512
513
514
515
516
517
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

        _, reward, _, _ = env.step(action_dict)
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
518
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

    rail, rail_map = make_simple_rail2()
    random.seed(47)
    np.random.seed(1234)
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=1,
                  random_seed=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
535
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
536
537
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
538
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
539
540
541
542
543
544
545
546
547

    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
548
549
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
550
551
552
553

        _, reward, _, _ = env.step(action_dict)
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 5,
                       'min_duration': 4,
                       'max_duration': 4}

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
                  number_of_agents=1,
                  random_seed=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
    env.reset()
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
    env.agents_static[0].target = (0, 0)

    env.reset(False, False, True)
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []

    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)

        # Check if the agent is still allowed to move in this step
596
        if env.agents[0].malfunction_data['malfunction'] > 0 or env.agents[0].malfunction_data['next_malfunction'] < 1:
597
598
599
600
601
602
603
604
605
606
607
608
609
610
            agent_can_move = False
        else:
            agent_can_move = True

        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
        post_position = env.agents[0].speed_data['position_fraction']

        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position