test_multi_speed.py 24.3 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
5
6
import numpy as np

from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
7
8
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
u214892's avatar
u214892 committed
9
from flatland.utils.simple_rail import make_simple_rail
10
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
11
from flatland.envs.step_utils.states import TrainState
Dipam Chakraborty's avatar
Dipam Chakraborty committed
12
from flatland.envs.step_utils.speed_counter import SpeedCounter
13

14

15
# Use the sparse_rail_generator to generate feasible network configurations with corresponding tasks
16
17
18
19
20
21
22
23
24
# Training on simple small tasks is the best way to get familiar with the environment
#


class RandomAgent:

    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
Dipam Chakraborty's avatar
Dipam Chakraborty committed
25
        self.np_random = np.random.RandomState(seed=42)
26
27
28
29
30
31

    def act(self, state):
        """
        :param state: input is the observation of the agent
        :return: returns an action
        """
Dipam Chakraborty's avatar
Dipam Chakraborty committed
32
        return self.np_random.choice([1, 2, 3])
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

    def step(self, memories):
        """
        Step function to improve agent by adjusting policy given the observations

        :param memories: SARS Tuple to be
        :return:
        """
        return

    def save(self, filename):
        # Store the current policy
        return

    def load(self, filename):
        # Load a policy
        return


def test_multi_speed_init():
53
    env = RailEnv(width=50, height=50,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
54
                  rail_generator=sparse_rail_generator(seed=2), line_generator=sparse_line_generator(),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
55
                  random_seed=3,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
56
                  number_of_agents=3)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
57
    
58
59
60
61
62
63
64
    # Initialize the agent with the parameters corresponding to the environment and observation_builder
    agent = RandomAgent(218, 4)

    # Empty dictionary for all agent action
    action_dict = dict()

    # Set all the different speeds
65
    # Reset environment and get initial observations for all agents
66
    env.reset(False, False)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
67
    env._max_episode_steps = 1000
u214892's avatar
u214892 committed
68

Dipam Chakraborty's avatar
Dipam Chakraborty committed
69
70
    for a_idx in range(len(env.agents)):
        env.agents[a_idx].position =  env.agents[a_idx].initial_position
71
        env.agents[a_idx]._set_state(TrainState.MOVING)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
72

73
74
    # Here you can also further enhance the provided observation by means of normalization
    # See training navigation example in the baseline repository
75
    old_pos = []
76
    for i_agent in range(env.get_num_agents()):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
77
        env.agents[i_agent].speed_counter = SpeedCounter(speed = 1. / (i_agent + 1))
78
        old_pos.append(env.agents[i_agent].position)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
79
        print(env.agents[i_agent].position)
80
81
    # Run episode
    for step in range(100):
82

u214892's avatar
u214892 committed
83
        # Choose an action for each agent in the environment
84
        for a in range(env.get_num_agents()):
85
            action = agent.act(0)
86
            action_dict.update({a: action})
87
88

            # Check that agent did not move in between its speed updates
89
            assert old_pos[a] == env.agents[a].position
90
91

        # Environment step which returns the observations for all agents, their corresponding
u214892's avatar
u214892 committed
92
        # reward and whether they are done
93
        _, _, _, _ = env.step(action_dict)
94

95
        # Update old position whenever an agent was allowed to move
96
97
        for i_agent in range(env.get_num_agents()):
            if (step + 1) % (i_agent + 1) == 0:
u214892's avatar
u214892 committed
98
                print(step, i_agent, env.agents[i_agent].position)
99
                old_pos[i_agent] = env.agents[i_agent].position
u214892's avatar
u214892 committed
100
101


102
def test_multispeed_actions_no_malfunction_no_blocking():
u214892's avatar
u214892 committed
103
    """Test that actions are correctly performed on cell exit for a single agent."""
104
105
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
106
                  line_generator=sparse_line_generator(), number_of_agents=1,
107
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
108
    env.reset()
u214892's avatar
u214892 committed
109

Dipam Chakraborty's avatar
Dipam Chakraborty committed
110
111
    env._max_episode_steps = 1000

112
    set_penalties_for_replay(env)
u214892's avatar
u214892 committed
113
    test_config = ReplayConfig(
u214892's avatar
u214892 committed
114
115
116
117
        replay=[
            Replay(
                position=(3, 9),  # east dead-end
                direction=Grid4TransitionsEnum.EAST,
118
119
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
u214892's avatar
u214892 committed
120
121
122
123
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
124
125
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
126
127
128
129
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
130
131
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
132
133
134
135
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
136
137
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
138
139
140
141
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
142
143
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
144
145
146
147
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
148
149
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
150
151
152
153
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
154
155
                action=RailEnvActions.MOVE_LEFT,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
156
157
158
159
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
160
161
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
162
163
164
165
            ),
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
166
167
                action=RailEnvActions.STOP_MOVING,
                reward=env.stop_penalty + env.step_penalty * 0.5  # stopping and step penalty
u214892's avatar
u214892 committed
168
            ),
u214892's avatar
u214892 committed
169
            #
u214892's avatar
u214892 committed
170
171
172
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
173
174
                action=RailEnvActions.STOP_MOVING,
                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when stopped
u214892's avatar
u214892 committed
175
176
177
178
            ),
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
179
180
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty + env.step_penalty * 0.5  # starting + running at speed 0.5
u214892's avatar
u214892 committed
181
182
183
184
            ),
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
185
186
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
187
188
189
190
            ),
            Replay(
                position=(5, 6),
                direction=Grid4TransitionsEnum.SOUTH,
191
192
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
193
194
195
            ),
        ],
        target=(3, 0),  # west dead-end
u214892's avatar
u214892 committed
196
197
198
        speed=0.5,
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
199
200
    )

Dipam Chakraborty's avatar
Dipam Chakraborty committed
201
    run_replay_config(env, [test_config], skip_reward_check=True, skip_action_required_check=True)
u214892's avatar
u214892 committed
202

u214892's avatar
u214892 committed
203

204
def test_multispeed_actions_no_malfunction_blocking():
u214892's avatar
u214892 committed
205
    """The second agent blocks the first because it is slower."""
206
207
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
208
                  line_generator=sparse_line_generator(), number_of_agents=2,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
209
210
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  random_seed=1)
211
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
212

213
    set_penalties_for_replay(env)
u214892's avatar
u214892 committed
214
    test_configs = [
u214892's avatar
u214892 committed
215
        ReplayConfig(
u214892's avatar
u214892 committed
216
217
            replay=[
                Replay(
u214892's avatar
u214892 committed
218
219
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
220
221
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.start_penalty + env.step_penalty * 1.0 / 3.0  # starting and running at speed 1/3
u214892's avatar
u214892 committed
222
223
                ),
                Replay(
u214892's avatar
u214892 committed
224
225
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
226
227
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
228
229
230
231
                ),
                Replay(
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
232
233
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
234
235
236
237
238
                ),

                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
239
240
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
241
242
                ),
                Replay(
u214892's avatar
u214892 committed
243
                    position=(3, 7),
u214892's avatar
u214892 committed
244
                    direction=Grid4TransitionsEnum.WEST,
245
246
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
247
248
249
250
                ),
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
251
252
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
253
254
255
256
257
                ),

                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
258
259
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
260
261
                ),
                Replay(
u214892's avatar
u214892 committed
262
                    position=(3, 6),
u214892's avatar
u214892 committed
263
                    direction=Grid4TransitionsEnum.WEST,
264
265
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
266
267
268
269
                ),
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
270
271
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
272
                ),
u214892's avatar
u214892 committed
273

u214892's avatar
u214892 committed
274
                Replay(
u214892's avatar
u214892 committed
275
276
                    position=(3, 5),
                    direction=Grid4TransitionsEnum.WEST,
277
278
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
279
280
281
                ),
                Replay(
                    position=(3, 5),
u214892's avatar
u214892 committed
282
                    direction=Grid4TransitionsEnum.WEST,
283
284
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
285
286
                ),
                Replay(
u214892's avatar
u214892 committed
287
288
                    position=(3, 5),
                    direction=Grid4TransitionsEnum.WEST,
289
290
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
291
292
293
                )
            ],
            target=(3, 0),  # west dead-end
u214892's avatar
u214892 committed
294
295
296
297
            speed=1 / 3,
            initial_position=(3, 8),
            initial_direction=Grid4TransitionsEnum.WEST,
        ),
u214892's avatar
u214892 committed
298
        ReplayConfig(
u214892's avatar
u214892 committed
299
300
301
302
            replay=[
                Replay(
                    position=(3, 9),  # east dead-end
                    direction=Grid4TransitionsEnum.EAST,
303
304
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
u214892's avatar
u214892 committed
305
306
                ),
                Replay(
u214892's avatar
u214892 committed
307
308
                    position=(3, 9),
                    direction=Grid4TransitionsEnum.EAST,
309
310
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
311
                ),
u214892's avatar
u214892 committed
312
                # blocked although fraction >= 1.0
u214892's avatar
u214892 committed
313
                Replay(
u214892's avatar
u214892 committed
314
315
                    position=(3, 9),
                    direction=Grid4TransitionsEnum.EAST,
316
317
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
318
319
320
321
322
                ),

                Replay(
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
323
324
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
325
326
                ),
                Replay(
u214892's avatar
u214892 committed
327
328
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
329
330
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
331
                ),
u214892's avatar
u214892 committed
332
                # blocked although fraction >= 1.0
u214892's avatar
u214892 committed
333
                Replay(
u214892's avatar
u214892 committed
334
335
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
336
337
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
338
339
340
341
342
                ),

                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
343
344
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
345
                ),
u214892's avatar
u214892 committed
346
347
348
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
349
350
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
351
352
353
354
355
                ),
                # blocked although fraction >= 1.0
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
356
357
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
358
                ),
u214892's avatar
u214892 committed
359

u214892's avatar
u214892 committed
360
361
362
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
363
364
                    action=RailEnvActions.MOVE_LEFT,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
365
366
367
368
                ),
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
369
370
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
371
372
373
374
375
                ),
                # not blocked, action required!
                Replay(
                    position=(4, 6),
                    direction=Grid4TransitionsEnum.SOUTH,
376
377
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
378
                ),
u214892's avatar
u214892 committed
379
380
            ],
            target=(3, 0),  # west dead-end
u214892's avatar
u214892 committed
381
382
383
            speed=0.5,
            initial_position=(3, 9),  # east dead-end
            initial_direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
384
        )
u214892's avatar
u214892 committed
385

u214892's avatar
u214892 committed
386
    ]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
387
    run_replay_config(env, test_configs, skip_reward_check=True)
u214892's avatar
u214892 committed
388

u214892's avatar
u214892 committed
389

390
def test_multispeed_actions_malfunction_no_blocking():
u214892's avatar
u214892 committed
391
    """Test on a single agent whether action on cell exit work correctly despite malfunction."""
392
393
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
394
                  line_generator=sparse_line_generator(), number_of_agents=1,
395
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
396
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
397
398
399
400
    
    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents
Dipam Chakraborty's avatar
Dipam Chakraborty committed
401
402

    env._max_episode_steps = 10000
Dipam Chakraborty's avatar
Dipam Chakraborty committed
403
    
404
    set_penalties_for_replay(env)
u214892's avatar
u214892 committed
405
    test_config = ReplayConfig(
u214892's avatar
u214892 committed
406
        replay=[
407
            Replay( # 0
u214892's avatar
u214892 committed
408
409
                position=(3, 9),  # east dead-end
                direction=Grid4TransitionsEnum.EAST,
410
411
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
u214892's avatar
u214892 committed
412
            ),
413
            Replay( # 1
u214892's avatar
u214892 committed
414
415
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
416
417
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
418
            ),
419
            Replay( # 2
u214892's avatar
u214892 committed
420
421
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
422
423
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
424
425
            ),
            # add additional step in the cell
426
            Replay( # 3
u214892's avatar
u214892 committed
427
428
429
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
430
431
432
                set_malfunction=2,  # recovers in two steps from now!,
                malfunction=2,
                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when malfunctioning
u214892's avatar
u214892 committed
433
434
            ),
            # agent recovers in this step
435
            Replay( # 4
u214892's avatar
u214892 committed
436
437
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
438
439
440
                action=None,
                malfunction=1,
                reward=env.step_penalty * 0.5  # recovered: running at speed 0.5
u214892's avatar
u214892 committed
441
            ),
442
            Replay( # 5
443
                position=(3, 8),
u214892's avatar
u214892 committed
444
                direction=Grid4TransitionsEnum.WEST,
445
                action=None,
446
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
447
            ),
448
            Replay( # 6
u214892's avatar
u214892 committed
449
450
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
Erik Nygren's avatar
Erik Nygren committed
451
                action=RailEnvActions.MOVE_FORWARD,
452
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
453
            ),
454
            Replay( # 7
455
                position=(3, 7),
u214892's avatar
u214892 committed
456
                direction=Grid4TransitionsEnum.WEST,
Erik Nygren's avatar
Erik Nygren committed
457
                action=None,
458
459
460
                set_malfunction=2,  # recovers in two steps from now!
                malfunction=2,
                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when malfunctioning
u214892's avatar
u214892 committed
461
462
            ),
            # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
463
            Replay( # 8
Erik Nygren's avatar
Erik Nygren committed
464
                position=(3, 7),
u214892's avatar
u214892 committed
465
                direction=Grid4TransitionsEnum.WEST,
466
                action=None,
467
468
                malfunction=1,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
469
            ),
470
            Replay( # 9
Erik Nygren's avatar
Erik Nygren committed
471
                position=(3, 7),
u214892's avatar
u214892 committed
472
                direction=Grid4TransitionsEnum.WEST,
Erik Nygren's avatar
Erik Nygren committed
473
                action=None,
474
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
475
            ),
476
            Replay( # 10
Erik Nygren's avatar
Erik Nygren committed
477
478
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
479
480
                action=RailEnvActions.STOP_MOVING,
                reward=env.stop_penalty + env.step_penalty * 0.5  # stopping and step penalty for speed 0.5
u214892's avatar
u214892 committed
481
            ),
482
            Replay( # 11
Erik Nygren's avatar
Erik Nygren committed
483
484
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
485
486
                action=RailEnvActions.STOP_MOVING,
                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 while stopped
u214892's avatar
u214892 committed
487
            ),
488
            Replay( # 12
Erik Nygren's avatar
Erik Nygren committed
489
490
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
491
492
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
u214892's avatar
u214892 committed
493
            ),
494
            Replay( # 13
Erik Nygren's avatar
Erik Nygren committed
495
496
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
497
498
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
499
            ),
u214892's avatar
u214892 committed
500
            # DO_NOTHING keeps moving!
501
            Replay( # 14
Erik Nygren's avatar
Erik Nygren committed
502
503
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
504
505
                action=RailEnvActions.DO_NOTHING,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
506
            ),
507
            Replay( # 15
Erik Nygren's avatar
Erik Nygren committed
508
509
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
510
511
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
512
            ),
513
            Replay( # 16
Erik Nygren's avatar
Erik Nygren committed
514
515
                position=(3, 4),
                direction=Grid4TransitionsEnum.WEST,
516
517
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
518
519
520
521
            ),

        ],
        target=(3, 0),  # west dead-end
u214892's avatar
u214892 committed
522
523
524
        speed=0.5,
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
525
    )
Dipam Chakraborty's avatar
Dipam Chakraborty committed
526
    run_replay_config(env, [test_config], skip_reward_check=True)
527
528
529
530
531


# TODO invalid action penalty seems only given when forward is not possible - is this the intended behaviour?
def test_multispeed_actions_no_malfunction_invalid_actions():
    """Test that actions are correctly performed on cell exit for a single agent."""
532
533
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
534
                  line_generator=sparse_line_generator(), number_of_agents=1,
535
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
536
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
537
538
539
540
    
    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents
Dipam Chakraborty's avatar
Dipam Chakraborty committed
541
542
    
    env._max_episode_steps = 10000
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608

    set_penalties_for_replay(env)
    test_config = ReplayConfig(
        replay=[
            Replay(
                position=(3, 9),  # east dead-end
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_LEFT,
                reward=env.start_penalty + env.step_penalty * 0.5  # auto-correction left to forward without penalty!
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_RIGHT,
                reward=env.step_penalty * 0.5  # wrong action is corrected to forward without penalty!
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_RIGHT,
                reward=env.step_penalty * 0.5  # wrong action is corrected to forward without penalty!
            ), Replay(
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),

        ],
        target=(3, 0),  # west dead-end
u214892's avatar
u214892 committed
609
610
611
        speed=0.5,
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
612
613
    )

Dipam Chakraborty's avatar
Dipam Chakraborty committed
614
    run_replay_config(env, [test_config], skip_reward_check=True)