test_flatland_malfunction.py 17.8 KB
Newer Older
u214892's avatar
u214892 committed
1
import random
2
from typing import Dict, List
u214892's avatar
u214892 committed
3

4
import numpy as np
5
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
6

7
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4 import Grid4TransitionsEnum
9
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
10
from flatland.envs.agent_utils import RailAgentStatus
u214892's avatar
u214892 committed
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
12
13
14
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
15
16


17
class SingleAgentNavigationObs(ObservationBuilder):
18
    """
19
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
20
21
22
23
24
25
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
26
        super().__init__()
27
28

    def reset(self):
29
        pass
30

31
    def get(self, handle: int = 0) -> List[int]:
32
33
        agent = self.env.agents[handle]

u214892's avatar
u214892 committed
34
        if agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
35
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
36
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
37
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
38
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
39
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
40
41
42
        else:
            return None

u214892's avatar
u214892 committed
43
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
44
45
46
47
48
49
50
51
52
53
54
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
u214892's avatar
u214892 committed
55
                    new_position = get_new_position(agent_virtual_position, direction)
u214892's avatar
u214892 committed
56
57
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
58
59
60
61
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
62
            observation[np.argmin(min_distances)] = 1
63
64
65
66
67

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
68
    # Set fixed malfunction duration for this test
69
    stochastic_data = {'prop_malfunction': 1.,
70
                       'malfunction_rate': 1000,
71
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
72
                       'max_duration': 3}
73
74
    random.seed(0)
    np.random.seed(0)
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
    env.reset()
94

u214892's avatar
u214892 committed
95
    obs = env.reset(False, False, True)
Erik Nygren's avatar
Erik Nygren committed
96
97
98

    # Check that a initial duration for malfunction was assigned
    assert env.agents[0].malfunction_data['next_malfunction'] > 0
u214892's avatar
u214892 committed
99
100
    for agent in env.agents:
        agent.status = RailAgentStatus.ACTIVE
Erik Nygren's avatar
Erik Nygren committed
101

102
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
103
104
    total_down_time = 0
    agent_old_position = env.agents[0].position
105
106
    for step in range(100):
        actions = {}
u214892's avatar
u214892 committed
107

108
109
110
111
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

        if step % 5 == 0:
Erik Nygren's avatar
Erik Nygren committed
112
            # Stop the agent and set it to be malfunctioning
113
            env.agents[0].malfunction_data['malfunction'] = -1
Erik Nygren's avatar
Erik Nygren committed
114
            env.agents[0].malfunction_data['next_malfunction'] = 0
115
116
            agent_halts += 1

117
118
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
119
120
121
122
123
124
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
125
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
126
127
128
129
130
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
131
    # Check that the appropriate number of malfunctions is achieved
Erik Nygren's avatar
Erik Nygren committed
132
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
u214892's avatar
u214892 committed
133
        env.agents[0].malfunction_data['nr_malfunctions'])
Erik Nygren's avatar
Erik Nygren committed
134

Erik Nygren's avatar
Erik Nygren committed
135
    # Check that 20 stops where performed
Erik Nygren's avatar
Erik Nygren committed
136
    assert agent_halts == 20
137

Erik Nygren's avatar
Erik Nygren committed
138
139
    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
140
141
142
143
144
145
146
147
148
149


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 2,
                       'min_duration': 3,
                       'max_duration': 3}

u214892's avatar
u214892 committed
150
    random.seed(0)
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    np.random.seed(0)


    rail, rail_map = make_simple_rail2()

    env = RailEnv(width=25,
                  height=30,
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  obs_builder_object=SingleAgentNavigationObs()
                  )
    # reset to initialize agents_static
Erik Nygren's avatar
Erik Nygren committed
165
166
    env.reset(False, False, False)
    env.agents[0].target = (0, 0)
u214892's avatar
u214892 committed
167
    nb_malfunction = 0
Erik Nygren's avatar
Erik Nygren committed
168
    for step in range(20):
169
        action_dict: Dict[int, RailEnvActions] = {}
u214892's avatar
u214892 committed
170
171
        for agent in env.agents:
            # We randomly select an action
172
            action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
u214892's avatar
u214892 committed
173
174
175
176

        env.step(action_dict)

    # check that generation of malfunctions works as expected
Erik Nygren's avatar
Erik Nygren committed
177
    assert env.agents[0].malfunction_data["nr_malfunctions"] == 4
u214892's avatar
u214892 committed
178
179


180
def test_initial_malfunction():
181
182
183
184

    random.seed(0)
    np.random.seed(0)

u214892's avatar
u214892 committed
185
186
187
188
189
190
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

191
192
    rail, rail_map = make_simple_rail2()

u214892's avatar
u214892 committed
193
194
    env = RailEnv(width=25,
                  height=30,
195
196
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
u214892's avatar
u214892 committed
197
198
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
199
                  obs_builder_object=SingleAgentNavigationObs()
u214892's avatar
u214892 committed
200
                  )
201
202
203
204

    # reset to initialize agents_static
    env.reset(False, False, True)

205
    set_penalties_for_replay(env)
206
207
208
    replay_config = ReplayConfig(
        replay=[
            Replay(
209
                position=(3, 2),
210
211
212
213
214
215
216
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            Replay(
217
                position=(3, 2),
218
219
220
221
222
223
224
225
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=2,
                reward=env.step_penalty  # full step penalty when malfunctioning
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
            Replay(
226
                position=(3, 2),
227
228
229
230
231
232
233
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=1,
                reward=env.start_penalty + env.step_penalty * 1.0
                # malfunctioning ends: starting and running at speed 1.0
            ),
            Replay(
234
                position=(3, 3),
235
                direction=Grid4TransitionsEnum.EAST,
236
237
238
239
240
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            ),
            Replay(
241
242
                position=(3, 4),
                direction=Grid4TransitionsEnum.EAST,
243
244
245
246
247
248
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
                reward=env.step_penalty * 1.0  # running at speed 1.0
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
249
        target=env.agents[0].target,
250
        initial_position=(3, 2),
u214892's avatar
u214892 committed
251
        initial_direction=Grid4TransitionsEnum.EAST,
252
    )
253
    run_replay_config(env, [replay_config])
254
255
256


def test_initial_malfunction_stop_moving():
257
258
259
    random.seed(0)
    np.random.seed(0)

260
261
262
263
264
265
    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

266
    rail, rail_map = make_simple_rail2()
267
268
269

    env = RailEnv(width=25,
                  height=30,
270
271
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
272
273
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
274
                  obs_builder_object=SingleAgentNavigationObs()
275
                  )
276
277
278
279
    # reset to initialize agents_static

    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)

280
    set_penalties_for_replay(env)
281
282
283
    replay_config = ReplayConfig(
        replay=[
            Replay(
u214892's avatar
u214892 committed
284
                position=None,
285
                direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
286
                action=RailEnvActions.MOVE_FORWARD,
287
288
                set_malfunction=3,
                malfunction=3,
u214892's avatar
u214892 committed
289
290
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.READY_TO_DEPART
291
292
            ),
            Replay(
293
                position=(3, 2),
294
295
296
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
297
298
                reward=env.step_penalty,  # full step penalty when stopped
                status=RailAgentStatus.ACTIVE
299
300
301
302
303
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action STOP_MOVING, agent should restart without moving
            #
            Replay(
304
                position=(3, 2),
305
306
307
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.STOP_MOVING,
                malfunction=1,
u214892's avatar
u214892 committed
308
309
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
310
311
312
            ),
            # we have stopped and do nothing --> should stand still
            Replay(
313
                position=(3, 2),
314
315
316
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
317
318
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
319
320
321
            ),
            # we start to move forward --> should go to next cell now
            Replay(
322
                position=(3, 2),
323
324
325
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
326
327
                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
328
329
            ),
            Replay(
330
                position=(3, 3),
331
                direction=Grid4TransitionsEnum.EAST,
332
333
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
334
335
                reward=env.step_penalty * 1.0,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
336
337
338
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
339
        target=env.agents[0].target,
340
        initial_position=(3, 2),
u214892's avatar
u214892 committed
341
        initial_direction=Grid4TransitionsEnum.EAST,
342
    )
343
344

    run_replay_config(env, [replay_config], activate_agents=False)
345
346


347
def test_initial_malfunction_do_nothing():
348
349
350
351
352
353
354
355
356
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
                       'max_duration': 5  # Max duration of malfunction
                       }

357
358
    rail, rail_map = make_simple_rail2()

359
360
361

    env = RailEnv(width=25,
                  height=30,
362
363
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
364
365
366
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
                  )
367
368
    # reset to initialize agents_static
    env.reset()
369
    set_penalties_for_replay(env)
370
    replay_config = ReplayConfig(
u214892's avatar
u214892 committed
371
372
373
374
375
376
377
378
379
380
        replay=[
            Replay(
                position=None,
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                set_malfunction=3,
                malfunction=3,
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.READY_TO_DEPART
            ),
381
            Replay(
382
                position=(3, 2),
383
384
385
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=2,
u214892's avatar
u214892 committed
386
387
                reward=env.step_penalty,  # full step penalty while malfunctioning
                status=RailAgentStatus.ACTIVE
388
389
390
391
392
            ),
            # malfunction stops in the next step and we're still at the beginning of the cell
            # --> if we take action DO_NOTHING, agent should restart without moving
            #
            Replay(
393
                position=(3, 2),
394
395
396
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=1,
u214892's avatar
u214892 committed
397
398
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
399
400
401
            ),
            # we haven't started moving yet --> stay here
            Replay(
402
                position=(3, 2),
403
404
405
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.DO_NOTHING,
                malfunction=0,
u214892's avatar
u214892 committed
406
407
                reward=env.step_penalty,  # full step penalty while stopped
                status=RailAgentStatus.ACTIVE
408
            ),
409

410
            Replay(
411
                position=(3, 2),
412
413
414
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
415
416
                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
417
            ),  # we start to move forward --> should go to next cell now
418
            Replay(
419
                position=(3, 3),
420
                direction=Grid4TransitionsEnum.EAST,
421
422
                action=RailEnvActions.MOVE_FORWARD,
                malfunction=0,
u214892's avatar
u214892 committed
423
424
                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
                status=RailAgentStatus.ACTIVE
425
426
427
            )
        ],
        speed=env.agents[0].speed_data['speed'],
u214892's avatar
u214892 committed
428
        target=env.agents[0].target,
429
        initial_position=(3, 2),
u214892's avatar
u214892 committed
430
        initial_direction=Grid4TransitionsEnum.EAST,
431
    )
432
    run_replay_config(env, [replay_config], activate_agents=False)
433
434
435
436
437
438
439


def test_initial_nextmalfunction_not_below_zero():
    random.seed(0)
    np.random.seed(0)

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
440
441
                       'malfunction_rate': 70,  # Rate of malfunction occurence
                       'min_duration': 2,  # Minimal duration of malfunction
442
443
444
                       'max_duration': 5  # Max duration of malfunction
                       }

445
    rail, rail_map = make_simple_rail2()
446
447
448

    env = RailEnv(width=25,
                  height=30,
449
450
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
451
452
                  number_of_agents=1,
                  stochastic_data=stochastic_data,  # Malfunction data generator
453
                  obs_builder_object=SingleAgentNavigationObs()
454
                  )
455
456
    # reset to initialize agents_static
    env.reset()
457
458
459
460
461
    agent = env.agents[0]
    env.step({})
    # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
    assert agent.malfunction_data['next_malfunction'] >= 0, \
        "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction'])