env_observation_builder.py 20.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
ObservationBuilder objects are objects that can be passed to environments designed for customizability.
The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).

+ Reset() is called after each environment reset, to allow for pre-computing relevant data.

+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""

11
import numpy as np
12

13
14
from collections import deque

15
16

class ObservationBuilder:
17
18
19
20
21
22
23
    """
    ObservationBuilder base class.
    """
    def __init__(self):
        pass

    def _set_env(self, env):
24
25
26
        self.env = env

    def reset(self):
27
28
29
        """
        Called after each environment reset.
        """
30
31
        raise NotImplementedError()

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    def get(self, handle=0):
        """
        Called whenever an observation has to be computed for the `env' environment, possibly
        for each agent independently (agent id `handle').

        Parameters
        -------
        handle : int (optional)
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
        function
            An observation structure, specific to the corresponding environment.
        """
47
48
49
50
        raise NotImplementedError()


class TreeObsForRailEnv(ObservationBuilder):
51
52
53
54
55
56
57
58
59
    """
    TreeObsForRailEnv object.

    This object returns observation vectors for agents in the RailEnv environment.
    The information is local to each agent and exploits the tree structure of the rail
    network to simplify the representation of the state of the environment for each agent.
    """
    def __init__(self, max_depth):
        self.max_depth = max_depth
60

61
    def reset(self):
62
63
        self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
                                                    self.env.height,
64
65
                                                    self.env.width,
                                                    4))
66
67
68
69
70
        self.max_dist = np.zeros(self.env.number_of_agents)

        for i in range(self.env.number_of_agents):
            self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)

spiglerg's avatar
spiglerg committed
71
72
73
        # Update local lookup table for all agents' target locations
        self.location_has_target = {}
        for loc in self.env.agents_target:
74
            self.location_has_target[(loc[0], loc[1])] = 1
spiglerg's avatar
spiglerg committed
75

76
    def _distance_map_walker(self, position, target_nr):
77
78
79
80
        """
        Utility function to compute distance maps from each cell in the rail network (and each possible
        orientation within it) to each agent's target cell.
        """
81
82
83
        # Returns max distance to target, from the farthest away node, while filling in distance_map

        for ori in range(4):
84
            self.distance_map[target_nr, position[0], position[1], ori] = 0
85
86

        # Fill in the (up to) 4 neighboring nodes
gmollard's avatar
gmollard committed
87
        # nodes_queue = []  # list of tuples (row, col, direction, distance);
88
89
90
        # direction is the direction of movement, meaning that at least a possible orientation of an agent
        # in cell (row,col) allows a movement in direction `direction'
        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
91
92
93

        # BFS from target `position' to all the reachable nodes in the grid
        # Stop the search if the target position is re-visited, in any direction
94
95
96
97
        visited = set([(position[0], position[1], 0),
                       (position[0], position[1], 1),
                       (position[0], position[1], 2),
                       (position[0], position[1], 3)])
98
99
100
101
102
103
104
105
106
107
108

        max_distance = 0

        while nodes_queue:
            node = nodes_queue.popleft()

            node_id = (node[0], node[1], node[2])

            if node_id not in visited:
                visited.add(node_id)

109
110
111
                # From the list of possible neighbors that have at least a path to the current node, only keep those
                # whose new orientation in the current cell would allow a transition to direction node[2]
                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
112
113
114
115

                for n in valid_neighbors:
                    nodes_queue.append(n)

116
                if len(valid_neighbors) > 0:
spiglerg's avatar
spiglerg committed
117
                    max_distance = max(max_distance, node[3] + 1)
118
119
120
121

        return max_distance

    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
122
123
124
125
        """
        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
        minimum distances from each target cell.
        """
126
127
        neighbors = []

128
129
130
131
        possible_directions = [0, 1, 2, 3]
        if enforce_target_direction >= 0:
            # The agent must land into the current cell with orientation `enforce_target_direction'.
            # This is only possible if the agent has arrived from the cell in the opposite direction!
spiglerg's avatar
spiglerg committed
132
            possible_directions = [(enforce_target_direction + 2) % 4]
133
134
135

        for neigh_direction in possible_directions:
            new_cell = self._new_position(position, neigh_direction)
136
137
138

            if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
               new_cell[1] >= 0 and new_cell[1] < self.env.width:
139

spiglerg's avatar
spiglerg committed
140
                desired_movement_from_new_cell = (neigh_direction + 2) % 4
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

                """
                # Is the next cell a dead-end?
                isNextCellDeadEnd = False
                nbits = 0
                tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                if nbits == 1:
                    # Dead-end!
                    isNextCellDeadEnd = True
                """

                # Check all possible transitions in new_cell
                for agent_orientation in range(4):
                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
                    isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
                                                           desired_movement_from_new_cell)

                    if isValid:
                        """
                        # TODO: check that it works with deadends! -- still bugged!
                        movement = desired_movement_from_new_cell
                        if isNextCellDeadEnd:
                            movement = (desired_movement_from_new_cell+2) % 4
                        """
                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
spiglerg's avatar
spiglerg committed
169
                                           current_distance + 1)
170
171
                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
172
173
174
175

        return neighbors

    def _new_position(self, position, movement):
176
177
178
        """
        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
        """
179
        if movement == 0:    # NORTH
spiglerg's avatar
spiglerg committed
180
            return (position[0] - 1, position[1])
181
182
183
        elif movement == 1:  # EAST
            return (position[0], position[1] + 1)
        elif movement == 2:  # SOUTH
spiglerg's avatar
spiglerg committed
184
            return (position[0] + 1, position[1])
185
186
187
        elif movement == 3:  # WEST
            return (position[0], position[1] - 1)

188
    def get(self, handle):
189
190
        """
        Computes the current observation for agent `handle' in env
191

192
193
194
195
196
        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
        the transitions. The order is:
            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
197

gmollard's avatar
gmollard committed
198
199
200
201




202
203
204
205
206
207
        Each branch data is organized as:
            [root node information] +
            [recursive branch data from 'left'] +
            [... from 'forward'] +
            [... from 'right] +
            [... from 'back']
208

209
        Finally, each node information is composed of 5 floating point values:
210

211
        #1:
212

spiglerg's avatar
spiglerg committed
213
        #2: 1 if a target of another agent is detected between the previous node and the current one.
214

spiglerg's avatar
spiglerg committed
215
        #3: 1 if another agent is detected between the previous node and the current one.
216

217
        #4: distance of agent to the current branch node
218

spiglerg's avatar
spiglerg committed
219
220
        #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
            branch.
221

222
223
        Missing/padding nodes are filled in with -inf (truncated).
        Missing values in present node are filled in with +inf (truncated).
224

225
226
227
228
229

        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
        In case the target node is reached, the values are [0, 0, 0, 0, 0].
        """

spiglerg's avatar
spiglerg committed
230
231
232
233
234
        # Update local lookup table for all agents' positions
        self.location_has_agent = {}
        for loc in self.env.agents_position:
            self.location_has_agent[(loc[0], loc[1])] = 1

235
236
237
238
239
        position = self.env.agents_position[handle]
        orientation = self.env.agents_direction[handle]

        # Root node - current position
        observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
240
        root_observation = observation[:]
241
242
243

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
spiglerg's avatar
spiglerg committed
244
        for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
245
246
247
            if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
                new_cell = self._new_position(position, branch_direction)

248
                branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
249
                observation = observation + branch_observation
250
            else:
251
252
253
254
255
                num_cells_to_fill_in = 0
                pow4 = 1
                for i in range(self.max_depth):
                    num_cells_to_fill_in += pow4
                    pow4 *= 4
spiglerg's avatar
spiglerg committed
256
                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
257

258
        return observation
259

260
    def _explore_branch(self, handle, position, direction, root_observation, depth):
261
262
263
264
        """
        Utility function to compute tree-based observations.
        """
        # [Recursive branch opened]
spiglerg's avatar
spiglerg committed
265
        if depth >= self.max_depth + 1:
266
267
268
269
270
271
            return []

        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops
        exploring = True
spiglerg's avatar
spiglerg committed
272
        last_isSwitch = False
spiglerg's avatar
spiglerg committed
273
        last_isDeadEnd = False
spiglerg's avatar
spiglerg committed
274
        last_isTerminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
spiglerg's avatar
spiglerg committed
275
276
        last_isTarget = False

277
        visited = set()
spiglerg's avatar
spiglerg committed
278

spiglerg's avatar
spiglerg committed
279
280
        other_agent_encountered = False
        other_target_encountered = False
281
        num_steps = 1
282
283
284
285
286
287
        while exploring:
            # #############################
            # #############################
            # Modify here to compute any useful data required to build the end node's features. This code is called
            # for each cell visited between the previous branching node and the next switch / target / dead-end.

spiglerg's avatar
spiglerg committed
288
289
290
291
292
293
            if position in self.location_has_agent:
                other_agent_encountered = True

            if position in self.location_has_target:
                other_target_encountered = True

294
295
296
            # #############################
            # #############################

spiglerg's avatar
spiglerg committed
297
298
299
            if (position[0], position[1], direction) in visited:
                last_isTerminal = True
                break
300
301
            visited.add((position[0], position[1], direction))

302
303
            # If the target node is encountered, pick that as node. Also, no further branching is possible.
            if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
spiglerg's avatar
spiglerg committed
304
                last_isTarget = True
305
306
307
308
309
310
311
312
313
314
315
                break

            cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
            num_transitions = 0
            for i in range(4):
                if cell_transitions[i]:
                    num_transitions += 1

            exploring = False
            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
316
317
318
319
320
321
322
323
                nbits = 0
                tmp = self.env.rail.get_transitions((position[0], position[1]))
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                if nbits == 1:
                    # Dead-end!
                    last_isDeadEnd = True
324

325
                if not last_isDeadEnd:
326
327
                    # Keep walking through the tree along `direction'
                    exploring = True
328
                    # TODO: Remove below calculation, this is computed already above and could be reused
329
330
331
332
                    for i in range(4):
                        if cell_transitions[i]:
                            position = self._new_position(position, i)
                            direction = i
333
                            num_steps += 1
334
                            break
335

336
337
            elif num_transitions > 0:
                # Switch detected
spiglerg's avatar
spiglerg committed
338
                last_isSwitch = True
339
                break
340

341
342
            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
343
                print("WRONG CELL TYPE detected in tree-search (0 transitions possible)")
spiglerg's avatar
spiglerg committed
344
                last_isTerminal = True
345
                break
346

347
        # `position' is either a terminal node or a switch
348

349
        observation = []
350

351
352
353
        # #############################
        # #############################
        # Modify here to append new / different features for each visited cell!
354

spiglerg's avatar
spiglerg committed
355
356
357
358
        if last_isTarget:
            observation = [0,
                           1 if other_target_encountered else 0,
                           1 if other_agent_encountered else 0,
spiglerg's avatar
spiglerg committed
359
                           root_observation[3] + num_steps,
spiglerg's avatar
spiglerg committed
360
361
                           0]

spiglerg's avatar
spiglerg committed
362
363
364
365
        elif last_isTerminal:
            observation = [0,
                           1 if other_target_encountered else 0,
                           1 if other_agent_encountered else 0,
366
                           np.inf,
spiglerg's avatar
spiglerg committed
367
                           np.inf]
spiglerg's avatar
spiglerg committed
368
369
370
371
        else:
            observation = [0,
                           1 if other_target_encountered else 0,
                           1 if other_agent_encountered else 0,
spiglerg's avatar
spiglerg committed
372
                           root_observation[3] + num_steps,
spiglerg's avatar
spiglerg committed
373
374
                           self.distance_map[handle, position[0], position[1], direction]]

375
376
        # #############################
        # #############################
377

378
379
        new_root_observation = observation[:]

380
381
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
spiglerg's avatar
spiglerg committed
382
        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
spiglerg's avatar
spiglerg committed
383
            if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
spiglerg's avatar
spiglerg committed
384
                                                               (branch_direction + 2) % 4):
spiglerg's avatar
spiglerg committed
385
386
                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                # it back
spiglerg's avatar
spiglerg committed
387
                new_cell = self._new_position(position, (branch_direction + 2) % 4)
spiglerg's avatar
spiglerg committed
388

spiglerg's avatar
spiglerg committed
389
390
                branch_observation = self._explore_branch(handle,
                                                          new_cell,
spiglerg's avatar
spiglerg committed
391
                                                          (branch_direction + 2) % 4,
spiglerg's avatar
spiglerg committed
392
                                                          new_root_observation,
spiglerg's avatar
spiglerg committed
393
                                                          depth + 1)
spiglerg's avatar
spiglerg committed
394
395
396
                observation = observation + branch_observation

            elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
397
                                                                ):
398
                new_cell = self._new_position(position, branch_direction)
399

spiglerg's avatar
spiglerg committed
400
401
402
403
                branch_observation = self._explore_branch(handle,
                                                          new_cell,
                                                          branch_direction,
                                                          new_root_observation,
spiglerg's avatar
spiglerg committed
404
                                                          depth + 1)
405
                observation = observation + branch_observation
406

407
408
409
            else:
                num_cells_to_fill_in = 0
                pow4 = 1
spiglerg's avatar
spiglerg committed
410
                for i in range(self.max_depth - depth):
411
412
                    num_cells_to_fill_in += pow4
                    pow4 *= 4
spiglerg's avatar
spiglerg committed
413
                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
414

415
        return observation
416

417
    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
418
        """
419
        Utility function to pretty-print tree observations returned by this object.
420
        """
421
        if len(tree) < num_features_per_node:
422
            return
423

424
        depth = 0
spiglerg's avatar
spiglerg committed
425
        tmp = len(tree) / num_features_per_node - 1
426
427
428
429
430
431
432
433
        pow4 = 4
        while tmp > 0:
            tmp -= pow4
            depth += 1
            pow4 *= 4

        prompt_ = ['L:', 'F:', 'R:', 'B:']

spiglerg's avatar
spiglerg committed
434
435
        print("  " * current_depth + prompt, tree[0:num_features_per_node])
        child_size = (len(tree) - num_features_per_node) // 4
436
        for children in range(4):
spiglerg's avatar
spiglerg committed
437
438
            child_tree = tree[(num_features_per_node + children * child_size):
                              (num_features_per_node + (children + 1) * child_size)]
439
            self.util_print_obs_subtree(child_tree,
440
                                        num_features_per_node,
441
                                        prompt=prompt_[children],
spiglerg's avatar
spiglerg committed
442
                                        current_depth=current_depth + 1)
443
444
445
446
447
448
449
450
451
452
453
454
455


class GlobalObsForRailEnv(ObservationBuilder):
    """
    Gives a global observation of the entire rail environment.
    The observation is composed of the following elements:

        - transition map array with dimensions (env.height, env.width, 16),
          assuming 16 bits encoding of transitions.

        - Four 2D arrays containing respectively the position of the given agent,
          the position of its target, the positions of the other agents and of
          their target.
gmollard's avatar
gmollard committed
456
457

        - A 4 elements array with one of encoding of the direction.
458
    """
gmollard's avatar
gmollard committed
459
460
    def __init__(self):
        super(GlobalObsForRailEnv, self).__init__()
gmollard's avatar
gmollard committed
461
462

    def reset(self):
463
464
465
        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
        for i in range(self.rail_obs.shape[0]):
            for j in range(self.rail_obs.shape[1]):
gmollard's avatar
gmollard committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
                self.rail_obs[i, j] = np.array(
                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)

        # self.targets = np.zeros(self.env.height, self.env.width)
        # for target_pos in self.env.agents_target:
        #     self.targets[target_pos] += 1

    def get(self, handle):
        obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
        agent_pos = self.env.agents_position[handle]
        obs_agents_targets_pos[0][agent_pos] += 1
        for i in range(len(self.env.agents_position)):
            if i != handle:
                obs_agents_targets_pos[3][self.env.agents_position[i]] += 1

        agent_target_pos = self.env.agents_target[handle]
        obs_agents_targets_pos[1][agent_target_pos] += 1
        for i in range(len(self.env.agents_target)):
            if i != handle:
                obs_agents_targets_pos[2][self.env.agents_target[i]] += 1

        direction = np.zeros(4)
        direction[self.env.agents_direction[handle]] = 1

        return self.rail_obs, obs_agents_targets_pos, direction