test_flatland_envs_observations.py 12.1 KB
Newer Older
1
2
3
#!/usr/bin/env python
# -*- coding: utf-8 -*-

spiglerg's avatar
spiglerg committed
4
5
import numpy as np

u214892's avatar
u214892 committed
6
from flatland.core.grid.grid4 import Grid4TransitionsEnum
7
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
8
9
10
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
u214892's avatar
u214892 committed
11
from flatland.envs.rail_generators import rail_from_grid_transition_map
12
from flatland.envs.line_generators import sparse_line_generator
u214892's avatar
u214892 committed
13
from flatland.utils.rendertools import RenderTool
14
from flatland.utils.simple_rail import make_simple_rail
15
from flatland.envs.step_utils.states import TrainState
16
17
18
19
20

"""Tests for `flatland` package."""


def test_global_obs():
21
    rail, rail_map, optionals = make_simple_rail()
22

23
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
24
                  line_generator=sparse_line_generator(), number_of_agents=1,
gmollard's avatar
gmollard committed
25
                  obs_builder_object=GlobalObsForRailEnv())
gmollard's avatar
gmollard committed
26

27
    global_obs, info = env.reset()
gmollard's avatar
gmollard committed
28

u214892's avatar
u214892 committed
29
30
31
    # we have to take step for the agent to enter the grid.
    global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})

32
    assert (global_obs[0][0].shape == rail_map.shape + (16,))
gmollard's avatar
gmollard committed
33
34

    rail_map_recons = np.zeros_like(rail_map)
gmollard's avatar
gmollard committed
35
36
37
38
    for i in range(global_obs[0][0].shape[0]):
        for j in range(global_obs[0][0].shape[1]):
            rail_map_recons[i, j] = int(
                ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)
gmollard's avatar
gmollard committed
39

40
    assert (rail_map_recons.all() == rail_map.all())
gmollard's avatar
gmollard committed
41
42
43

    # If this assertion is wrong, it means that the observation returned
    # places the agent on an empty cell
u229589's avatar
u229589 committed
44
45
46
    obs_agents_state = global_obs[0][1]
    obs_agents_state = obs_agents_state + 1
    assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
u214892's avatar
u214892 committed
47
48
49
50
51
52
53
54
55


def _step_along_shortest_path(env, obs_builder, rail):
    actions = {}
    expected_next_position = {}
    for agent in env.agents:
        shortest_distance = np.inf

        for exit_direction in range(4):
56
            neighbour = get_new_position(agent.position, exit_direction)
u214892's avatar
u214892 committed
57
58
59
60
61
62

            if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width:
                desired_movement_from_new_cell = (exit_direction + 2) % 4

                # Check all possible transitions in new_cell
                for agent_orientation in range(4):
u214892's avatar
u214892 committed
63
                    # Is a transition along movement `entry_direction` to the neighbour possible?
u214892's avatar
u214892 committed
64
65
66
                    is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation),
                                                                   desired_movement_from_new_cell)
                    if is_valid:
67
                        distance_to_target = obs_builder.env.distance_map.get()[
u214892's avatar
u214892 committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                            (agent.handle, *agent.position, exit_direction)]
                        print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position,
                                                                                      agent.direction,
                                                                                      exit_direction,
                                                                                      distance_to_target))

                        if distance_to_target < shortest_distance:
                            shortest_distance = distance_to_target
                            actions_to_be_taken_when_facing_north = {
                                Grid4TransitionsEnum.NORTH: RailEnvActions.MOVE_FORWARD,
                                Grid4TransitionsEnum.EAST: RailEnvActions.MOVE_RIGHT,
                                Grid4TransitionsEnum.WEST: RailEnvActions.MOVE_LEFT,
                                Grid4TransitionsEnum.SOUTH: RailEnvActions.DO_NOTHING,
                            }
                            print("   improved (direction) -> {}".format(exit_direction))

                            actions[agent.handle] = actions_to_be_taken_when_facing_north[
                                (exit_direction - agent.direction) % len(rail.transitions.get_direction_enum())]
                            expected_next_position[agent.handle] = neighbour
                            print("   improved (action) -> {}".format(actions[agent.handle]))
    _, rewards, dones, _ = env.step(actions)
89
    return rewards, dones
u214892's avatar
u214892 committed
90
91
92


def test_reward_function_conflict(rendering=False):
93
94
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
95
                  line_generator=sparse_line_generator(), number_of_agents=2,
96
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
u214892's avatar
u214892 committed
97
98
99
100
    obs_builder: TreeObsForRailEnv = env.obs_builder
    env.reset()

    # set the initial position
u229589's avatar
u229589 committed
101
    agent = env.agents[0]
u214892's avatar
u214892 committed
102
    agent.position = (5, 6)  # south dead-end
u229589's avatar
u229589 committed
103
    agent.initial_position = (5, 6)  # south dead-end
u214892's avatar
u214892 committed
104
    agent.direction = 0  # north
u229589's avatar
u229589 committed
105
    agent.initial_direction = 0  # north
u214892's avatar
u214892 committed
106
107
    agent.target = (3, 9)  # east dead-end
    agent.moving = True
108
    agent._set_state(TrainState.MOVING)
u214892's avatar
u214892 committed
109

u229589's avatar
u229589 committed
110
    agent = env.agents[1]
u214892's avatar
u214892 committed
111
    agent.position = (3, 8)  # east dead-end
u229589's avatar
u229589 committed
112
    agent.initial_position = (3, 8)  # east dead-end
u214892's avatar
u214892 committed
113
    agent.direction = 3  # west
u229589's avatar
u229589 committed
114
    agent.initial_direction = 3  # west
u214892's avatar
u214892 committed
115
116
    agent.target = (6, 6)  # south dead-end
    agent.moving = True
117
    agent._set_state(TrainState.MOVING)
u214892's avatar
u214892 committed
118
119

    env.reset(False, False)
u229589's avatar
u229589 committed
120
121
    env.agents[0].moving = True
    env.agents[1].moving = True
122
123
    env.agents[0]._set_state(TrainState.MOVING)
    env.agents[1]._set_state(TrainState.MOVING)
u229589's avatar
u229589 committed
124
125
126
127
128
    env.agents[0].position = (5, 6)
    env.agents[1].position = (3, 8)
    print("\n")
    print(env.agents[0])
    print(env.agents[1])
u214892's avatar
u214892 committed
129
130
131

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
Erik Nygren's avatar
Erik Nygren committed
132
        renderer.render_env(show=True, show_observations=True)
u214892's avatar
u214892 committed
133

u214892's avatar
u214892 committed
134
    iteration = 0
u214892's avatar
u214892 committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    expected_positions = {
        0: {
            0: (5, 6),
            1: (3, 8)
        },
        # both can move
        1: {
            0: (4, 6),
            1: (3, 7)
        },
        # first can move, second stuck
        2: {
            0: (3, 6),
            1: (3, 7)
        },
        # both stuck from now on
        3: {
            0: (3, 6),
            1: (3, 7)
        },
        4: {
            0: (3, 6),
            1: (3, 7)
        },
        5: {
            0: (3, 6),
            1: (3, 7)
        },
    }
u214892's avatar
u214892 committed
164
    while iteration < 5:
165
166
167
        rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
        if dones["__all__"]:
            break
u214892's avatar
u214892 committed
168
        for agent in env.agents:
169
            # assert rewards[agent.handle] == 0
u214892's avatar
u214892 committed
170
171
172
173
174
175
            expected_position = expected_positions[iteration + 1][agent.handle]
            assert agent.position == expected_position, "[{}] agent {} at {}, expected {}".format(iteration + 1,
                                                                                                  agent.handle,
                                                                                                  agent.position,
                                                                                                  expected_position)
        if rendering:
Erik Nygren's avatar
Erik Nygren committed
176
            renderer.render_env(show=True, show_observations=True)
u214892's avatar
u214892 committed
177

u214892's avatar
u214892 committed
178
179
        iteration += 1

u214892's avatar
u214892 committed
180
181

def test_reward_function_waiting(rendering=False):
182
183
    rail, rail_map, optionals = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
184
                  line_generator=sparse_line_generator(), number_of_agents=2,
u214892's avatar
u214892 committed
185
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
186
                  remove_agents_at_target=False, random_seed=1)
u214892's avatar
u214892 committed
187
188
189
190
    obs_builder: TreeObsForRailEnv = env.obs_builder
    env.reset()

    # set the initial position
u229589's avatar
u229589 committed
191
    agent = env.agents[0]
u214892's avatar
u214892 committed
192
    agent.initial_position = (3, 8)  # east dead-end
u214892's avatar
u214892 committed
193
194
    agent.position = (3, 8)  # east dead-end
    agent.direction = 3  # west
u229589's avatar
u229589 committed
195
    agent.initial_direction = 3  # west
u214892's avatar
u214892 committed
196
197
    agent.target = (3, 1)  # west dead-end
    agent.moving = True
198
    agent._set_state(TrainState.MOVING)
u214892's avatar
u214892 committed
199

u229589's avatar
u229589 committed
200
    agent = env.agents[1]
u214892's avatar
u214892 committed
201
    agent.initial_position = (5, 6)  # south dead-end
u214892's avatar
u214892 committed
202
203
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
u229589's avatar
u229589 committed
204
    agent.initial_direction = 0  # north
u214892's avatar
u214892 committed
205
206
    agent.target = (3, 8)  # east dead-end
    agent.moving = True
207
    agent._set_state(TrainState.MOVING)
u214892's avatar
u214892 committed
208
209

    env.reset(False, False)
u229589's avatar
u229589 committed
210
211
    env.agents[0].moving = True
    env.agents[1].moving = True
212
213
    env.agents[0]._set_state(TrainState.MOVING)
    env.agents[1]._set_state(TrainState.MOVING)
u229589's avatar
u229589 committed
214
215
    env.agents[0].position = (3, 8)
    env.agents[1].position = (5, 6)
u214892's avatar
u214892 committed
216
217
218

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
Erik Nygren's avatar
Erik Nygren committed
219
        renderer.render_env(show=True, show_observations=True)
u214892's avatar
u214892 committed
220

u214892's avatar
u214892 committed
221
    iteration = 0
u214892's avatar
u214892 committed
222
223
224
225
226
227
    expectations = {
        0: {
            'positions': {
                0: (3, 8),
                1: (5, 6),
            },
Dipam Chakraborty's avatar
Dipam Chakraborty committed
228
            'rewards': [0, 0],
u214892's avatar
u214892 committed
229
230
231
232
233
234
        },
        1: {
            'positions': {
                0: (3, 7),
                1: (4, 6),
            },
Dipam Chakraborty's avatar
Dipam Chakraborty committed
235
            'rewards': [0, 0],
u214892's avatar
u214892 committed
236
237
238
239
240
241
242
        },
        # second agent has to wait for first, first can continue
        2: {
            'positions': {
                0: (3, 6),
                1: (4, 6),
            },
Dipam Chakraborty's avatar
Dipam Chakraborty committed
243
            'rewards': [0, 0],
u214892's avatar
u214892 committed
244
245
246
247
248
249
250
        },
        # both can move again
        3: {
            'positions': {
                0: (3, 5),
                1: (3, 6),
            },
Dipam Chakraborty's avatar
Dipam Chakraborty committed
251
            'rewards': [0, 0],
u214892's avatar
u214892 committed
252
253
254
255
256
257
        },
        4: {
            'positions': {
                0: (3, 4),
                1: (3, 7),
            },
Dipam Chakraborty's avatar
Dipam Chakraborty committed
258
            'rewards': [0, 0],
u214892's avatar
u214892 committed
259
260
261
262
263
264
265
        },
        # second reached target
        5: {
            'positions': {
                0: (3, 3),
                1: (3, 8),
            },
Dipam Chakraborty's avatar
Dipam Chakraborty committed
266
            'rewards': [0, 0],
u214892's avatar
u214892 committed
267
268
269
270
271
272
        },
        6: {
            'positions': {
                0: (3, 2),
                1: (3, 8),
            },
Dipam Chakraborty's avatar
Dipam Chakraborty committed
273
            'rewards': [0, 0],
u214892's avatar
u214892 committed
274
275
276
277
278
        },
        # first reaches, target too
        7: {
            'positions': {
                0: (3, 1),
u214892's avatar
u214892 committed
279
                1: (3, 8),
u214892's avatar
u214892 committed
280
            },
Dipam Chakraborty's avatar
Dipam Chakraborty committed
281
            'rewards': [0, 0],
u214892's avatar
u214892 committed
282
283
284
285
        },
        8: {
            'positions': {
                0: (3, 1),
u214892's avatar
u214892 committed
286
                1: (3, 8),
u214892's avatar
u214892 committed
287
            },
Dipam Chakraborty's avatar
Dipam Chakraborty committed
288
            'rewards': [0, 0],
u214892's avatar
u214892 committed
289
290
        },
    }
u214892's avatar
u214892 committed
291
292
    while iteration < 7:

293
294
295
        rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
        if dones["__all__"]:
            break
u214892's avatar
u214892 committed
296
297

        if rendering:
Erik Nygren's avatar
Erik Nygren committed
298
            renderer.render_env(show=True, show_observations=True)
u214892's avatar
u214892 committed
299
300
301
302
303
304
305
306
307
308
309
310

        print(env.dones["__all__"])
        for agent in env.agents:
            print("[{}] agent {} at {}, target {} ".format(iteration + 1, agent.handle, agent.position, agent.target))
        print(np.all([np.array_equal(agent2.position, agent2.target) for agent2 in env.agents]))
        for agent in env.agents:
            expected_position = expectations[iteration + 1]['positions'][agent.handle]
            assert agent.position == expected_position, \
                "[{}] agent {} at {}, expected {}".format(iteration + 1,
                                                          agent.handle,
                                                          agent.position,
                                                          expected_position)
311
312
313
314
315
316
            # expected_reward = expectations[iteration + 1]['rewards'][agent.handle]
            # actual_reward = rewards[agent.handle]
            # assert expected_reward == actual_reward, "[{}] agent {} reward {}, expected {}".format(iteration + 1,
            #                                                                                        agent.handle,
            #                                                                                        actual_reward,
            #                                                                                        expected_reward)
u214892's avatar
u214892 committed
317
        iteration += 1