test_flatland_envs_predictions.py 12.4 KB
Newer Older
u214892's avatar
u214892 committed
1
2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
3
import pprint
u214892's avatar
u214892 committed
4
5
6

import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.grid.grid4 import Grid4TransitionsEnum
8
from flatland.envs.observations import TreeObsForRailEnv, Node
9
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
10
from flatland.envs.rail_env import RailEnv
11
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
u214892's avatar
u214892 committed
12
from flatland.envs.rail_generators import rail_from_grid_transition_map
13
from flatland.envs.rail_trainrun_data_structures import Waypoint
14
from flatland.envs.line_generators import sparse_line_generator
15
from flatland.utils.rendertools import RenderTool
u214892's avatar
u214892 committed
16
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
17
from flatland.envs.rail_env_action import RailEnvActions
18
from flatland.envs.step_utils.states import TrainState
u214892's avatar
u214892 committed
19

u214892's avatar
u214892 committed
20
"""Test predictions for `flatland` package."""
u214892's avatar
u214892 committed
21
22


23
def test_dummy_predictor(rendering=False):
24
    rail, rail_map, optionals = make_simple_rail2()
25

u214892's avatar
u214892 committed
26
27
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
28
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
29
                  line_generator=sparse_line_generator(),
u214892's avatar
u214892 committed
30
                  number_of_agents=1,
u214892's avatar
u214892 committed
31
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
u214892's avatar
u214892 committed
32
33
34
                  )
    env.reset()

u214892's avatar
u214892 committed
35
    # set initial position and direction for testing...
u229589's avatar
u229589 committed
36
37
38
39
    env.agents[0].initial_position = (5, 6)
    env.agents[0].initial_direction = 0
    env.agents[0].direction = 0
    env.agents[0].target = (3, 0)
40
41

    env.reset(False, False)
42
43
44
45
46
    env.agents[0].earliest_departure = 1
    env._max_episode_steps = 100
    # Make Agent 0 active
    env.step({})
    env.step({0: RailEnvActions.MOVE_FORWARD})
47
48
49

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
Erik Nygren's avatar
Erik Nygren committed
50
        renderer.render_env(show=True, show_observations=False)
51
        input("Continue?")
u214892's avatar
u214892 committed
52

53
    # test assertions
54
    predictions = env.obs_builder.predictor.get(None)
55
    positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
u214892's avatar
u214892 committed
56
57
58
59
60
61
62
63
64
65
66
67
68
    directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
    time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
    actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))

    # compare against expected values
    expected_positions = np.array([[5., 6.],
                                   [4., 6.],
                                   [3., 6.],
                                   [3., 5.],
                                   [3., 4.],
                                   [3., 3.],
                                   [3., 2.],
                                   [3., 1.],
69
                                   # at target (3,0): stay in this position from here on
u214892's avatar
u214892 committed
70
                                   [3., 0.],
71
72
73
                                   [3., 0.],
                                   [3., 0.],
                                   ])
u214892's avatar
u214892 committed
74
75
76
77
78
79
80
81
    expected_directions = np.array([[0.],
                                    [0.],
                                    [0.],
                                    [3.],
                                    [3.],
                                    [3.],
                                    [3.],
                                    [3.],
82
                                    # at target (3,0): stay in this position from here on
u214892's avatar
u214892 committed
83
84
                                    [3.],
                                    [3.],
85
86
                                    [3.]
                                    ])
u214892's avatar
u214892 committed
87
88
89
90
91
92
93
94
95
96
97
    expected_time_offsets = np.array([[0.],
                                      [1.],
                                      [2.],
                                      [3.],
                                      [4.],
                                      [5.],
                                      [6.],
                                      [7.],
                                      [8.],
                                      [9.],
                                      [10.],
98
                                      ])
u214892's avatar
u214892 committed
99
100
101
    expected_actions = np.array([[0.],
                                 [2.],
                                 [2.],
u214892's avatar
u214892 committed
102
                                 [2.],
u214892's avatar
u214892 committed
103
104
105
106
                                 [2.],
                                 [2.],
                                 [2.],
                                 [2.],
107
                                 # reaching target by straight
u214892's avatar
u214892 committed
108
                                 [2.],
109
110
111
112
                                 # at target: stopped moving
                                 [4.],
                                 [4.],
                                 ])
u214892's avatar
u214892 committed
113
114
115
116
    assert np.array_equal(positions, expected_positions)
    assert np.array_equal(directions, expected_directions)
    assert np.array_equal(time_offsets, expected_time_offsets)
    assert np.array_equal(actions, expected_actions)
u214892's avatar
u214892 committed
117
118


119
def test_shortest_path_predictor(rendering=False):
120
    rail, rail_map, optionals = make_simple_rail()
121
122
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
123
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
124
                  line_generator=sparse_line_generator(),
125
126
127
128
129
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )
    env.reset()

130
    # set the initial position
u229589's avatar
u229589 committed
131
    agent = env.agents[0]
u214892's avatar
u214892 committed
132
    agent.initial_position = (5, 6)  # south dead-end
133
134
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
u229589's avatar
u229589 committed
135
    agent.initial_direction = 0  # north
136
137
    agent.target = (3, 9)  # east dead-end
    agent.moving = True
138
    agent._set_state(TrainState.MOVING)
139

140
    env.reset(False, False)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
141
142
143
144
145
    env.distance_map._compute(env.agents, env.rail)
    
    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents
146

147
148
    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
Erik Nygren's avatar
Erik Nygren committed
149
        renderer.render_env(show=True, show_observations=False)
150
151
        input("Continue?")

152
    # compute the observations and predictions
153
    distance_map = env.distance_map.get()
154
155
    distance_on_map = distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction]
    assert distance_on_map == 5.0, "found {} instead of {}".format(distance_on_map, 5.0)
156

u214892's avatar
u214892 committed
157
    paths = get_shortest_paths(env.distance_map)[0]
u214892's avatar
u214892 committed
158
    assert paths == [
159
160
161
162
163
164
        Waypoint((5, 6), 0),
        Waypoint((4, 6), 0),
        Waypoint((3, 6), 0),
        Waypoint((3, 7), 1),
        Waypoint((3, 8), 1),
        Waypoint((3, 9), 1)
u214892's avatar
u214892 committed
165
    ]
u214892's avatar
u214892 committed
166

167
    # extract the data
168
169
170
171
172
    predictions = env.obs_builder.predictions
    positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
    directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
    time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))

173
    # test if data meets expectations
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    expected_positions = [
        [5, 6],
        [4, 6],
        [3, 6],
        [3, 7],
        [3, 8],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
    ]
    expected_directions = [
        [Grid4TransitionsEnum.NORTH],  # next is [5,6] heading north
        [Grid4TransitionsEnum.NORTH],  # next is [4,6] heading north
        [Grid4TransitionsEnum.NORTH],  # next is [3,6] heading north
        [Grid4TransitionsEnum.EAST],  # next is [3,7] heading east
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
    ]

    expected_time_offsets = np.array([
        [0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.],
        [9.],
        [10.],
        [11.],
        [12.],
        [13.],
        [14.],
        [15.],
        [16.],
        [17.],
        [18.],
        [19.],
        [20.],
    ])
u214892's avatar
u214892 committed
244

245
246
247
    assert np.array_equal(time_offsets, expected_time_offsets), \
        "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)

248
249
250
251
    assert np.array_equal(positions, expected_positions), \
        "positions {}, expected {}".format(positions, expected_positions)
    assert np.array_equal(directions, expected_directions), \
        "directions {}, expected {}".format(directions, expected_directions)
252
253
254


def test_shortest_path_predictor_conflicts(rendering=False):
255
    rail, rail_map, optionals = make_invalid_simple_rail()
256
257
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
258
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
259
                  line_generator=sparse_line_generator(),
260
261
262
263
264
265
                  number_of_agents=2,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )
    env.reset()

    # set the initial position
266
267
268
269
270
271
    env.agents[0].initial_position = (5, 6)  # south dead-end
    env.agents[0].position = (5, 6)  # south dead-end
    env.agents[0].direction = 0  # north
    env.agents[0].initial_direction = 0  # north
    env.agents[0].target = (3, 9)  # east dead-end
    env.agents[0].moving = True
272
    env.agents[0]._set_state(TrainState.MOVING)
273

274
275
276
277
278
279
    env.agents[1].initial_position = (3, 8)  # east dead-end
    env.agents[1].position = (3, 8)  # east dead-end
    env.agents[1].direction = 3  # west
    env.agents[1].initial_direction = 3  # west
    env.agents[1].target = (6, 6)  # south dead-end
    env.agents[1].moving = True
280
    env.agents[1]._set_state(TrainState.MOVING)
281
282
283
284
285
286
287

    observations, info = env.reset(False, False)

    env.agents[0].position = (5, 6)  # south dead-end
    env.agent_positions[env.agents[0].position] = 0
    env.agents[1].position = (3, 8)  # east dead-end
    env.agent_positions[env.agents[1].position] = 1
288
289
    env.agents[0]._set_state(TrainState.MOVING)
    env.agents[1]._set_state(TrainState.MOVING)
290
291

    observations = env._get_observations()
292
293
294
295


    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
Erik Nygren's avatar
Erik Nygren committed
296
        renderer.render_env(show=True, show_observations=False)
297
298
299
300
301
        input("Continue?")

    # get the trees to test
    obs_builder: TreeObsForRailEnv = env.obs_builder
    pp = pprint.PrettyPrinter(indent=4)
302
303
304
305
    tree_0 = observations[0]
    tree_1 = observations[1]
    env.obs_builder.util_print_obs_subtree(tree_0)
    env.obs_builder.util_print_obs_subtree(tree_1)
306
307

    # check the expectations
u214892's avatar
u214892 committed
308
309
    expected_conflicts_0 = [('F', 'R')]
    expected_conflicts_1 = [('F', 'L')]
310
311
312
313
    _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ")
    _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")


314
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: Node, prompt=''):
315
    assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
u214892's avatar
u214892 committed
316
    for a_1 in obs_builder.tree_explored_actions_char:
317
318
319
320
        if tree.childs[a_1] == -np.inf:
            assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
            continue
        else:
321
            conflict = tree.childs[a_1].num_agents_opposite_direction
322
            assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
u214892's avatar
u214892 committed
323
        for a_2 in obs_builder.tree_explored_actions_char:
324
325
326
            if tree.childs[a_1].childs[a_2] == -np.inf:
                assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
            else:
327
                conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction
328
                assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)