env_observation_builder.py 20.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
ObservationBuilder objects are objects that can be passed to environments designed for customizability.
The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).

+ Reset() is called after each environment reset, to allow for pre-computing relevant data.

+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""

11
import numpy as np
12

13
14
from collections import deque

15
16

class ObservationBuilder:
17
18
19
    """
    ObservationBuilder base class.
    """
Erik Nygren's avatar
Erik Nygren committed
20

21
22
23
24
    def __init__(self):
        pass

    def _set_env(self, env):
25
26
27
        self.env = env

    def reset(self):
28
29
30
        """
        Called after each environment reset.
        """
31
32
        raise NotImplementedError()

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    def get(self, handle=0):
        """
        Called whenever an observation has to be computed for the `env' environment, possibly
        for each agent independently (agent id `handle').

        Parameters
        -------
        handle : int (optional)
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
        function
            An observation structure, specific to the corresponding environment.
        """
48
49
50
51
        raise NotImplementedError()


class TreeObsForRailEnv(ObservationBuilder):
52
53
54
55
56
57
58
    """
    TreeObsForRailEnv object.

    This object returns observation vectors for agents in the RailEnv environment.
    The information is local to each agent and exploits the tree structure of the rail
    network to simplify the representation of the state of the environment for each agent.
    """
Erik Nygren's avatar
Erik Nygren committed
59

60
61
    def __init__(self, max_depth):
        self.max_depth = max_depth
62

63
    def reset(self):
64
65
        self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
                                                    self.env.height,
66
67
                                                    self.env.width,
                                                    4))
68
69
70
71
72
        self.max_dist = np.zeros(self.env.number_of_agents)

        for i in range(self.env.number_of_agents):
            self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)

spiglerg's avatar
spiglerg committed
73
74
75
        # Update local lookup table for all agents' target locations
        self.location_has_target = {}
        for loc in self.env.agents_target:
76
            self.location_has_target[(loc[0], loc[1])] = 1
spiglerg's avatar
spiglerg committed
77

78
    def _distance_map_walker(self, position, target_nr):
79
80
81
82
        """
        Utility function to compute distance maps from each cell in the rail network (and each possible
        orientation within it) to each agent's target cell.
        """
83
84
85
        # Returns max distance to target, from the farthest away node, while filling in distance_map

        for ori in range(4):
86
            self.distance_map[target_nr, position[0], position[1], ori] = 0
87
88

        # Fill in the (up to) 4 neighboring nodes
gmollard's avatar
gmollard committed
89
        # nodes_queue = []  # list of tuples (row, col, direction, distance);
90
91
92
        # direction is the direction of movement, meaning that at least a possible orientation of an agent
        # in cell (row,col) allows a movement in direction `direction'
        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
93
94
95

        # BFS from target `position' to all the reachable nodes in the grid
        # Stop the search if the target position is re-visited, in any direction
96
97
98
99
        visited = set([(position[0], position[1], 0),
                       (position[0], position[1], 1),
                       (position[0], position[1], 2),
                       (position[0], position[1], 3)])
100
101
102
103
104
105
106
107
108
109
110

        max_distance = 0

        while nodes_queue:
            node = nodes_queue.popleft()

            node_id = (node[0], node[1], node[2])

            if node_id not in visited:
                visited.add(node_id)

111
112
113
                # From the list of possible neighbors that have at least a path to the current node, only keep those
                # whose new orientation in the current cell would allow a transition to direction node[2]
                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
114
115
116
117

                for n in valid_neighbors:
                    nodes_queue.append(n)

118
                if len(valid_neighbors) > 0:
spiglerg's avatar
spiglerg committed
119
                    max_distance = max(max_distance, node[3] + 1)
120
121
122
123

        return max_distance

    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
124
125
126
127
        """
        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
        minimum distances from each target cell.
        """
128
129
        neighbors = []

130
131
132
133
        possible_directions = [0, 1, 2, 3]
        if enforce_target_direction >= 0:
            # The agent must land into the current cell with orientation `enforce_target_direction'.
            # This is only possible if the agent has arrived from the cell in the opposite direction!
spiglerg's avatar
spiglerg committed
134
            possible_directions = [(enforce_target_direction + 2) % 4]
135
136
137

        for neigh_direction in possible_directions:
            new_cell = self._new_position(position, neigh_direction)
138
139

            if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
Erik Nygren's avatar
Erik Nygren committed
140
                new_cell[1] >= 0 and new_cell[1] < self.env.width:
141

spiglerg's avatar
spiglerg committed
142
                desired_movement_from_new_cell = (neigh_direction + 2) % 4
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170

                """
                # Is the next cell a dead-end?
                isNextCellDeadEnd = False
                nbits = 0
                tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                if nbits == 1:
                    # Dead-end!
                    isNextCellDeadEnd = True
                """

                # Check all possible transitions in new_cell
                for agent_orientation in range(4):
                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
                    isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
                                                           desired_movement_from_new_cell)

                    if isValid:
                        """
                        # TODO: check that it works with deadends! -- still bugged!
                        movement = desired_movement_from_new_cell
                        if isNextCellDeadEnd:
                            movement = (desired_movement_from_new_cell+2) % 4
                        """
                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
spiglerg's avatar
spiglerg committed
171
                                           current_distance + 1)
172
173
                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
174
175
176
177

        return neighbors

    def _new_position(self, position, movement):
178
179
180
        """
        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
        """
Erik Nygren's avatar
Erik Nygren committed
181
        if movement == 0:  # NORTH
spiglerg's avatar
spiglerg committed
182
            return (position[0] - 1, position[1])
183
184
185
        elif movement == 1:  # EAST
            return (position[0], position[1] + 1)
        elif movement == 2:  # SOUTH
spiglerg's avatar
spiglerg committed
186
            return (position[0] + 1, position[1])
187
188
189
        elif movement == 3:  # WEST
            return (position[0], position[1] - 1)

190
    def get(self, handle):
191
192
        """
        Computes the current observation for agent `handle' in env
193

194
195
196
197
198
        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
        the transitions. The order is:
            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
199

gmollard's avatar
gmollard committed
200
201
202
203




204
205
206
207
208
209
        Each branch data is organized as:
            [root node information] +
            [recursive branch data from 'left'] +
            [... from 'forward'] +
            [... from 'right] +
            [... from 'back']
210

211
        Finally, each node information is composed of 5 floating point values:
212

213
        #1:
214

spiglerg's avatar
spiglerg committed
215
        #2: 1 if a target of another agent is detected between the previous node and the current one.
216

spiglerg's avatar
spiglerg committed
217
        #3: 1 if another agent is detected between the previous node and the current one.
218

219
        #4: distance of agent to the current branch node
220

spiglerg's avatar
spiglerg committed
221
222
        #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
            branch.
223

224
225
        Missing/padding nodes are filled in with -inf (truncated).
        Missing values in present node are filled in with +inf (truncated).
226

227
228
229
230
231

        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
        In case the target node is reached, the values are [0, 0, 0, 0, 0].
        """

spiglerg's avatar
spiglerg committed
232
233
234
235
236
        # Update local lookup table for all agents' positions
        self.location_has_agent = {}
        for loc in self.env.agents_position:
            self.location_has_agent[(loc[0], loc[1])] = 1

237
238
239
240
241
        position = self.env.agents_position[handle]
        orientation = self.env.agents_direction[handle]

        # Root node - current position
        observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
242
        root_observation = observation[:]
243
244
245

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
spiglerg's avatar
spiglerg committed
246
        for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
247
248
249
            if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
                new_cell = self._new_position(position, branch_direction)

250
                branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
251
                observation = observation + branch_observation
252
            else:
253
254
255
256
257
                num_cells_to_fill_in = 0
                pow4 = 1
                for i in range(self.max_depth):
                    num_cells_to_fill_in += pow4
                    pow4 *= 4
spiglerg's avatar
spiglerg committed
258
                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
259

260
        return observation
261

262
    def _explore_branch(self, handle, position, direction, root_observation, depth):
263
264
265
266
        """
        Utility function to compute tree-based observations.
        """
        # [Recursive branch opened]
spiglerg's avatar
spiglerg committed
267
        if depth >= self.max_depth + 1:
268
269
270
271
272
273
            return []

        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops
        exploring = True
spiglerg's avatar
spiglerg committed
274
        last_isSwitch = False
spiglerg's avatar
spiglerg committed
275
        last_isDeadEnd = False
spiglerg's avatar
spiglerg committed
276
        last_isTerminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
spiglerg's avatar
spiglerg committed
277
278
        last_isTarget = False

279
        visited = set()
spiglerg's avatar
spiglerg committed
280

spiglerg's avatar
spiglerg committed
281
282
        other_agent_encountered = False
        other_target_encountered = False
283
        num_steps = 1
284
285
286
287
288
289
        while exploring:
            # #############################
            # #############################
            # Modify here to compute any useful data required to build the end node's features. This code is called
            # for each cell visited between the previous branching node and the next switch / target / dead-end.

spiglerg's avatar
spiglerg committed
290
291
292
293
294
295
            if position in self.location_has_agent:
                other_agent_encountered = True

            if position in self.location_has_target:
                other_target_encountered = True

296
297
298
            # #############################
            # #############################

spiglerg's avatar
spiglerg committed
299
300
301
            if (position[0], position[1], direction) in visited:
                last_isTerminal = True
                break
302
303
            visited.add((position[0], position[1], direction))

304
305
            # If the target node is encountered, pick that as node. Also, no further branching is possible.
            if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
spiglerg's avatar
spiglerg committed
306
                last_isTarget = True
307
308
309
310
311
312
313
314
315
316
317
                break

            cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
            num_transitions = 0
            for i in range(4):
                if cell_transitions[i]:
                    num_transitions += 1

            exploring = False
            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
318
319
320
321
322
323
324
325
                nbits = 0
                tmp = self.env.rail.get_transitions((position[0], position[1]))
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                if nbits == 1:
                    # Dead-end!
                    last_isDeadEnd = True
326

327
                if not last_isDeadEnd:
328
329
                    # Keep walking through the tree along `direction'
                    exploring = True
330
                    # TODO: Remove below calculation, this is computed already above and could be reused
331
332
333
334
                    for i in range(4):
                        if cell_transitions[i]:
                            position = self._new_position(position, i)
                            direction = i
335
                            num_steps += 1
336
                            break
337

338
339
            elif num_transitions > 0:
                # Switch detected
spiglerg's avatar
spiglerg committed
340
                last_isSwitch = True
341
                break
342

343
344
            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
Erik Nygren's avatar
Erik Nygren committed
345
346
                #print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
                #      position[1], direction)
spiglerg's avatar
spiglerg committed
347
                last_isTerminal = True
348
                break
349

350
        # `position' is either a terminal node or a switch
351

352
        observation = []
353

354
355
356
        # #############################
        # #############################
        # Modify here to append new / different features for each visited cell!
357

spiglerg's avatar
spiglerg committed
358
359
360
361
        if last_isTarget:
            observation = [0,
                           1 if other_target_encountered else 0,
                           1 if other_agent_encountered else 0,
spiglerg's avatar
spiglerg committed
362
                           root_observation[3] + num_steps,
spiglerg's avatar
spiglerg committed
363
364
                           0]

spiglerg's avatar
spiglerg committed
365
366
367
368
        elif last_isTerminal:
            observation = [0,
                           1 if other_target_encountered else 0,
                           1 if other_agent_encountered else 0,
369
                           np.inf,
spiglerg's avatar
spiglerg committed
370
                           np.inf]
spiglerg's avatar
spiglerg committed
371
372
373
374
        else:
            observation = [0,
                           1 if other_target_encountered else 0,
                           1 if other_agent_encountered else 0,
spiglerg's avatar
spiglerg committed
375
                           root_observation[3] + num_steps,
spiglerg's avatar
spiglerg committed
376
377
                           self.distance_map[handle, position[0], position[1], direction]]

378
379
        # #############################
        # #############################
380

381
382
        new_root_observation = observation[:]

383
384
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
spiglerg's avatar
spiglerg committed
385
        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
spiglerg's avatar
spiglerg committed
386
            if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
spiglerg's avatar
spiglerg committed
387
                                                               (branch_direction + 2) % 4):
spiglerg's avatar
spiglerg committed
388
389
                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                # it back
spiglerg's avatar
spiglerg committed
390
                new_cell = self._new_position(position, (branch_direction + 2) % 4)
spiglerg's avatar
spiglerg committed
391

spiglerg's avatar
spiglerg committed
392
393
                branch_observation = self._explore_branch(handle,
                                                          new_cell,
spiglerg's avatar
spiglerg committed
394
                                                          (branch_direction + 2) % 4,
spiglerg's avatar
spiglerg committed
395
                                                          new_root_observation,
spiglerg's avatar
spiglerg committed
396
                                                          depth + 1)
spiglerg's avatar
spiglerg committed
397
398
399
                observation = observation + branch_observation

            elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
Erik Nygren's avatar
Erik Nygren committed
400
                                                                (branch_direction + 2) % 4):
401
                new_cell = self._new_position(position, branch_direction)
402

spiglerg's avatar
spiglerg committed
403
404
405
406
                branch_observation = self._explore_branch(handle,
                                                          new_cell,
                                                          branch_direction,
                                                          new_root_observation,
spiglerg's avatar
spiglerg committed
407
                                                          depth + 1)
408
                observation = observation + branch_observation
409

410
411
412
            else:
                num_cells_to_fill_in = 0
                pow4 = 1
spiglerg's avatar
spiglerg committed
413
                for i in range(self.max_depth - depth):
414
415
                    num_cells_to_fill_in += pow4
                    pow4 *= 4
spiglerg's avatar
spiglerg committed
416
                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
417

418
        return observation
419

420
    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
421
        """
422
        Utility function to pretty-print tree observations returned by this object.
423
        """
424
        if len(tree) < num_features_per_node:
425
            return
426

427
        depth = 0
spiglerg's avatar
spiglerg committed
428
        tmp = len(tree) / num_features_per_node - 1
429
430
431
432
433
434
435
436
        pow4 = 4
        while tmp > 0:
            tmp -= pow4
            depth += 1
            pow4 *= 4

        prompt_ = ['L:', 'F:', 'R:', 'B:']

spiglerg's avatar
spiglerg committed
437
438
        print("  " * current_depth + prompt, tree[0:num_features_per_node])
        child_size = (len(tree) - num_features_per_node) // 4
439
        for children in range(4):
spiglerg's avatar
spiglerg committed
440
441
            child_tree = tree[(num_features_per_node + children * child_size):
                              (num_features_per_node + (children + 1) * child_size)]
442
            self.util_print_obs_subtree(child_tree,
443
                                        num_features_per_node,
444
                                        prompt=prompt_[children],
spiglerg's avatar
spiglerg committed
445
                                        current_depth=current_depth + 1)
446
447
448
449
450
451
452
453
454
455
456
457
458


class GlobalObsForRailEnv(ObservationBuilder):
    """
    Gives a global observation of the entire rail environment.
    The observation is composed of the following elements:

        - transition map array with dimensions (env.height, env.width, 16),
          assuming 16 bits encoding of transitions.

        - Four 2D arrays containing respectively the position of the given agent,
          the position of its target, the positions of the other agents and of
          their target.
gmollard's avatar
gmollard committed
459
460

        - A 4 elements array with one of encoding of the direction.
461
    """
Erik Nygren's avatar
Erik Nygren committed
462

gmollard's avatar
gmollard committed
463
464
    def __init__(self):
        super(GlobalObsForRailEnv, self).__init__()
gmollard's avatar
gmollard committed
465
466

    def reset(self):
467
468
469
        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
        for i in range(self.rail_obs.shape[0]):
            for j in range(self.rail_obs.shape[1]):
gmollard's avatar
gmollard committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
                self.rail_obs[i, j] = np.array(
                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)

        # self.targets = np.zeros(self.env.height, self.env.width)
        # for target_pos in self.env.agents_target:
        #     self.targets[target_pos] += 1

    def get(self, handle):
        obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
        agent_pos = self.env.agents_position[handle]
        obs_agents_targets_pos[0][agent_pos] += 1
        for i in range(len(self.env.agents_position)):
            if i != handle:
                obs_agents_targets_pos[3][self.env.agents_position[i]] += 1

        agent_target_pos = self.env.agents_target[handle]
        obs_agents_targets_pos[1][agent_target_pos] += 1
        for i in range(len(self.env.agents_target)):
            if i != handle:
                obs_agents_targets_pos[2][self.env.agents_target[i]] += 1

        direction = np.zeros(4)
        direction[self.env.agents_direction[handle]] = 1

        return self.rail_obs, obs_agents_targets_pos, direction