test_flatland_envs_observations.py 11.9 KB
Newer Older
1
2
3
#!/usr/bin/env python
# -*- coding: utf-8 -*-

spiglerg's avatar
spiglerg committed
4
5
import numpy as np

u214892's avatar
u214892 committed
6
from flatland.core.grid.grid4 import Grid4TransitionsEnum
7
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
8
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
u214892's avatar
u214892 committed
9
10
11
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
u214892's avatar
u214892 committed
12
from flatland.envs.rail_generators import rail_from_grid_transition_map
13
from flatland.envs.line_generators import random_line_generator
u214892's avatar
u214892 committed
14
from flatland.utils.rendertools import RenderTool
15
from flatland.utils.simple_rail import make_simple_rail
16
17
18
19
20

"""Tests for `flatland` package."""


def test_global_obs():
21
    rail, rail_map = make_simple_rail()
22

23
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
24
                  line_generator=random_line_generator(), number_of_agents=1,
gmollard's avatar
gmollard committed
25
                  obs_builder_object=GlobalObsForRailEnv())
gmollard's avatar
gmollard committed
26

27
    global_obs, info = env.reset()
gmollard's avatar
gmollard committed
28

u214892's avatar
u214892 committed
29
30
31
    # we have to take step for the agent to enter the grid.
    global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})

32
    assert (global_obs[0][0].shape == rail_map.shape + (16,))
gmollard's avatar
gmollard committed
33
34

    rail_map_recons = np.zeros_like(rail_map)
gmollard's avatar
gmollard committed
35
36
37
38
    for i in range(global_obs[0][0].shape[0]):
        for j in range(global_obs[0][0].shape[1]):
            rail_map_recons[i, j] = int(
                ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)
gmollard's avatar
gmollard committed
39

40
    assert (rail_map_recons.all() == rail_map.all())
gmollard's avatar
gmollard committed
41
42
43

    # If this assertion is wrong, it means that the observation returned
    # places the agent on an empty cell
u229589's avatar
u229589 committed
44
45
46
    obs_agents_state = global_obs[0][1]
    obs_agents_state = obs_agents_state + 1
    assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0)
u214892's avatar
u214892 committed
47
48
49
50
51
52
53
54
55
56


def _step_along_shortest_path(env, obs_builder, rail):
    actions = {}
    expected_next_position = {}
    for agent in env.agents:
        agent: EnvAgent
        shortest_distance = np.inf

        for exit_direction in range(4):
57
            neighbour = get_new_position(agent.position, exit_direction)
u214892's avatar
u214892 committed
58
59
60
61
62
63

            if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width:
                desired_movement_from_new_cell = (exit_direction + 2) % 4

                # Check all possible transitions in new_cell
                for agent_orientation in range(4):
u214892's avatar
u214892 committed
64
                    # Is a transition along movement `entry_direction` to the neighbour possible?
u214892's avatar
u214892 committed
65
66
67
                    is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation),
                                                                   desired_movement_from_new_cell)
                    if is_valid:
68
                        distance_to_target = obs_builder.env.distance_map.get()[
u214892's avatar
u214892 committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                            (agent.handle, *agent.position, exit_direction)]
                        print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position,
                                                                                      agent.direction,
                                                                                      exit_direction,
                                                                                      distance_to_target))

                        if distance_to_target < shortest_distance:
                            shortest_distance = distance_to_target
                            actions_to_be_taken_when_facing_north = {
                                Grid4TransitionsEnum.NORTH: RailEnvActions.MOVE_FORWARD,
                                Grid4TransitionsEnum.EAST: RailEnvActions.MOVE_RIGHT,
                                Grid4TransitionsEnum.WEST: RailEnvActions.MOVE_LEFT,
                                Grid4TransitionsEnum.SOUTH: RailEnvActions.DO_NOTHING,
                            }
                            print("   improved (direction) -> {}".format(exit_direction))

                            actions[agent.handle] = actions_to_be_taken_when_facing_north[
                                (exit_direction - agent.direction) % len(rail.transitions.get_direction_enum())]
                            expected_next_position[agent.handle] = neighbour
                            print("   improved (action) -> {}".format(actions[agent.handle]))
    _, rewards, dones, _ = env.step(actions)
    return rewards


def test_reward_function_conflict(rendering=False):
    rail, rail_map = make_simple_rail()
95
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
96
                  line_generator=random_line_generator(), number_of_agents=2,
97
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
u214892's avatar
u214892 committed
98
99
100
101
    obs_builder: TreeObsForRailEnv = env.obs_builder
    env.reset()

    # set the initial position
u229589's avatar
u229589 committed
102
    agent = env.agents[0]
u214892's avatar
u214892 committed
103
    agent.position = (5, 6)  # south dead-end
u229589's avatar
u229589 committed
104
    agent.initial_position = (5, 6)  # south dead-end
u214892's avatar
u214892 committed
105
    agent.direction = 0  # north
u229589's avatar
u229589 committed
106
    agent.initial_direction = 0  # north
u214892's avatar
u214892 committed
107
108
    agent.target = (3, 9)  # east dead-end
    agent.moving = True
u214892's avatar
u214892 committed
109
    agent.status = RailAgentStatus.ACTIVE
u214892's avatar
u214892 committed
110

u229589's avatar
u229589 committed
111
    agent = env.agents[1]
u214892's avatar
u214892 committed
112
    agent.position = (3, 8)  # east dead-end
u229589's avatar
u229589 committed
113
    agent.initial_position = (3, 8)  # east dead-end
u214892's avatar
u214892 committed
114
    agent.direction = 3  # west
u229589's avatar
u229589 committed
115
    agent.initial_direction = 3  # west
u214892's avatar
u214892 committed
116
117
    agent.target = (6, 6)  # south dead-end
    agent.moving = True
u214892's avatar
u214892 committed
118
    agent.status = RailAgentStatus.ACTIVE
u214892's avatar
u214892 committed
119
120

    env.reset(False, False)
u229589's avatar
u229589 committed
121
122
123
124
125
126
127
128
129
    env.agents[0].moving = True
    env.agents[1].moving = True
    env.agents[0].status = RailAgentStatus.ACTIVE
    env.agents[1].status = RailAgentStatus.ACTIVE
    env.agents[0].position = (5, 6)
    env.agents[1].position = (3, 8)
    print("\n")
    print(env.agents[0])
    print(env.agents[1])
u214892's avatar
u214892 committed
130
131
132

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
Erik Nygren's avatar
Erik Nygren committed
133
        renderer.render_env(show=True, show_observations=True)
u214892's avatar
u214892 committed
134

u214892's avatar
u214892 committed
135
    iteration = 0
u214892's avatar
u214892 committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    expected_positions = {
        0: {
            0: (5, 6),
            1: (3, 8)
        },
        # both can move
        1: {
            0: (4, 6),
            1: (3, 7)
        },
        # first can move, second stuck
        2: {
            0: (3, 6),
            1: (3, 7)
        },
        # both stuck from now on
        3: {
            0: (3, 6),
            1: (3, 7)
        },
        4: {
            0: (3, 6),
            1: (3, 7)
        },
        5: {
            0: (3, 6),
            1: (3, 7)
        },
    }
u214892's avatar
u214892 committed
165
    while iteration < 5:
u214892's avatar
u214892 committed
166
167
168
169
170
171
172
173
174
175
        rewards = _step_along_shortest_path(env, obs_builder, rail)

        for agent in env.agents:
            assert rewards[agent.handle] == -1
            expected_position = expected_positions[iteration + 1][agent.handle]
            assert agent.position == expected_position, "[{}] agent {} at {}, expected {}".format(iteration + 1,
                                                                                                  agent.handle,
                                                                                                  agent.position,
                                                                                                  expected_position)
        if rendering:
Erik Nygren's avatar
Erik Nygren committed
176
            renderer.render_env(show=True, show_observations=True)
u214892's avatar
u214892 committed
177

u214892's avatar
u214892 committed
178
179
        iteration += 1

u214892's avatar
u214892 committed
180
181
182

def test_reward_function_waiting(rendering=False):
    rail, rail_map = make_simple_rail()
183
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
184
                  line_generator=random_line_generator(), number_of_agents=2,
u214892's avatar
u214892 committed
185
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
186
                  remove_agents_at_target=False)
u214892's avatar
u214892 committed
187
188
189
190
    obs_builder: TreeObsForRailEnv = env.obs_builder
    env.reset()

    # set the initial position
u229589's avatar
u229589 committed
191
    agent = env.agents[0]
u214892's avatar
u214892 committed
192
    agent.initial_position = (3, 8)  # east dead-end
u214892's avatar
u214892 committed
193
194
    agent.position = (3, 8)  # east dead-end
    agent.direction = 3  # west
u229589's avatar
u229589 committed
195
    agent.initial_direction = 3  # west
u214892's avatar
u214892 committed
196
197
    agent.target = (3, 1)  # west dead-end
    agent.moving = True
u214892's avatar
u214892 committed
198
    agent.status = RailAgentStatus.ACTIVE
u214892's avatar
u214892 committed
199

u229589's avatar
u229589 committed
200
    agent = env.agents[1]
u214892's avatar
u214892 committed
201
    agent.initial_position = (5, 6)  # south dead-end
u214892's avatar
u214892 committed
202
203
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
u229589's avatar
u229589 committed
204
    agent.initial_direction = 0  # north
u214892's avatar
u214892 committed
205
206
    agent.target = (3, 8)  # east dead-end
    agent.moving = True
u214892's avatar
u214892 committed
207
    agent.status = RailAgentStatus.ACTIVE
u214892's avatar
u214892 committed
208
209

    env.reset(False, False)
u229589's avatar
u229589 committed
210
211
212
213
214
215
    env.agents[0].moving = True
    env.agents[1].moving = True
    env.agents[0].status = RailAgentStatus.ACTIVE
    env.agents[1].status = RailAgentStatus.ACTIVE
    env.agents[0].position = (3, 8)
    env.agents[1].position = (5, 6)
u214892's avatar
u214892 committed
216
217
218

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
Erik Nygren's avatar
Erik Nygren committed
219
        renderer.render_env(show=True, show_observations=True)
u214892's avatar
u214892 committed
220

u214892's avatar
u214892 committed
221
    iteration = 0
u214892's avatar
u214892 committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    expectations = {
        0: {
            'positions': {
                0: (3, 8),
                1: (5, 6),
            },
            'rewards': [-1, -1],
        },
        1: {
            'positions': {
                0: (3, 7),
                1: (4, 6),
            },
            'rewards': [-1, -1],
        },
        # second agent has to wait for first, first can continue
        2: {
            'positions': {
                0: (3, 6),
                1: (4, 6),
            },
            'rewards': [-1, -1],
        },
        # both can move again
        3: {
            'positions': {
                0: (3, 5),
                1: (3, 6),
            },
            'rewards': [-1, -1],
        },
        4: {
            'positions': {
                0: (3, 4),
                1: (3, 7),
            },
            'rewards': [-1, -1],
        },
        # second reached target
        5: {
            'positions': {
                0: (3, 3),
                1: (3, 8),
            },
            'rewards': [-1, 0],
        },
        6: {
            'positions': {
                0: (3, 2),
                1: (3, 8),
            },
            'rewards': [-1, 0],
        },
        # first reaches, target too
        7: {
            'positions': {
                0: (3, 1),
u214892's avatar
u214892 committed
279
                1: (3, 8),
u214892's avatar
u214892 committed
280
281
282
283
284
285
            },
            'rewards': [1, 1],
        },
        8: {
            'positions': {
                0: (3, 1),
u214892's avatar
u214892 committed
286
                1: (3, 8),
u214892's avatar
u214892 committed
287
288
289
290
            },
            'rewards': [1, 1],
        },
    }
u214892's avatar
u214892 committed
291
292
    while iteration < 7:

u214892's avatar
u214892 committed
293
294
295
        rewards = _step_along_shortest_path(env, obs_builder, rail)

        if rendering:
Erik Nygren's avatar
Erik Nygren committed
296
            renderer.render_env(show=True, show_observations=True)
u214892's avatar
u214892 committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

        print(env.dones["__all__"])
        for agent in env.agents:
            agent: EnvAgent
            print("[{}] agent {} at {}, target {} ".format(iteration + 1, agent.handle, agent.position, agent.target))
        print(np.all([np.array_equal(agent2.position, agent2.target) for agent2 in env.agents]))
        for agent in env.agents:
            expected_position = expectations[iteration + 1]['positions'][agent.handle]
            assert agent.position == expected_position, \
                "[{}] agent {} at {}, expected {}".format(iteration + 1,
                                                          agent.handle,
                                                          agent.position,
                                                          expected_position)
            expected_reward = expectations[iteration + 1]['rewards'][agent.handle]
            actual_reward = rewards[agent.handle]
            assert expected_reward == actual_reward, "[{}] agent {} reward {}, expected {}".format(iteration + 1,
                                                                                                   agent.handle,
                                                                                                   actual_reward,
                                                                                                   expected_reward)
u214892's avatar
u214892 committed
316
        iteration += 1