test_flatland_malfunction.py 19.1 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
13
14
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
15
16


17
class SingleAgentNavigationObs(ObservationBuilder):
18
    """
19
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
20
21
22
23
24
25
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
26
        super().__init__()
27
28

    def reset(self):
29
        pass
30

31
    def get(self, handle: int = 0) -> List[int]:
32
33
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
34
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
35
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
36
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
37
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
38
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
39
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
40
41
42
        else:
            return None

u214892's avatar
u214892 committed
43
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
44
45
46
47
48
49
50
51
52
53
54
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
55
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
56
57
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
58
59
60
61
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
62
            observation[np.argmin(min_distances)] = 1
63
64
65
66
67

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
68
    # Set fixed malfunction duration for this test
69
    stochastic_data = {'prop_malfunction': 1.,
70
                       'malfunction_rate': 1000,
71
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
72
                       'max_duration': 3}
73
74
    # random.seed(0)
    # np.random.seed(0)
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
93
    obs = env.reset(False, False, True, random_seed=0)
Erik Nygren's avatar
Erik Nygren committed
94
95
96

    # Check that a initial duration for malfunction was assigned
    assert env.agents[0].malfunction_data['next_malfunction'] > 0
u214892's avatar
u214892 committed
97
98
    for agent in env.agents:
        agent.status = RailAgentStatus.ACTIVE
Erik Nygren's avatar
Erik Nygren committed
99

100
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
101
102
    total_down_time = 0
    agent_old_position = env.agents[0].position
103
104
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
105

106
107
108
109
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

        if step % 5 == 0:
Erik Nygren's avatar
Erik Nygren committed
110
            # Stop the agent and set it to be malfunctioning
111
            env.agents[0].malfunction_data['malfunction'] = -1
Erik Nygren's avatar
Erik Nygren committed
112
            env.agents[0].malfunction_data['next_malfunction'] = 0
113
114
            agent_halts += 1

115
116
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
117
118
119
120
121
122
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
123
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
124
125
126
127
128
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
129
    # Check that the appropriate number of malfunctions is achieved
Erik Nygren's avatar
Erik Nygren committed
130
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
u214892's avatar
u214892 committed
131
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
132

Erik Nygren's avatar
Erik Nygren committed
133
    # Check that 20 stops where performed
Erik Nygren's avatar
Erik Nygren committed
134
    assert agent_halts == 20
135

Erik Nygren's avatar
Erik Nygren committed
136
137
    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
138
139
140
141
142
143
144
145
146
147


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 2,
                       'min_duration': 3,
                       'max_duration': 3}

148
149
150
151
152
153
154
155
156
157
158
    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
159
    env.reset(False, False, False, random_seed=0)
Erik Nygren's avatar
Erik Nygren committed
160
    env.agents[0].target = (0, 0)
u214892's avatar
u214892 committed
161
    nb_malfunction = 0
Erik Nygren's avatar
Erik Nygren committed
162
    for step in range(20):
163
        action_dict: Dict[int, RailEnvActions] = {}
u214892's avatar
u214892 committed
164
165
        for agent in env.agents:
            # We randomly select an action
166
            action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
u214892's avatar
u214892 committed
167
168
169
170

        env.step(action_dict)

    # check that generation of malfunctions works as expected
Erik Nygren's avatar
Erik Nygren committed
171
    assert env.agents[0].malfunction_data["nr_malfunctions"] == 4
u214892's avatar
u214892 committed
172
173


174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def test_malfunction_before_entry():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 2,
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
193
    env.reset(False, False, False, random_seed=0)
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    env.agents[0].target = (0, 0)
    nb_malfunction = 0
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            if step < 10:
                action_dict[agent.handle] = RailEnvActions(0)
                assert env.agents[0].malfunction_data['malfunction'] == 0
            else:
                action_dict[agent.handle] = RailEnvActions(2)

        print(env.agents[0].malfunction_data)
        env.step(action_dict)
    assert env.agents[0].malfunction_data['malfunction'] > 0


211
def test_initial_malfunction():
212

u214892's avatar
u214892 committed
213
214
215
216
217
218
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

219
220
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
221
222
    env = RailEnv(width=25,
                  height=30,
223
224
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
u214892's avatar
u214892 committed
225
226
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
227
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
228
                  )
229
230

    # reset to initialize agents_static
231
    env.reset(False, False, True, random_seed=0)
232

233
    set_penalties_for_replay(env)
234
235
236
    replay_config = ReplayConfig(
        replay=[
            Replay(
237
                position=(3, 2),
238
239
240
241
242
243
244
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
245
                position=(3, 2),
246
247
248
249
250
251
252
253
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
254
                position=(3, 2),
255
256
257
258
259
260
261
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
                reward=env.start_penalty + env.step_penalty * 1.0
                # malfunctioning ends: starting and running at speed 1.0
            ),
            Replay(
262
                position=(3, 3),
263
                direction=Grid4TransitionsEnum.EAST,
264
265
266
267
268
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            ),
            Replay(
269
270
                position=(3, 4),
                direction=Grid4TransitionsEnum.EAST,
271
272
273
274
275
276
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
277
        target=env.agents[0].target,
278
        initial_position=(3, 2),
u214892's avatar
u214892 committed
279
        initial_direction=Grid4TransitionsEnum.EAST,
280
    )
281
    run_replay_config(env, [replay_config])
282
283
284


def test_initial_malfunction_stop_moving():
285
286
287
288
289
290
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

291
    rail, rail_map = make_simple_rail2()
292
293
294

    env = RailEnv(width=25,
                  height=30,
295
296
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
297
298
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
299
                  obs_builder_object=SingleAgentNavigationObs()
300
                  )
301
302
303
304
    # reset to initialize agents_static

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

305
    set_penalties_for_replay(env)
306
307
308
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
309
                position=None,
310
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
311
                action=RailEnvActions.MOVE_FORWARD,
312
313
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
314
315
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
316
317
            ),
            Replay(
318
                position=(3, 2),
319
320
321
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
322
323
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
324
325
326
327
328
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
329
                position=(3, 2),
330
331
332
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
                malfunction=1,
u214892's avatar
u214892 committed
333
334
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
335
336
337
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
338
                position=(3, 2),
339
340
341
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
342
343
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
344
345
346
            ),
            # we start to move forward --> should go to next cell now
            Replay(
347
                position=(3, 2),
348
349
350
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
351
352
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
353
354
            ),
            Replay(
355
                position=(3, 3),
356
                direction=Grid4TransitionsEnum.EAST,
357
358
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
359
360
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
361
362
363
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
364
        target=env.agents[0].target,
365
        initial_position=(3, 2),
u214892's avatar
u214892 committed
366
        initial_direction=Grid4TransitionsEnum.EAST,
367
    )
368
369

    run_replay_config(env, [replay_config], activate_agents=False)
370
371


372
def test_initial_malfunction_do_nothing():
373
374
375
376
377
378
379
380
381
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

382
383
    rail, rail_map = make_simple_rail2()

384
385
386

    env = RailEnv(width=25,
                  height=30,
387
388
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
389
390
391
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
392
393
    # reset to initialize agents_static
    env.reset()
394
    set_penalties_for_replay(env)
395
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
396
397
398
399
400
401
402
403
404
405
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
406
            Replay(
407
                position=(3, 2),
408
409
410
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
411
412
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
413
414
415
416
417
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
418
                position=(3, 2),
419
420
421
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=1,
u214892's avatar
u214892 committed
422
423
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
424
425
426
            ),
            # we haven't started moving yet --> stay here
            Replay(
427
                position=(3, 2),
428
429
430
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
431
432
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
433
            ),
434

435
            Replay(
436
                position=(3, 2),
437
438
439
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
440
441
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
442
            ),  # we start to move forward --> should go to next cell now
443
            Replay(
444
                position=(3, 3),
445
                direction=Grid4TransitionsEnum.EAST,
446
447
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
448
449
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
450
451
452
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
453
        target=env.agents[0].target,
454
        initial_position=(3, 2),
u214892's avatar
u214892 committed
455
        initial_direction=Grid4TransitionsEnum.EAST,
456
    )
457
    run_replay_config(env, [replay_config], activate_agents=False)
458
459
460
461
462
463
464


def test_initial_nextmalfunction_not_below_zero():
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
465
466
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
467
468
469
                       'max_duration': 5  # Max duration of malfunction
                       }

470
    rail, rail_map = make_simple_rail2()
471
472
473

    env = RailEnv(width=25,
                  height=30,
474
475
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
476
477
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
478
                  obs_builder_object=SingleAgentNavigationObs()
479
                  )
480
481
    # reset to initialize agents_static
    env.reset()
482
483
484
485
486
    agent = env.agents[0]
    env.step({})
    # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
    assert agent.malfunction_data['next_malfunction'] >= 0, \
        "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction'])