test_flatland_malfunction.py 23.4 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
11
from flatland.envs.malfunction_generators import malfunction_from_params
u214892's avatar
u214892 committed
12
from flatland.envs.rail_env import RailEnv, RailEnvActions
13
14
15
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
16
17


18
class SingleAgentNavigationObs(ObservationBuilder):
19
    """
20
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
21
22
23
24
25
26
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
27
        super().__init__()
28
29

    def reset(self):
30
        pass
31

32
    def get(self, handle: int = 0) -> List[int]:
33
34
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
35
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
36
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
37
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
38
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
39
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
40
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
41
42
43
        else:
            return None

u214892's avatar
u214892 committed
44
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
45
46
47
48
49
50
51
52
53
54
55
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
56
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
57
58
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
59
60
61
62
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
63
            observation[np.argmin(min_distances)] = 1
64
65
66
67
68

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
69
    # Set fixed malfunction duration for this test
70
    stochastic_data = {'malfunction_rate': 1,
71
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
72
                       'max_duration': 3}
73
74
75

    rail, rail_map = make_simple_rail2()

Erik Nygren's avatar
Erik Nygren committed
76
77
78
79
80
81
82
83
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
84
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
85

86
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
87
88
    total_down_time = 0
    agent_old_position = env.agents[0].position
89
90
91

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
92
93
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
94

95
96
97
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

98
99
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
100
101
102
103
104
105
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
106
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
107
108
109
110
111
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
112
    # Check that the appropriate number of malfunctions is achieved
Erik Nygren's avatar
Erik Nygren committed
113
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 23, "Actual {}".format(
u214892's avatar
u214892 committed
114
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
115
116
117

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
118
119
120
121
122


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
123
    stochastic_data = {'malfunction_rate': 5,
124
125
                       'min_duration': 5,
                       'max_duration': 5}
u214892's avatar
u214892 committed
126

127
128
    rail, rail_map = make_simple_rail2()

Erik Nygren's avatar
Erik Nygren committed
129
130
131
132
133
134
135
136
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
137

Erik Nygren's avatar
Erik Nygren committed
138
    env.reset(True, True, False, random_seed=10)
139

Erik Nygren's avatar
Erik Nygren committed
140
    env.agents[0].target = (0, 0)
141
    # Next line only for test generation
Erik Nygren's avatar
Erik Nygren committed
142
    # agent_malfunction_list = [[] for i in range(10)]
Erik Nygren's avatar
Erik Nygren committed
143
    agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
Erik Nygren's avatar
Erik Nygren committed
144
145
146
147
148
149
150
151
152
     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
     [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
     [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
     [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
     [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
     [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
153

Erik Nygren's avatar
Erik Nygren committed
154
    for step in range(20):
155
        action_dict: Dict[int, RailEnvActions] = {}
156
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
157
            # We randomly select an action
158
159
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
Erik Nygren's avatar
Erik Nygren committed
160
            # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
161
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
u214892's avatar
u214892 committed
162
        env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
163
    # print(agent_malfunction_list)
164

u214892's avatar
u214892 committed
165

166
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
167
    """Tests that malfunctions are working properly for agents before entering the environment!"""
168
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
169
    stochastic_data = {'malfunction_rate': 2,
170
171
172
173
174
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()

Erik Nygren's avatar
Erik Nygren committed
175
176
177
178
179
180
181
182
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
183
    env.reset(False, False, False, random_seed=10)
184
    env.agents[0].target = (0, 0)
185

186
187
188
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
189
190
    assert env.agents[0].malfunction_data['malfunction'] == 0
    assert env.agents[1].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
191
    assert env.agents[2].malfunction_data['malfunction'] == 10
192
193
    assert env.agents[3].malfunction_data['malfunction'] == 0
    assert env.agents[4].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
194
    assert env.agents[5].malfunction_data['malfunction'] == 0
195
196
    assert env.agents[6].malfunction_data['malfunction'] == 0
    assert env.agents[7].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
197
198
199
    assert env.agents[8].malfunction_data['malfunction'] == 10
    assert env.agents[9].malfunction_data['malfunction'] == 10

Erik Nygren's avatar
Erik Nygren committed
200
    # for a in range(10):
Erik Nygren's avatar
Erik Nygren committed
201
    #  print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
202
203


204
205
def test_malfunction_values_and_behavior():
    """
206
    Test the malfunction counts down as desired
207
208
209
210
211
212
213
214
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()
    action_dict: Dict[int, RailEnvActions] = {}
215
    stochastic_data = {'malfunction_rate': 0.001,
216
217
                       'min_duration': 10,
                       'max_duration': 10}
Erik Nygren's avatar
Erik Nygren committed
218
219
220
221
222
223
224
225
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
226

227
228
    env.reset(False, False, activate_agents=True, random_seed=10)

Erik Nygren's avatar
Erik Nygren committed
229
    # Assertions
230
    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
Erik Nygren's avatar
Erik Nygren committed
231
232
    print("[")
    for time_step in range(15):
233
234
235
        # Move in the env
        env.step(action_dict)
        # Check that next_step decreases as expected
Erik Nygren's avatar
Erik Nygren committed
236
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
237

238

239
def test_initial_malfunction():
240
    stochastic_data = {'malfunction_rate': 1000,  # Rate of malfunction occurence
u214892's avatar
u214892 committed
241
242
243
244
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

245
246
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
247
248
    env = RailEnv(width=25,
                  height=30,
249
                  rail_generator=rail_from_grid_transition_map(rail),
250
                  schedule_generator=random_schedule_generator(seed=10),
u214892's avatar
u214892 committed
251
                  number_of_agents=1,
252
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),  # Malfunction data generator
253
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
254
                  )
255
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
256
    env.reset(False, False, True, random_seed=10)
257
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
258
    env.agents[0].target = (0, 5)
259
    set_penalties_for_replay(env)
260
261
262
    replay_config = ReplayConfig(
        replay=[
            Replay(
263
                position=(3, 2),
264
265
266
267
268
269
270
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
271
                position=(3, 2),
272
273
274
275
276
277
278
279
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
280
                position=(3, 2),
281
282
283
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
284
                reward=env.step_penalty
285

286
            ),  # malfunctioning ends: starting and running at speed 1.0
287
            Replay(
288
                position=(3, 2),
289
                direction=Grid4TransitionsEnum.EAST,
290
291
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
292
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
293
294
            ),
            Replay(
295
                position=(3, 3),
296
                direction=Grid4TransitionsEnum.EAST,
297
298
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
299
                reward=env.step_penalty  # running at speed 1.0
300
301
302
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
303
        target=env.agents[0].target,
304
        initial_position=(3, 2),
u214892's avatar
u214892 committed
305
        initial_direction=Grid4TransitionsEnum.EAST,
306
    )
307
    run_replay_config(env, [replay_config])
308
309
310


def test_initial_malfunction_stop_moving():
311
312
313
314
315
316
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

317
    rail, rail_map = make_simple_rail2()
318

319
320
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=1,
Erik Nygren's avatar
Erik Nygren committed
321
                  obs_builder_object=SingleAgentNavigationObs())
322
    env.reset()
323
324
325

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

326
    set_penalties_for_replay(env)
327
328
329
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
330
                position=None,
331
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
332
                action=RailEnvActions.MOVE_FORWARD,
333
334
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
335
336
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
337
338
            ),
            Replay(
339
                position=(3, 2),
340
341
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
342
                malfunction=2,
u214892's avatar
u214892 committed
343
344
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
345
346
347
348
349
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
350
                position=(3, 2),
351
352
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
353
                malfunction=1,
u214892's avatar
u214892 committed
354
355
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
356
357
358
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
359
                position=(3, 2),
360
361
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
362
                malfunction=0,
u214892's avatar
u214892 committed
363
364
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
365
366
367
            ),
            # we start to move forward --> should go to next cell now
            Replay(
368
                position=(3, 2),
369
370
371
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
372
373
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
374
375
            ),
            Replay(
376
                position=(3, 3),
377
                direction=Grid4TransitionsEnum.EAST,
378
379
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
380
381
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
382
383
384
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
385
        target=env.agents[0].target,
386
        initial_position=(3, 2),
u214892's avatar
u214892 committed
387
        initial_direction=Grid4TransitionsEnum.EAST,
388
    )
389
390

    run_replay_config(env, [replay_config], activate_agents=False)
391
392


393
def test_initial_malfunction_do_nothing():
394
395
396
397
398
399
400
401
402
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

403
404
    rail, rail_map = make_simple_rail2()

405
406
    env = RailEnv(width=25,
                  height=30,
407
408
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
409
                  number_of_agents=1,
410
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),  # Malfunction data generator
411
                  )
412
    env.reset()
413
    set_penalties_for_replay(env)
414
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
415
416
417
418
419
420
421
422
423
424
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
425
            Replay(
426
                position=(3, 2),
427
428
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
429
                malfunction=2,
u214892's avatar
u214892 committed
430
431
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
432
433
434
435
436
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
437
                position=(3, 2),
438
439
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
440
                malfunction=1,
u214892's avatar
u214892 committed
441
442
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
443
444
445
            ),
            # we haven't started moving yet --> stay here
            Replay(
446
                position=(3, 2),
447
448
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
449
                malfunction=0,
u214892's avatar
u214892 committed
450
451
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
452
            ),
453

454
            Replay(
455
                position=(3, 2),
456
457
458
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
459
460
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
461
            ),  # we start to move forward --> should go to next cell now
462
            Replay(
463
                position=(3, 3),
464
                direction=Grid4TransitionsEnum.EAST,
465
466
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
467
468
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
469
470
471
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
472
        target=env.agents[0].target,
473
        initial_position=(3, 2),
u214892's avatar
u214892 committed
474
        initial_direction=Grid4TransitionsEnum.EAST,
475
    )
476
    run_replay_config(env, [replay_config], activate_agents=False)
477
478


Erik Nygren's avatar
Erik Nygren committed
479
480
481
def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
482
    stochastic_data = {'malfunction_rate': 1,
Erik Nygren's avatar
Erik Nygren committed
483
484
485
486
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()
487
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
Erik Nygren's avatar
Erik Nygren committed
488
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
489
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
490
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
491
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
492
493
494
495
496
497
498
499
500
501
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

        _, reward, _, _ = env.step(action_dict)
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
502
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
503
504
505
506
507
508
509
510
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

    rail, rail_map = make_simple_rail2()
    random.seed(47)
    np.random.seed(1234)
511
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
Erik Nygren's avatar
Erik Nygren committed
512
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
513
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
514
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
515
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
516
517
518
519
520
521
522
523
524

    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
525
526
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
527
528
529
530

        _, reward, _, _ = env.step(action_dict)
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
531
532
533
534
535
536
537
538
539


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test
540
    stochastic_data = {'malfunction_rate': 5,
541
542
543
544
545
                       'min_duration': 4,
                       'max_duration': 4}

    rail, rail_map = make_simple_rail2()

546
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
Erik Nygren's avatar
Erik Nygren committed
547
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
548
    env.reset()
549
    env.agents[0].speed_data['speed'] = 1. / 3.
u229589's avatar
u229589 committed
550
    env.agents[0].target = (0, 0)
551
552
553
554
555
556
557
558
559
560
561
562

    env.reset(False, False, True)
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)

563
564
        if env.agents[0].malfunction_data['malfunction'] < 1:
            agent_can_move = True
565
566
567
        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
568
        # Check if the agent is still allowed to move in this step
569

570
571
572
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_can_move = False
        post_position = env.agents[0].speed_data['position_fraction']
573
574
575
576
577
        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position