test_flatland_malfunction.py 23.6 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
Erik Nygren's avatar
Erik Nygren committed
11
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
u214892's avatar
u214892 committed
12
from flatland.envs.rail_env import RailEnv, RailEnvActions
13
14
15
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
16
17


18
class SingleAgentNavigationObs(ObservationBuilder):
19
    """
20
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
21
22
23
24
25
26
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
27
        super().__init__()
28
29

    def reset(self):
30
        pass
31

32
    def get(self, handle: int = 0) -> List[int]:
33
34
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
35
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
36
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
37
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
38
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
39
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
40
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
41
42
43
        else:
            return None

u214892's avatar
u214892 committed
44
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
45
46
47
48
49
50
51
52
53
54
55
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
56
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
57
58
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
59
60
61
62
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
63
            observation[np.argmin(min_distances)] = 1
64
65
66
67
68

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
69
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
70
71
72
73
    stochastic_data = MalfunctionParameters(malfunction_rate=1,  # Rate of malfunction occurence
                                            min_duration=3,  # Minimal duration of malfunction
                                            max_duration=3  # Max duration of malfunction
                                            )
74
75
76

    rail, rail_map = make_simple_rail2()

Erik Nygren's avatar
Erik Nygren committed
77
78
79
80
81
82
83
84
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
85
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
86

87
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
88
89
    total_down_time = 0
    agent_old_position = env.agents[0].position
90
91
92

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
93
94
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
95

96
97
98
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

99
100
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
101
102
103
104
105
106
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
107
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
108
109
110
111
112
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
113
    # Check that the appropriate number of malfunctions is achieved
Erik Nygren's avatar
Erik Nygren committed
114
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 23, "Actual {}".format(
u214892's avatar
u214892 committed
115
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
116
117
118

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
119
120
121
122
123


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
124
125
126
127
    stochastic_data = MalfunctionParameters(malfunction_rate=5,  # Rate of malfunction occurence
                                            min_duration=5,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
u214892's avatar
u214892 committed
128

129
130
    rail, rail_map = make_simple_rail2()

Erik Nygren's avatar
Erik Nygren committed
131
132
133
134
135
136
137
138
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
139

Erik Nygren's avatar
Erik Nygren committed
140
    env.reset(True, True, False, random_seed=10)
141

Erik Nygren's avatar
Erik Nygren committed
142
    env.agents[0].target = (0, 0)
143
    # Next line only for test generation
Erik Nygren's avatar
Erik Nygren committed
144
    # agent_malfunction_list = [[] for i in range(10)]
Erik Nygren's avatar
Erik Nygren committed
145
    agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
Erik Nygren's avatar
Erik Nygren committed
146
147
148
149
150
151
152
153
154
     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
     [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
     [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
     [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
     [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
     [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
     [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
155

Erik Nygren's avatar
Erik Nygren committed
156
    for step in range(20):
157
        action_dict: Dict[int, RailEnvActions] = {}
158
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
159
            # We randomly select an action
160
161
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
Erik Nygren's avatar
Erik Nygren committed
162
            # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
163
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
u214892's avatar
u214892 committed
164
        env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
165
    # print(agent_malfunction_list)
166

u214892's avatar
u214892 committed
167

168
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
169
    """Tests that malfunctions are working properly for agents before entering the environment!"""
170
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
171
172
173
174
    stochastic_data = MalfunctionParameters(malfunction_rate=2,  # Rate of malfunction occurence
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
175
176
177

    rail, rail_map = make_simple_rail2()

Erik Nygren's avatar
Erik Nygren committed
178
179
180
181
182
183
184
185
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
186
    env.reset(False, False, False, random_seed=10)
187
    env.agents[0].target = (0, 0)
188

189
190
191
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
192
193
    assert env.agents[0].malfunction_data['malfunction'] == 0
    assert env.agents[1].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
194
    assert env.agents[2].malfunction_data['malfunction'] == 10
195
196
    assert env.agents[3].malfunction_data['malfunction'] == 0
    assert env.agents[4].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
197
    assert env.agents[5].malfunction_data['malfunction'] == 0
198
199
    assert env.agents[6].malfunction_data['malfunction'] == 0
    assert env.agents[7].malfunction_data['malfunction'] == 0
Erik Nygren's avatar
Erik Nygren committed
200
201
202
    assert env.agents[8].malfunction_data['malfunction'] == 10
    assert env.agents[9].malfunction_data['malfunction'] == 10

Erik Nygren's avatar
Erik Nygren committed
203
    # for a in range(10):
Erik Nygren's avatar
Erik Nygren committed
204
    #  print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
205
206


207
208
def test_malfunction_values_and_behavior():
    """
209
    Test the malfunction counts down as desired
210
211
212
213
214
215
216
217
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()
    action_dict: Dict[int, RailEnvActions] = {}
Erik Nygren's avatar
Erik Nygren committed
218
219
220
221
    stochastic_data = MalfunctionParameters(malfunction_rate=0.001,  # Rate of malfunction occurence
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
Erik Nygren's avatar
Erik Nygren committed
222
223
224
225
226
227
228
229
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
230

231
232
    env.reset(False, False, activate_agents=True, random_seed=10)

Erik Nygren's avatar
Erik Nygren committed
233
    # Assertions
234
    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
Erik Nygren's avatar
Erik Nygren committed
235
236
    print("[")
    for time_step in range(15):
237
238
239
        # Move in the env
        env.step(action_dict)
        # Check that next_step decreases as expected
Erik Nygren's avatar
Erik Nygren committed
240
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
241

242

243
def test_initial_malfunction():
Erik Nygren's avatar
Erik Nygren committed
244
245
246
247
    stochastic_data = MalfunctionParameters(malfunction_rate=1000,  # Rate of malfunction occurence
                                            min_duration=2,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
u214892's avatar
u214892 committed
248

249
250
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
251
252
    env = RailEnv(width=25,
                  height=30,
253
                  rail_generator=rail_from_grid_transition_map(rail),
254
                  schedule_generator=random_schedule_generator(seed=10),
u214892's avatar
u214892 committed
255
                  number_of_agents=1,
256
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),  # Malfunction data generator
257
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
258
                  )
259
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
260
    env.reset(False, False, True, random_seed=10)
261
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
262
    env.agents[0].target = (0, 5)
263
    set_penalties_for_replay(env)
264
265
266
    replay_config = ReplayConfig(
        replay=[
            Replay(
267
                position=(3, 2),
268
269
270
271
272
273
274
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
275
                position=(3, 2),
276
277
278
279
280
281
282
283
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
284
                position=(3, 2),
285
286
287
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
288
                reward=env.step_penalty
289

290
            ),  # malfunctioning ends: starting and running at speed 1.0
291
            Replay(
292
                position=(3, 2),
293
                direction=Grid4TransitionsEnum.EAST,
294
295
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
296
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
297
298
            ),
            Replay(
299
                position=(3, 3),
300
                direction=Grid4TransitionsEnum.EAST,
301
302
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
303
                reward=env.step_penalty  # running at speed 1.0
304
305
306
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
307
        target=env.agents[0].target,
308
        initial_position=(3, 2),
u214892's avatar
u214892 committed
309
        initial_direction=Grid4TransitionsEnum.EAST,
310
    )
311
    run_replay_config(env, [replay_config])
312
313
314


def test_initial_malfunction_stop_moving():
315
    rail, rail_map = make_simple_rail2()
316

317
318
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(), number_of_agents=1,
Erik Nygren's avatar
Erik Nygren committed
319
                  obs_builder_object=SingleAgentNavigationObs())
320
    env.reset()
321
322
323

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

324
    set_penalties_for_replay(env)
325
326
327
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
328
                position=None,
329
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
330
                action=RailEnvActions.MOVE_FORWARD,
331
332
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
333
334
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
335
336
            ),
            Replay(
337
                position=(3, 2),
338
339
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
340
                malfunction=2,
u214892's avatar
u214892 committed
341
342
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
343
344
345
346
347
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
348
                position=(3, 2),
349
350
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
351
                malfunction=1,
u214892's avatar
u214892 committed
352
353
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
354
355
356
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
357
                position=(3, 2),
358
359
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
360
                malfunction=0,
u214892's avatar
u214892 committed
361
362
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
363
364
365
            ),
            # we start to move forward --> should go to next cell now
            Replay(
366
                position=(3, 2),
367
368
369
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
370
371
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
372
373
            ),
            Replay(
374
                position=(3, 3),
375
                direction=Grid4TransitionsEnum.EAST,
376
377
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
378
379
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
380
381
382
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
383
        target=env.agents[0].target,
384
        initial_position=(3, 2),
u214892's avatar
u214892 committed
385
        initial_direction=Grid4TransitionsEnum.EAST,
386
    )
387
388

    run_replay_config(env, [replay_config], activate_agents=False)
389
390


391
def test_initial_malfunction_do_nothing():
392
393
394
    random.seed(0)
    np.random.seed(0)

Erik Nygren's avatar
Erik Nygren committed
395
396
397
398
    stochastic_data = MalfunctionParameters(malfunction_rate=70,  # Rate of malfunction occurence
                                            min_duration=2,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
399

400
401
    rail, rail_map = make_simple_rail2()

402
403
    env = RailEnv(width=25,
                  height=30,
404
405
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
406
                  number_of_agents=1,
407
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),  # Malfunction data generator
408
                  )
409
    env.reset()
410
    set_penalties_for_replay(env)
411
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
412
413
414
415
416
417
418
419
420
421
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
422
            Replay(
423
                position=(3, 2),
424
425
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
426
                malfunction=2,
u214892's avatar
u214892 committed
427
428
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
429
430
431
432
433
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
434
                position=(3, 2),
435
436
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
437
                malfunction=1,
u214892's avatar
u214892 committed
438
439
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
440
441
442
            ),
            # we haven't started moving yet --> stay here
            Replay(
443
                position=(3, 2),
444
445
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
446
                malfunction=0,
u214892's avatar
u214892 committed
447
448
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
449
            ),
450

451
            Replay(
452
                position=(3, 2),
453
454
455
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
456
457
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
458
            ),  # we start to move forward --> should go to next cell now
459
            Replay(
460
                position=(3, 3),
461
                direction=Grid4TransitionsEnum.EAST,
462
463
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
464
465
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
466
467
468
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
469
        target=env.agents[0].target,
470
        initial_position=(3, 2),
u214892's avatar
u214892 committed
471
        initial_direction=Grid4TransitionsEnum.EAST,
472
    )
473
    run_replay_config(env, [replay_config], activate_agents=False)
474
475


Erik Nygren's avatar
Erik Nygren committed
476
477
478
479
def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    rail, rail_map = make_simple_rail2()
480
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
Erik Nygren's avatar
Erik Nygren committed
481
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
482
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
483
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
484
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
485
486
487
488
489
490
491
492
493
494
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

        _, reward, _, _ = env.step(action_dict)
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
495
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
496
497
498
499
500
501
502
503
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

    rail, rail_map = make_simple_rail2()
    random.seed(47)
    np.random.seed(1234)
504
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
Erik Nygren's avatar
Erik Nygren committed
505
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
506
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
507
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
508
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
509
510
511
512
513
514
515
516
517

    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
518
519
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
520
521
522
523

        _, reward, _, _ = env.step(action_dict)
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
524
525
526
527
528
529
530
531
532
533
534
535


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()

536
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
Erik Nygren's avatar
Erik Nygren committed
537
                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
538
    env.reset()
539
    env.agents[0].speed_data['speed'] = 1. / 3.
u229589's avatar
u229589 committed
540
    env.agents[0].target = (0, 0)
541
542
543
544
545
546
547
548
549
550
551
552

    env.reset(False, False, True)
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)

553
554
        if env.agents[0].malfunction_data['malfunction'] < 1:
            agent_can_move = True
555
556
557
        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
558
        # Check if the agent is still allowed to move in this step
559

560
561
562
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_can_move = False
        post_position = env.agents[0].speed_data['position_fraction']
563
564
565
566
567
        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position