test_multi_speed.py 24.1 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
5
6
import numpy as np

from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
7
8
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
u214892's avatar
u214892 committed
9
from flatland.utils.simple_rail import make_simple_rail
10
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
11
from flatland.envs.step_utils.states import TrainState
Dipam Chakraborty's avatar
Dipam Chakraborty committed
12
from flatland.envs.step_utils.speed_counter import SpeedCounter
13

14

15
# Use the sparse_rail_generator to generate feasible network configurations with corresponding tasks
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Training on simple small tasks is the best way to get familiar with the environment
#


class RandomAgent:

    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size

    def act(self, state):
        """
        :param state: input is the observation of the agent
        :return: returns an action
        """
31
        return np.random.choice([1, 2, 3])
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    def step(self, memories):
        """
        Step function to improve agent by adjusting policy given the observations

        :param memories: SARS Tuple to be
        :return:
        """
        return

    def save(self, filename):
        # Store the current policy
        return

    def load(self, filename):
        # Load a policy
        return


def test_multi_speed_init():
52
    env = RailEnv(width=50, height=50,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
53
                  rail_generator=sparse_rail_generator(seed=2), line_generator=sparse_line_generator(),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
54
                  random_seed=3,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
55
                  number_of_agents=3)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
56
    
57
58
59
60
61
62
63
    # Initialize the agent with the parameters corresponding to the environment and observation_builder
    agent = RandomAgent(218, 4)

    # Empty dictionary for all agent action
    action_dict = dict()

    # Set all the different speeds
64
    # Reset environment and get initial observations for all agents
65
    env.reset(False, False)
u214892's avatar
u214892 committed
66

Dipam Chakraborty's avatar
Dipam Chakraborty committed
67
68
    for a_idx in range(len(env.agents)):
        env.agents[a_idx].position =  env.agents[a_idx].initial_position
69
        env.agents[a_idx]._set_state(TrainState.MOVING)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
70

71
72
    # Here you can also further enhance the provided observation by means of normalization
    # See training navigation example in the baseline repository
73
    old_pos = []
74
    for i_agent in range(env.get_num_agents()):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
75
        env.agents[i_agent].speed_counter = SpeedCounter(speed = 1. / (i_agent + 1))
76
        old_pos.append(env.agents[i_agent].position)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
77
        print(env.agents[i_agent].position)
78
79
    # Run episode
    for step in range(100):
80

u214892's avatar
u214892 committed
81
        # Choose an action for each agent in the environment
82
        for a in range(env.get_num_agents()):
83
            action = agent.act(0)
84
            action_dict.update({a: action})
85
86

            # Check that agent did not move in between its speed updates
87
            assert old_pos[a] == env.agents[a].position
88
89

        # Environment step which returns the observations for all agents, their corresponding
u214892's avatar
u214892 committed
90
        # reward and whether they are done
91
        _, _, _, _ = env.step(action_dict)
92

93
        # Update old position whenever an agent was allowed to move
94
95
        for i_agent in range(env.get_num_agents()):
            if (step + 1) % (i_agent + 1) == 0:
u214892's avatar
u214892 committed
96
                print(step, i_agent, env.agents[i_agent].position)
97
                old_pos[i_agent] = env.agents[i_agent].position
u214892's avatar
u214892 committed
98
99


100
def test_multispeed_actions_no_malfunction_no_blocking():
u214892's avatar
u214892 committed
101
    """Test that actions are correctly performed on cell exit for a single agent."""
102
103
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
104
                  line_generator=sparse_line_generator(), number_of_agents=1,
105
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
106
    env.reset()
u214892's avatar
u214892 committed
107

Dipam Chakraborty's avatar
Dipam Chakraborty committed
108
109
    env._max_episode_steps = 1000

110
    set_penalties_for_replay(env)
u214892's avatar
u214892 committed
111
    test_config = ReplayConfig(
u214892's avatar
u214892 committed
112
113
114
115
        replay=[
            Replay(
                position=(3, 9),  # east dead-end
                direction=Grid4TransitionsEnum.EAST,
116
117
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
u214892's avatar
u214892 committed
118
119
120
121
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
122
123
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
124
125
126
127
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
128
129
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
130
131
132
133
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
134
135
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
136
137
138
139
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
140
141
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
142
143
144
145
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
146
147
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
148
149
150
151
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
152
153
                action=RailEnvActions.MOVE_LEFT,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
154
155
156
157
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
158
159
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
160
161
162
163
            ),
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
164
165
                action=RailEnvActions.STOP_MOVING,
                reward=env.stop_penalty + env.step_penalty * 0.5  # stopping and step penalty
u214892's avatar
u214892 committed
166
            ),
u214892's avatar
u214892 committed
167
            #
u214892's avatar
u214892 committed
168
169
170
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
171
172
                action=RailEnvActions.STOP_MOVING,
                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when stopped
u214892's avatar
u214892 committed
173
174
175
176
            ),
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
177
178
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty + env.step_penalty * 0.5  # starting + running at speed 0.5
u214892's avatar
u214892 committed
179
180
181
182
            ),
            Replay(
                position=(4, 6),
                direction=Grid4TransitionsEnum.SOUTH,
183
184
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
185
186
187
188
            ),
            Replay(
                position=(5, 6),
                direction=Grid4TransitionsEnum.SOUTH,
189
190
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
191
192
193
            ),
        ],
        target=(3, 0),  # west dead-end
u214892's avatar
u214892 committed
194
195
196
        speed=0.5,
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
197
198
    )

Dipam Chakraborty's avatar
Dipam Chakraborty committed
199
    run_replay_config(env, [test_config], skip_reward_check=True, skip_action_required_check=True)
u214892's avatar
u214892 committed
200

u214892's avatar
u214892 committed
201

202
def test_multispeed_actions_no_malfunction_blocking():
u214892's avatar
u214892 committed
203
    """The second agent blocks the first because it is slower."""
204
205
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
206
                  line_generator=sparse_line_generator(), number_of_agents=2,
207
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
208
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
209

210
    set_penalties_for_replay(env)
u214892's avatar
u214892 committed
211
    test_configs = [
u214892's avatar
u214892 committed
212
        ReplayConfig(
u214892's avatar
u214892 committed
213
214
            replay=[
                Replay(
u214892's avatar
u214892 committed
215
216
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
217
218
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.start_penalty + env.step_penalty * 1.0 / 3.0  # starting and running at speed 1/3
u214892's avatar
u214892 committed
219
220
                ),
                Replay(
u214892's avatar
u214892 committed
221
222
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
223
224
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
225
226
227
228
                ),
                Replay(
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
229
230
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
231
232
233
234
235
                ),

                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
236
237
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
238
239
                ),
                Replay(
u214892's avatar
u214892 committed
240
                    position=(3, 7),
u214892's avatar
u214892 committed
241
                    direction=Grid4TransitionsEnum.WEST,
242
243
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
244
245
246
247
                ),
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
248
249
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
250
251
252
253
254
                ),

                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
255
256
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
257
258
                ),
                Replay(
u214892's avatar
u214892 committed
259
                    position=(3, 6),
u214892's avatar
u214892 committed
260
                    direction=Grid4TransitionsEnum.WEST,
261
262
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
263
264
265
266
                ),
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
267
268
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
269
                ),
u214892's avatar
u214892 committed
270

u214892's avatar
u214892 committed
271
                Replay(
u214892's avatar
u214892 committed
272
273
                    position=(3, 5),
                    direction=Grid4TransitionsEnum.WEST,
274
275
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
276
277
278
                ),
                Replay(
                    position=(3, 5),
u214892's avatar
u214892 committed
279
                    direction=Grid4TransitionsEnum.WEST,
280
281
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
282
283
                ),
                Replay(
u214892's avatar
u214892 committed
284
285
                    position=(3, 5),
                    direction=Grid4TransitionsEnum.WEST,
286
287
                    action=None,
                    reward=env.step_penalty * 1.0 / 3.0  # running at speed 1/3
u214892's avatar
u214892 committed
288
289
290
                )
            ],
            target=(3, 0),  # west dead-end
u214892's avatar
u214892 committed
291
292
293
294
            speed=1 / 3,
            initial_position=(3, 8),
            initial_direction=Grid4TransitionsEnum.WEST,
        ),
u214892's avatar
u214892 committed
295
        ReplayConfig(
u214892's avatar
u214892 committed
296
297
298
299
            replay=[
                Replay(
                    position=(3, 9),  # east dead-end
                    direction=Grid4TransitionsEnum.EAST,
300
301
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
u214892's avatar
u214892 committed
302
303
                ),
                Replay(
u214892's avatar
u214892 committed
304
305
                    position=(3, 9),
                    direction=Grid4TransitionsEnum.EAST,
306
307
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
308
                ),
u214892's avatar
u214892 committed
309
                # blocked although fraction >= 1.0
u214892's avatar
u214892 committed
310
                Replay(
u214892's avatar
u214892 committed
311
312
                    position=(3, 9),
                    direction=Grid4TransitionsEnum.EAST,
313
314
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
315
316
317
318
319
                ),

                Replay(
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
320
321
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
322
323
                ),
                Replay(
u214892's avatar
u214892 committed
324
325
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
326
327
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
328
                ),
u214892's avatar
u214892 committed
329
                # blocked although fraction >= 1.0
u214892's avatar
u214892 committed
330
                Replay(
u214892's avatar
u214892 committed
331
332
                    position=(3, 8),
                    direction=Grid4TransitionsEnum.WEST,
333
334
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
335
336
337
338
339
                ),

                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
340
341
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
342
                ),
u214892's avatar
u214892 committed
343
344
345
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
346
347
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
348
349
350
351
352
                ),
                # blocked although fraction >= 1.0
                Replay(
                    position=(3, 7),
                    direction=Grid4TransitionsEnum.WEST,
353
354
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
355
                ),
u214892's avatar
u214892 committed
356

u214892's avatar
u214892 committed
357
358
359
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
360
361
                    action=RailEnvActions.MOVE_LEFT,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
362
363
364
365
                ),
                Replay(
                    position=(3, 6),
                    direction=Grid4TransitionsEnum.WEST,
366
367
                    action=None,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
368
369
370
371
372
                ),
                # not blocked, action required!
                Replay(
                    position=(4, 6),
                    direction=Grid4TransitionsEnum.SOUTH,
373
374
                    action=RailEnvActions.MOVE_FORWARD,
                    reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
375
                ),
u214892's avatar
u214892 committed
376
377
            ],
            target=(3, 0),  # west dead-end
u214892's avatar
u214892 committed
378
379
380
            speed=0.5,
            initial_position=(3, 9),  # east dead-end
            initial_direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
381
        )
u214892's avatar
u214892 committed
382

u214892's avatar
u214892 committed
383
    ]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
384
    run_replay_config(env, test_configs, skip_reward_check=True)
u214892's avatar
u214892 committed
385

u214892's avatar
u214892 committed
386

387
def test_multispeed_actions_malfunction_no_blocking():
u214892's avatar
u214892 committed
388
    """Test on a single agent whether action on cell exit work correctly despite malfunction."""
389
390
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
391
                  line_generator=sparse_line_generator(), number_of_agents=1,
392
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
393
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
394
395
396
397
    
    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents
Dipam Chakraborty's avatar
Dipam Chakraborty committed
398
399

    env._max_episode_steps = 10000
Dipam Chakraborty's avatar
Dipam Chakraborty committed
400
    
401
    set_penalties_for_replay(env)
u214892's avatar
u214892 committed
402
    test_config = ReplayConfig(
u214892's avatar
u214892 committed
403
        replay=[
404
            Replay( # 0
u214892's avatar
u214892 committed
405
406
                position=(3, 9),  # east dead-end
                direction=Grid4TransitionsEnum.EAST,
407
408
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
u214892's avatar
u214892 committed
409
            ),
410
            Replay( # 1
u214892's avatar
u214892 committed
411
412
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
413
414
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
415
            ),
416
            Replay( # 2
u214892's avatar
u214892 committed
417
418
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
419
420
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
421
422
            ),
            # add additional step in the cell
423
            Replay( # 3
u214892's avatar
u214892 committed
424
425
426
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
427
428
429
                set_malfunction=2,  # recovers in two steps from now!,
                malfunction=2,
                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when malfunctioning
u214892's avatar
u214892 committed
430
431
            ),
            # agent recovers in this step
432
            Replay( # 4
u214892's avatar
u214892 committed
433
434
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
435
436
437
                action=None,
                malfunction=1,
                reward=env.step_penalty * 0.5  # recovered: running at speed 0.5
u214892's avatar
u214892 committed
438
            ),
439
            Replay( # 5
440
                position=(3, 8),
u214892's avatar
u214892 committed
441
                direction=Grid4TransitionsEnum.WEST,
442
                action=None,
443
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
444
            ),
445
            Replay( # 6
u214892's avatar
u214892 committed
446
447
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
Erik Nygren's avatar
Erik Nygren committed
448
                action=RailEnvActions.MOVE_FORWARD,
449
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
450
            ),
451
            Replay( # 7
452
                position=(3, 7),
u214892's avatar
u214892 committed
453
                direction=Grid4TransitionsEnum.WEST,
Erik Nygren's avatar
Erik Nygren committed
454
                action=None,
455
456
457
                set_malfunction=2,  # recovers in two steps from now!
                malfunction=2,
                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 when malfunctioning
u214892's avatar
u214892 committed
458
459
            ),
            # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
460
            Replay( # 8
Erik Nygren's avatar
Erik Nygren committed
461
                position=(3, 7),
u214892's avatar
u214892 committed
462
                direction=Grid4TransitionsEnum.WEST,
463
                action=None,
464
465
                malfunction=1,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
466
            ),
467
            Replay( # 9
Erik Nygren's avatar
Erik Nygren committed
468
                position=(3, 7),
u214892's avatar
u214892 committed
469
                direction=Grid4TransitionsEnum.WEST,
Erik Nygren's avatar
Erik Nygren committed
470
                action=None,
471
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
472
            ),
473
            Replay( # 10
Erik Nygren's avatar
Erik Nygren committed
474
475
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
476
477
                action=RailEnvActions.STOP_MOVING,
                reward=env.stop_penalty + env.step_penalty * 0.5  # stopping and step penalty for speed 0.5
u214892's avatar
u214892 committed
478
            ),
479
            Replay( # 11
Erik Nygren's avatar
Erik Nygren committed
480
481
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
482
483
                action=RailEnvActions.STOP_MOVING,
                reward=env.step_penalty * 0.5  # step penalty for speed 0.5 while stopped
u214892's avatar
u214892 committed
484
            ),
485
            Replay( # 12
Erik Nygren's avatar
Erik Nygren committed
486
487
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
488
489
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.start_penalty + env.step_penalty * 0.5  # starting and running at speed 0.5
u214892's avatar
u214892 committed
490
            ),
491
            Replay( # 13
Erik Nygren's avatar
Erik Nygren committed
492
493
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
494
495
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
496
            ),
u214892's avatar
u214892 committed
497
            # DO_NOTHING keeps moving!
498
            Replay( # 14
Erik Nygren's avatar
Erik Nygren committed
499
500
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
501
502
                action=RailEnvActions.DO_NOTHING,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
503
            ),
504
            Replay( # 15
Erik Nygren's avatar
Erik Nygren committed
505
506
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
507
508
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
509
            ),
510
            Replay( # 16
Erik Nygren's avatar
Erik Nygren committed
511
512
                position=(3, 4),
                direction=Grid4TransitionsEnum.WEST,
513
514
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
u214892's avatar
u214892 committed
515
516
517
518
            ),

        ],
        target=(3, 0),  # west dead-end
u214892's avatar
u214892 committed
519
520
521
        speed=0.5,
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
u214892's avatar
u214892 committed
522
    )
Dipam Chakraborty's avatar
Dipam Chakraborty committed
523
    run_replay_config(env, [test_config], skip_reward_check=True)
524
525
526
527
528


# TODO invalid action penalty seems only given when forward is not possible - is this the intended behaviour?
def test_multispeed_actions_no_malfunction_invalid_actions():
    """Test that actions are correctly performed on cell exit for a single agent."""
529
530
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
531
                  line_generator=sparse_line_generator(), number_of_agents=1,
532
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
533
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
534
535
536
537
    
    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents
Dipam Chakraborty's avatar
Dipam Chakraborty committed
538
539
    
    env._max_episode_steps = 10000
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605

    set_penalties_for_replay(env)
    test_config = ReplayConfig(
        replay=[
            Replay(
                position=(3, 9),  # east dead-end
                direction=Grid4TransitionsEnum.EAST,
                action=RailEnvActions.MOVE_LEFT,
                reward=env.start_penalty + env.step_penalty * 0.5  # auto-correction left to forward without penalty!
            ),
            Replay(
                position=(3, 9),
                direction=Grid4TransitionsEnum.EAST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 8),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_FORWARD,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 7),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_RIGHT,
                reward=env.step_penalty * 0.5  # wrong action is corrected to forward without penalty!
            ),
            Replay(
                position=(3, 6),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),
            Replay(
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
                action=RailEnvActions.MOVE_RIGHT,
                reward=env.step_penalty * 0.5  # wrong action is corrected to forward without penalty!
            ), Replay(
                position=(3, 5),
                direction=Grid4TransitionsEnum.WEST,
                action=None,
                reward=env.step_penalty * 0.5  # running at speed 0.5
            ),

        ],
        target=(3, 0),  # west dead-end
u214892's avatar
u214892 committed
606
607
608
        speed=0.5,
        initial_position=(3, 9),  # east dead-end
        initial_direction=Grid4TransitionsEnum.EAST,
609
610
    )

Dipam Chakraborty's avatar
Dipam Chakraborty committed
611
    run_replay_config(env, [test_config], skip_reward_check=True)