test_flatland_envs_rail_env.py 14.7 KB
Newer Older
spiglerg's avatar
spiglerg committed
1
2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
spiglerg's avatar
spiglerg committed
3
import numpy as np
4
import os
spiglerg's avatar
spiglerg committed
5

u214892's avatar
u214892 committed
6
from flatland.core.grid.rail_env_grid import RailEnvTransitions
7
from flatland.core.transition_map import GridTransitionMap
8
from flatland.envs.agent_utils import EnvAgent
u214892's avatar
u214892 committed
9
10
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
11
from flatland.envs.rail_env import RailEnv, RailEnvActions
Dipam Chakraborty's avatar
Dipam Chakraborty committed
12
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file
u214892's avatar
u214892 committed
13
from flatland.envs.rail_generators import rail_from_grid_transition_map
14
from flatland.envs.line_generators import sparse_line_generator, line_from_file
u214892's avatar
u214892 committed
15
from flatland.utils.simple_rail import make_simple_rail
16
from flatland.envs.persistence import RailEnvPersister
17
18
19
from flatland.utils.rendertools import RenderTool

import pytest
20
import time
21

22

spiglerg's avatar
spiglerg committed
23
24
"""Tests for `flatland` package."""

25
@pytest.mark.skip("Msgpack serializing not supported")
u214892's avatar
u214892 committed
26
def test_load_env():
27
28
29
    #env = RailEnv(10, 10)
    #env.reset()
    # env.load_resource('env_data.tests', 'test-10x10.mpk')
30
31
    env, env_dict = RailEnvPersister.load_resource("env_data.tests", "test-10x10.mpk")
    #env, env_dict = RailEnvPersister.load_new("./env_data/tests/test-10x10.mpk")
u214892's avatar
u214892 committed
32

u229589's avatar
u229589 committed
33
34
    agent_static = EnvAgent((0, 0), 2, (5, 5), False)
    env.add_agent(agent_static)
u214892's avatar
u214892 committed
35
    assert env.get_num_agents() == 1
spiglerg's avatar
spiglerg committed
36

37

maljx's avatar
maljx committed
38
def test_save_load():
39
40
    env = RailEnv(width=30, height=30,
                  rail_generator=sparse_rail_generator(seed=1),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
41
                  line_generator=sparse_line_generator(), number_of_agents=2)
maljx's avatar
maljx committed
42
    env.reset()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
43

u229589's avatar
u229589 committed
44
45
46
47
48
49
    agent_1_pos = env.agents[0].position
    agent_1_dir = env.agents[0].direction
    agent_1_tar = env.agents[0].target
    agent_2_pos = env.agents[1].position
    agent_2_dir = env.agents[1].direction
    agent_2_tar = env.agents[1].target
50

51
52
53
54
    os.makedirs("tmp", exist_ok=True)

    RailEnvPersister.save(env, "tmp/test_save.pkl")
    env.save("tmp/test_save_2.pkl")
55

56
    #env.load("test_save.dat")
57
    env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl")
58
59
    assert (env.width == 30)
    assert (env.height == 30)
u214892's avatar
u214892 committed
60
    assert (len(env.agents) == 2)
u229589's avatar
u229589 committed
61
62
63
64
65
66
    assert (agent_1_pos == env.agents[0].position)
    assert (agent_1_dir == env.agents[0].direction)
    assert (agent_1_tar == env.agents[0].target)
    assert (agent_2_pos == env.agents[1].position)
    assert (agent_2_dir == env.agents[1].direction)
    assert (agent_2_tar == env.agents[1].target)
maljx's avatar
maljx committed
67

68
@pytest.mark.skip("Msgpack serializing not supported")
69
def test_save_load_mpk():
70
71
    env = RailEnv(width=30, height=30,
                  rail_generator=sparse_rail_generator(seed=1),
72
                  line_generator=sparse_line_generator(), number_of_agents=2)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    env.reset()

    os.makedirs("tmp", exist_ok=True)

    RailEnvPersister.save(env, "tmp/test_save.mpk")

    #env.load("test_save.dat")
    env2, env_dict = RailEnvPersister.load_new("tmp/test_save.mpk")
    assert (env.width == env2.width)
    assert (env.height == env2.height)
    assert (len(env2.agents) == len(env.agents))
    
    for agent1, agent2 in zip(env.agents, env2.agents):
        assert(agent1.position == agent2.position)
        assert(agent1.direction == agent2.direction)
        assert(agent1.target == agent2.target)


Dipam Chakraborty's avatar
Dipam Chakraborty committed
91
@pytest.mark.skip(reason="Old file used to create env, not sure how to regenerate")
92
def test_rail_environment_single_agent(show=False):
93
94
95
96
97
98
    # We instantiate the following map on a 3x3 grid
    #  _  _
    # / \/ \
    # | |  |
    # \_/\_/

spiglerg's avatar
spiglerg committed
99
    transitions = RailEnvTransitions()
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    
    
    
    if False:
        # This env creation doesn't quite work right.
        cells = transitions.transition_list
        vertical_line = cells[1]
        south_symmetrical_switch = cells[6]
        north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
        south_east_turn = int('0100000000000010', 2)
        south_west_turn = transitions.rotate_transition(south_east_turn, 90)
        north_east_turn = transitions.rotate_transition(south_east_turn, 270)
        north_west_turn = transitions.rotate_transition(south_east_turn, 180)

        rail_map = np.array([[south_east_turn, south_symmetrical_switch,
                            south_west_turn],
                            [vertical_line, vertical_line, vertical_line],
                            [north_east_turn, north_symmetrical_switch,
                            north_west_turn]],
                            dtype=np.uint16)

        rail = GridTransitionMap(width=3, height=3, transitions=transitions)
        rail.grid = rail_map
        rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail),
124
                        line_generator=sparse_line_generator(), number_of_agents=1,
125
126
127
128
129
130
131
132
133
134
                        obs_builder_object=GlobalObsForRailEnv())
    else:
        rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests")
        rail_map = rail_env.rail.grid
    
    rail_env._max_episode_steps = 1000

    _ = rail_env.reset(False, False, True)

    liActions = [int(a) for a in RailEnvActions]
gmollard's avatar
gmollard committed
135

136
    env_renderer = RenderTool(rail_env)
137
138

    #RailEnvPersister.save(rail_env, "test_env_figure8.pkl")
139
    
140
    for _ in range(5):
gmollard's avatar
gmollard committed
141

142
143
        #rail_env.agents[0].initial_position = (1,2)
        _ = rail_env.reset(False, False, True)
144

gmollard's avatar
gmollard committed
145
        # We do not care about target for the moment
146
147
        agent = rail_env.agents[0]
        agent.target = [-1, -1]
148

maljx's avatar
maljx committed
149
150
        # Check that trains are always initialized at a consistent position
        # or direction.
gmollard's avatar
gmollard committed
151
        # They should always be able to go somewhere.
152
153
154
        if show:
            print("After reset - agent pos:", agent.position, "dir: ", agent.direction)
            print(transitions.get_transitions(rail_map[agent.position], agent.direction))
155

156
157
158
        #assert (transitions.get_transitions(
        #    rail_map[agent.position],
        #    agent.direction) != (0, 0, 0, 0))
159

160
161
162
        # HACK - force the direction to one we know is good.
        #agent.initial_position = agent.position = (2,3)
        agent.initial_direction = agent.direction = 0
gmollard's avatar
gmollard committed
163

164
165
166
        if show:
            print ("handle:", agent.handle)
        #agent.initial_position = initial_pos = agent.position
167

168
169
        valid_active_actions_done = 0
        pos = agent.position
170

171
172
173
        if show:
            env_renderer.render_env(show=show, show_agents=True)
            time.sleep(0.01)
174

175
        iStep = 0
gmollard's avatar
gmollard committed
176
177
        while valid_active_actions_done < 6:
            # We randomly select an action
178
179
            action = np.random.choice(liActions)
            #action = RailEnvActions.MOVE_FORWARD
gmollard's avatar
gmollard committed
180

181
            _, _, dict_done, _ = rail_env.step({0: action})
gmollard's avatar
gmollard committed
182
183

            prev_pos = pos
184
            pos = agent.position  # rail_env.agents_position[0]
185

186
187
            print("action:", action, "pos:", agent.position, "prev:", prev_pos, agent.direction)
            print(dict_done)
gmollard's avatar
gmollard committed
188
189
            if prev_pos != pos:
                valid_active_actions_done += 1
190
            iStep += 1
191
192
193
194
            
            if show:
                env_renderer.render_env(show=show, show_agents=True, step=iStep)
                time.sleep(0.01)
195
            assert iStep < 100, "valid actions should have been performed by now - hung agent"
gmollard's avatar
gmollard committed
196

maljx's avatar
maljx committed
197
        # After 6 movements on this railway network, the train should be back
gmollard's avatar
gmollard committed
198
        # to its original height on the map.
199
        #assert (initial_pos[0] == agent.position[0])
gmollard's avatar
gmollard committed
200
201

        # We check that the train always attains its target after some time
gmollard's avatar
gmollard committed
202
        for _ in range(10):
gmollard's avatar
gmollard committed
203
204
            _ = rail_env.reset()

205
            rail_env.agents[0].direction = 0
206

207
            # JW - to avoid problem with sparse_line_generator.
208
            #rail_env.agents[0].position = (1,2)
209
210
211

            iStep = 0
            while iStep < 100:
gmollard's avatar
gmollard committed
212
                # We randomly select an action
213
                action = np.random.choice(liActions)
gmollard's avatar
gmollard committed
214
215
216

                _, _, dones, _ = rail_env.step({0: action})
                done = dones['__all__']
217
218
219
220
                if done:
                    break
                iStep +=1
                assert iStep < 100, "agent should have finished by now"
221
                env_renderer.render_env(show=show)
gmollard's avatar
gmollard committed
222
223
224


def test_dead_end():
u214892's avatar
u214892 committed
225
    transitions = RailEnvTransitions()
gmollard's avatar
gmollard committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

    straight_vertical = int('1000000000100000', 2)  # Case 1 - straight
    straight_horizontal = transitions.rotate_transition(straight_vertical,
                                                        90)

    dead_end_from_south = int('0010000000000000', 2)  # Case 7 - dead end

    # We instantiate the following railway
    # O->-- where > is the train and O the target. After 6 steps,
    # the train should be done.

    rail_map = np.array(
        [[transitions.rotate_transition(dead_end_from_south, 270)] +
         [straight_horizontal] * 3 +
         [transitions.rotate_transition(dead_end_from_south, 90)]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)

    rail.grid = rail_map
Dipam Chakraborty's avatar
Dipam Chakraborty committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    city_positions = [(0, 0), (0, 3)]
    train_stations = [
                      [( (0, 0), 0 ) ], 
                      [( (0, 0), 0 ) ],
                     ]
    city_orientations = [0, 2]
    agents_hints = {'num_agents': 2,
                   'city_positions': city_positions,
                   'train_stations': train_stations,
                   'city_orientations': city_orientations
                  }
    optionals = {'agents_hints': agents_hints}

262
    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
Dipam Chakraborty's avatar
Dipam Chakraborty committed
263
                       rail_generator=rail_from_grid_transition_map(rail, optionals),
264
                       line_generator=sparse_line_generator(), number_of_agents=1,
gmollard's avatar
gmollard committed
265
                       obs_builder_object=GlobalObsForRailEnv())
gmollard's avatar
gmollard committed
266
267
268

    # We try the configuration in the 4 directions:
    rail_env.reset()
u229589's avatar
u229589 committed
269
    rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)]
gmollard's avatar
gmollard committed
270
271

    rail_env.reset()
u229589's avatar
u229589 committed
272
    rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)]
gmollard's avatar
gmollard committed
273
274
275
276
277
278
279
280
281
282
283

    # In the vertical configuration:
    rail_map = np.array(
        [[dead_end_from_south]] + [[straight_vertical]] * 3 +
        [[transitions.rotate_transition(dead_end_from_south, 180)]],
        dtype=np.uint16)

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0],
                             transitions=transitions)

Dipam Chakraborty's avatar
Dipam Chakraborty committed
284
285
286
287
288
289
290
291
292
293
294
295
296
    city_positions = [(0, 0), (0, 3)]
    train_stations = [
                      [( (0, 0), 0 ) ], 
                      [( (0, 0), 0 ) ],
                     ]
    city_orientations = [0, 2]
    agents_hints = {'num_agents': 2,
                   'city_positions': city_positions,
                   'train_stations': train_stations,
                   'city_orientations': city_orientations
                  }
    optionals = {'agents_hints': agents_hints}

gmollard's avatar
gmollard committed
297
    rail.grid = rail_map
298
    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
Dipam Chakraborty's avatar
Dipam Chakraborty committed
299
                       rail_generator=rail_from_grid_transition_map(rail, optionals),
300
                       line_generator=sparse_line_generator(), number_of_agents=1,
gmollard's avatar
gmollard committed
301
                       obs_builder_object=GlobalObsForRailEnv())
gmollard's avatar
gmollard committed
302
303

    rail_env.reset()
u229589's avatar
u229589 committed
304
    rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)]
gmollard's avatar
gmollard committed
305
306

    rail_env.reset()
u229589's avatar
u229589 committed
307
    rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)]
u214892's avatar
u214892 committed
308
309

    # TODO make assertions
u214892's avatar
u214892 committed
310
311
312


def test_get_entry_directions():
313
314
315
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                  line_generator=sparse_line_generator(), number_of_agents=1,
316
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
317
    env.reset()
u214892's avatar
u214892 committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

    def _assert(position, expected):
        actual = env.get_valid_directions_on_grid(*position)
        assert actual == expected, "[{},{}] actual={}, expected={}".format(*position, actual, expected)

    # north dead end
    _assert((0, 3), [True, False, False, False])

    # west dead end
    _assert((3, 0), [False, False, False, True])

    # switch
    _assert((3, 3), [False, True, True, True])

    # horizontal
    _assert((3, 2), [False, True, False, True])

    # vertical
    _assert((2, 3), [True, False, True, False])

    # nowhere
    _assert((0, 0), [False, False, False, False])
340

Erik Nygren's avatar
Erik Nygren committed
341

342
343
344
def test_rail_env_reset():
    file_name = "test_rail_env_reset.pkl"

345
    # Test to save and load file.
346

347
    rail, rail_map, optionals = make_simple_rail()
348

349
350
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                  line_generator=sparse_line_generator(), number_of_agents=3,
351
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
352
    env.reset()
353
354
355
356

    #env.save(file_name)
    RailEnvPersister.save(env, file_name)

357
358
359
360
    dist_map_shape = np.shape(env.distance_map.get())
    rails_initial = env.rail.grid
    agents_initial = env.agents

361
    #env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
362
    #               line_generator=line_from_file(file_name), number_of_agents=1,
363
364
365
366
    #               obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
    #env2.reset(False, False, False)
    env2, env2_dict = RailEnvPersister.load_new(file_name)

367
368
369
370
371
372
    rails_loaded = env2.rail.grid
    agents_loaded = env2.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded

373
    env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
374
                   line_generator=line_from_file(file_name), number_of_agents=1,
375
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
376
377
378
379
380
381
382
    env3.reset(False, True, False)
    rails_loaded = env3.rail.grid
    agents_loaded = env3.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded

383
    env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
384
                   line_generator=line_from_file(file_name), number_of_agents=1,
385
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
386
387
388
389
390
391
    env4.reset(True, False, False)
    rails_loaded = env4.rail.grid
    agents_loaded = env4.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded
392
393
394


def main():
395
    test_rail_environment_single_agent(show=True)
396
397
398

if __name__=="__main__":
    main()