test_flatland_malfunction.py 24.2 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
5
import numpy as np

6
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
7
from flatland.core.grid.grid4 import Grid4TransitionsEnum
8
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
9
from flatland.envs.agent_utils import RailAgentStatus
Erik Nygren's avatar
Erik Nygren committed
10
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
from flatland.envs.rail_generators import rail_from_grid_transition_map
13
from flatland.envs.line_generators import sparse_line_generator
14
from flatland.utils.simple_rail import make_simple_rail2
15
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
16
17


18
class SingleAgentNavigationObs(ObservationBuilder):
19
    """
20
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
21
22
23
24
25
26
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
27
        super().__init__()
28
29

    def reset(self):
30
        pass
31

32
    def get(self, handle: int = 0) -> List[int]:
33
34
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
35
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
36
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
37
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
38
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
39
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
40
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
41
42
43
        else:
            return None

u214892's avatar
u214892 committed
44
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
45
46
47
48
49
50
51
52
53
54
55
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
56
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
57
58
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
59
60
61
62
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
63
            observation[np.argmin(min_distances)] = 1
64
65
66
67
68

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
69
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
70
71
72
73
    stochastic_data = MalfunctionParameters(malfunction_rate=1,  # Rate of malfunction occurence
                                            min_duration=3,  # Minimal duration of malfunction
                                            max_duration=3  # Max duration of malfunction
                                            )
74

75
    rail, rail_map, optionals = make_simple_rail2()
76

Erik Nygren's avatar
Erik Nygren committed
77
78
    env = RailEnv(width=25,
                  height=30,
79
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
80
                  line_generator=sparse_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
81
82
83
84
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
85
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
86

87
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
88
89
    total_down_time = 0
    agent_old_position = env.agents[0].position
90
91
92

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
93
94
95
    
    # Add in max episode steps because scheudule generator sets it to 0 for dummy data
    env._max_episode_steps = 200
96
97
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
98

99
100
101
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

102
        obs, all_rewards, done, _ = env.step(actions)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
103
104
        if done["__all__"]:
            break
105

Erik Nygren's avatar
Erik Nygren committed
106
107
108
109
110
111
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
112
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
113
114
115
116
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']
Erik Nygren's avatar
Erik Nygren committed
117
    # Check that the appropriate number of malfunctions is achieved
Dipam Chakraborty's avatar
Dipam Chakraborty committed
118
119
    # Dipam: The number of malfunctions varies by seed
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
u214892's avatar
u214892 committed
120
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
121
122
123

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
124
125
126


def test_malfunction_process_statistically():
127
    """Tests that malfunctions are produced by stochastic_data!"""
u214892's avatar
u214892 committed
128
    # Set fixed malfunction duration for this test
129
    stochastic_data = MalfunctionParameters(malfunction_rate=1/5,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
130
131
132
                                            min_duration=5,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
u214892's avatar
u214892 committed
133

134
    rail, rail_map, optionals = make_simple_rail2()
135

Erik Nygren's avatar
Erik Nygren committed
136
137
    env = RailEnv(width=25,
                  height=30,
138
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
139
                  line_generator=sparse_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
140
141
142
143
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
144

Erik Nygren's avatar
Erik Nygren committed
145
    env.reset(True, True, False, random_seed=10)
146

Erik Nygren's avatar
Erik Nygren committed
147
    env.agents[0].target = (0, 0)
148
    # Next line only for test generation
Erik Nygren's avatar
Erik Nygren committed
149
    # agent_malfunction_list = [[] for i in range(10)]
150
151
152
153
154
155
156
157
158
159
    agent_malfunction_list = [[0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4],
                              [0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2],
                              [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
                              [0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
                              [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
                              [5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 5],
                              [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2],
                              [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4]]
160

Erik Nygren's avatar
Erik Nygren committed
161
    for step in range(20):
162
        action_dict: Dict[int, RailEnvActions] = {}
163
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
164
            # We randomly select an action
165
166
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
Erik Nygren's avatar
Erik Nygren committed
167
            # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
168
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
u214892's avatar
u214892 committed
169
        env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
170
    # print(agent_malfunction_list)
171

u214892's avatar
u214892 committed
172

173
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
174
    """Tests that malfunctions are working properly for agents before entering the environment!"""
175
    # Set fixed malfunction duration for this test
176
    stochastic_data = MalfunctionParameters(malfunction_rate=1/2,  # Rate of malfunction occurrence
Erik Nygren's avatar
Erik Nygren committed
177
178
179
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
180

181
    rail, rail_map, optionals = make_simple_rail2()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
182
    
Erik Nygren's avatar
Erik Nygren committed
183
184
    env = RailEnv(width=25,
                  height=30,
185
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
186
                  line_generator=sparse_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
187
188
189
190
                  number_of_agents=10,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
191
    env.reset(False, False, False, random_seed=10)
192
    env.agents[0].target = (0, 0)
193

194
195
196
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
197
    assert env.agents[0].malfunction_data['malfunction'] == 0
198
199
200
201
202
203
204
    assert env.agents[1].malfunction_data['malfunction'] == 10
    assert env.agents[2].malfunction_data['malfunction'] == 0
    assert env.agents[3].malfunction_data['malfunction'] == 10
    assert env.agents[4].malfunction_data['malfunction'] == 10
    assert env.agents[5].malfunction_data['malfunction'] == 10
    assert env.agents[6].malfunction_data['malfunction'] == 10
    assert env.agents[7].malfunction_data['malfunction'] == 10
Erik Nygren's avatar
Erik Nygren committed
205
206
207
    assert env.agents[8].malfunction_data['malfunction'] == 10
    assert env.agents[9].malfunction_data['malfunction'] == 10

Erik Nygren's avatar
Erik Nygren committed
208
    # for a in range(10):
209
    # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
210
211


212
213
def test_malfunction_values_and_behavior():
    """
214
    Test the malfunction counts down as desired
215
216
217
218
219
220
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

221
    rail, rail_map, optionals = make_simple_rail2()
222
    action_dict: Dict[int, RailEnvActions] = {}
223
    stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
224
225
226
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
Erik Nygren's avatar
Erik Nygren committed
227
228
    env = RailEnv(width=25,
                  height=30,
229
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
230
                  line_generator=sparse_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
231
232
233
234
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
235

236
237
    env.reset(False, False, activate_agents=True, random_seed=10)

Erik Nygren's avatar
Erik Nygren committed
238
    # Assertions
239
    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
Erik Nygren's avatar
Erik Nygren committed
240
241
    print("[")
    for time_step in range(15):
242
243
244
        # Move in the env
        env.step(action_dict)
        # Check that next_step decreases as expected
Erik Nygren's avatar
Erik Nygren committed
245
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
246

247

248
def test_initial_malfunction():
249
    stochastic_data = MalfunctionParameters(malfunction_rate=1/1000,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
250
251
252
                                            min_duration=2,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
u214892's avatar
u214892 committed
253

254
    rail, rail_map, optionals = make_simple_rail2()
255

u214892's avatar
u214892 committed
256
257
    env = RailEnv(width=25,
                  height=30,
258
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
259
                  line_generator=sparse_line_generator(seed=10),
u214892's avatar
u214892 committed
260
                  number_of_agents=1,
261
262
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  # Malfunction data generator
263
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
264
                  )
265
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
266
    env.reset(False, False, True, random_seed=10)
267
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
268
    env.agents[0].target = (0, 5)
269
    set_penalties_for_replay(env)
270
271
272
    replay_config = ReplayConfig(
        replay=[
            Replay(
273
                position=(3, 2),
274
275
276
277
278
279
280
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
281
                position=(3, 2),
282
283
284
285
286
287
288
289
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
290
                position=(3, 2),
291
292
293
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
294
                reward=env.step_penalty
295

296
            ),  # malfunctioning ends: starting and running at speed 1.0
297
            Replay(
298
                position=(3, 2),
299
                direction=Grid4TransitionsEnum.EAST,
300
301
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
302
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
303
304
            ),
            Replay(
305
                position=(3, 3),
306
                direction=Grid4TransitionsEnum.EAST,
307
308
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
309
                reward=env.step_penalty  # running at speed 1.0
310
311
312
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
313
        target=env.agents[0].target,
314
        initial_position=(3, 2),
u214892's avatar
u214892 committed
315
        initial_direction=Grid4TransitionsEnum.EAST,
316
    )
317
    run_replay_config(env, [replay_config])
318
319
320


def test_initial_malfunction_stop_moving():
321
    rail, rail_map, optionals = make_simple_rail2()
322

323
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
324
                  line_generator=sparse_line_generator(), number_of_agents=1,
Erik Nygren's avatar
Erik Nygren committed
325
                  obs_builder_object=SingleAgentNavigationObs())
326
    env.reset()
327
328
329

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

330
    set_penalties_for_replay(env)
331
332
333
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
334
                position=None,
335
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
336
                action=RailEnvActions.MOVE_FORWARD,
337
338
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
339
340
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
341
342
            ),
            Replay(
343
                position=(3, 2),
344
345
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
346
                malfunction=2,
u214892's avatar
u214892 committed
347
348
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
349
350
351
352
353
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
354
                position=(3, 2),
355
356
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
357
                malfunction=1,
u214892's avatar
u214892 committed
358
359
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
360
361
362
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
363
                position=(3, 2),
364
365
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
366
                malfunction=0,
u214892's avatar
u214892 committed
367
368
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
369
370
371
            ),
            # we start to move forward --> should go to next cell now
            Replay(
372
                position=(3, 2),
373
374
375
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
376
377
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
378
379
            ),
            Replay(
380
                position=(3, 3),
381
                direction=Grid4TransitionsEnum.EAST,
382
383
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
384
385
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
386
387
388
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
389
        target=env.agents[0].target,
390
        initial_position=(3, 2),
u214892's avatar
u214892 committed
391
        initial_direction=Grid4TransitionsEnum.EAST,
392
    )
393
394

    run_replay_config(env, [replay_config], activate_agents=False)
395
396


397
def test_initial_malfunction_do_nothing():
398
    stochastic_data = MalfunctionParameters(malfunction_rate=1/70,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
399
400
401
                                            min_duration=2,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
402

403
    rail, rail_map, optionals = make_simple_rail2()
404

405
406
    env = RailEnv(width=25,
                  height=30,
407
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
408
                  line_generator=sparse_line_generator(),
409
                  number_of_agents=1,
410
411
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  # Malfunction data generator
412
                  )
413
    env.reset()
414
    set_penalties_for_replay(env)
415
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
416
417
418
419
420
421
422
423
424
425
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
426
            Replay(
427
                position=(3, 2),
428
429
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
430
                malfunction=2,
u214892's avatar
u214892 committed
431
432
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
433
434
435
436
437
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
438
                position=(3, 2),
439
440
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
441
                malfunction=1,
u214892's avatar
u214892 committed
442
443
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
444
445
446
            ),
            # we haven't started moving yet --> stay here
            Replay(
447
                position=(3, 2),
448
449
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
450
                malfunction=0,
u214892's avatar
u214892 committed
451
452
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
453
            ),
454

455
            Replay(
456
                position=(3, 2),
457
458
459
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
460
461
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
462
            ),  # we start to move forward --> should go to next cell now
463
            Replay(
464
                position=(3, 3),
465
                direction=Grid4TransitionsEnum.EAST,
466
467
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
468
469
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
470
471
472
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
473
        target=env.agents[0].target,
474
        initial_position=(3, 2),
u214892's avatar
u214892 committed
475
        initial_direction=Grid4TransitionsEnum.EAST,
476
    )
477
    run_replay_config(env, [replay_config], activate_agents=False)
478
479


Erik Nygren's avatar
Erik Nygren committed
480
481
482
def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
483
484
    rail, rail_map, optionals = make_simple_rail2()
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
485
                  line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
486
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
487
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
488
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
489
490
491
492
493
494
495
496
497
498
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

        _, reward, _, _ = env.step(action_dict)
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
499
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
500
501
502
503
504
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

505
    rail, rail_map, optionals = make_simple_rail2()
Erik Nygren's avatar
Erik Nygren committed
506
507
    random.seed(47)
    np.random.seed(1234)
508
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
509
                  line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
510
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
511
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
512
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
513
514
515
516
517
518
519
520
521

    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
522
523
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
524
525
526
527

        _, reward, _, _ = env.step(action_dict)
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
528
529
530
531
532
533
534
535
536
537


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test

538
    rail, rail_map, optionals = make_simple_rail2()
539

540
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
541
                  line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
542
    env.reset()
543
    env.agents[0].speed_data['speed'] = 1. / 3.
u229589's avatar
u229589 committed
544
    env.agents[0].target = (0, 0)
545
546
547
548
549
550
551
552
553
554
555
556

    env.reset(False, False, True)
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)

557
558
        if env.agents[0].malfunction_data['malfunction'] < 1:
            agent_can_move = True
559
560
561
        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
562
        # Check if the agent is still allowed to move in this step
563

564
565
566
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_can_move = False
        post_position = env.agents[0].speed_data['position_fraction']
567
568
569
570
571
        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position