tests_generators.py 7.94 KB
Newer Older
1
2
3
#!/usr/bin/env python
# -*- coding: utf-8 -*-

4
5
import numpy as np

6
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
7
8
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
u214892's avatar
u214892 committed
9
10
from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
    random_rail_generator, empty_rail_generator
11
from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \
12
    schedule_from_file
13
from flatland.utils.simple_rail import make_simple_rail
14
15


Erik Nygren's avatar
Erik Nygren committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def test_empty_rail_generator():
    np.random.seed(0)
    n_agents = 1
    x_dim = 5
    y_dim = 10

    # Check that a random level at with correct parameters is generated
    env = RailEnv(width=x_dim,
                  height=y_dim,
                  number_of_agents=n_agents,
                  rail_generator=empty_rail_generator()
                  )
    # Check the dimensions
    assert env.rail.grid.shape == (y_dim, x_dim)
    # Check that no grid was generated
    assert np.count_nonzero(env.rail.grid) == 0
    # Check that no agents where placed
    assert env.get_num_agents() == 0


def test_random_rail_generator():
    np.random.seed(0)
    n_agents = 1
    x_dim = 5
    y_dim = 10

    # Check that a random level at with correct parameters is generated
    env = RailEnv(width=x_dim,
                  height=y_dim,
                  number_of_agents=n_agents,
                  rail_generator=random_rail_generator()
                  )
    assert env.rail.grid.shape == (y_dim, x_dim)
    assert env.get_num_agents() == n_agents


def test_complex_rail_generator():
    n_agents = 10
    n_start = 2
    x_dim = 10
    y_dim = 10
    min_dist = 4

    # Check that agent number is changed to fit generated level
    env = RailEnv(width=x_dim,
                  height=y_dim,
                  number_of_agents=n_agents,
u214892's avatar
u214892 committed
63
                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
64
                  schedule_generator=complex_schedule_generator()
Erik Nygren's avatar
Erik Nygren committed
65
66
                  )
    assert env.get_num_agents() == 2
67
    assert env.rail.grid.shape == (y_dim, x_dim)
Erik Nygren's avatar
Erik Nygren committed
68
69
70
71
72
73
74

    min_dist = 2 * x_dim

    # Check that no agents are generated when level cannot be generated
    env = RailEnv(width=x_dim,
                  height=y_dim,
                  number_of_agents=n_agents,
u214892's avatar
u214892 committed
75
                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
76
                  schedule_generator=complex_schedule_generator()
Erik Nygren's avatar
Erik Nygren committed
77
78
                  )
    assert env.get_num_agents() == 0
79
    assert env.rail.grid.shape == (y_dim, x_dim)
Erik Nygren's avatar
Erik Nygren committed
80
81
82
83
84
85
86
87
88

    # Check that everything stays the same when correct parameters are given
    min_dist = 2
    n_start = 5
    n_agents = 5

    env = RailEnv(width=x_dim,
                  height=y_dim,
                  number_of_agents=n_agents,
u214892's avatar
u214892 committed
89
                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
90
                  schedule_generator=complex_schedule_generator()
Erik Nygren's avatar
Erik Nygren committed
91
92
                  )
    assert env.get_num_agents() == n_agents
93
    assert env.rail.grid.shape == (y_dim, x_dim)
Erik Nygren's avatar
Erik Nygren committed
94
95


96
def test_rail_from_grid_transition_map():
Erik Nygren's avatar
Erik Nygren committed
97
    rail, rail_map = make_simple_rail()
98
    n_agents = 3
Erik Nygren's avatar
Erik Nygren committed
99
100
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
101
                  rail_generator=rail_from_grid_transition_map(rail),
102
                  schedule_generator=random_schedule_generator(),
103
                  number_of_agents=n_agents
Erik Nygren's avatar
Erik Nygren committed
104
105
106
107
108
109
110
111
112
                  )
    nr_rail_elements = np.count_nonzero(env.rail.grid)

    # Check if the number of non-empty rail cells is ok
    assert nr_rail_elements == 16

    # Check that agents are placed on a rail
    for a in env.agents:
        assert env.rail.grid[a.position] != 0
113
114

    assert env.get_num_agents() == n_agents
Erik Nygren's avatar
Erik Nygren committed
115
116
117


def tests_rail_from_file():
118
119
120
121
    file_name = "test_with_distance_map.pkl"

    # Test to save and load file with distance map.

122
    rail, rail_map = make_simple_rail()
123

124
125
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
126
                  rail_generator=rail_from_grid_transition_map(rail),
127
                  schedule_generator=random_schedule_generator(),
128
129
130
                  number_of_agents=3,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )
131
    env.save(file_name)
132
    dist_map_shape = np.shape(env.distance_map.get())
133
    # initialize agents_static
134
135
    rails_initial = env.rail.grid
    agents_initial = env.agents
136
137
138

    env = RailEnv(width=1,
                  height=1,
139
                  rail_generator=rail_from_file(file_name),
140
                  schedule_generator=schedule_from_file(file_name),
141
142
143
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )
144
145
146
147
148
    rails_loaded = env.rail.grid
    agents_loaded = env.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded
149
150

    # Check that distance map was not recomputed
151
152
153
    assert env.distance_map.distance_map_computed is False
    assert np.shape(env.distance_map.get()) == dist_map_shape
    assert env.distance_map.get() is not None
154
155
156
157
158
159
160
161

    # Test to save and load file without distance map.

    file_name_2 = "test_without_distance_map.pkl"

    env2 = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
162
                   schedule_generator=random_schedule_generator(),
163
164
165
166
167
168
169
170
171
                   number_of_agents=3,
                   obs_builder_object=GlobalObsForRailEnv(),
                   )

    env2.save(file_name_2)

    # initialize agents_static
    rails_initial_2 = env2.rail.grid
    agents_initial_2 = env2.agents
172

173
174
175
    env2 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name_2),
176
                   schedule_generator=schedule_from_file(file_name_2),
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                   number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv(),
                   )

    rails_loaded_2 = env2.rail.grid
    agents_loaded_2 = env2.agents

    assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
    assert agents_initial_2 == agents_loaded_2
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save with distance map and load without

    env3 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name),
193
                   schedule_generator=schedule_from_file(file_name),
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
                   number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv(),
                   )

    rails_loaded_3 = env3.rail.grid
    agents_loaded_3 = env3.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded_3))
    assert agents_initial == agents_loaded_3
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save without distance map and load with generating distance map

    # initialize agents_static
    env4 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name_2),
211
                   schedule_generator=schedule_from_file(file_name_2),
212
213
214
215
216
217
218
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2),
                   )

    rails_loaded_4 = env4.rail.grid
    agents_loaded_4 = env4.agents

219
220
    # Check that no distance map was saved
    assert not hasattr(env2.obs_builder, "distance_map")
221
222
    assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
    assert agents_initial_2 == agents_loaded_4
223
224

    # Check that distance map was generated with correct shape
225
226
227
    assert env4.distance_map.distance_map_computed is True
    assert env4.distance_map.get() is not None
    assert np.shape(env4.distance_map.get()) == dist_map_shape