observations.py 32.6 KB
Newer Older
1
2
3
"""
Collection of environment-specific ObservationBuilder.
"""
4
import pprint
5

u214892's avatar
u214892 committed
6
7
import numpy as np

8
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
9
from flatland.core.grid.grid4 import Grid4TransitionsEnum
10
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
11
from flatland.core.grid.grid_utils import coordinate_to_position
12
13
14
15
16
17
18


class TreeObsForRailEnv(ObservationBuilder):
    """
    TreeObsForRailEnv object.

    This object returns observation vectors for agents in the RailEnv environment.
19
    The information is local to each agent and exploits the graph structure of the rail
20
    network to simplify the representation of the state of the environment for each agent.
21
22

    For details about the features in the tree observation see the get() function.
23
24
    """

25
    def __init__(self, max_depth, predictor=None):
u214892's avatar
u214892 committed
26
        super().__init__()
27
        self.max_depth = max_depth
Erik Nygren's avatar
Erik Nygren committed
28
        self.observation_dim = 11
29
30
31
        # Compute the size of the returned observation vector
        size = 0
        pow4 = 1
u214892's avatar
u214892 committed
32
        for i in range(self.max_depth + 1):
33
34
            size += pow4
            pow4 *= 4
35
36
37
        self.observation_space = [size * self.observation_dim]
        self.location_has_agent = {}
        self.location_has_agent_direction = {}
38
        self.predictor = predictor
spiglerg's avatar
spiglerg committed
39
        self.agents_previous_reset = None
40
41
        self.tree_explored_actions = [1, 2, 3, 0]
        self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
spiglerg's avatar
spiglerg committed
42

43
44
    def reset(self):
        agents = self.env.agents
u214892's avatar
u214892 committed
45
        nb_agents = len(agents)
spiglerg's avatar
spiglerg committed
46
        compute_distance_map = True
u214892's avatar
u214892 committed
47
48
49
50
51
        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
            compute_distance_map = False
            for i in range(nb_agents):
                if agents[i].target != self.agents_previous_reset[i].target:
                    compute_distance_map = True
52
        # Don't compute the distance map if it was loaded
53
        if self.agents_previous_reset is None and self.env.distance_map.get() is not None:
54
55
            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
            compute_distance_map = False
spiglerg's avatar
spiglerg committed
56
57

        if compute_distance_map:
58
            self.env.compute_distance_map()
spiglerg's avatar
spiglerg committed
59

60
61
        self.agents_previous_reset = agents

u214892's avatar
u214892 committed
62
    def get_many(self, handles=None):
63
64
65
66
67
        """
        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
        in the `handles' list.
        """

u214892's avatar
u214892 committed
68
69
        if handles is None:
            handles = []
70
        if self.predictor:
71
            self.max_prediction_depth = 0
72
73
            self.predicted_pos = {}
            self.predicted_dir = {}
74
            self.predictions = self.predictor.get()
75
76
77
78
79
80
81
82
83
84
85
            if self.predictions:

                for t in range(len(self.predictions[0])):
                    pos_list = []
                    dir_list = []
                    for a in handles:
                        pos_list.append(self.predictions[a][t][1:3])
                        dir_list.append(self.predictions[a][t][3])
                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
                    self.predicted_dir.update({t: dir_list})
                self.max_prediction_depth = len(self.predicted_pos)
86
87
88
89
90
        observations = {}
        for h in handles:
            observations[h] = self.get(h)
        return observations

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    def get(self, handle):
        """
        Computes the current observation for agent `handle' in env

        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
        the transitions. The order is:
            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']

        Each branch data is organized as:
            [root node information] +
            [recursive branch data from 'left'] +
            [... from 'forward'] +
            [... from 'right] +
            [... from 'back']

108
        Each node information is composed of 9 features:
109

110
        #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
111

112
        #2: if another agents target is detected the distance in number of cells from the agents current location
113
            is stored
114

115
        #3: if another agent is detected the distance in number of cells from current agent position is stored.
116

117
118
119
        #4: possible conflict detected
            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
             distance in number of cells from current agent position
120

121
122
123
124
125
126
127
            0 = No other agent reserve the same cell at similar time

        #5: if an not usable switch (for agent) is detected we store the distance.

        #6: This feature stores the distance in number of cells to the next branching  (current node)

        #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
128

129
        #8: agent in the same direction
130
            n = number of agents present same direction
131
132
                (possible future use: number of other agents in the same direction in this branch)
            0 = no agent present same direction
u214892's avatar
u214892 committed
133

u214892's avatar
u214892 committed
134
        #9: agent in the opposite direction
135
            n = number of agents present other direction than myself (so conflict)
u214892's avatar
u214892 committed
136
                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
137
            0 = no agent present other direction than myself
u214892's avatar
u214892 committed
138

139
140
141
142
143
144
145
146
        #10: malfunctioning/blokcing agents
            n = number of time steps the oberved agent remains blocked

        #11: slowest observed speed of an agent in same direction
            1 if no agent is observed

            min_fractional speed otherwise

147
148
149
150
        Missing/padding nodes are filled in with -inf (truncated).
        Missing values in present node are filled in with +inf (truncated).


Erik Nygren's avatar
Erik Nygren committed
151
        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
152
153
154
155
156
        In case the target node is reached, the values are [0, 0, 0, 0, 0].
        """

        # Update local lookup table for all agents' positions
        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
157
        self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
158
159
160
161
        self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
        self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
                                               self.env.agents}

162
163
        if handle > len(self.env.agents):
            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
164
        agent = self.env.agents[handle]  # TODO: handle being treated as index
u214892's avatar
u214892 committed
165
        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
166
        num_transitions = np.count_nonzero(possible_transitions)
Erik Nygren's avatar
Erik Nygren committed
167

168
        # Root node - current position
Erik Nygren's avatar
Erik Nygren committed
169
        # Here information about the agent itself is stored
170
        observation = [0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent.position, agent.direction)], 0, 0,
Erik Nygren's avatar
Erik Nygren committed
171
                       agent.malfunction_data['malfunction'], agent.speed_data['speed']]
172

173
        visited = set()
174

175
176
177
178
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
        orientation = agent.direction
179

180
        if num_transitions == 1:
181
            orientation = np.argmax(possible_transitions)
182
183
184

        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
            if possible_transitions[branch_direction]:
185
                new_cell = get_new_position(agent.position, branch_direction)
186
                branch_observation, branch_visited = \
u214892's avatar
u214892 committed
187
                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
188
                observation = observation + branch_observation
189
                visited = visited.union(branch_visited)
190
            else:
191
192
                # add cells filled with infinity if no transition is possible
                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
193
        self.env.dev_obs_dict[handle] = visited
194

195
        return observation
196

197
198
199
200
201
202
203
204
205
    def _num_cells_to_fill_in(self, remaining_depth):
        """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
        num_observations = 0
        pow4 = 1
        for i in range(remaining_depth):
            num_observations += pow4
            pow4 *= 4
        return num_observations * self.observation_dim

u214892's avatar
u214892 committed
206
    def _explore_branch(self, handle, position, direction, tot_dist, depth):
207
208
        """
        Utility function to compute tree-based observations.
209
210
        We walk along the branch and collect the information documented in the get() function.
        If there is a branching point a new node is created and each possible branch is explored.
211
        """
212

213
214
        # [Recursive branch opened]
        if depth >= self.max_depth + 1:
215
            return [], []
216
217
218
219
220

        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops
        exploring = True
u214892's avatar
u214892 committed
221
222
223
224
        last_is_switch = False
        last_is_dead_end = False
        last_is_terminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
        last_is_target = False
225
226

        visited = set()
Erik Nygren's avatar
Erik Nygren committed
227
        agent = self.env.agents[handle]
228
        time_per_cell = np.reciprocal(agent.speed_data["speed"])
Erik Nygren's avatar
Erik Nygren committed
229
        own_target_encountered = np.inf
230
231
        other_agent_encountered = np.inf
        other_target_encountered = np.inf
232
233
        potential_conflict = np.inf
        unusable_switch = np.inf
234
235
        other_agent_same_direction = 0
        other_agent_opposite_direction = 0
236
237
        malfunctioning_agent = 0
        min_fractional_speed = 1.
238
239
240
241
242
243
244
        num_steps = 1
        while exploring:
            # #############################
            # #############################
            # Modify here to compute any useful data required to build the end node's features. This code is called
            # for each cell visited between the previous branching node and the next switch / target / dead-end.
            if position in self.location_has_agent:
245
246
                if tot_dist < other_agent_encountered:
                    other_agent_encountered = tot_dist
247

Erik Nygren's avatar
Erik Nygren committed
248
249
250
251
                # Check if any of the observed agents is malfunctioning, store agent with longest duration left
                if self.location_has_agent_malfunction[position] > malfunctioning_agent:
                    malfunctioning_agent = self.location_has_agent_malfunction[position]

252
253
254
255
                if self.location_has_agent_direction[position] == direction:
                    # Cummulate the number of agents on branch with same direction
                    other_agent_same_direction += 1

256
257
258
259
                    # Check fractional speed of agents
                    current_fractional_speed = self.location_has_agent_speed[position]
                    if current_fractional_speed < min_fractional_speed:
                        min_fractional_speed = current_fractional_speed
Erik Nygren's avatar
Erik Nygren committed
260

261
262
263
264
                if self.location_has_agent_direction[position] != direction:
                    # Cummulate the number of agents on branch with other direction
                    other_agent_opposite_direction += 1

265
266
267
268
269
270
271
272
            # Check number of possible transitions for agent and total number of transitions in cell (type)
            cell_transitions = self.env.rail.get_transitions(*position, direction)
            transition_bit = bin(self.env.rail.get_full_transitions(*position))
            total_transitions = transition_bit.count("1")
            crossing_found = False
            if int(transition_bit, 2) == int('1000010000100001', 2):
                crossing_found = True

273
            # Register possible future conflict
274
275
            predicted_time = int(tot_dist * time_per_cell)
            if self.predictor and predicted_time < self.max_prediction_depth:
276
                int_position = coordinate_to_position(self.env.width, [position])
277
                if tot_dist < self.max_prediction_depth:
278
279
280

                    pre_step = max(0, predicted_time - 1)
                    post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
281

282
                    # Look for conflicting paths at distance tot_dist
283
284
                    if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
285
                        for ca in conflicting_agent[0]:
286
287
288
                            if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
                                self._reverse_dir(
                                    self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
289
                                potential_conflict = tot_dist
290
                            if self.env.dones[ca] and tot_dist < potential_conflict:
291
                                potential_conflict = tot_dist
292
293

                    # Look for conflicting paths at distance num_step-1
294
                    elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
295
                        conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
296
                        for ca in conflicting_agent[0]:
297
298
299
                            if direction != self.predicted_dir[pre_step][ca] \
                                and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
300
                                potential_conflict = tot_dist
301
302
                            if self.env.dones[ca] and tot_dist < potential_conflict:
                                potential_conflict = tot_dist
303
304

                    # Look for conflicting paths at distance num_step+1
305
306
307
                    elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                        for ca in conflicting_agent[0]:
308
                            if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
309
310
                                self.predicted_dir[post_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
311
                                potential_conflict = tot_dist
312
313
                            if self.env.dones[ca] and tot_dist < potential_conflict:
                                potential_conflict = tot_dist
314

Erik Nygren's avatar
Erik Nygren committed
315
            if position in self.location_has_target and position != agent.target:
316
317
                if tot_dist < other_target_encountered:
                    other_target_encountered = tot_dist
Erik Nygren's avatar
Erik Nygren committed
318

u214892's avatar
u214892 committed
319
320
            if position == agent.target and tot_dist < own_target_encountered:
                own_target_encountered = tot_dist
Erik Nygren's avatar
Erik Nygren committed
321

322
323
324
            # #############################
            # #############################
            if (position[0], position[1], direction) in visited:
u214892's avatar
u214892 committed
325
                last_is_terminal = True
326
327
328
329
330
                break
            visited.add((position[0], position[1], direction))

            # If the target node is encountered, pick that as node. Also, no further branching is possible.
            if np.array_equal(position, self.env.agents[handle].target):
u214892's avatar
u214892 committed
331
                last_is_target = True
332
333
                break

334
            # Check if crossing is found --> Not an unusable switch
335
            if crossing_found:
336
337
                # Treat the crossing as a straight rail cell
                total_transitions = 2
338
            num_transitions = np.count_nonzero(cell_transitions)
339

340
            exploring = False
341

342
            # Detect Switches that can only be used by other agents.
343
            if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
344
345
                unusable_switch = tot_dist

346
347
            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
348
                nbits = total_transitions
349
350
                if nbits == 1:
                    # Dead-end!
351
                    last_is_dead_end = True
352

u214892's avatar
u214892 committed
353
                if not last_is_dead_end:
354
355
                    # Keep walking through the tree along `direction'
                    exploring = True
356
                    # convert one-hot encoding to 0,1,2,3
357
                    direction = np.argmax(cell_transitions)
358
                    position = get_new_position(position, direction)
359
                    num_steps += 1
360
                    tot_dist += 1
361
362
            elif num_transitions > 0:
                # Switch detected
u214892's avatar
u214892 committed
363
                last_is_switch = True
364
365
366
367
368
369
                break

            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
                      position[1], direction)
u214892's avatar
u214892 committed
370
                last_is_terminal = True
371
372
373
374
375
376
377
                break

        # `position' is either a terminal node or a switch

        # #############################
        # #############################
        # Modify here to append new / different features for each visited cell!
378

u214892's avatar
u214892 committed
379
        if last_is_target:
Erik Nygren's avatar
Erik Nygren committed
380
            observation = [own_target_encountered,
381
382
                           other_target_encountered,
                           other_agent_encountered,
383
384
385
                           potential_conflict,
                           unusable_switch,
                           tot_dist,
u214892's avatar
u214892 committed
386
                           0,
387
                           other_agent_same_direction,
Erik Nygren's avatar
Erik Nygren committed
388
389
390
                           other_agent_opposite_direction,
                           malfunctioning_agent,
                           min_fractional_speed
u214892's avatar
u214892 committed
391
                           ]
392

u214892's avatar
u214892 committed
393
        elif last_is_terminal:
Erik Nygren's avatar
Erik Nygren committed
394
            observation = [own_target_encountered,
395
396
                           other_target_encountered,
                           other_agent_encountered,
397
398
                           potential_conflict,
                           unusable_switch,
399
                           np.inf,
400
                           self.env.distance_map.get()[handle, position[0], position[1], direction],
401
                           other_agent_same_direction,
Erik Nygren's avatar
Erik Nygren committed
402
403
404
                           other_agent_opposite_direction,
                           malfunctioning_agent,
                           min_fractional_speed
u214892's avatar
u214892 committed
405
                           ]
406

407
        else:
Erik Nygren's avatar
Erik Nygren committed
408
            observation = [own_target_encountered,
409
410
                           other_target_encountered,
                           other_agent_encountered,
411
412
413
                           potential_conflict,
                           unusable_switch,
                           tot_dist,
414
                           self.env.distance_map.get()[handle, position[0], position[1], direction],
415
                           other_agent_same_direction,
416
                           other_agent_opposite_direction,
Erik Nygren's avatar
Erik Nygren committed
417
418
                           malfunctioning_agent,
                           min_fractional_speed
u214892's avatar
u214892 committed
419
                           ]
420
421
422
423
424
        # #############################
        # #############################
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # Get the possible transitions
u214892's avatar
u214892 committed
425
        possible_transitions = self.env.rail.get_transitions(*position, direction)
426
        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
u214892's avatar
u214892 committed
427
428
            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
                                                                 (branch_direction + 2) % 4):
429
430
                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                # it back
431
                new_cell = get_new_position(position, (branch_direction + 2) % 4)
432
433
434
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          (branch_direction + 2) % 4,
u214892's avatar
u214892 committed
435
                                                                          tot_dist + 1,
436
                                                                          depth + 1)
437
                observation = observation + branch_observation
438
                if len(branch_visited) != 0:
439
                    visited = visited.union(branch_visited)
u214892's avatar
u214892 committed
440
            elif last_is_switch and possible_transitions[branch_direction]:
441
                new_cell = get_new_position(position, branch_direction)
442
443
444
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          branch_direction,
u214892's avatar
u214892 committed
445
                                                                          tot_dist + 1,
446
                                                                          depth + 1)
447
                observation = observation + branch_observation
448
                if len(branch_visited) != 0:
449
                    visited = visited.union(branch_visited)
450
            else:
451
452
                # no exploring possible, add just cells with infinity
                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
453

454
        return observation, visited
455

456
    def util_print_obs_subtree(self, tree):
457
458
459
        """
        Utility function to pretty-print tree observations returned by this object.
        """
460
461
462
463
464
465
466
467
        pp = pprint.PrettyPrinter(indent=4)
        pp.pprint(self.unfold_observation_tree(tree))

    def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
        """
        Utility function to pretty-print tree observations returned by this object.
        """
        if len(tree) < self.observation_dim:
468
469
470
            return

        depth = 0
471
        tmp = len(tree) / self.observation_dim - 1
472
473
474
475
476
477
        pow4 = 4
        while tmp > 0:
            tmp -= pow4
            depth += 1
            pow4 *= 4

478
479
480
481
482
483
484
485
486
487
488
489
490
491
        unfolded = {}
        unfolded[''] = tree[0:self.observation_dim]
        child_size = (len(tree) - self.observation_dim) // 4
        for child in range(4):
            child_tree = tree[(self.observation_dim + child * child_size):
                              (self.observation_dim + (child + 1) * child_size)]
            observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
            if observation_tree is not None:
                if actions_for_display:
                    label = self.tree_explorted_actions_char[child]
                else:
                    label = self.tree_explored_actions[child]
                unfolded[label] = observation_tree
        return unfolded
492

493
494
495
496
497
    def _set_env(self, env):
        self.env = env
        if self.predictor:
            self.predictor._set_env(self.env)

498
499
500
    def _reverse_dir(self, direction):
        return int((direction + 2) % 4)

501
502
503
504
505
506
507
508
509

class GlobalObsForRailEnv(ObservationBuilder):
    """
    Gives a global observation of the entire rail environment.
    The observation is composed of the following elements:

        - transition map array with dimensions (env.height, env.width, 16),
          assuming 16 bits encoding of transitions.

510
511
        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
         target and the positions of the other agents targets.
512

513
514
515
516
517
        - A 3D array (map_height, map_width, 4) wtih
            - first channel containing the agents position and direction
            - second channel containing the other agents positions and diretions
            - third channel containing agent malfunctions
            - fourth channel containing agent fractional speeds
518
519
520
    """

    def __init__(self):
521
        self.observation_space = ()
522
523
        super(GlobalObsForRailEnv, self).__init__()

524
525
526
527
528
    def _set_env(self, env):
        super()._set_env(env)

        self.observation_space = [4, self.env.height, self.env.width]

529
530
531
532
    def reset(self):
        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
        for i in range(self.rail_obs.shape[0]):
            for j in range(self.rail_obs.shape[1]):
u214892's avatar
u214892 committed
533
                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
spiglerg's avatar
spiglerg committed
534
535
                bitlist = [0] * (16 - len(bitlist)) + bitlist
                self.rail_obs[i, j] = np.array(bitlist)
536
537

    def get(self, handle):
538
        obs_targets = np.zeros((self.env.height, self.env.width, 2))
539
        obs_agents_state = np.zeros((self.env.height, self.env.width, 4))
540
541
542
543
        agents = self.env.agents
        agent = agents[handle]

        agent_pos = agents[handle].position
544
        obs_agents_state[agent_pos][0] = agents[handle].direction
545
        obs_targets[agent.target][0] = 1
546
547
548
549

        for i in range(len(agents)):
            if i != handle:  # TODO: handle used as index...?
                agent2 = agents[i]
550
551
                obs_agents_state[agent2.position][1] = agent2.direction
                obs_targets[agent2.target][1] = 1
552
553
            obs_agents_state[agents[i].position][2] = agents[i].malfunction_data['malfunction']
            obs_agents_state[agents[i].position][3] = agents[i].speed_data['speed']
554

555
        return self.rail_obs, obs_agents_state, obs_targets
556

557
558
559

class LocalObsForRailEnv(ObservationBuilder):
    """
560
    !!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!!
561
    Gives a local observation of the rail environment around the agent.
562
563
564
    The observation is composed of the following elements:

        - transition map array of the local environment around the given agent,
565
          with dimensions (view_height,2*view_width+1, 16),
566
567
          assuming 16 bits encoding of transitions.

568
        - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively,
Erik Nygren's avatar
Erik Nygren committed
569
        if they are in the agent's vision range, its target position, the positions of the other targets.
570

571
        - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
572
          of the other agents at their position coordinates, if they are in the agent's vision range.
573
574

        - A 4 elements array with one hot encoding of the direction.
575
576
577
578

    Use the parameters view_width and view_height to define the rectangular view of the agent.
    The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
    observation in front of it.
579
580
    """

581
    def __init__(self, view_width, view_height, center):
582

583
        super(LocalObsForRailEnv, self).__init__()
584
585
586
587
        self.view_width = view_width
        self.view_height = view_height
        self.center = center
        self.max_padding = max(self.view_width, self.view_height - self.center)
588
589
590
591

    def reset(self):
        # We build the transition map with a view_radius empty cells expansion on each side.
        # This helps to collect the local transition map view when the agent is close to a border.
592
        self.max_padding = max(self.view_width, self.view_height)
593
594
        self.rail_obs = np.zeros((self.env.height,
                                  self.env.width, 16))
595
596
        for i in range(self.env.height):
            for j in range(self.env.width):
u214892's avatar
u214892 committed
597
                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
598
                bitlist = [0] * (16 - len(bitlist)) + bitlist
599
                self.rail_obs[i, j] = np.array(bitlist)
600
601
602
603
604

    def get(self, handle):
        agents = self.env.agents
        agent = agents[handle]

605
        # Correct agents position for padding
606
607
        # agent_rel_pos[0] = agent.position[0] + self.max_padding
        # agent_rel_pos[1] = agent.position[1] + self.max_padding
608

609
        # Collect visible cells as set to be plotted
610
611
        visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
        local_rail_obs = None
612

613
        # Add the visible cells to the observed cells
614
        self.env.dev_obs_dict[handle] = set(visited)
615

616
        # Locate observed agents and their coresponding targets
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
        local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
        obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
        obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
        _idx = 0
        for pos in visited:
            curr_rel_coord = rel_coords[_idx]
            local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
            if pos == agent.target:
                obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
            else:
                for tmp_agent in agents:
                    if pos == tmp_agent.target:
                        obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
            if pos != agent.position:
                for tmp_agent in agents:
                    if pos == tmp_agent.position:
                        obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
                            tmp_agent.direction]

            _idx += 1

        direction = np.identity(4)[agent.direction]
639
        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
640

641
642
643
644
645
    def get_many(self, handles=None):
        """
        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
        in the `handles' list.
        """
646

647
648
649
650
        observations = {}
        for h in handles:
            observations[h] = self.get(h)
        return observations
651

652
653
654
655
656
657
658
659
660
661
662
663
664
665
    def field_of_view(self, position, direction, state=None):
        # Compute the local field of view for an agent in the environment
        data_collection = False
        if state is not None:
            temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
            data_collection = True
        if direction == 0:
            origin = (position[0] + self.center, position[1] - self.view_width)
        elif direction == 1:
            origin = (position[0] - self.view_width, position[1] - self.center)
        elif direction == 2:
            origin = (position[0] - self.center, position[1] + self.view_width)
        else:
            origin = (position[0] + self.view_width, position[1] + self.center)
666
667
        visible = list()
        rel_coords = list()
668
669
670
671
        for h in range(self.view_height):
            for w in range(2 * self.view_width + 1):
                if direction == 0:
                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
672
673
                        visible.append((origin[0] - h, origin[1] + w))
                        rel_coords.append((h, w))
674
675
                    # if data_collection:
                    #    temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
676
677
                elif direction == 1:
                    if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
678
679
                        visible.append((origin[0] + w, origin[1] + h))
                        rel_coords.append((h, w))
680
681
                    # if data_collection:
                    #    temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
682
                elif direction == 2:
683
                    if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
684
685
                        visible.append((origin[0] + h, origin[1] - w))
                        rel_coords.append((h, w))
686
687
                    # if data_collection:
                    #    temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
688
                else:
689
                    if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
690
691
                        visible.append((origin[0] - w, origin[1] - h))
                        rel_coords.append((h, w))
692
693
                    # if data_collection:
                    #    temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
694
695
696
        if data_collection:
            return temp_visible_data
        else:
697
            return visible, rel_coords