test_flatland_malfunction.py 23.2 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
11
from flatland.envs.malfunction_generators import malfunction_from_params
u214892's avatar
u214892 committed
12
from flatland.envs.rail_env import RailEnv, RailEnvActions
13
14
15
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
16
17


18
class SingleAgentNavigationObs(ObservationBuilder):
19
    """
20
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
21
22
23
24
25
26
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
27
        super().__init__()
28
29

    def reset(self):
30
        pass
31

32
    def get(self, handle: int = 0) -> List[int]:
33
34
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
35
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
36
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
37
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
38
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
39
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
40
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
41
42
43
        else:
            return None

u214892's avatar
u214892 committed
44
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
45
46
47
48
49
50
51
52
53
54
55
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
56
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
57
58
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
59
60
61
62
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
63
            observation[np.argmin(min_distances)] = 1
64
65
66
67
68

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
69
    # Set fixed malfunction duration for this test
70
    stochastic_data = {'malfunction_rate': 1,
71
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
72
                       'max_duration': 3}
73
74
75

    rail, rail_map = make_simple_rail2()

76
77
78
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=1,
                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
79
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
80
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
81

82
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
83
84
    total_down_time = 0
    agent_old_position = env.agents[0].position
85
86
87

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
88
89
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
90

91
92
93
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

94
95
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
96
97
98
99
100
101
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
102
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
103
104
105
106
107
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
108
    # Check that the appropriate number of malfunctions is achieved
Erik Nygren's avatar
Erik Nygren committed
109
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 23, "Actual {}".format(
u214892's avatar
u214892 committed
110
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
111
112
113

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
114
115
116
117
118


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
119
    stochastic_data = {'malfunction_rate': 5,
120
121
                       'min_duration': 5,
                       'max_duration': 5}
u214892's avatar
u214892 committed
122

123
124
    rail, rail_map = make_simple_rail2()

125
126
127
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=10,
                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
128

129
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
130
    env.reset(True, True, False, random_seed=10)
131

Erik Nygren's avatar
Erik Nygren committed
132
    env.agents[0].target = (0, 0)
133
    # Next line only for test generation
134
    #agent_malfunction_list = [[] for i in range(10)]
Erik Nygren's avatar
Erik Nygren committed
135
136
137
138
139
140
141
142
143
144
    agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
     [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
     [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
     [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
     [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
     [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
145

Erik Nygren's avatar
Erik Nygren committed
146
    for step in range(20):
147
        action_dict: Dict[int, RailEnvActions] = {}
148
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
149
            # We randomly select an action
150
151
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
152
            #agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
153
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
u214892's avatar
u214892 committed
154
        env.step(action_dict)
155
    #print(agent_malfunction_list)
156

u214892's avatar
u214892 committed
157

158
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
159
    """Tests that malfunctions are working properly for agents before entering the environment!"""
160
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
161
    stochastic_data = {'malfunction_rate': 2,
162
163
164
165
166
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()

167
168
169
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=1), number_of_agents=10,
                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
170
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
171
    env.reset(False, False, False, random_seed=10)
172
    env.agents[0].target = (0, 0)
173

174
175
176
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
177
178
    assert env.agents[0].malfunction_data['malfunction'] == 0
    assert env.agents[1].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
179
    assert env.agents[2].malfunction_data['malfunction'] == 10
180
181
    assert env.agents[3].malfunction_data['malfunction'] == 0
    assert env.agents[4].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
182
    assert env.agents[5].malfunction_data['malfunction'] == 0
183
184
    assert env.agents[6].malfunction_data['malfunction'] == 0
    assert env.agents[7].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
185
186
187
188
189
    assert env.agents[8].malfunction_data['malfunction'] == 10
    assert env.agents[9].malfunction_data['malfunction'] == 10

    #for a in range(10):
    #  print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
190
191


192
193
def test_malfunction_values_and_behavior():
    """
194
    Test the malfunction counts down as desired
195
196
197
198
199
200
201
202
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()
    action_dict: Dict[int, RailEnvActions] = {}
203
    stochastic_data = {'malfunction_rate': 0.001,
204
205
                       'min_duration': 10,
                       'max_duration': 10}
206
207
208
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
Erik Nygren's avatar
Erik Nygren committed
209

210
211
212
    # reset to initialize agents_static
    env.reset(False, False, activate_agents=True, random_seed=10)

Erik Nygren's avatar
Erik Nygren committed
213
    # Assertions
214
    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
Erik Nygren's avatar
Erik Nygren committed
215
216
    print("[")
    for time_step in range(15):
217
218
219
        # Move in the env
        env.step(action_dict)
        # Check that next_step decreases as expected
Erik Nygren's avatar
Erik Nygren committed
220
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
221

222

223
def test_initial_malfunction():
224
    stochastic_data = {'malfunction_rate': 1000,  # Rate of malfunction occurence
u214892's avatar
u214892 committed
225
226
227
228
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

229
230
    rail, rail_map = make_simple_rail2()

231
232
233
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=10), number_of_agents=1,
                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
234
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
235
    env.reset(False, False, True, random_seed=10)
236
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
237
    env.agents[0].target = (0, 5)
238
    set_penalties_for_replay(env)
239
240
241
    replay_config = ReplayConfig(
        replay=[
            Replay(
242
                position=(3, 2),
243
244
245
246
247
248
249
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
250
                position=(3, 2),
251
252
253
254
255
256
257
258
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
259
                position=(3, 2),
260
261
262
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
263
                reward=env.step_penalty
264

265
            ),  # malfunctioning ends: starting and running at speed 1.0
266
            Replay(
267
                position=(3, 2),
268
                direction=Grid4TransitionsEnum.EAST,
269
270
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
271
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
272
273
            ),
            Replay(
274
                position=(3, 3),
275
                direction=Grid4TransitionsEnum.EAST,
276
277
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
278
                reward=env.step_penalty  # running at speed 1.0
279
280
281
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
282
        target=env.agents[0].target,
283
        initial_position=(3, 2),
u214892's avatar
u214892 committed
284
        initial_direction=Grid4TransitionsEnum.EAST,
285
    )
286
    run_replay_config(env, [replay_config])
287
288
289


def test_initial_malfunction_stop_moving():
290
291
292
293
294
295
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

296
    rail, rail_map = make_simple_rail2()
297

298
299
300
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=1,
                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
301
    env.reset()
302
303
304

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

305
    set_penalties_for_replay(env)
306
307
308
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
309
                position=None,
310
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
311
                action=RailEnvActions.MOVE_FORWARD,
312
313
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
314
315
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
316
317
            ),
            Replay(
318
                position=(3, 2),
319
320
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
321
                malfunction=2,
u214892's avatar
u214892 committed
322
323
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
324
325
326
327
328
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
329
                position=(3, 2),
330
331
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
332
                malfunction=1,
u214892's avatar
u214892 committed
333
334
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
335
336
337
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
338
                position=(3, 2),
339
340
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
341
                malfunction=0,
u214892's avatar
u214892 committed
342
343
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
344
345
346
            ),
            # we start to move forward --> should go to next cell now
            Replay(
347
                position=(3, 2),
348
349
350
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
351
352
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
353
354
            ),
            Replay(
355
                position=(3, 3),
356
                direction=Grid4TransitionsEnum.EAST,
357
358
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
359
360
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
361
362
363
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
364
        target=env.agents[0].target,
365
        initial_position=(3, 2),
u214892's avatar
u214892 committed
366
        initial_direction=Grid4TransitionsEnum.EAST,
367
    )
368
369

    run_replay_config(env, [replay_config], activate_agents=False)
370
371


372
def test_initial_malfunction_do_nothing():
373
374
375
376
377
378
379
380
381
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

382
383
    rail, rail_map = make_simple_rail2()

384
385
386
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=1,
                  malfunction_generator=malfunction_from_params(stochastic_data))
387
388
    # reset to initialize agents_static
    env.reset()
389
    set_penalties_for_replay(env)
390
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
391
392
393
394
395
396
397
398
399
400
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
401
            Replay(
402
                position=(3, 2),
403
404
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
405
                malfunction=2,
u214892's avatar
u214892 committed
406
407
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
408
409
410
411
412
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
413
                position=(3, 2),
414
415
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
416
                malfunction=1,
u214892's avatar
u214892 committed
417
418
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
419
420
421
            ),
            # we haven't started moving yet --> stay here
            Replay(
422
                position=(3, 2),
423
424
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
425
                malfunction=0,
u214892's avatar
u214892 committed
426
427
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
428
            ),
429

430
            Replay(
431
                position=(3, 2),
432
433
434
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
435
436
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
437
            ),  # we start to move forward --> should go to next cell now
438
            Replay(
439
                position=(3, 3),
440
                direction=Grid4TransitionsEnum.EAST,
441
442
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
443
444
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
445
446
447
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
448
        target=env.agents[0].target,
449
        initial_position=(3, 2),
u214892's avatar
u214892 committed
450
        initial_direction=Grid4TransitionsEnum.EAST,
451
    )
452
    run_replay_config(env, [replay_config], activate_agents=False)
453
454


Erik Nygren's avatar
Erik Nygren committed
455
456
457
def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
458
    stochastic_data = {'malfunction_rate': 1,
Erik Nygren's avatar
Erik Nygren committed
459
460
461
462
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()
463
464
465
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
466
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
467
468
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
469
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
470
471
472
473
474
475
476
477
478
479
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

        _, reward, _, _ = env.step(action_dict)
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
480
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
481
482
483
484
485
486
487
488
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

    rail, rail_map = make_simple_rail2()
    random.seed(47)
    np.random.seed(1234)
489
490
491
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
492
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
493
494
    # reset to initialize agents_static
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
495
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
496
497
498
499
500
501
502
503
504

    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
505
506
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
507
508
509
510

        _, reward, _, _ = env.step(action_dict)
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
511
512
513
514
515
516
517
518
519


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test
520
    stochastic_data = {'malfunction_rate': 5,
521
522
523
524
525
                       'min_duration': 4,
                       'max_duration': 4}

    rail, rail_map = make_simple_rail2()

526
527
528
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
529
530
    env.reset()
    # reset to initialize agents_static
531
    env.agents[0].speed_data['speed'] = 1. / 3.
532
533
534
535
536
537
538
539
540
541
542
543
544
    env.agents_static[0].target = (0, 0)

    env.reset(False, False, True)
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)

545
546
        if env.agents[0].malfunction_data['malfunction'] < 1:
            agent_can_move = True
547
548
549
        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
550
        # Check if the agent is still allowed to move in this step
551

552
553
554
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_can_move = False
        post_position = env.agents[0].speed_data['position_fraction']
555
556
557
558
559
        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position