test_flatland_malfunction.py 23.8 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
5
import numpy as np

6
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
7
from flatland.core.grid.grid4 import Grid4TransitionsEnum
8
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
9
from flatland.envs.agent_utils import RailAgentStatus
Erik Nygren's avatar
Erik Nygren committed
10
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
from flatland.envs.rail_generators import rail_from_grid_transition_map
13
from flatland.envs.line_generators import sparse_line_generator
14
from flatland.utils.simple_rail import make_simple_rail2
15
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
16
17


18
class SingleAgentNavigationObs(ObservationBuilder):
19
    """
20
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
21
22
23
24
25
26
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
27
        super().__init__()
28
29

    def reset(self):
30
        pass
31

32
    def get(self, handle: int = 0) -> List[int]:
33
34
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
35
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
36
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
37
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
38
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
39
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
40
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
41
42
43
        else:
            return None

u214892's avatar
u214892 committed
44
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
45
46
47
48
49
50
51
52
53
54
55
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
56
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
57
58
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
59
60
61
62
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
63
            observation[np.argmin(min_distances)] = 1
64
65
66
67
68

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
69
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
70
71
72
73
    stochastic_data = MalfunctionParameters(malfunction_rate=1,  # Rate of malfunction occurence
                                            min_duration=3,  # Minimal duration of malfunction
                                            max_duration=3  # Max duration of malfunction
                                            )
74

75
    rail, rail_map, optionals = make_simple_rail2()
76

Erik Nygren's avatar
Erik Nygren committed
77
78
    env = RailEnv(width=25,
                  height=30,
79
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
80
                  line_generator=sparse_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
81
82
83
84
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
85
86
87
88
    obs, info = env.reset(False, False, random_seed=10)
    for a_idx in range(len(env.agents)):
        env.agents[a_idx].position =  env.agents[a_idx].initial_position
        env.agents[a_idx].status = RailAgentStatus.ACTIVE
Erik Nygren's avatar
Erik Nygren committed
89

90
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
91
92
    total_down_time = 0
    agent_old_position = env.agents[0].position
93
94
95

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
96
97
98
    
    # Add in max episode steps because scheudule generator sets it to 0 for dummy data
    env._max_episode_steps = 200
99
100
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
101

102
103
104
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

105
        obs, all_rewards, done, _ = env.step(actions)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
106
107
        if done["__all__"]:
            break
108

Erik Nygren's avatar
Erik Nygren committed
109
110
111
112
113
114
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
115
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
116
117
118
119
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']
Erik Nygren's avatar
Erik Nygren committed
120
    # Check that the appropriate number of malfunctions is achieved
Dipam Chakraborty's avatar
Dipam Chakraborty committed
121
122
    # Dipam: The number of malfunctions varies by seed
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
u214892's avatar
u214892 committed
123
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
124
125
126

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
127
128
129


def test_malfunction_process_statistically():
130
    """Tests that malfunctions are produced by stochastic_data!"""
u214892's avatar
u214892 committed
131
    # Set fixed malfunction duration for this test
132
    stochastic_data = MalfunctionParameters(malfunction_rate=1/5,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
133
134
135
                                            min_duration=5,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
u214892's avatar
u214892 committed
136

137
    rail, rail_map, optionals = make_simple_rail2()
138

Erik Nygren's avatar
Erik Nygren committed
139
140
    env = RailEnv(width=25,
                  height=30,
141
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
142
                  line_generator=sparse_line_generator(),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
143
                  number_of_agents=2,
Erik Nygren's avatar
Erik Nygren committed
144
145
146
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
147

148
    env.reset(True, True, random_seed=10)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
149
    env._max_episode_steps = 1000
150

Erik Nygren's avatar
Erik Nygren committed
151
    env.agents[0].target = (0, 0)
152
    # Next line only for test generation
Dipam Chakraborty's avatar
Dipam Chakraborty committed
153
154
155
156
    # agent_malfunction_list = [[] for i in range(2)]
    agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], 
                              [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 5]]
    
Erik Nygren's avatar
Erik Nygren committed
157
    for step in range(20):
158
        action_dict: Dict[int, RailEnvActions] = {}
159
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
160
            # We randomly select an action
161
162
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
Erik Nygren's avatar
Erik Nygren committed
163
            # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
164
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
u214892's avatar
u214892 committed
165
        env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
166
    # print(agent_malfunction_list)
167

u214892's avatar
u214892 committed
168

169
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
170
    """Tests that malfunctions are working properly for agents before entering the environment!"""
171
    # Set fixed malfunction duration for this test
172
    stochastic_data = MalfunctionParameters(malfunction_rate=1/2,  # Rate of malfunction occurrence
Erik Nygren's avatar
Erik Nygren committed
173
174
175
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
176

177
    rail, rail_map, optionals = make_simple_rail2()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
178
    
Erik Nygren's avatar
Erik Nygren committed
179
180
    env = RailEnv(width=25,
                  height=30,
181
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
182
                  line_generator=sparse_line_generator(),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
183
                  number_of_agents=2,
Erik Nygren's avatar
Erik Nygren committed
184
185
186
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
187
    env.reset(False, False, random_seed=10)
188
    env.agents[0].target = (0, 0)
189

190
191
192
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
193
    assert env.agents[0].malfunction_data['malfunction'] == 0
194
    assert env.agents[1].malfunction_data['malfunction'] == 10
195
196


197
198
def test_malfunction_values_and_behavior():
    """
199
    Test the malfunction counts down as desired
200
201
202
203
204
205
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

206
    rail, rail_map, optionals = make_simple_rail2()
207
    action_dict: Dict[int, RailEnvActions] = {}
208
    stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
209
210
211
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
Erik Nygren's avatar
Erik Nygren committed
212
213
    env = RailEnv(width=25,
                  height=30,
214
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
215
                  line_generator=sparse_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
216
217
218
219
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
220

221
    env.reset(False, False, random_seed=10)
222

Erik Nygren's avatar
Erik Nygren committed
223
    # Assertions
224
    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
Erik Nygren's avatar
Erik Nygren committed
225
226
    print("[")
    for time_step in range(15):
227
        # Move in the env
Dipam Chakraborty's avatar
Dipam Chakraborty committed
228
        _, _, dones,_ = env.step(action_dict)
229
        # Check that next_step decreases as expected
Erik Nygren's avatar
Erik Nygren committed
230
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
231
232
        if dones['__all__']:
            break
233

234

235
def test_initial_malfunction():
236
    stochastic_data = MalfunctionParameters(malfunction_rate=1/1000,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
237
238
239
                                            min_duration=2,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
u214892's avatar
u214892 committed
240

241
    rail, rail_map, optionals = make_simple_rail2()
242

u214892's avatar
u214892 committed
243
244
    env = RailEnv(width=25,
                  height=30,
245
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
246
                  line_generator=sparse_line_generator(seed=10),
u214892's avatar
u214892 committed
247
                  number_of_agents=1,
248
249
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  # Malfunction data generator
250
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
251
                  )
252
    # reset to initialize agents_static
253
    env.reset(False, False, random_seed=10)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
254
    env._max_episode_steps = 1000
255
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
256
    env.agents[0].target = (0, 5)
257
    set_penalties_for_replay(env)
258
259
260
    replay_config = ReplayConfig(
        replay=[
            Replay(
261
                position=(3, 2),
262
263
264
265
266
267
268
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
269
                position=(3, 2),
270
271
272
273
274
275
276
277
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
278
                position=(3, 2),
279
280
281
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
282
                reward=env.step_penalty
283

284
            ),  # malfunctioning ends: starting and running at speed 1.0
285
            Replay(
286
                position=(3, 2),
287
                direction=Grid4TransitionsEnum.EAST,
288
289
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
290
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
291
292
            ),
            Replay(
293
                position=(3, 3),
294
                direction=Grid4TransitionsEnum.EAST,
295
296
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
297
                reward=env.step_penalty  # running at speed 1.0
298
299
300
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
301
        target=env.agents[0].target,
302
        initial_position=(3, 2),
u214892's avatar
u214892 committed
303
        initial_direction=Grid4TransitionsEnum.EAST,
304
    )
Dipam Chakraborty's avatar
Dipam Chakraborty committed
305
    run_replay_config(env, [replay_config], skip_reward_check=True)
306
307
308


def test_initial_malfunction_stop_moving():
309
    rail, rail_map, optionals = make_simple_rail2()
310

311
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
312
                  line_generator=sparse_line_generator(), number_of_agents=1,
Erik Nygren's avatar
Erik Nygren committed
313
                  obs_builder_object=SingleAgentNavigationObs())
314
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
315
316
    
    env._max_episode_steps = 1000
317
318
319

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

320
    set_penalties_for_replay(env)
321
322
323
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
324
                position=None,
325
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
326
                action=RailEnvActions.MOVE_FORWARD,
327
328
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
329
330
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
331
332
            ),
            Replay(
333
                position=(3, 2),
334
335
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
336
                malfunction=2,
u214892's avatar
u214892 committed
337
338
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
339
340
341
342
343
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
344
                position=(3, 2),
345
346
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
347
                malfunction=1,
u214892's avatar
u214892 committed
348
349
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
350
351
352
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
353
                position=(3, 2),
354
355
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
356
                malfunction=0,
u214892's avatar
u214892 committed
357
358
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
359
360
361
            ),
            # we start to move forward --> should go to next cell now
            Replay(
362
                position=(3, 2),
363
364
365
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
366
367
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
368
369
            ),
            Replay(
370
                position=(3, 3),
371
                direction=Grid4TransitionsEnum.EAST,
372
373
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
374
375
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
376
377
378
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
379
        target=env.agents[0].target,
380
        initial_position=(3, 2),
u214892's avatar
u214892 committed
381
        initial_direction=Grid4TransitionsEnum.EAST,
382
    )
383

Dipam Chakraborty's avatar
Dipam Chakraborty committed
384
    run_replay_config(env, [replay_config], activate_agents=False, skip_reward_check=True)
385
386


387
def test_initial_malfunction_do_nothing():
388
    stochastic_data = MalfunctionParameters(malfunction_rate=1/70,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
389
390
391
                                            min_duration=2,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
392

393
    rail, rail_map, optionals = make_simple_rail2()
394

395
396
    env = RailEnv(width=25,
                  height=30,
397
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
398
                  line_generator=sparse_line_generator(),
399
                  number_of_agents=1,
400
401
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  # Malfunction data generator
402
                  )
403
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
404
    env._max_episode_steps = 1000
405
    set_penalties_for_replay(env)
406
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
407
408
409
410
411
412
413
414
415
416
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
417
            Replay(
418
                position=(3, 2),
419
420
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
421
                malfunction=2,
u214892's avatar
u214892 committed
422
423
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
424
425
426
427
428
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
429
                position=(3, 2),
430
431
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
432
                malfunction=1,
u214892's avatar
u214892 committed
433
434
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
435
436
437
            ),
            # we haven't started moving yet --> stay here
            Replay(
438
                position=(3, 2),
439
440
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
441
                malfunction=0,
u214892's avatar
u214892 committed
442
443
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
444
            ),
445

446
            Replay(
447
                position=(3, 2),
448
449
450
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
451
452
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
453
            ),  # we start to move forward --> should go to next cell now
454
            Replay(
455
                position=(3, 3),
456
                direction=Grid4TransitionsEnum.EAST,
457
458
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
459
460
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
461
462
463
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
464
        target=env.agents[0].target,
465
        initial_position=(3, 2),
u214892's avatar
u214892 committed
466
        initial_direction=Grid4TransitionsEnum.EAST,
467
    )
Dipam Chakraborty's avatar
Dipam Chakraborty committed
468
    run_replay_config(env, [replay_config], activate_agents=False, skip_reward_check=True)
469
470


Erik Nygren's avatar
Erik Nygren committed
471
472
473
def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
474
475
    rail, rail_map, optionals = make_simple_rail2()
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
476
                  line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
477
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
478
    env.agents[0].speed_data['speed'] = 0.33
479
    env.reset(False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
480
481
482
483
484
485
486
487
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

Dipam Chakraborty's avatar
Dipam Chakraborty committed
488
        _, reward, dones, _ = env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
489
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
490
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
491
492
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
493
494
        if dones['__all__']:
            break
Erik Nygren's avatar
Erik Nygren committed
495
496
497
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

498
    rail, rail_map, optionals = make_simple_rail2()
Erik Nygren's avatar
Erik Nygren committed
499
500
    random.seed(47)
    np.random.seed(1234)
501
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
502
                  line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
503
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
504
    env.agents[0].speed_data['speed'] = 0.33
505
    env.reset(False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
506
507
508
509
510
511
512
513
514

    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
515
516
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
517

Dipam Chakraborty's avatar
Dipam Chakraborty committed
518
        _, reward, dones, _ = env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
519
520
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
521
522
        if dones['__all__']:
            break
523
524
525
526
527
528
529
530
531
532


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test

533
    rail, rail_map, optionals = make_simple_rail2()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
534
    # import pdb; pdb.set_trace()
535

536
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
537
                  line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
538
    env.reset()
539
    env.agents[0].speed_data['speed'] = 1. / 3.
Dipam Chakraborty's avatar
Dipam Chakraborty committed
540
541
542
543
544
    env.agents[0].initial_position = (6, 6)
    env.agents[0].initial_direction = 2
    env.agents[0].target = (0, 3)

    env._max_episode_steps = 1000
545

546
547
548
549
    env.reset(False, False)
    for a_idx in range(len(env.agents)):
        env.agents[a_idx].position =  env.agents[a_idx].initial_position
        env.agents[a_idx].status = RailAgentStatus.ACTIVE
550
551
552
553
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []
Dipam Chakraborty's avatar
Dipam Chakraborty committed
554
555
556
557
558

    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents

559
560
561
562
563
564
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)

565
566
        if env.agents[0].malfunction_data['malfunction'] < 1:
            agent_can_move = True
567
568
569
        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
570
        # Check if the agent is still allowed to move in this step
571

572
573
574
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_can_move = False
        post_position = env.agents[0].speed_data['position_fraction']
575
576
577
578
579
        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position