env_observation_builder.py 20.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
ObservationBuilder objects are objects that can be passed to environments designed for customizability.
The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).

+ Reset() is called after each environment reset, to allow for pre-computing relevant data.

+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""

11
import numpy as np
12

13
14
from collections import deque

15
16

class ObservationBuilder:
17
18
19
    """
    ObservationBuilder base class.
    """
Erik Nygren's avatar
Erik Nygren committed
20

21
22
23
24
    def __init__(self):
        pass

    def _set_env(self, env):
25
26
27
        self.env = env

    def reset(self):
28
29
30
        """
        Called after each environment reset.
        """
31
32
        raise NotImplementedError()

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    def get(self, handle=0):
        """
        Called whenever an observation has to be computed for the `env' environment, possibly
        for each agent independently (agent id `handle').

        Parameters
        -------
        handle : int (optional)
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
        function
            An observation structure, specific to the corresponding environment.
        """
48
49
50
51
        raise NotImplementedError()


class TreeObsForRailEnv(ObservationBuilder):
52
53
54
55
56
57
58
    """
    TreeObsForRailEnv object.

    This object returns observation vectors for agents in the RailEnv environment.
    The information is local to each agent and exploits the tree structure of the rail
    network to simplify the representation of the state of the environment for each agent.
    """
Erik Nygren's avatar
Erik Nygren committed
59

60
61
    def __init__(self, max_depth):
        self.max_depth = max_depth
62

63
    def reset(self):
64
65
        self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
                                                    self.env.height,
66
67
                                                    self.env.width,
                                                    4))
68
69
70
71
72
        self.max_dist = np.zeros(self.env.number_of_agents)

        for i in range(self.env.number_of_agents):
            self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)

spiglerg's avatar
spiglerg committed
73
74
75
        # Update local lookup table for all agents' target locations
        self.location_has_target = {}
        for loc in self.env.agents_target:
76
            self.location_has_target[(loc[0], loc[1])] = 1
spiglerg's avatar
spiglerg committed
77

78
    def _distance_map_walker(self, position, target_nr):
79
80
81
82
        """
        Utility function to compute distance maps from each cell in the rail network (and each possible
        orientation within it) to each agent's target cell.
        """
83
84
85
        # Returns max distance to target, from the farthest away node, while filling in distance_map

        for ori in range(4):
86
            self.distance_map[target_nr, position[0], position[1], ori] = 0
87
88

        # Fill in the (up to) 4 neighboring nodes
gmollard's avatar
gmollard committed
89
        # nodes_queue = []  # list of tuples (row, col, direction, distance);
90
91
92
        # direction is the direction of movement, meaning that at least a possible orientation of an agent
        # in cell (row,col) allows a movement in direction `direction'
        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
93
94
95

        # BFS from target `position' to all the reachable nodes in the grid
        # Stop the search if the target position is re-visited, in any direction
96
97
98
99
        visited = set([(position[0], position[1], 0),
                       (position[0], position[1], 1),
                       (position[0], position[1], 2),
                       (position[0], position[1], 3)])
100
101
102
103
104
105
106
107
108
109
110

        max_distance = 0

        while nodes_queue:
            node = nodes_queue.popleft()

            node_id = (node[0], node[1], node[2])

            if node_id not in visited:
                visited.add(node_id)

111
112
113
                # From the list of possible neighbors that have at least a path to the current node, only keep those
                # whose new orientation in the current cell would allow a transition to direction node[2]
                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
114
115
116
117

                for n in valid_neighbors:
                    nodes_queue.append(n)

118
                if len(valid_neighbors) > 0:
spiglerg's avatar
spiglerg committed
119
                    max_distance = max(max_distance, node[3] + 1)
120
121
122
123

        return max_distance

    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
124
125
126
127
        """
        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
        minimum distances from each target cell.
        """
128
129
        neighbors = []

130
131
132
133
        possible_directions = [0, 1, 2, 3]
        if enforce_target_direction >= 0:
            # The agent must land into the current cell with orientation `enforce_target_direction'.
            # This is only possible if the agent has arrived from the cell in the opposite direction!
spiglerg's avatar
spiglerg committed
134
            possible_directions = [(enforce_target_direction + 2) % 4]
135
136
137

        for neigh_direction in possible_directions:
            new_cell = self._new_position(position, neigh_direction)
138
139

            if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
Erik Nygren's avatar
Erik Nygren committed
140
                new_cell[1] >= 0 and new_cell[1] < self.env.width:
141

spiglerg's avatar
spiglerg committed
142
                desired_movement_from_new_cell = (neigh_direction + 2) % 4
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170

                """
                # Is the next cell a dead-end?
                isNextCellDeadEnd = False
                nbits = 0
                tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                if nbits == 1:
                    # Dead-end!
                    isNextCellDeadEnd = True
                """

                # Check all possible transitions in new_cell
                for agent_orientation in range(4):
                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
                    isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
                                                           desired_movement_from_new_cell)

                    if isValid:
                        """
                        # TODO: check that it works with deadends! -- still bugged!
                        movement = desired_movement_from_new_cell
                        if isNextCellDeadEnd:
                            movement = (desired_movement_from_new_cell+2) % 4
                        """
                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
spiglerg's avatar
spiglerg committed
171
                                           current_distance + 1)
172
173
                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
174
175
176
177

        return neighbors

    def _new_position(self, position, movement):
178
179
180
        """
        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
        """
Erik Nygren's avatar
Erik Nygren committed
181
        if movement == 0:  # NORTH
spiglerg's avatar
spiglerg committed
182
            return (position[0] - 1, position[1])
183
184
185
        elif movement == 1:  # EAST
            return (position[0], position[1] + 1)
        elif movement == 2:  # SOUTH
spiglerg's avatar
spiglerg committed
186
            return (position[0] + 1, position[1])
187
188
189
        elif movement == 3:  # WEST
            return (position[0], position[1] - 1)

190
    def get(self, handle):
191
192
        """
        Computes the current observation for agent `handle' in env
193

194
195
196
197
198
        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
        the transitions. The order is:
            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
199

gmollard's avatar
gmollard committed
200
201
202
203




204
205
206
207
208
209
        Each branch data is organized as:
            [root node information] +
            [recursive branch data from 'left'] +
            [... from 'forward'] +
            [... from 'right] +
            [... from 'back']
210

211
        Finally, each node information is composed of 5 floating point values:
212

213
        #1:
214

spiglerg's avatar
spiglerg committed
215
        #2: 1 if a target of another agent is detected between the previous node and the current one.
216

spiglerg's avatar
spiglerg committed
217
        #3: 1 if another agent is detected between the previous node and the current one.
218

219
        #4: distance of agent to the current branch node
220

spiglerg's avatar
spiglerg committed
221
222
        #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
            branch.
223

224
225
        Missing/padding nodes are filled in with -inf (truncated).
        Missing values in present node are filled in with +inf (truncated).
226

227
228
229
230
231

        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
        In case the target node is reached, the values are [0, 0, 0, 0, 0].
        """

spiglerg's avatar
spiglerg committed
232
233
234
235
236
        # Update local lookup table for all agents' positions
        self.location_has_agent = {}
        for loc in self.env.agents_position:
            self.location_has_agent[(loc[0], loc[1])] = 1

237
238
239
240
241
        position = self.env.agents_position[handle]
        orientation = self.env.agents_direction[handle]

        # Root node - current position
        observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
242
        root_observation = observation[:]
243
244
245

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
246
        # TODO: Adjust this to the novel movement dynamics --> Only Forward present when one transition is possible.
spiglerg's avatar
spiglerg committed
247
        for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
248
249
250
            if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
                new_cell = self._new_position(position, branch_direction)

251
                branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
252
                observation = observation + branch_observation
253
            else:
254
255
256
257
258
                num_cells_to_fill_in = 0
                pow4 = 1
                for i in range(self.max_depth):
                    num_cells_to_fill_in += pow4
                    pow4 *= 4
spiglerg's avatar
spiglerg committed
259
                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
260

261
        return observation
262

263
    def _explore_branch(self, handle, position, direction, root_observation, depth):
264
265
266
267
        """
        Utility function to compute tree-based observations.
        """
        # [Recursive branch opened]
spiglerg's avatar
spiglerg committed
268
        if depth >= self.max_depth + 1:
269
270
271
272
273
274
            return []

        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops
        exploring = True
spiglerg's avatar
spiglerg committed
275
        last_isSwitch = False
spiglerg's avatar
spiglerg committed
276
        last_isDeadEnd = False
spiglerg's avatar
spiglerg committed
277
        last_isTerminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
spiglerg's avatar
spiglerg committed
278
279
        last_isTarget = False

280
        visited = set()
spiglerg's avatar
spiglerg committed
281

spiglerg's avatar
spiglerg committed
282
283
        other_agent_encountered = False
        other_target_encountered = False
284
        num_steps = 1
285
286
287
288
289
290
        while exploring:
            # #############################
            # #############################
            # Modify here to compute any useful data required to build the end node's features. This code is called
            # for each cell visited between the previous branching node and the next switch / target / dead-end.

spiglerg's avatar
spiglerg committed
291
292
293
294
295
296
            if position in self.location_has_agent:
                other_agent_encountered = True

            if position in self.location_has_target:
                other_target_encountered = True

297
298
299
            # #############################
            # #############################

spiglerg's avatar
spiglerg committed
300
301
302
            if (position[0], position[1], direction) in visited:
                last_isTerminal = True
                break
303
304
            visited.add((position[0], position[1], direction))

305
306
            # If the target node is encountered, pick that as node. Also, no further branching is possible.
            if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
spiglerg's avatar
spiglerg committed
307
                last_isTarget = True
308
309
310
311
312
313
314
315
316
317
318
                break

            cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
            num_transitions = 0
            for i in range(4):
                if cell_transitions[i]:
                    num_transitions += 1

            exploring = False
            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
319
320
321
322
323
324
325
326
                nbits = 0
                tmp = self.env.rail.get_transitions((position[0], position[1]))
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                if nbits == 1:
                    # Dead-end!
                    last_isDeadEnd = True
327

328
                if not last_isDeadEnd:
329
330
                    # Keep walking through the tree along `direction'
                    exploring = True
331
                    # TODO: Remove below calculation, this is computed already above and could be reused
332
333
334
335
                    for i in range(4):
                        if cell_transitions[i]:
                            position = self._new_position(position, i)
                            direction = i
336
                            num_steps += 1
337
                            break
338

339
340
            elif num_transitions > 0:
                # Switch detected
spiglerg's avatar
spiglerg committed
341
                last_isSwitch = True
342
                break
343

344
345
            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
346
347
                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
                      position[1], direction)
spiglerg's avatar
spiglerg committed
348
                last_isTerminal = True
349
                break
350

351
        # `position' is either a terminal node or a switch
352

353
        observation = []
354

355
356
357
        # #############################
        # #############################
        # Modify here to append new / different features for each visited cell!
358

spiglerg's avatar
spiglerg committed
359
360
361
362
        if last_isTarget:
            observation = [0,
                           1 if other_target_encountered else 0,
                           1 if other_agent_encountered else 0,
spiglerg's avatar
spiglerg committed
363
                           root_observation[3] + num_steps,
spiglerg's avatar
spiglerg committed
364
365
                           0]

spiglerg's avatar
spiglerg committed
366
367
368
369
        elif last_isTerminal:
            observation = [0,
                           1 if other_target_encountered else 0,
                           1 if other_agent_encountered else 0,
370
                           np.inf,
spiglerg's avatar
spiglerg committed
371
                           np.inf]
spiglerg's avatar
spiglerg committed
372
373
374
375
        else:
            observation = [0,
                           1 if other_target_encountered else 0,
                           1 if other_agent_encountered else 0,
spiglerg's avatar
spiglerg committed
376
                           root_observation[3] + num_steps,
spiglerg's avatar
spiglerg committed
377
378
                           self.distance_map[handle, position[0], position[1], direction]]

379
380
        # #############################
        # #############################
381

382
383
        new_root_observation = observation[:]

384
385
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
spiglerg's avatar
spiglerg committed
386
        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
spiglerg's avatar
spiglerg committed
387
            if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
spiglerg's avatar
spiglerg committed
388
                                                               (branch_direction + 2) % 4):
spiglerg's avatar
spiglerg committed
389
390
                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                # it back
spiglerg's avatar
spiglerg committed
391
                new_cell = self._new_position(position, (branch_direction + 2) % 4)
spiglerg's avatar
spiglerg committed
392

spiglerg's avatar
spiglerg committed
393
394
                branch_observation = self._explore_branch(handle,
                                                          new_cell,
spiglerg's avatar
spiglerg committed
395
                                                          (branch_direction + 2) % 4,
spiglerg's avatar
spiglerg committed
396
                                                          new_root_observation,
spiglerg's avatar
spiglerg committed
397
                                                          depth + 1)
spiglerg's avatar
spiglerg committed
398
399
400
                observation = observation + branch_observation

            elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
Erik Nygren's avatar
Erik Nygren committed
401
                                                                (branch_direction + 2) % 4):
402
                new_cell = self._new_position(position, branch_direction)
403

spiglerg's avatar
spiglerg committed
404
405
406
407
                branch_observation = self._explore_branch(handle,
                                                          new_cell,
                                                          branch_direction,
                                                          new_root_observation,
spiglerg's avatar
spiglerg committed
408
                                                          depth + 1)
409
                observation = observation + branch_observation
410

411
412
413
            else:
                num_cells_to_fill_in = 0
                pow4 = 1
spiglerg's avatar
spiglerg committed
414
                for i in range(self.max_depth - depth):
415
416
                    num_cells_to_fill_in += pow4
                    pow4 *= 4
spiglerg's avatar
spiglerg committed
417
                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
418

419
        return observation
420

421
    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
422
        """
423
        Utility function to pretty-print tree observations returned by this object.
424
        """
425
        if len(tree) < num_features_per_node:
426
            return
427

428
        depth = 0
spiglerg's avatar
spiglerg committed
429
        tmp = len(tree) / num_features_per_node - 1
430
431
432
433
434
435
436
437
        pow4 = 4
        while tmp > 0:
            tmp -= pow4
            depth += 1
            pow4 *= 4

        prompt_ = ['L:', 'F:', 'R:', 'B:']

spiglerg's avatar
spiglerg committed
438
439
        print("  " * current_depth + prompt, tree[0:num_features_per_node])
        child_size = (len(tree) - num_features_per_node) // 4
440
        for children in range(4):
spiglerg's avatar
spiglerg committed
441
442
            child_tree = tree[(num_features_per_node + children * child_size):
                              (num_features_per_node + (children + 1) * child_size)]
443
            self.util_print_obs_subtree(child_tree,
444
                                        num_features_per_node,
445
                                        prompt=prompt_[children],
spiglerg's avatar
spiglerg committed
446
                                        current_depth=current_depth + 1)
447
448
449
450
451
452
453
454
455
456
457
458
459


class GlobalObsForRailEnv(ObservationBuilder):
    """
    Gives a global observation of the entire rail environment.
    The observation is composed of the following elements:

        - transition map array with dimensions (env.height, env.width, 16),
          assuming 16 bits encoding of transitions.

        - Four 2D arrays containing respectively the position of the given agent,
          the position of its target, the positions of the other agents and of
          their target.
gmollard's avatar
gmollard committed
460
461

        - A 4 elements array with one of encoding of the direction.
462
    """
Erik Nygren's avatar
Erik Nygren committed
463

gmollard's avatar
gmollard committed
464
465
    def __init__(self):
        super(GlobalObsForRailEnv, self).__init__()
gmollard's avatar
gmollard committed
466
467

    def reset(self):
468
469
470
        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
        for i in range(self.rail_obs.shape[0]):
            for j in range(self.rail_obs.shape[1]):
gmollard's avatar
gmollard committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
                self.rail_obs[i, j] = np.array(
                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)

        # self.targets = np.zeros(self.env.height, self.env.width)
        # for target_pos in self.env.agents_target:
        #     self.targets[target_pos] += 1

    def get(self, handle):
        obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
        agent_pos = self.env.agents_position[handle]
        obs_agents_targets_pos[0][agent_pos] += 1
        for i in range(len(self.env.agents_position)):
            if i != handle:
                obs_agents_targets_pos[3][self.env.agents_position[i]] += 1

        agent_target_pos = self.env.agents_target[handle]
        obs_agents_targets_pos[1][agent_target_pos] += 1
        for i in range(len(self.env.agents_target)):
            if i != handle:
                obs_agents_targets_pos[2][self.env.agents_target[i]] += 1

        direction = np.zeros(4)
        direction[self.env.agents_direction[handle]] = 1

        return self.rail_obs, obs_agents_targets_pos, direction