test_flatland_envs_predictions.py 12.4 KB
Newer Older
u214892's avatar
u214892 committed
1
2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
3
import pprint
u214892's avatar
u214892 committed
4
5
6

import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.grid.grid4 import Grid4TransitionsEnum
8
from flatland.envs.observations import TreeObsForRailEnv, Node
9
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
10
from flatland.envs.rail_env import RailEnv
11
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
u214892's avatar
u214892 committed
12
from flatland.envs.rail_generators import rail_from_grid_transition_map
13
from flatland.envs.rail_trainrun_data_structures import Waypoint
14
from flatland.envs.line_generators import sparse_line_generator
15
from flatland.utils.rendertools import RenderTool
u214892's avatar
u214892 committed
16
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
17
from flatland.envs.rail_env_action import RailEnvActions
18
from flatland.envs.step_utils.states import TrainState
u214892's avatar
u214892 committed
19

20

u214892's avatar
u214892 committed
21
"""Test predictions for `flatland` package."""
u214892's avatar
u214892 committed
22
23


24
def test_dummy_predictor(rendering=False):
25
    rail, rail_map, optionals = make_simple_rail2()
26

u214892's avatar
u214892 committed
27
28
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
29
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
30
                  line_generator=sparse_line_generator(),
u214892's avatar
u214892 committed
31
                  number_of_agents=1,
u214892's avatar
u214892 committed
32
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
u214892's avatar
u214892 committed
33
34
35
                  )
    env.reset()

u214892's avatar
u214892 committed
36
    # set initial position and direction for testing...
u229589's avatar
u229589 committed
37
38
39
40
    env.agents[0].initial_position = (5, 6)
    env.agents[0].initial_direction = 0
    env.agents[0].direction = 0
    env.agents[0].target = (3, 0)
41
42

    env.reset(False, False)
43
44
45
46
47
    env.agents[0].earliest_departure = 1
    env._max_episode_steps = 100
    # Make Agent 0 active
    env.step({})
    env.step({0: RailEnvActions.MOVE_FORWARD})
48
49
50

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
Erik Nygren's avatar
Erik Nygren committed
51
        renderer.render_env(show=True, show_observations=False)
52
        input("Continue?")
u214892's avatar
u214892 committed
53

54
    # test assertions
55
    predictions = env.obs_builder.predictor.get(None)
56
    positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
u214892's avatar
u214892 committed
57
58
59
60
61
62
63
64
65
66
67
68
69
    directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
    time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
    actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))

    # compare against expected values
    expected_positions = np.array([[5., 6.],
                                   [4., 6.],
                                   [3., 6.],
                                   [3., 5.],
                                   [3., 4.],
                                   [3., 3.],
                                   [3., 2.],
                                   [3., 1.],
70
                                   # at target (3,0): stay in this position from here on
u214892's avatar
u214892 committed
71
                                   [3., 0.],
72
73
74
                                   [3., 0.],
                                   [3., 0.],
                                   ])
u214892's avatar
u214892 committed
75
76
77
78
79
80
81
82
    expected_directions = np.array([[0.],
                                    [0.],
                                    [0.],
                                    [3.],
                                    [3.],
                                    [3.],
                                    [3.],
                                    [3.],
83
                                    # at target (3,0): stay in this position from here on
u214892's avatar
u214892 committed
84
85
                                    [3.],
                                    [3.],
86
87
                                    [3.]
                                    ])
u214892's avatar
u214892 committed
88
89
90
91
92
93
94
95
96
97
98
    expected_time_offsets = np.array([[0.],
                                      [1.],
                                      [2.],
                                      [3.],
                                      [4.],
                                      [5.],
                                      [6.],
                                      [7.],
                                      [8.],
                                      [9.],
                                      [10.],
99
                                      ])
u214892's avatar
u214892 committed
100
101
102
    expected_actions = np.array([[0.],
                                 [2.],
                                 [2.],
u214892's avatar
u214892 committed
103
                                 [2.],
u214892's avatar
u214892 committed
104
105
106
107
                                 [2.],
                                 [2.],
                                 [2.],
                                 [2.],
108
                                 # reaching target by straight
u214892's avatar
u214892 committed
109
                                 [2.],
110
111
112
113
                                 # at target: stopped moving
                                 [4.],
                                 [4.],
                                 ])
u214892's avatar
u214892 committed
114
115
116
117
    assert np.array_equal(positions, expected_positions)
    assert np.array_equal(directions, expected_directions)
    assert np.array_equal(time_offsets, expected_time_offsets)
    assert np.array_equal(actions, expected_actions)
u214892's avatar
u214892 committed
118
119


120
def test_shortest_path_predictor(rendering=False):
121
    rail, rail_map, optionals = make_simple_rail()
122
123
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
124
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
125
                  line_generator=sparse_line_generator(),
126
127
128
129
130
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )
    env.reset()

131
    # set the initial position
u229589's avatar
u229589 committed
132
    agent = env.agents[0]
u214892's avatar
u214892 committed
133
    agent.initial_position = (5, 6)  # south dead-end
134
135
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
u229589's avatar
u229589 committed
136
    agent.initial_direction = 0  # north
137
138
    agent.target = (3, 9)  # east dead-end
    agent.moving = True
139
    agent._set_state(TrainState.MOVING)
140

141
    env.reset(False, False)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
142
143
144
145
146
    env.distance_map._compute(env.agents, env.rail)
    
    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents
147

148
149
    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
Erik Nygren's avatar
Erik Nygren committed
150
        renderer.render_env(show=True, show_observations=False)
151
152
        input("Continue?")

153
    # compute the observations and predictions
154
    distance_map = env.distance_map.get()
155
156
    distance_on_map = distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction]
    assert distance_on_map == 5.0, "found {} instead of {}".format(distance_on_map, 5.0)
157

u214892's avatar
u214892 committed
158
    paths = get_shortest_paths(env.distance_map)[0]
u214892's avatar
u214892 committed
159
    assert paths == [
160
161
162
163
164
165
        Waypoint((5, 6), 0),
        Waypoint((4, 6), 0),
        Waypoint((3, 6), 0),
        Waypoint((3, 7), 1),
        Waypoint((3, 8), 1),
        Waypoint((3, 9), 1)
u214892's avatar
u214892 committed
166
    ]
u214892's avatar
u214892 committed
167

168
    # extract the data
169
170
171
172
173
    predictions = env.obs_builder.predictions
    positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
    directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
    time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))

174
    # test if data meets expectations
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    expected_positions = [
        [5, 6],
        [4, 6],
        [3, 6],
        [3, 7],
        [3, 8],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
    ]
    expected_directions = [
        [Grid4TransitionsEnum.NORTH],  # next is [5,6] heading north
        [Grid4TransitionsEnum.NORTH],  # next is [4,6] heading north
        [Grid4TransitionsEnum.NORTH],  # next is [3,6] heading north
        [Grid4TransitionsEnum.EAST],  # next is [3,7] heading east
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
    ]

    expected_time_offsets = np.array([
        [0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.],
        [9.],
        [10.],
        [11.],
        [12.],
        [13.],
        [14.],
        [15.],
        [16.],
        [17.],
        [18.],
        [19.],
        [20.],
    ])
u214892's avatar
u214892 committed
245

246
247
248
    assert np.array_equal(time_offsets, expected_time_offsets), \
        "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)

249
250
251
252
    assert np.array_equal(positions, expected_positions), \
        "positions {}, expected {}".format(positions, expected_positions)
    assert np.array_equal(directions, expected_directions), \
        "directions {}, expected {}".format(directions, expected_directions)
253
254
255


def test_shortest_path_predictor_conflicts(rendering=False):
256
    rail, rail_map, optionals = make_invalid_simple_rail()
257
258
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
259
                  rail_generator=rail_from_grid_transition_map(rail, optionals),
260
                  line_generator=sparse_line_generator(),
261
262
263
264
265
266
                  number_of_agents=2,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )
    env.reset()

    # set the initial position
267
268
269
270
271
272
    env.agents[0].initial_position = (5, 6)  # south dead-end
    env.agents[0].position = (5, 6)  # south dead-end
    env.agents[0].direction = 0  # north
    env.agents[0].initial_direction = 0  # north
    env.agents[0].target = (3, 9)  # east dead-end
    env.agents[0].moving = True
273
    env.agents[0]._set_state(TrainState.MOVING)
274

275
276
277
278
279
280
    env.agents[1].initial_position = (3, 8)  # east dead-end
    env.agents[1].position = (3, 8)  # east dead-end
    env.agents[1].direction = 3  # west
    env.agents[1].initial_direction = 3  # west
    env.agents[1].target = (6, 6)  # south dead-end
    env.agents[1].moving = True
281
    env.agents[1]._set_state(TrainState.MOVING)
282
283
284
285
286
287
288

    observations, info = env.reset(False, False)

    env.agents[0].position = (5, 6)  # south dead-end
    env.agent_positions[env.agents[0].position] = 0
    env.agents[1].position = (3, 8)  # east dead-end
    env.agent_positions[env.agents[1].position] = 1
289
290
    env.agents[0]._set_state(TrainState.MOVING)
    env.agents[1]._set_state(TrainState.MOVING)
291
292

    observations = env._get_observations()
293
294
295
296


    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
Erik Nygren's avatar
Erik Nygren committed
297
        renderer.render_env(show=True, show_observations=False)
298
299
300
301
302
        input("Continue?")

    # get the trees to test
    obs_builder: TreeObsForRailEnv = env.obs_builder
    pp = pprint.PrettyPrinter(indent=4)
303
304
305
306
    tree_0 = observations[0]
    tree_1 = observations[1]
    env.obs_builder.util_print_obs_subtree(tree_0)
    env.obs_builder.util_print_obs_subtree(tree_1)
307
308

    # check the expectations
u214892's avatar
u214892 committed
309
310
    expected_conflicts_0 = [('F', 'R')]
    expected_conflicts_1 = [('F', 'L')]
311
312
313
314
    _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ")
    _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")


315
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: Node, prompt=''):
316
    assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
u214892's avatar
u214892 committed
317
    for a_1 in obs_builder.tree_explored_actions_char:
318
319
320
321
        if tree.childs[a_1] == -np.inf:
            assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
            continue
        else:
322
            conflict = tree.childs[a_1].num_agents_opposite_direction
323
            assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
u214892's avatar
u214892 committed
324
        for a_2 in obs_builder.tree_explored_actions_char:
325
326
327
            if tree.childs[a_1].childs[a_2] == -np.inf:
                assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
            else:
328
                conflict = tree.childs[a_1].childs[a_2].num_agents_opposite_direction
329
                assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)