observations.py 33.5 KB
Newer Older
1
2
3
"""
Collection of environment-specific ObservationBuilder.
"""
4
import pprint
5
6
from collections import deque

u214892's avatar
u214892 committed
7
8
import numpy as np

9
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
10
from flatland.core.grid.grid4 import Grid4TransitionsEnum
u214892's avatar
u214892 committed
11
from flatland.core.grid.grid_utils import coordinate_to_position
12
13
14
15
16
17
18
19
20
21
22


class TreeObsForRailEnv(ObservationBuilder):
    """
    TreeObsForRailEnv object.

    This object returns observation vectors for agents in the RailEnv environment.
    The information is local to each agent and exploits the tree structure of the rail
    network to simplify the representation of the state of the environment for each agent.
    """

u214892's avatar
u214892 committed
23
24
    observation_dim = 9

25
    def __init__(self, max_depth, predictor=None):
u214892's avatar
u214892 committed
26
        super().__init__()
27
28
        self.max_depth = max_depth

29
30
31
        # Compute the size of the returned observation vector
        size = 0
        pow4 = 1
u214892's avatar
u214892 committed
32
        for i in range(self.max_depth + 1):
33
34
            size += pow4
            pow4 *= 4
35
36
37
        self.observation_space = [size * self.observation_dim]
        self.location_has_agent = {}
        self.location_has_agent_direction = {}
38
        self.predictor = predictor
spiglerg's avatar
spiglerg committed
39
        self.agents_previous_reset = None
40
41
        self.tree_explored_actions = [1, 2, 3, 0]
        self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
spiglerg's avatar
spiglerg committed
42

43
44
    def reset(self):
        agents = self.env.agents
u214892's avatar
u214892 committed
45
        nb_agents = len(agents)
46

spiglerg's avatar
spiglerg committed
47
        compute_distance_map = True
u214892's avatar
u214892 committed
48
49
50
51
52
        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
            compute_distance_map = False
            for i in range(nb_agents):
                if agents[i].target != self.agents_previous_reset[i].target:
                    compute_distance_map = True
spiglerg's avatar
spiglerg committed
53
54
55
        self.agents_previous_reset = agents

        if compute_distance_map:
56
            self._compute_distance_map()
spiglerg's avatar
spiglerg committed
57

58
59
    def _compute_distance_map(self):
        agents = self.env.agents
u214892's avatar
u214892 committed
60
61
        nb_agents = len(agents)
        self.distance_map = np.inf * np.ones(shape=(nb_agents,
62
63
64
                                                    self.env.height,
                                                    self.env.width,
                                                    4))
u214892's avatar
u214892 committed
65
        self.max_dist = np.zeros(nb_agents)
66
67
68
        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
        # Update local lookup table for all agents' target locations
        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    def _distance_map_walker(self, position, target_nr):
        """
        Utility function to compute distance maps from each cell in the rail network (and each possible
        orientation within it) to each agent's target cell.
        """
        # Returns max distance to target, from the farthest away node, while filling in distance_map

        self.distance_map[target_nr, position[0], position[1], :] = 0

        # Fill in the (up to) 4 neighboring nodes
        # direction is the direction of movement, meaning that at least a possible orientation of an agent
        # in cell (row,col) allows a movement in direction `direction'
        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))

        # BFS from target `position' to all the reachable nodes in the grid
        # Stop the search if the target position is re-visited, in any direction
u214892's avatar
u214892 committed
86
87
        visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
                   (position[0], position[1], 3)}
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

        max_distance = 0

        while nodes_queue:
            node = nodes_queue.popleft()

            node_id = (node[0], node[1], node[2])

            if node_id not in visited:
                visited.add(node_id)

                # From the list of possible neighbors that have at least a path to the current node, only keep those
                # whose new orientation in the current cell would allow a transition to direction node[2]
                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])

                for n in valid_neighbors:
                    nodes_queue.append(n)

                if len(valid_neighbors) > 0:
                    max_distance = max(max_distance, node[3] + 1)

        return max_distance

    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
        """
        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
        minimum distances from each target cell.
        """
        neighbors = []

        possible_directions = [0, 1, 2, 3]
        if enforce_target_direction >= 0:
            # The agent must land into the current cell with orientation `enforce_target_direction'.
            # This is only possible if the agent has arrived from the cell in the opposite direction!
            possible_directions = [(enforce_target_direction + 2) % 4]

        for neigh_direction in possible_directions:
            new_cell = self._new_position(position, neigh_direction)

            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:

                desired_movement_from_new_cell = (neigh_direction + 2) % 4

                # Check all possible transitions in new_cell
                for agent_orientation in range(4):
                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
u214892's avatar
u214892 committed
134
135
                    is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
                                                            desired_movement_from_new_cell)
136

u214892's avatar
u214892 committed
137
                    if is_valid:
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                        """
                        # TODO: check that it works with deadends! -- still bugged!
                        movement = desired_movement_from_new_cell
                        if isNextCellDeadEnd:
                            movement = (desired_movement_from_new_cell+2) % 4
                        """
                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
                                           current_distance + 1)
                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance

        return neighbors

    def _new_position(self, position, movement):
        """
        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
        """
155
        if movement == Grid4TransitionsEnum.NORTH:
156
            return (position[0] - 1, position[1])
157
        elif movement == Grid4TransitionsEnum.EAST:
158
            return (position[0], position[1] + 1)
159
        elif movement == Grid4TransitionsEnum.SOUTH:
160
            return (position[0] + 1, position[1])
161
        elif movement == Grid4TransitionsEnum.WEST:
162
163
            return (position[0], position[1] - 1)

u214892's avatar
u214892 committed
164
    def get_many(self, handles=None):
165
166
167
168
169
        """
        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
        in the `handles' list.
        """

u214892's avatar
u214892 committed
170
171
        if handles is None:
            handles = []
172
        if self.predictor:
173
174
            self.predicted_pos = {}
            self.predicted_dir = {}
175
            self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
176
177
178
179
180
181
182
183
            for t in range(len(self.predictions[0])):
                pos_list = []
                dir_list = []
                for a in handles:
                    pos_list.append(self.predictions[a][t][1:3])
                    dir_list.append(self.predictions[a][t][3])
                self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
                self.predicted_dir.update({t: dir_list})
184
            self.max_prediction_depth = len(self.predicted_pos)
185
186
187
188
189
        observations = {}
        for h in handles:
            observations[h] = self.get(h)
        return observations

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def get(self, handle):
        """
        Computes the current observation for agent `handle' in env

        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
        the transitions. The order is:
            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']

        Each branch data is organized as:
            [root node information] +
            [recursive branch data from 'left'] +
            [... from 'forward'] +
            [... from 'right] +
            [... from 'back']

207
        Finally, each node information is composed of 8 floating point values:
208

209
        #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
210

211
212
        #2: if another agents target is detected the distance in number of cells from the agents current locaiton
        is stored
213
214


215
        #3: if another agent is detected the distance in number of cells from current agent position is stored.
216

217
218
219
        #4: possible conflict detected
            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
             distance in number of cells from current agent position
220

221
222
223
224
225
226
227
            0 = No other agent reserve the same cell at similar time

        #5: if an not usable switch (for agent) is detected we store the distance.

        #6: This feature stores the distance in number of cells to the next branching  (current node)

        #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
228

229
        #8: agent in the same direction
230
            n = number of agents present same direction
231
232
                (possible future use: number of other agents in the same direction in this branch)
            0 = no agent present same direction
u214892's avatar
u214892 committed
233

u214892's avatar
u214892 committed
234
        #9: agent in the opposite direction
235
            n = number of agents present other direction than myself (so conflict)
u214892's avatar
u214892 committed
236
                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
237
            0 = no agent present other direction than myself
u214892's avatar
u214892 committed
238

239

240

u214892's avatar
u214892 committed
241

242
243
244
245
246
247
248
249
250
251
        Missing/padding nodes are filled in with -inf (truncated).
        Missing values in present node are filled in with +inf (truncated).


        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
        In case the target node is reached, the values are [0, 0, 0, 0, 0].
        """

        # Update local lookup table for all agents' positions
        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
252
        self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
253
254
        if handle > len(self.env.agents):
            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
255
256
257
        agent = self.env.agents[handle]  # TODO: handle being treated as index
        possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
        num_transitions = np.count_nonzero(possible_transitions)
Erik Nygren's avatar
Erik Nygren committed
258

259
        # Root node - current position
260
        observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
261

262
        visited = set()
263
264
265
266
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
        orientation = agent.direction
267

268
        if num_transitions == 1:
269
            orientation = np.argmax(possible_transitions)
270
271
272
273

        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
            if possible_transitions[branch_direction]:
                new_cell = self._new_position(agent.position, branch_direction)
274
                branch_observation, branch_visited = \
u214892's avatar
u214892 committed
275
                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
276
                observation = observation + branch_observation
277
                visited = visited.union(branch_visited)
278
            else:
279
280
                # add cells filled with infinity if no transition is possible
                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
281
282
        self.env.dev_obs_dict[handle] = visited
        return observation
283

284
285
286
287
288
289
290
291
292
    def _num_cells_to_fill_in(self, remaining_depth):
        """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
        num_observations = 0
        pow4 = 1
        for i in range(remaining_depth):
            num_observations += pow4
            pow4 *= 4
        return num_observations * self.observation_dim

u214892's avatar
u214892 committed
293
    def _explore_branch(self, handle, position, direction, tot_dist, depth):
294
295
        """
        Utility function to compute tree-based observations.
296
297
        We walk along the branch and collect the information documented in the get() function.
        If there is a branching point a new node is created and each possible branch is explored.
298
299
300
        """
        # [Recursive branch opened]
        if depth >= self.max_depth + 1:
301
            return [], []
302
303
304
305
306

        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops
        exploring = True
u214892's avatar
u214892 committed
307
308
309
310
        last_is_switch = False
        last_is_dead_end = False
        last_is_terminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
        last_is_target = False
311
312

        visited = set()
Erik Nygren's avatar
Erik Nygren committed
313
314
        agent = self.env.agents[handle]
        own_target_encountered = np.inf
315
316
        other_agent_encountered = np.inf
        other_target_encountered = np.inf
317
318
        potential_conflict = np.inf
        unusable_switch = np.inf
319
320
        other_agent_same_direction = 0
        other_agent_opposite_direction = 0
321

322
323
324
325
326
327
328
        num_steps = 1
        while exploring:
            # #############################
            # #############################
            # Modify here to compute any useful data required to build the end node's features. This code is called
            # for each cell visited between the previous branching node and the next switch / target / dead-end.
            if position in self.location_has_agent:
329
330
                if tot_dist < other_agent_encountered:
                    other_agent_encountered = tot_dist
331

332
333
334
335
336
337
338
339
                if self.location_has_agent_direction[position] == direction:
                    # Cummulate the number of agents on branch with same direction
                    other_agent_same_direction += 1

                if self.location_has_agent_direction[position] != direction:
                    # Cummulate the number of agents on branch with other direction
                    other_agent_opposite_direction += 1

340
            # Register possible future conflict
341
            if self.predictor and num_steps < self.max_prediction_depth:
342
                int_position = coordinate_to_position(self.env.width, [position])
343
344
345
346
347
                if tot_dist < self.max_prediction_depth:
                    pre_step = max(0, tot_dist - 1)
                    post_step = min(self.max_prediction_depth - 1, tot_dist + 1)

                    # Look for opposing paths at distance num_step
348
349
350
351
352
                    if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
                        for ca in conflicting_agent[0]:

                            if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict:
353
                                potential_conflict = tot_dist
354
                    # Look for opposing paths at distance num_step-1
355
                    elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
356
                        conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
357
358
                        for ca in conflicting_agent[0]:
                            if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict:
359
                                potential_conflict = tot_dist
360
                    # Look for opposing paths at distance num_step+1
361
362
363
364
                    elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                        for ca in conflicting_agent[0]:
                            if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict:
365
                                potential_conflict = tot_dist
366

Erik Nygren's avatar
Erik Nygren committed
367
            if position in self.location_has_target and position != agent.target:
368
369
                if tot_dist < other_target_encountered:
                    other_target_encountered = tot_dist
Erik Nygren's avatar
Erik Nygren committed
370

u214892's avatar
u214892 committed
371
372
            if position == agent.target and tot_dist < own_target_encountered:
                own_target_encountered = tot_dist
Erik Nygren's avatar
Erik Nygren committed
373

374
375
376
            # #############################
            # #############################
            if (position[0], position[1], direction) in visited:
u214892's avatar
u214892 committed
377
                last_is_terminal = True
378
379
380
381
382
                break
            visited.add((position[0], position[1], direction))

            # If the target node is encountered, pick that as node. Also, no further branching is possible.
            if np.array_equal(position, self.env.agents[handle].target):
u214892's avatar
u214892 committed
383
                last_is_target = True
384
385
386
                break

            cell_transitions = self.env.rail.get_transitions((*position, direction))
387
            total_transitions = bin(self.env.rail.get_transitions(position)).count("1")
388
389
            num_transitions = np.count_nonzero(cell_transitions)
            exploring = False
390
391
392
393
            # Detect Switches that can only be used by other agents.
            if total_transitions > 2 > num_transitions:
                unusable_switch = tot_dist

394
395
396
397
398
399
400
401
402
            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
                nbits = 0
                tmp = self.env.rail.get_transitions(tuple(position))
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                if nbits == 1:
                    # Dead-end!
u214892's avatar
u214892 committed
403
                    last_is_dead_end = True
404

u214892's avatar
u214892 committed
405
                if not last_is_dead_end:
406
407
                    # Keep walking through the tree along `direction'
                    exploring = True
408
                    # convert one-hot encoding to 0,1,2,3
409
410
411
                    direction = np.argmax(cell_transitions)
                    position = self._new_position(position, direction)
                    num_steps += 1
412
                    tot_dist += 1
413
414
            elif num_transitions > 0:
                # Switch detected
u214892's avatar
u214892 committed
415
                last_is_switch = True
416
417
418
419
420
421
                break

            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
                      position[1], direction)
u214892's avatar
u214892 committed
422
                last_is_terminal = True
423
424
425
426
427
428
429
430
431
                break

        # `position' is either a terminal node or a switch

        observation = []

        # #############################
        # #############################
        # Modify here to append new / different features for each visited cell!
432

u214892's avatar
u214892 committed
433
        if last_is_target:
Erik Nygren's avatar
Erik Nygren committed
434
            observation = [own_target_encountered,
435
436
                           other_target_encountered,
                           other_agent_encountered,
437
438
439
                           potential_conflict,
                           unusable_switch,
                           tot_dist,
u214892's avatar
u214892 committed
440
                           0,
441
                           other_agent_same_direction,
442
                           other_agent_opposite_direction
u214892's avatar
u214892 committed
443
                           ]
444

u214892's avatar
u214892 committed
445
        elif last_is_terminal:
Erik Nygren's avatar
Erik Nygren committed
446
            observation = [own_target_encountered,
447
448
                           other_target_encountered,
                           other_agent_encountered,
449
450
                           potential_conflict,
                           unusable_switch,
451
                           np.inf,
452
                           self.distance_map[handle, position[0], position[1], direction],
453
                           other_agent_same_direction,
454
                           other_agent_opposite_direction
u214892's avatar
u214892 committed
455
                           ]
456
        else:
Erik Nygren's avatar
Erik Nygren committed
457
            observation = [own_target_encountered,
458
459
                           other_target_encountered,
                           other_agent_encountered,
460
461
462
                           potential_conflict,
                           unusable_switch,
                           tot_dist,
u214892's avatar
u214892 committed
463
                           self.distance_map[handle, position[0], position[1], direction],
464
                           other_agent_same_direction,
465
                           other_agent_opposite_direction,
u214892's avatar
u214892 committed
466
                           ]
467
468
469
470
471
472
473
        # #############################
        # #############################
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # Get the possible transitions
        possible_transitions = self.env.rail.get_transitions((*position, direction))
        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
u214892's avatar
u214892 committed
474
475
            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
                                                                 (branch_direction + 2) % 4):
476
477
478
                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                # it back
                new_cell = self._new_position(position, (branch_direction + 2) % 4)
479
480
481
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          (branch_direction + 2) % 4,
u214892's avatar
u214892 committed
482
                                                                          tot_dist + 1,
483
                                                                          depth + 1)
484
                observation = observation + branch_observation
485
                if len(branch_visited) != 0:
486
                    visited = visited.union(branch_visited)
u214892's avatar
u214892 committed
487
            elif last_is_switch and possible_transitions[branch_direction]:
488
                new_cell = self._new_position(position, branch_direction)
489
490
491
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          branch_direction,
u214892's avatar
u214892 committed
492
                                                                          tot_dist + 1,
493
                                                                          depth + 1)
494
                observation = observation + branch_observation
495
                if len(branch_visited) != 0:
496
                    visited = visited.union(branch_visited)
497
            else:
498
499
                # no exploring possible, add just cells with infinity
                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
500

501
        return observation, visited
502

503
    def util_print_obs_subtree(self, tree):
504
505
506
        """
        Utility function to pretty-print tree observations returned by this object.
        """
507
508
509
510
511
512
513
514
        pp = pprint.PrettyPrinter(indent=4)
        pp.pprint(self.unfold_observation_tree(tree))

    def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
        """
        Utility function to pretty-print tree observations returned by this object.
        """
        if len(tree) < self.observation_dim:
515
516
517
            return

        depth = 0
518
        tmp = len(tree) / self.observation_dim - 1
519
520
521
522
523
524
        pow4 = 4
        while tmp > 0:
            tmp -= pow4
            depth += 1
            pow4 *= 4

525
526
527
528
529
530
531
532
533
534
535
536
537
538
        unfolded = {}
        unfolded[''] = tree[0:self.observation_dim]
        child_size = (len(tree) - self.observation_dim) // 4
        for child in range(4):
            child_tree = tree[(self.observation_dim + child * child_size):
                              (self.observation_dim + (child + 1) * child_size)]
            observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
            if observation_tree is not None:
                if actions_for_display:
                    label = self.tree_explorted_actions_char[child]
                else:
                    label = self.tree_explored_actions[child]
                unfolded[label] = observation_tree
        return unfolded
539

540
541
542
543
544
    def _set_env(self, env):
        self.env = env
        if self.predictor:
            self.predictor._set_env(self.env)

545
546
547
548
549
550
551
552
553

class GlobalObsForRailEnv(ObservationBuilder):
    """
    Gives a global observation of the entire rail environment.
    The observation is composed of the following elements:

        - transition map array with dimensions (env.height, env.width, 16),
          assuming 16 bits encoding of transitions.

554
555
        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
         target and the positions of the other agents targets.
556

557
558
        - A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding
          of the direction of the given agent and the 4 second channels containing the positions
559
          of the other agents at their position coordinates.
560
561
562
    """

    def __init__(self):
563
        self.observation_space = ()
564
565
        super(GlobalObsForRailEnv, self).__init__()

566
567
568
569
570
    def _set_env(self, env):
        super()._set_env(env)

        self.observation_space = [4, self.env.height, self.env.width]

571
572
573
574
    def reset(self):
        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
        for i in range(self.rail_obs.shape[0]):
            for j in range(self.rail_obs.shape[1]):
spiglerg's avatar
spiglerg committed
575
576
577
                bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
                bitlist = [0] * (16 - len(bitlist)) + bitlist
                self.rail_obs[i, j] = np.array(bitlist)
578
579

    def get(self, handle):
580
581
        obs_targets = np.zeros((self.env.height, self.env.width, 2))
        obs_agents_state = np.zeros((self.env.height, self.env.width, 8))
582
583
584
        agents = self.env.agents
        agent = agents[handle]

585
586
        direction = np.zeros(4)
        direction[agent.direction] = 1
587
        agent_pos = agents[handle].position
588
589
        obs_agents_state[agent_pos][:4] = direction
        obs_targets[agent.target][0] += 1
590
591
592
593

        for i in range(len(agents)):
            if i != handle:  # TODO: handle used as index...?
                agent2 = agents[i]
594
595
                obs_agents_state[agent2.position][4 + agent2.direction] = 1
                obs_targets[agent2.target][1] += 1
596

u214892's avatar
u214892 committed
597
598
599
        direction = self._get_one_hot_for_agent_direction(agent)

        return self.rail_obs, obs_agents_state, obs_targets, direction
600
601
602
603
604
605
606
607
608


class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
    """
    Gives a global observation of the entire rail environment.
    The observation is composed of the following elements:

        - transition map array with dimensions (env.height, env.width, 16),
          assuming 16 bits encoding of transitions, flipped in the direction of the agent
u214892's avatar
u214892 committed
609
          (the agent is always heading north on the flipped view).
610
611
612

        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
         target and the positions of the other agents targets, also flipped depending on the agent's direction.
613

614
615
        - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
          agents at their position coordinates, and the last channel containing the position of the given agent.
u214892's avatar
u214892 committed
616
617

        - A 4 elements array with one hot encoding of the direction.
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
    """

    def __init__(self):
        self.observation_space = ()
        super(GlobalObsForRailEnvDirectionDependent, self).__init__()

    def _set_env(self, env):
        super()._set_env(env)

        self.observation_space = [4, self.env.height, self.env.width]

    def reset(self):
        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
        for i in range(self.rail_obs.shape[0]):
            for j in range(self.rail_obs.shape[1]):
                bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
                bitlist = [0] * (16 - len(bitlist)) + bitlist
                self.rail_obs[i, j] = np.array(bitlist)

    def get(self, handle):
        obs_targets = np.zeros((self.env.height, self.env.width, 2))
        obs_agents_state = np.zeros((self.env.height, self.env.width, 5))
        agents = self.env.agents
        agent = agents[handle]
        direction = agent.direction

        idx = np.tile(np.arange(16), 2)

        rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]]

        if direction == 1:
            rail_obs = np.flip(rail_obs, axis=1)
        elif direction == 2:
            rail_obs = np.flip(rail_obs)
        elif direction == 3:
            rail_obs = np.flip(rail_obs, axis=0)

        agent_pos = agents[handle].position
        obs_agents_state[agent_pos][0] = 1
        obs_targets[agent.target][0] += 1

        idx = np.tile(np.arange(4), 2)
        for i in range(len(agents)):
            if i != handle:  # TODO: handle used as index...?
                agent2 = agents[i]
                obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
                obs_targets[agent2.target][1] += 1

u214892's avatar
u214892 committed
666
667
668
        direction = self._get_one_hot_for_agent_direction(agent)

        return rail_obs, obs_agents_state, obs_targets, direction
669

670
671
672

class LocalObsForRailEnv(ObservationBuilder):
    """
673
    Gives a local observation of the rail environment around the agent.
674
675
676
677
678
679
    The observation is composed of the following elements:

        - transition map array of the local environment around the given agent,
          with dimensions (2*view_radius + 1, 2*view_radius + 1, 16),
          assuming 16 bits encoding of transitions.

680
681
682
683
684
        - Two 2D arrays containing respectively, if they are in the agent's vision range,
          its target position, the positions of the other targets.

        - A 3D array (map_height, map_width, 4) containing the one hot encoding of directions
          of the other agents at their position coordinates, if they are in the agent's vision range.
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699

        - A 4 elements array with one hot encoding of the direction.
    """

    def __init__(self, view_radius):
        """
        :param view_radius:
        """
        super(LocalObsForRailEnv, self).__init__()
        self.view_radius = view_radius

    def reset(self):
        # We build the transition map with a view_radius empty cells expansion on each side.
        # This helps to collect the local transition map view when the agent is close to a border.

u214892's avatar
u214892 committed
700
701
        self.rail_obs = np.zeros((self.env.height + 2 * self.view_radius,
                                  self.env.width + 2 * self.view_radius, 16))
702
703
704
705
706
        for i in range(self.env.height):
            for j in range(self.env.width):
                bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
                bitlist = [0] * (16 - len(bitlist)) + bitlist
                self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
707
708
709
710
711

    def get(self, handle):
        agents = self.env.agents
        agent = agents[handle]

u214892's avatar
u214892 committed
712
713
        local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0] + 2 * self.view_radius + 1,
                         agent.position[1]:agent.position[1] + 2 * self.view_radius + 1]
714

u214892's avatar
u214892 committed
715
        obs_map_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 2))
716

u214892's avatar
u214892 committed
717
        obs_other_agents_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 4))
718
719
720
721
722

        def relative_pos(pos):
            return [agent.position[0] - pos[0], agent.position[1] - pos[1]]

        def is_in(rel_pos):
723
            return (abs(rel_pos[0]) <= self.view_radius) and (abs(rel_pos[1]) <= self.view_radius)
724
725
726

        target_rel_pos = relative_pos(agent.target)
        if is_in(target_rel_pos):
727
            obs_map_state[self.view_radius + np.array(target_rel_pos)][0] += 1
728
729
730
731
732
733
734

        for i in range(len(agents)):
            if i != handle:  # TODO: handle used as index...?
                agent2 = agents[i]

                agent_2_rel_pos = relative_pos(agent2.position)
                if is_in(agent_2_rel_pos):
735
736
                    obs_other_agents_state[self.view_radius + agent_2_rel_pos[0],
                                           self.view_radius + agent_2_rel_pos[1]][agent2.direction] += 1
737
738
739

                target_rel_pos_2 = relative_pos(agent2.position)
                if is_in(target_rel_pos_2):
740
                    obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1
741

u214892's avatar
u214892 committed
742
        direction = self._get_one_hot_for_agent_direction(agent)
743

744
        return local_rail_obs, obs_map_state, obs_other_agents_state, direction