test_flatland_malfunction.py 18.2 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
13
14
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
15
16


17
class SingleAgentNavigationObs(ObservationBuilder):
18
    """
19
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
20
21
22
23
24
25
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
26
        super().__init__()
27
28

    def reset(self):
29
        pass
30

31
    def get(self, handle: int = 0) -> List[int]:
32
33
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
34
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
35
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
36
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
37
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
38
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
39
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
40
41
42
        else:
            return None

u214892's avatar
u214892 committed
43
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
44
45
46
47
48
49
50
51
52
53
54
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
55
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
56
57
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
58
59
60
61
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
62
            observation[np.argmin(min_distances)] = 1
63
64
65
66
67

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
68
    # Set fixed malfunction duration for this test
69
    stochastic_data = {'prop_malfunction': 1.,
70
                       'malfunction_rate': 1000,
71
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
72
                       'max_duration': 3}
73
74
    random.seed(0)
    np.random.seed(0)
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
    env.reset()
94

u214892's avatar
u214892 committed
95
    obs = env.reset(False, False, True)
Erik Nygren's avatar
Erik Nygren committed
96
97
98

    # Check that a initial duration for malfunction was assigned
    assert env.agents[0].malfunction_data['next_malfunction'] > 0
u214892's avatar
u214892 committed
99
100
    for agent in env.agents:
        agent.status = RailAgentStatus.ACTIVE
Erik Nygren's avatar
Erik Nygren committed
101

102
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
103
104
    total_down_time = 0
    agent_old_position = env.agents[0].position
105
106
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
107

108
109
110
111
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

        if step % 5 == 0:
Erik Nygren's avatar
Erik Nygren committed
112
            # Stop the agent and set it to be malfunctioning
113
            env.agents[0].malfunction_data['malfunction'] = -1
Erik Nygren's avatar
Erik Nygren committed
114
            env.agents[0].malfunction_data['next_malfunction'] = 0
115
116
            agent_halts += 1

117
118
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
119
120
121
122
123
124
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
125
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
126
127
128
129
130
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
131
    # Check that the appropriate number of malfunctions is achieved
132
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 11, "Actual {}".format(
u214892's avatar
u214892 committed
133
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
134

Erik Nygren's avatar
Erik Nygren committed
135
    # Check that 20 stops where performed
Erik Nygren's avatar
Erik Nygren committed
136
    assert agent_halts == 20
137

Erik Nygren's avatar
Erik Nygren committed
138
139
    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
140
141
142
143
144
145
146
147
148
149


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 2,
                       'min_duration': 3,
                       'max_duration': 3}

u214892's avatar
u214892 committed
150
    random.seed(0)
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
u214892's avatar
u214892 committed
170
    env.reset(False, False, True)
u214892's avatar
u214892 committed
171
172
    nb_malfunction = 0
    for step in range(100):
173
        action_dict: Dict[int, RailEnvActions] = {}
u214892's avatar
u214892 committed
174
175
176
177
        for agent in env.agents:
            if agent.malfunction_data['malfunction'] > 0:
                nb_malfunction += 1
            # We randomly select an action
178
            action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
u214892's avatar
u214892 committed
179
180
181
182

        env.step(action_dict)

    # check that generation of malfunctions works as expected
183
    assert nb_malfunction == 3, "nb_malfunction={}".format(nb_malfunction)
u214892's avatar
u214892 committed
184
185


186
def test_initial_malfunction():
187
188
189
190

    random.seed(0)
    np.random.seed(0)

u214892's avatar
u214892 committed
191
192
193
194
195
196
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

197
198
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
199
200
    env = RailEnv(width=25,
                  height=30,
201
202
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
u214892's avatar
u214892 committed
203
204
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
205
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
206
                  )
207
208
209
210

    # reset to initialize agents_static
    env.reset(False, False, True)

211
    set_penalties_for_replay(env)
212
213
214
    replay_config = ReplayConfig(
        replay=[
            Replay(
215
                position=(3, 2),
216
217
218
219
220
221
222
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
223
                position=(3, 2),
224
225
226
227
228
229
230
231
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
232
                position=(3, 2),
233
234
235
236
237
238
239
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
                reward=env.start_penalty + env.step_penalty * 1.0
                # malfunctioning ends: starting and running at speed 1.0
            ),
            Replay(
240
                position=(3, 3),
241
                direction=Grid4TransitionsEnum.EAST,
242
243
244
245
246
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            ),
            Replay(
247
248
                position=(3, 4),
                direction=Grid4TransitionsEnum.EAST,
249
250
251
252
253
254
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
255
        target=env.agents[0].target,
256
        initial_position=(3, 2),
u214892's avatar
u214892 committed
257
        initial_direction=Grid4TransitionsEnum.EAST,
258
    )
259
    run_replay_config(env, [replay_config])
260
261
262


def test_initial_malfunction_stop_moving():
263
264
265
    random.seed(0)
    np.random.seed(0)

266
267
268
269
270
271
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

272
    rail, rail_map = make_simple_rail2()
273
274
275

    env = RailEnv(width=25,
                  height=30,
276
277
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
278
279
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
280
                  obs_builder_object=SingleAgentNavigationObs()
281
                  )
282
283
284
285
    # reset to initialize agents_static

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

286
    set_penalties_for_replay(env)
287
288
289
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
290
                position=None,
291
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
292
                action=RailEnvActions.MOVE_FORWARD,
293
294
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
295
296
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
297
298
            ),
            Replay(
299
                position=(3, 2),
300
301
302
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
303
304
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
305
306
307
308
309
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
310
                position=(3, 2),
311
312
313
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
                malfunction=1,
u214892's avatar
u214892 committed
314
315
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
316
317
318
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
319
                position=(3, 2),
320
321
322
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
323
324
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
325
326
327
            ),
            # we start to move forward --> should go to next cell now
            Replay(
328
                position=(3, 2),
329
330
331
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
332
333
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
334
335
            ),
            Replay(
336
                position=(3, 3),
337
                direction=Grid4TransitionsEnum.EAST,
338
339
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
340
341
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
342
343
344
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
345
        target=env.agents[0].target,
346
        initial_position=(3, 2),
u214892's avatar
u214892 committed
347
        initial_direction=Grid4TransitionsEnum.EAST,
348
    )
349
350

    run_replay_config(env, [replay_config], activate_agents=False)
351
352


353
def test_initial_malfunction_do_nothing():
354
355
356
357
358
359
360
361
362
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

363
364
    rail, rail_map = make_simple_rail2()

365
366
367

    env = RailEnv(width=25,
                  height=30,
368
369
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
370
371
372
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
373
374
    # reset to initialize agents_static
    env.reset()
375
    set_penalties_for_replay(env)
376
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
377
378
379
380
381
382
383
384
385
386
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
387
            Replay(
388
                position=(3, 2),
389
390
391
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
392
393
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
394
395
396
397
398
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
399
                position=(3, 2),
400
401
402
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=1,
u214892's avatar
u214892 committed
403
404
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
405
406
407
            ),
            # we haven't started moving yet --> stay here
            Replay(
408
                position=(3, 2),
409
410
411
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
412
413
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
414
            ),
415

416
            Replay(
417
                position=(3, 2),
418
419
420
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
421
422
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
423
            ),  # we start to move forward --> should go to next cell now
424
            Replay(
425
                position=(3, 3),
426
                direction=Grid4TransitionsEnum.EAST,
427
428
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
429
430
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
431
432
433
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
434
        target=env.agents[0].target,
435
        initial_position=(3, 2),
u214892's avatar
u214892 committed
436
        initial_direction=Grid4TransitionsEnum.EAST,
437
    )
438
    run_replay_config(env, [replay_config], activate_agents=False)
439
440
441
442
443
444
445


def test_initial_nextmalfunction_not_below_zero():
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
446
447
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
448
449
450
                       'max_duration': 5  # Max duration of malfunction
                       }

451
    rail, rail_map = make_simple_rail2()
452
453
454

    env = RailEnv(width=25,
                  height=30,
455
456
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
457
458
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
459
                  obs_builder_object=SingleAgentNavigationObs()
460
                  )
461
462
    # reset to initialize agents_static
    env.reset()
463
464
465
466
467
    agent = env.agents[0]
    env.step({})
    # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
    assert agent.malfunction_data['next_malfunction'] >= 0, \
        "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction'])