test_flatland_malfunction.py 23.8 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
5
import numpy as np

6
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
7
from flatland.core.grid.grid4 import Grid4TransitionsEnum
8
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
9
from flatland.envs.agent_utils import RailAgentStatus
Erik Nygren's avatar
Erik Nygren committed
10
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
from flatland.envs.rail_generators import rail_from_grid_transition_map
13
from flatland.envs.line_generators import random_line_generator
14
from flatland.utils.simple_rail import make_simple_rail2
15
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
16
17


18
class SingleAgentNavigationObs(ObservationBuilder):
19
    """
20
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
21
22
23
24
25
26
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
27
        super().__init__()
28
29

    def reset(self):
30
        pass
31

32
    def get(self, handle: int = 0) -> List[int]:
33
34
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
35
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
36
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
37
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
38
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
39
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
40
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
41
42
43
        else:
            return None

u214892's avatar
u214892 committed
44
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
45
46
47
48
49
50
51
52
53
54
55
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
56
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
57
58
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
59
60
61
62
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
63
            observation[np.argmin(min_distances)] = 1
64
65
66
67
68

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
69
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
70
71
72
73
    stochastic_data = MalfunctionParameters(malfunction_rate=1,  # Rate of malfunction occurence
                                            min_duration=3,  # Minimal duration of malfunction
                                            max_duration=3  # Max duration of malfunction
                                            )
74
75
76

    rail, rail_map = make_simple_rail2()

Erik Nygren's avatar
Erik Nygren committed
77
78
79
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
80
                  line_generator=random_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
81
82
83
84
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
85
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
86

87
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
88
89
    total_down_time = 0
    agent_old_position = env.agents[0].position
90
91
92

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
93
94
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
95

96
97
98
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

99
        obs, all_rewards, done, _ = env.step(actions)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
100
101
        if done["__all__"]:
            break
102

Erik Nygren's avatar
Erik Nygren committed
103
104
105
106
107
108
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
109
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
110
111
112
113
114
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
115
    # Check that the appropriate number of malfunctions is achieved
Erik Nygren's avatar
Erik Nygren committed
116
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 23, "Actual {}".format(
u214892's avatar
u214892 committed
117
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
118
119
120

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
121
122
123


def test_malfunction_process_statistically():
124
    """Tests that malfunctions are produced by stochastic_data!"""
u214892's avatar
u214892 committed
125
    # Set fixed malfunction duration for this test
126
    stochastic_data = MalfunctionParameters(malfunction_rate=1/5,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
127
128
129
                                            min_duration=5,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
u214892's avatar
u214892 committed
130

131
132
    rail, rail_map = make_simple_rail2()

Erik Nygren's avatar
Erik Nygren committed
133
134
135
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
136
                  line_generator=random_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
137
138
139
140
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
141

Erik Nygren's avatar
Erik Nygren committed
142
    env.reset(True, True, False, random_seed=10)
143

Erik Nygren's avatar
Erik Nygren committed
144
    env.agents[0].target = (0, 0)
145
    # Next line only for test generation
Erik Nygren's avatar
Erik Nygren committed
146
    # agent_malfunction_list = [[] for i in range(10)]
147
148
149
150
151
152
153
154
155
156
    agent_malfunction_list = [[0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4],
                              [0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2],
                              [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
                              [0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
                              [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
                              [5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 5],
                              [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2],
                              [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4]]
157

Erik Nygren's avatar
Erik Nygren committed
158
    for step in range(20):
159
        action_dict: Dict[int, RailEnvActions] = {}
160
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
161
            # We randomly select an action
162
163
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
Erik Nygren's avatar
Erik Nygren committed
164
            # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
165
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
u214892's avatar
u214892 committed
166
        env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
167
    # print(agent_malfunction_list)
168

u214892's avatar
u214892 committed
169

170
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
171
    """Tests that malfunctions are working properly for agents before entering the environment!"""
172
    # Set fixed malfunction duration for this test
173
    stochastic_data = MalfunctionParameters(malfunction_rate=1/2,  # Rate of malfunction occurrence
Erik Nygren's avatar
Erik Nygren committed
174
175
176
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
177
178
179

    rail, rail_map = make_simple_rail2()

Erik Nygren's avatar
Erik Nygren committed
180
181
182
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
183
                  line_generator=random_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
184
185
186
187
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
188
    env.reset(False, False, False, random_seed=10)
189
    env.agents[0].target = (0, 0)
190

191
192
193
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
194
    assert env.agents[0].malfunction_data['malfunction'] == 0
195
196
197
198
199
200
201
    assert env.agents[1].malfunction_data['malfunction'] == 10
    assert env.agents[2].malfunction_data['malfunction'] == 0
    assert env.agents[3].malfunction_data['malfunction'] == 10
    assert env.agents[4].malfunction_data['malfunction'] == 10
    assert env.agents[5].malfunction_data['malfunction'] == 10
    assert env.agents[6].malfunction_data['malfunction'] == 10
    assert env.agents[7].malfunction_data['malfunction'] == 10
Erik Nygren's avatar
Erik Nygren committed
202
203
204
    assert env.agents[8].malfunction_data['malfunction'] == 10
    assert env.agents[9].malfunction_data['malfunction'] == 10

Erik Nygren's avatar
Erik Nygren committed
205
    # for a in range(10):
206
    # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
207
208


209
210
def test_malfunction_values_and_behavior():
    """
211
    Test the malfunction counts down as desired
212
213
214
215
216
217
218
219
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()
    action_dict: Dict[int, RailEnvActions] = {}
220
    stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
221
222
223
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
Erik Nygren's avatar
Erik Nygren committed
224
225
226
    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
227
                  line_generator=random_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
228
229
230
231
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
232

233
234
    env.reset(False, False, activate_agents=True, random_seed=10)

Erik Nygren's avatar
Erik Nygren committed
235
    # Assertions
236
    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
Erik Nygren's avatar
Erik Nygren committed
237
238
    print("[")
    for time_step in range(15):
239
240
241
        # Move in the env
        env.step(action_dict)
        # Check that next_step decreases as expected
Erik Nygren's avatar
Erik Nygren committed
242
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
243

244

245
def test_initial_malfunction():
246
    stochastic_data = MalfunctionParameters(malfunction_rate=1/1000,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
247
248
249
                                            min_duration=2,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
u214892's avatar
u214892 committed
250

251
252
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
253
254
    env = RailEnv(width=25,
                  height=30,
255
                  rail_generator=rail_from_grid_transition_map(rail),
256
                  line_generator=random_line_generator(seed=10),
u214892's avatar
u214892 committed
257
                  number_of_agents=1,
258
259
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  # Malfunction data generator
260
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
261
                  )
262
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
263
    env.reset(False, False, True, random_seed=10)
264
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
265
    env.agents[0].target = (0, 5)
266
    set_penalties_for_replay(env)
267
268
269
    replay_config = ReplayConfig(
        replay=[
            Replay(
270
                position=(3, 2),
271
272
273
274
275
276
277
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
278
                position=(3, 2),
279
280
281
282
283
284
285
286
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
287
                position=(3, 2),
288
289
290
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
291
                reward=env.step_penalty
292

293
            ),  # malfunctioning ends: starting and running at speed 1.0
294
            Replay(
295
                position=(3, 2),
296
                direction=Grid4TransitionsEnum.EAST,
297
298
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
299
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
300
301
            ),
            Replay(
302
                position=(3, 3),
303
                direction=Grid4TransitionsEnum.EAST,
304
305
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
306
                reward=env.step_penalty  # running at speed 1.0
307
308
309
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
310
        target=env.agents[0].target,
311
        initial_position=(3, 2),
u214892's avatar
u214892 committed
312
        initial_direction=Grid4TransitionsEnum.EAST,
313
    )
314
    run_replay_config(env, [replay_config])
315
316
317


def test_initial_malfunction_stop_moving():
318
    rail, rail_map = make_simple_rail2()
319

320
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
321
                  line_generator=random_line_generator(), number_of_agents=1,
Erik Nygren's avatar
Erik Nygren committed
322
                  obs_builder_object=SingleAgentNavigationObs())
323
    env.reset()
324
325
326

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

327
    set_penalties_for_replay(env)
328
329
330
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
331
                position=None,
332
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
333
                action=RailEnvActions.MOVE_FORWARD,
334
335
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
336
337
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
338
339
            ),
            Replay(
340
                position=(3, 2),
341
342
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
343
                malfunction=2,
u214892's avatar
u214892 committed
344
345
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
346
347
348
349
350
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
351
                position=(3, 2),
352
353
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
354
                malfunction=1,
u214892's avatar
u214892 committed
355
356
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
357
358
359
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
360
                position=(3, 2),
361
362
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
363
                malfunction=0,
u214892's avatar
u214892 committed
364
365
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
366
367
368
            ),
            # we start to move forward --> should go to next cell now
            Replay(
369
                position=(3, 2),
370
371
372
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
373
374
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
375
376
            ),
            Replay(
377
                position=(3, 3),
378
                direction=Grid4TransitionsEnum.EAST,
379
380
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
381
382
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
383
384
385
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
386
        target=env.agents[0].target,
387
        initial_position=(3, 2),
u214892's avatar
u214892 committed
388
        initial_direction=Grid4TransitionsEnum.EAST,
389
    )
390
391

    run_replay_config(env, [replay_config], activate_agents=False)
392
393


394
def test_initial_malfunction_do_nothing():
395
    stochastic_data = MalfunctionParameters(malfunction_rate=1/70,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
396
397
398
                                            min_duration=2,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
399

400
401
    rail, rail_map = make_simple_rail2()

402
403
    env = RailEnv(width=25,
                  height=30,
404
                  rail_generator=rail_from_grid_transition_map(rail),
405
                  line_generator=random_line_generator(),
406
                  number_of_agents=1,
407
408
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  # Malfunction data generator
409
                  )
410
    env.reset()
411
    set_penalties_for_replay(env)
412
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
413
414
415
416
417
418
419
420
421
422
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
423
            Replay(
424
                position=(3, 2),
425
426
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
427
                malfunction=2,
u214892's avatar
u214892 committed
428
429
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
430
431
432
433
434
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
435
                position=(3, 2),
436
437
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
438
                malfunction=1,
u214892's avatar
u214892 committed
439
440
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
441
442
443
            ),
            # we haven't started moving yet --> stay here
            Replay(
444
                position=(3, 2),
445
446
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
447
                malfunction=0,
u214892's avatar
u214892 committed
448
449
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
450
            ),
451

452
            Replay(
453
                position=(3, 2),
454
455
456
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
457
458
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
459
            ),  # we start to move forward --> should go to next cell now
460
            Replay(
461
                position=(3, 3),
462
                direction=Grid4TransitionsEnum.EAST,
463
464
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
465
466
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
467
468
469
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
470
        target=env.agents[0].target,
471
        initial_position=(3, 2),
u214892's avatar
u214892 committed
472
        initial_direction=Grid4TransitionsEnum.EAST,
473
    )
474
    run_replay_config(env, [replay_config], activate_agents=False)
475
476


Erik Nygren's avatar
Erik Nygren committed
477
478
479
480
def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    rail, rail_map = make_simple_rail2()
481
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
482
                  line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1)
483
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
484
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
485
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
486
487
488
489
490
491
492
493
494
495
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

        _, reward, _, _ = env.step(action_dict)
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
496
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
497
498
499
500
501
502
503
504
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

    rail, rail_map = make_simple_rail2()
    random.seed(47)
    np.random.seed(1234)
505
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
506
                  line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1)
507
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
508
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
509
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
510
511
512
513
514
515
516
517
518

    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
519
520
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
521
522
523
524

        _, reward, _, _ = env.step(action_dict)
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
525
526
527
528
529
530
531
532
533
534
535
536


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test

    rail, rail_map = make_simple_rail2()

537
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
538
                  line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1)
539
    env.reset()
540
    env.agents[0].speed_data['speed'] = 1. / 3.
u229589's avatar
u229589 committed
541
    env.agents[0].target = (0, 0)
542
543
544
545
546
547
548
549
550
551
552
553

    env.reset(False, False, True)
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)

554
555
        if env.agents[0].malfunction_data['malfunction'] < 1:
            agent_can_move = True
556
557
558
        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
559
        # Check if the agent is still allowed to move in this step
560

561
562
563
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_can_move = False
        post_position = env.agents[0].speed_data['position_fraction']
564
565
566
567
568
        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position