test_flatland_malfunction.py 23.6 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
5
import numpy as np

6
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
7
from flatland.core.grid.grid4 import Grid4TransitionsEnum
8
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
9
from flatland.envs.agent_utils import RailAgentStatus
Erik Nygren's avatar
Erik Nygren committed
10
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
from flatland.envs.rail_generators import rail_from_grid_transition_map
13
from flatland.envs.line_generators import sparse_line_generator
14
from flatland.utils.simple_rail import make_simple_rail2
15
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
16
17


18
class SingleAgentNavigationObs(ObservationBuilder):
19
    """
20
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
21
22
23
24
25
26
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
27
        super().__init__()
28
29

    def reset(self):
30
        pass
31

32
    def get(self, handle: int = 0) -> List[int]:
33
34
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
35
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
36
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
37
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
38
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
39
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
40
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
41
42
43
        else:
            return None

u214892's avatar
u214892 committed
44
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
45
46
47
48
49
50
51
52
53
54
55
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
56
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
57
58
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
59
60
61
62
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
63
            observation[np.argmin(min_distances)] = 1
64
65
66
67
68

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
69
    # Set fixed malfunction duration for this test
Erik Nygren's avatar
Erik Nygren committed
70
71
72
73
    stochastic_data = MalfunctionParameters(malfunction_rate=1,  # Rate of malfunction occurence
                                            min_duration=3,  # Minimal duration of malfunction
                                            max_duration=3  # Max duration of malfunction
                                            )
74

75
    rail, rail_map, optionals = make_simple_rail2()
76

Erik Nygren's avatar
Erik Nygren committed
77
78
    env = RailEnv(width=25,
                  height=30,
79
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
80
                  line_generator=sparse_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
81
82
83
84
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
85
    obs, info = env.reset(False, False, True, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
86

87
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
88
89
    total_down_time = 0
    agent_old_position = env.agents[0].position
90
91
92

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
93
94
95
    
    # Add in max episode steps because scheudule generator sets it to 0 for dummy data
    env._max_episode_steps = 200
96
97
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
98

99
100
101
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

102
        obs, all_rewards, done, _ = env.step(actions)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
103
104
        if done["__all__"]:
            break
105

Erik Nygren's avatar
Erik Nygren committed
106
107
108
109
110
111
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
112
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
113
114
115
116
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']
Erik Nygren's avatar
Erik Nygren committed
117
    # Check that the appropriate number of malfunctions is achieved
Dipam Chakraborty's avatar
Dipam Chakraborty committed
118
119
    # Dipam: The number of malfunctions varies by seed
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
u214892's avatar
u214892 committed
120
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
121
122
123

    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
124
125
126


def test_malfunction_process_statistically():
127
    """Tests that malfunctions are produced by stochastic_data!"""
u214892's avatar
u214892 committed
128
    # Set fixed malfunction duration for this test
129
    stochastic_data = MalfunctionParameters(malfunction_rate=1/5,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
130
131
132
                                            min_duration=5,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
u214892's avatar
u214892 committed
133

134
    rail, rail_map, optionals = make_simple_rail2()
135

Erik Nygren's avatar
Erik Nygren committed
136
137
    env = RailEnv(width=25,
                  height=30,
138
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
139
                  line_generator=sparse_line_generator(),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
140
                  number_of_agents=2,
Erik Nygren's avatar
Erik Nygren committed
141
142
143
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
144

Erik Nygren's avatar
Erik Nygren committed
145
    env.reset(True, True, False, random_seed=10)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
146
    env._max_episode_steps = 1000
147

Erik Nygren's avatar
Erik Nygren committed
148
    env.agents[0].target = (0, 0)
149
    # Next line only for test generation
Dipam Chakraborty's avatar
Dipam Chakraborty committed
150
151
152
153
    # agent_malfunction_list = [[] for i in range(2)]
    agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], 
                              [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 5]]
    
Erik Nygren's avatar
Erik Nygren committed
154
    for step in range(20):
155
        action_dict: Dict[int, RailEnvActions] = {}
156
        for agent_idx in range(env.get_num_agents()):
u214892's avatar
u214892 committed
157
            # We randomly select an action
158
159
            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
            # For generating tests only:
Erik Nygren's avatar
Erik Nygren committed
160
            # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
161
            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
u214892's avatar
u214892 committed
162
        env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
163
    # print(agent_malfunction_list)
164

u214892's avatar
u214892 committed
165

166
def test_malfunction_before_entry():
Erik Nygren's avatar
Erik Nygren committed
167
    """Tests that malfunctions are working properly for agents before entering the environment!"""
168
    # Set fixed malfunction duration for this test
169
    stochastic_data = MalfunctionParameters(malfunction_rate=1/2,  # Rate of malfunction occurrence
Erik Nygren's avatar
Erik Nygren committed
170
171
172
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
173

174
    rail, rail_map, optionals = make_simple_rail2()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
175
    
Erik Nygren's avatar
Erik Nygren committed
176
177
    env = RailEnv(width=25,
                  height=30,
178
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
179
                  line_generator=sparse_line_generator(),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
180
                  number_of_agents=2,
Erik Nygren's avatar
Erik Nygren committed
181
182
183
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
184
    env.reset(False, False, False, random_seed=10)
185
    env.agents[0].target = (0, 0)
186

187
188
189
    # Test initial malfunction values for all agents
    # we want some agents to be malfuncitoning already and some to be working
    # we want different next_malfunction values for the agents
190
    assert env.agents[0].malfunction_data['malfunction'] == 0
191
    assert env.agents[1].malfunction_data['malfunction'] == 10
192
193


194
195
def test_malfunction_values_and_behavior():
    """
196
    Test the malfunction counts down as desired
197
198
199
200
201
202
    Returns
    -------

    """
    # Set fixed malfunction duration for this test

203
    rail, rail_map, optionals = make_simple_rail2()
204
    action_dict: Dict[int, RailEnvActions] = {}
205
    stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
206
207
208
                                            min_duration=10,  # Minimal duration of malfunction
                                            max_duration=10  # Max duration of malfunction
                                            )
Erik Nygren's avatar
Erik Nygren committed
209
210
    env = RailEnv(width=25,
                  height=30,
211
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
212
                  line_generator=sparse_line_generator(),
Erik Nygren's avatar
Erik Nygren committed
213
214
215
216
                  number_of_agents=1,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  obs_builder_object=SingleAgentNavigationObs()
                  )
Erik Nygren's avatar
Erik Nygren committed
217

218
219
    env.reset(False, False, activate_agents=True, random_seed=10)

Erik Nygren's avatar
Erik Nygren committed
220
    # Assertions
221
    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
Erik Nygren's avatar
Erik Nygren committed
222
223
    print("[")
    for time_step in range(15):
224
        # Move in the env
Dipam Chakraborty's avatar
Dipam Chakraborty committed
225
        _, _, dones,_ = env.step(action_dict)
226
        # Check that next_step decreases as expected
Erik Nygren's avatar
Erik Nygren committed
227
        assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
228
229
        if dones['__all__']:
            break
230

231

232
def test_initial_malfunction():
233
    stochastic_data = MalfunctionParameters(malfunction_rate=1/1000,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
234
235
236
                                            min_duration=2,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
u214892's avatar
u214892 committed
237

238
    rail, rail_map, optionals = make_simple_rail2()
239

u214892's avatar
u214892 committed
240
241
    env = RailEnv(width=25,
                  height=30,
242
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
243
                  line_generator=sparse_line_generator(seed=10),
u214892's avatar
u214892 committed
244
                  number_of_agents=1,
245
246
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  # Malfunction data generator
247
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
248
                  )
249
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
250
    env.reset(False, False, True, random_seed=10)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
251
    env._max_episode_steps = 1000
252
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
253
    env.agents[0].target = (0, 5)
254
    set_penalties_for_replay(env)
255
256
257
    replay_config = ReplayConfig(
        replay=[
            Replay(
258
                position=(3, 2),
259
260
261
262
263
264
265
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
266
                position=(3, 2),
267
268
269
270
271
272
273
274
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
275
                position=(3, 2),
276
277
278
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
279
                reward=env.step_penalty
280

281
            ),  # malfunctioning ends: starting and running at speed 1.0
282
            Replay(
283
                position=(3, 2),
284
                direction=Grid4TransitionsEnum.EAST,
285
286
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
287
                reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
288
289
            ),
            Replay(
290
                position=(3, 3),
291
                direction=Grid4TransitionsEnum.EAST,
292
293
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
294
                reward=env.step_penalty  # running at speed 1.0
295
296
297
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
298
        target=env.agents[0].target,
299
        initial_position=(3, 2),
u214892's avatar
u214892 committed
300
        initial_direction=Grid4TransitionsEnum.EAST,
301
    )
Dipam Chakraborty's avatar
Dipam Chakraborty committed
302
    run_replay_config(env, [replay_config], skip_reward_check=True)
303
304
305


def test_initial_malfunction_stop_moving():
306
    rail, rail_map, optionals = make_simple_rail2()
307

308
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
309
                  line_generator=sparse_line_generator(), number_of_agents=1,
Erik Nygren's avatar
Erik Nygren committed
310
                  obs_builder_object=SingleAgentNavigationObs())
311
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
312
313
    
    env._max_episode_steps = 1000
314
315
316

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

317
    set_penalties_for_replay(env)
318
319
320
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
321
                position=None,
322
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
323
                action=RailEnvActions.MOVE_FORWARD,
324
325
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
326
327
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
328
329
            ),
            Replay(
330
                position=(3, 2),
331
332
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
333
                malfunction=2,
u214892's avatar
u214892 committed
334
335
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
336
337
338
339
340
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
341
                position=(3, 2),
342
343
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
344
                malfunction=1,
u214892's avatar
u214892 committed
345
346
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
347
348
349
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
350
                position=(3, 2),
351
352
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
353
                malfunction=0,
u214892's avatar
u214892 committed
354
355
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
356
357
358
            ),
            # we start to move forward --> should go to next cell now
            Replay(
359
                position=(3, 2),
360
361
362
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
363
364
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
365
366
            ),
            Replay(
367
                position=(3, 3),
368
                direction=Grid4TransitionsEnum.EAST,
369
370
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
371
372
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
373
374
375
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
376
        target=env.agents[0].target,
377
        initial_position=(3, 2),
u214892's avatar
u214892 committed
378
        initial_direction=Grid4TransitionsEnum.EAST,
379
    )
380

Dipam Chakraborty's avatar
Dipam Chakraborty committed
381
    run_replay_config(env, [replay_config], activate_agents=False, skip_reward_check=True)
382
383


384
def test_initial_malfunction_do_nothing():
385
    stochastic_data = MalfunctionParameters(malfunction_rate=1/70,  # Rate of malfunction occurence
Erik Nygren's avatar
Erik Nygren committed
386
387
388
                                            min_duration=2,  # Minimal duration of malfunction
                                            max_duration=5  # Max duration of malfunction
                                            )
389

390
    rail, rail_map, optionals = make_simple_rail2()
391

392
393
    env = RailEnv(width=25,
                  height=30,
394
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
395
                  line_generator=sparse_line_generator(),
396
                  number_of_agents=1,
397
398
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  # Malfunction data generator
399
                  )
400
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
401
    env._max_episode_steps = 1000
402
    set_penalties_for_replay(env)
403
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
404
405
406
407
408
409
410
411
412
413
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
414
            Replay(
415
                position=(3, 2),
416
417
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
418
                malfunction=2,
u214892's avatar
u214892 committed
419
420
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
421
422
423
424
425
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
426
                position=(3, 2),
427
428
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
429
                malfunction=1,
u214892's avatar
u214892 committed
430
431
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
432
433
434
            ),
            # we haven't started moving yet --> stay here
            Replay(
435
                position=(3, 2),
436
437
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
438
                malfunction=0,
u214892's avatar
u214892 committed
439
440
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
441
            ),
442

443
            Replay(
444
                position=(3, 2),
445
446
447
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
448
449
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
450
            ),  # we start to move forward --> should go to next cell now
451
            Replay(
452
                position=(3, 3),
453
                direction=Grid4TransitionsEnum.EAST,
454
455
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
456
457
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
458
459
460
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
461
        target=env.agents[0].target,
462
        initial_position=(3, 2),
u214892's avatar
u214892 committed
463
        initial_direction=Grid4TransitionsEnum.EAST,
464
    )
Dipam Chakraborty's avatar
Dipam Chakraborty committed
465
    run_replay_config(env, [replay_config], activate_agents=False, skip_reward_check=True)
466
467


Erik Nygren's avatar
Erik Nygren committed
468
469
470
def tests_random_interference_from_outside():
    """Tests that malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
471
472
    rail, rail_map, optionals = make_simple_rail2()
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
473
                  line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
474
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
475
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
476
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
477
478
479
480
481
482
483
484
    env_data = []

    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

Dipam Chakraborty's avatar
Dipam Chakraborty committed
485
        _, reward, dones, _ = env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
486
        # Append the rewards of the first trial
Erik Nygren's avatar
Erik Nygren committed
487
        env_data.append((reward[0], env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
488
489
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
490
491
        if dones['__all__']:
            break
Erik Nygren's avatar
Erik Nygren committed
492
493
494
    # Run the same test as above but with an external random generator running
    # Check that the reward stays the same

495
    rail, rail_map, optionals = make_simple_rail2()
Erik Nygren's avatar
Erik Nygren committed
496
497
    random.seed(47)
    np.random.seed(1234)
498
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
499
                  line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
500
    env.reset()
Erik Nygren's avatar
Erik Nygren committed
501
    env.agents[0].speed_data['speed'] = 0.33
Erik Nygren's avatar
Erik Nygren committed
502
    env.reset(False, False, False, random_seed=10)
Erik Nygren's avatar
Erik Nygren committed
503
504
505
506
507
508
509
510
511

    dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
    for step in range(200):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            action_dict[agent.handle] = RailEnvActions(2)

            # Do dummy random number generations
Erik Nygren's avatar
Erik Nygren committed
512
513
            random.shuffle(dummy_list)
            np.random.rand()
Erik Nygren's avatar
Erik Nygren committed
514

Dipam Chakraborty's avatar
Dipam Chakraborty committed
515
        _, reward, dones, _ = env.step(action_dict)
Erik Nygren's avatar
Erik Nygren committed
516
517
        assert reward[0] == env_data[step][0]
        assert env.agents[0].position == env_data[step][1]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
518
519
        if dones['__all__']:
            break
520
521
522
523
524
525
526
527
528
529


def test_last_malfunction_step():
    """
    Test to check that agent moves when it is not malfunctioning

    """

    # Set fixed malfunction duration for this test

530
    rail, rail_map, optionals = make_simple_rail2()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
531
    # import pdb; pdb.set_trace()
532

533
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
534
                  line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
535
    env.reset()
536
    env.agents[0].speed_data['speed'] = 1. / 3.
Dipam Chakraborty's avatar
Dipam Chakraborty committed
537
538
539
540
541
    env.agents[0].initial_position = (6, 6)
    env.agents[0].initial_direction = 2
    env.agents[0].target = (0, 3)

    env._max_episode_steps = 1000
542
543
544
545
546
547

    env.reset(False, False, True)
    # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
    env.agents[0].malfunction_data['next_malfunction'] = 2
    env.agents[0].malfunction_data['malfunction'] = 0
    env_data = []
Dipam Chakraborty's avatar
Dipam Chakraborty committed
548
549
550
551
552

    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents

553
554
555
556
557
558
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # Go forward all the time
            action_dict[agent.handle] = RailEnvActions(2)

559
560
        if env.agents[0].malfunction_data['malfunction'] < 1:
            agent_can_move = True
561
562
563
        # Store the position before and after the step
        pre_position = env.agents[0].speed_data['position_fraction']
        _, reward, _, _ = env.step(action_dict)
564
        # Check if the agent is still allowed to move in this step
565

566
567
568
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_can_move = False
        post_position = env.agents[0].speed_data['position_fraction']
569
570
571
572
573
        # Assert that the agent moved while it was still allowed
        if agent_can_move:
            assert pre_position != post_position
        else:
            assert post_position == pre_position