test_flatland_malfunction.py 18.9 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
13
14
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
15
16


17
class SingleAgentNavigationObs(ObservationBuilder):
18
    """
19
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
20
21
22
23
24
25
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
26
        super().__init__()
27
28

    def reset(self):
29
        pass
30

31
    def get(self, handle: int = 0) -> List[int]:
32
33
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
34
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
35
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
36
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
37
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
38
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
39
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
40
41
42
        else:
            return None

u214892's avatar
u214892 committed
43
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
44
45
46
47
48
49
50
51
52
53
54
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
55
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
56
57
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
58
59
60
61
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
62
            observation[np.argmin(min_distances)] = 1
63
64
65
66
67

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
68
    # Set fixed malfunction duration for this test
69
    stochastic_data = {'prop_malfunction': 1.,
70
                       'malfunction_rate': 1000,
71
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
72
                       'max_duration': 3}
73
74
75
76
77
78
79
80
81
82
83
84

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
85
86
    obs, info = env.reset(False, False, True, random_seed=0)
    print(env.agents[0].malfunction_data)
Erik Nygren's avatar
Erik Nygren committed
87
88
    # Check that a initial duration for malfunction was assigned
    assert env.agents[0].malfunction_data['next_malfunction'] > 0
u214892's avatar
u214892 committed
89
90
    for agent in env.agents:
        agent.status = RailAgentStatus.ACTIVE
Erik Nygren's avatar
Erik Nygren committed
91

92
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
93
94
    total_down_time = 0
    agent_old_position = env.agents[0].position
95
96
97

    # Move target to unreachable position in order to not interfere with test
    env.agents[0].target = (0, 0)
98
99
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
100

101
102
103
104
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

        if step % 5 == 0:
Erik Nygren's avatar
Erik Nygren committed
105
            # Stop the agent and set it to be malfunctioning
106
            env.agents[0].malfunction_data['malfunction'] = -1
Erik Nygren's avatar
Erik Nygren committed
107
            env.agents[0].malfunction_data['next_malfunction'] = 0
108
109
            agent_halts += 1

110
111
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
112
113
114
115
116
117
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
118
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
119
120
121
122
123
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
124
    # Check that the appropriate number of malfunctions is achieved
Erik Nygren's avatar
Erik Nygren committed
125
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
u214892's avatar
u214892 committed
126
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
127

Erik Nygren's avatar
Erik Nygren committed
128
    # Check that 20 stops where performed
Erik Nygren's avatar
Erik Nygren committed
129
    assert agent_halts == 20
130

Erik Nygren's avatar
Erik Nygren committed
131
132
    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
133
134
135
136
137
138
139
140
141
142


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 2,
                       'min_duration': 3,
                       'max_duration': 3}

143
144
145
146
147
148
149
150
151
152
153
    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
154
    env.reset(True, True, False, random_seed=0)
155

Erik Nygren's avatar
Erik Nygren committed
156
    env.agents[0].target = (0, 0)
u214892's avatar
u214892 committed
157
    nb_malfunction = 0
Erik Nygren's avatar
Erik Nygren committed
158
    for step in range(20):
159
        action_dict: Dict[int, RailEnvActions] = {}
u214892's avatar
u214892 committed
160
161
        for agent in env.agents:
            # We randomly select an action
162
            action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
u214892's avatar
u214892 committed
163
164
165

        env.step(action_dict)
    # check that generation of malfunctions works as expected
Erik Nygren's avatar
Erik Nygren committed
166
    assert env.agents[0].malfunction_data["nr_malfunctions"] == 4
u214892's avatar
u214892 committed
167
168


169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def test_malfunction_before_entry():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 2,
                       'min_duration': 10,
                       'max_duration': 10}

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
188
    env.reset(False, False, False, random_seed=0)
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    env.agents[0].target = (0, 0)
    nb_malfunction = 0
    for step in range(20):
        action_dict: Dict[int, RailEnvActions] = {}
        for agent in env.agents:
            # We randomly select an action
            if step < 10:
                action_dict[agent.handle] = RailEnvActions(0)
                assert env.agents[0].malfunction_data['malfunction'] == 0
            else:
                action_dict[agent.handle] = RailEnvActions(2)

        print(env.agents[0].malfunction_data)
        env.step(action_dict)
    assert env.agents[0].malfunction_data['malfunction'] > 0


206
def test_initial_malfunction():
207

u214892's avatar
u214892 committed
208
209
210
211
212
213
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

214
215
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
216
217
    env = RailEnv(width=25,
                  height=30,
218
219
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
u214892's avatar
u214892 committed
220
221
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
222
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
223
                  )
224
225

    # reset to initialize agents_static
226
    env.reset(False, False, True, random_seed=0)
227

228
    set_penalties_for_replay(env)
229
230
231
    replay_config = ReplayConfig(
        replay=[
            Replay(
232
                position=(3, 2),
233
234
235
236
237
238
239
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
240
                position=(3, 2),
241
242
243
244
245
246
247
248
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
249
                position=(3, 2),
250
251
252
253
254
255
256
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
                reward=env.start_penalty + env.step_penalty * 1.0
                # malfunctioning ends: starting and running at speed 1.0
            ),
            Replay(
257
                position=(3, 3),
258
                direction=Grid4TransitionsEnum.EAST,
259
260
261
262
263
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            ),
            Replay(
264
265
                position=(3, 4),
                direction=Grid4TransitionsEnum.EAST,
266
267
268
269
270
271
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
272
        target=env.agents[0].target,
273
        initial_position=(3, 2),
u214892's avatar
u214892 committed
274
        initial_direction=Grid4TransitionsEnum.EAST,
275
    )
276
    run_replay_config(env, [replay_config])
277
278
279


def test_initial_malfunction_stop_moving():
280
281
282
283
284
285
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

286
    rail, rail_map = make_simple_rail2()
287
288
289

    env = RailEnv(width=25,
                  height=30,
290
291
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
292
293
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
294
                  obs_builder_object=SingleAgentNavigationObs()
295
                  )
296
297
298
299
    # reset to initialize agents_static

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

300
    set_penalties_for_replay(env)
301
302
303
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
304
                position=None,
305
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
306
                action=RailEnvActions.MOVE_FORWARD,
307
308
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
309
310
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
311
312
            ),
            Replay(
313
                position=(3, 2),
314
315
316
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
317
318
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
319
320
321
322
323
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
324
                position=(3, 2),
325
326
327
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
                malfunction=1,
u214892's avatar
u214892 committed
328
329
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
330
331
332
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
333
                position=(3, 2),
334
335
336
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
337
338
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
339
340
341
            ),
            # we start to move forward --> should go to next cell now
            Replay(
342
                position=(3, 2),
343
344
345
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
346
347
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
348
349
            ),
            Replay(
350
                position=(3, 3),
351
                direction=Grid4TransitionsEnum.EAST,
352
353
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
354
355
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
356
357
358
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
359
        target=env.agents[0].target,
360
        initial_position=(3, 2),
u214892's avatar
u214892 committed
361
        initial_direction=Grid4TransitionsEnum.EAST,
362
    )
363
364

    run_replay_config(env, [replay_config], activate_agents=False)
365
366


367
def test_initial_malfunction_do_nothing():
368
369
370
371
372
373
374
375
376
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

377
378
    rail, rail_map = make_simple_rail2()

379
380
381

    env = RailEnv(width=25,
                  height=30,
382
383
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
384
385
386
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
387
388
    # reset to initialize agents_static
    env.reset()
389
    set_penalties_for_replay(env)
390
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
391
392
393
394
395
396
397
398
399
400
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
401
            Replay(
402
                position=(3, 2),
403
404
405
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
406
407
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
408
409
410
411
412
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
413
                position=(3, 2),
414
415
416
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=1,
u214892's avatar
u214892 committed
417
418
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
419
420
421
            ),
            # we haven't started moving yet --> stay here
            Replay(
422
                position=(3, 2),
423
424
425
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
426
427
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
428
            ),
429

430
            Replay(
431
                position=(3, 2),
432
433
434
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
435
436
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
437
            ),  # we start to move forward --> should go to next cell now
438
            Replay(
439
                position=(3, 3),
440
                direction=Grid4TransitionsEnum.EAST,
441
442
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
443
444
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
445
446
447
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
448
        target=env.agents[0].target,
449
        initial_position=(3, 2),
u214892's avatar
u214892 committed
450
        initial_direction=Grid4TransitionsEnum.EAST,
451
    )
452
    run_replay_config(env, [replay_config], activate_agents=False)
453
454
455
456
457
458
459


def test_initial_nextmalfunction_not_below_zero():
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
460
461
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
462
463
464
                       'max_duration': 5  # Max duration of malfunction
                       }

465
    rail, rail_map = make_simple_rail2()
466
467
468

    env = RailEnv(width=25,
                  height=30,
469
470
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
471
472
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
473
                  obs_builder_object=SingleAgentNavigationObs()
474
                  )
475
476
    # reset to initialize agents_static
    env.reset()
477
478
479
480
481
    agent = env.agents[0]
    env.step({})
    # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
    assert agent.malfunction_data['next_malfunction'] >= 0, \
        "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction'])