test_flatland_malfunction.py 23.5 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
11
from flatland.envs.malfunction_generators import malfunction_from_params
u214892's avatar
u214892 committed
12
from flatland.envs.rail_env import RailEnv, RailEnvActions
13
14
15
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
16
17


18
class SingleAgentNavigationObs(ObservationBuilder):
19
    """
20
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
21
22
23
24
25
26
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
27
        super().__init__()
28
29

    def reset(self):
30
        pass
31

32
    def get(self, handle: int = 0) -> List[int]:
33
34
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
35
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
36
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
37
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
38
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
39
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
40
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
41
42
43
        else:
            return None

u214892's avatar
u214892 committed
44
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
45
46
47
48
49
50
51
52
53
54
55
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
56
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
57
58
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
59
60
61
62
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
63
            observation[np.argmin(min_distances)] = 1
64
65
66
67
68

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
69
    # Set fixed malfunction duration for this test
70
    stochastic_data = {'malfunction_rate': 1,
71
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
72
                       'max_duration': 3}
73
74
75

    rail, rail_map = make_simple_rail2()

76
77
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=1,
Erik Nygren's avatar
Erik Nygren committed
78
79
                  obs_builder_object=SingleAgentNavigationObs(),
                  malfunction_generator=malfunction_from_params(stochastic_data))
80
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
81
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
82

83
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
84
85
    total_down_time = 0
    agent_old_position = env.agents[0].position
86
87
88

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
89
90
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
91

92
93
94
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

95
96
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
97
98
99
100
101
102
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
103
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
104
105
106
107
108
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
109
    # Check that the appropriate number of malfunctions is achieved
Erik Nygren's avatar
Erik Nygren committed
110
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 23, "Actual {}".format(
u214892's avatar
u214892 committed
111
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
112
113
114

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
115
116
117
118
119


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
120
    stochastic_data = {'malfunction_rate': 5,
121
122
                       'min_duration': 5,
                       'max_duration': 5}
u214892's avatar
u214892 committed
123

124
125
    rail, rail_map = make_simple_rail2()

126
127
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=10,
Erik Nygren's avatar
Erik Nygren committed
128
129
                  obs_builder_object=SingleAgentNavigationObs(),
                  malfunction_generator=malfunction_from_params(stochastic_data))
130

131
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
132
    env.reset(True, True, False, random_seed=10)
133

Erik Nygren's avatar
Erik Nygren committed
134
    env.agents[0].target = (0, 0)
135
    # Next line only for test generation
Erik Nygren's avatar
Erik Nygren committed
136
    # agent_malfunction_list = [[] for i in range(10)]
Erik Nygren's avatar
Erik Nygren committed
137
    agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
Erik Nygren's avatar
Erik Nygren committed
138
139
140
141
142
143
144
145
146
                              [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
                              [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
                              [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
                              [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
                              [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
                              [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                              [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
147

Erik Nygren's avatar
Erik Nygren committed
148
    for step in range(20):
149
        action_dict: Dict[int, RailEnvActions] = {}
150
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
151
            # We randomly select an action
152
153
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
Erik Nygren's avatar
Erik Nygren committed
154
            # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
155
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
u214892's avatar
u214892 committed
156
        env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
157
    # print(agent_malfunction_list)
158

u214892's avatar
u214892 committed
159

160
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
161
    """Tests that malfunctions are working properly for agents before entering the environment!"""
162
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
163
    stochastic_data = {'malfunction_rate': 2,
164
165
166
167
168
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()

169
170
171
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=1), number_of_agents=10,
                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
172
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
173
    env.reset(False, False, False, random_seed=10)
174
    env.agents[0].target = (0, 0)
175

176
177
178
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
179
180
    assert env.agents[0].malfunction_data['malfunction'] == 0
    assert env.agents[1].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
181
    assert env.agents[2].malfunction_data['malfunction'] == 10
182
183
    assert env.agents[3].malfunction_data['malfunction'] == 0
    assert env.agents[4].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
184
    assert env.agents[5].malfunction_data['malfunction'] == 0
185
186
    assert env.agents[6].malfunction_data['malfunction'] == 0
    assert env.agents[7].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
187
188
189
    assert env.agents[8].malfunction_data['malfunction'] == 10
    assert env.agents[9].malfunction_data['malfunction'] == 10

Erik Nygren's avatar
Erik Nygren committed
190
    # for a in range(10):
Erik Nygren's avatar
Erik Nygren committed
191
    #  print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
192
193


194
195
def test_malfunction_values_and_behavior():
    """
196
    Test the malfunction counts down as desired
197
198
199
200
201
202
203
204
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()
    action_dict: Dict[int, RailEnvActions] = {}
205
    stochastic_data = {'malfunction_rate': 0.001,
206
207
                       'min_duration': 10,
                       'max_duration': 10}
208
209
210
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
Erik Nygren's avatar
Erik Nygren committed
211

212
213
214
    # reset to initialize agents_static
    env.reset(False, False, activate_agents=True, random_seed=10)

Erik Nygren's avatar
Erik Nygren committed
215
    # Assertions
216
    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
Erik Nygren's avatar
Erik Nygren committed
217
218
    print("[")
    for time_step in range(15):
219
220
221
        # Move in the env
        env.step(action_dict)
        # Check that next_step decreases as expected
Erik Nygren's avatar
Erik Nygren committed
222
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
223

224

225
def test_initial_malfunction():
226
    stochastic_data = {'malfunction_rate': 1000,  # Rate of malfunction occurence
u214892's avatar
u214892 committed
227
228
229
230
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

231
232
    rail, rail_map = make_simple_rail2()

233
234
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=10), number_of_agents=1,
Erik Nygren's avatar
Erik Nygren committed
235
236
                  obs_builder_object=SingleAgentNavigationObs(),
                  malfunction_generator=malfunction_from_params(stochastic_data))
237
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
238
    env.reset(False, False, True, random_seed=10)
239
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
240
    env.agents[0].target = (0, 5)
241
    set_penalties_for_replay(env)
242
243
244
    replay_config = ReplayConfig(
        replay=[
            Replay(
245
                position=(3, 2),
246
247
248
249
250
251
252
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
253
                position=(3, 2),
254
255
256
257
258
259
260
261
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
262
                position=(3, 2),
263
264
265
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
266
                reward=env.step_penalty
267

268
            ),  # malfunctioning ends: starting and running at speed 1.0
269
            Replay(
270
                position=(3, 2),
271
                direction=Grid4TransitionsEnum.EAST,
272
273
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
274
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
275
276
            ),
            Replay(
277
                position=(3, 3),
278
                direction=Grid4TransitionsEnum.EAST,
279
280
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
281
                reward=env.step_penalty  # running at speed 1.0
282
283
284
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
285
        target=env.agents[0].target,
286
        initial_position=(3, 2),
u214892's avatar
u214892 committed
287
        initial_direction=Grid4TransitionsEnum.EAST,
288
    )
289
    run_replay_config(env, [replay_config])
290
291
292


def test_initial_malfunction_stop_moving():
293
294
295
296
297
298
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

299
    rail, rail_map = make_simple_rail2()
300

301
302
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=1,
Erik Nygren's avatar
Erik Nygren committed
303
304
                  obs_builder_object=SingleAgentNavigationObs(),
                  malfunction_generator=malfunction_from_params(stochastic_data))
305
    env.reset()
306
307
308

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

309
    set_penalties_for_replay(env)
310
311
312
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
313
                position=None,
314
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
315
                action=RailEnvActions.MOVE_FORWARD,
316
317
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
318
319
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
320
321
            ),
            Replay(
322
                position=(3, 2),
323
324
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
325
                malfunction=2,
u214892's avatar
u214892 committed
326
327
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
328
329
330
331
332
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
333
                position=(3, 2),
334
335
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
336
                malfunction=1,
u214892's avatar
u214892 committed
337
338
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
339
340
341
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
342
                position=(3, 2),
343
344
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
345
                malfunction=0,
u214892's avatar
u214892 committed
346
347
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
348
349
350
            ),
            # we start to move forward --> should go to next cell now
            Replay(
351
                position=(3, 2),
352
353
354
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
355
356
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
357
358
            ),
            Replay(
359
                position=(3, 3),
360
                direction=Grid4TransitionsEnum.EAST,
361
362
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
363
364
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
365
366
367
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
368
        target=env.agents[0].target,
369
        initial_position=(3, 2),
u214892's avatar
u214892 committed
370
        initial_direction=Grid4TransitionsEnum.EAST,
371
    )
372
373

    run_replay_config(env, [replay_config], activate_agents=False)
374
375


376
def test_initial_malfunction_do_nothing():
377
378
379
380
381
382
383
384
385
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

386
387
    rail, rail_map = make_simple_rail2()

388
389
390
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=1,
                  malfunction_generator=malfunction_from_params(stochastic_data))
391
392
    # reset to initialize agents_static
    env.reset()
393
    set_penalties_for_replay(env)
394
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
395
396
397
398
399
400
401
402
403
404
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
405
            Replay(
406
                position=(3, 2),
407
408
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
409
                malfunction=2,
u214892's avatar
u214892 committed
410
411
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
412
413
414
415
416
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
417
                position=(3, 2),
418
419
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
420
                malfunction=1,
u214892's avatar
u214892 committed
421
422
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
423
424
425
            ),
            # we haven't started moving yet --> stay here
            Replay(
426
                position=(3, 2),
427
428
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
429
                malfunction=0,
u214892's avatar
u214892 committed
430
431
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
432
            ),
433

434
            Replay(
435
                position=(3, 2),
436
437
438
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
439
440
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
441
            ),  # we start to move forward --> should go to next cell now
442
            Replay(
443
                position=(3, 3),
444
                direction=Grid4TransitionsEnum.EAST,
445
446
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
447
448
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
449
450
451
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
452
        target=env.agents[0].target,
453
        initial_position=(3, 2),
u214892's avatar
u214892 committed
454
        initial_direction=Grid4TransitionsEnum.EAST,
455
    )
456
    run_replay_config(env, [replay_config], activate_agents=False)
457
458


Erik Nygren's avatar
Erik Nygren committed
459
460
461
def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
462
    stochastic_data = {'malfunction_rate': 1,
Erik Nygren's avatar
Erik Nygren committed
463
464
465
466
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()
467
468
469
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
470
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
471
472
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
473
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
474
475
476
477
478
479
480
481
482
483
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

        _, reward, _, _ = env.step(action_dict)
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
484
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
485
486
487
488
489
490
491
492
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

    rail, rail_map = make_simple_rail2()
    random.seed(47)
    np.random.seed(1234)
493
494
495
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
496
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
497
498
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
499
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
500
501
502
503
504
505
506
507
508

    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
509
510
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
511
512
513
514

        _, reward, _, _ = env.step(action_dict)
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
515
516
517
518
519
520
521
522
523


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test
524
    stochastic_data = {'malfunction_rate': 5,
525
526
527
528
529
                       'min_duration': 4,
                       'max_duration': 4}

    rail, rail_map = make_simple_rail2()

530
531
532
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
533
534
    env.reset()
    # reset to initialize agents_static
535
    env.agents[0].speed_data['speed'] = 1. / 3.
536
537
538
539
540
541
542
543
544
545
546
547
548
    env.agents_static[0].target = (0, 0)

    env.reset(False, False, True)
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)

549
550
        if env.agents[0].malfunction_data['malfunction'] < 1:
            agent_can_move = True
551
552
553
        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
554
        # Check if the agent is still allowed to move in this step
555

556
557
558
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_can_move = False
        post_position = env.agents[0].speed_data['position_fraction']
559
560
561
562
563
        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position