observations.py 36.6 KB
Newer Older
1
2
3
"""
Collection of environment-specific ObservationBuilder.
"""
4
5
import collections
from typing import Optional, List, Dict, Tuple
6

u214892's avatar
u214892 committed
7
8
import numpy as np

9
from flatland.core.env import Environment
10
from flatland.core.env_observation_builder import ObservationBuilder
11
from flatland.core.env_prediction_builder import PredictionBuilder
12
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
13
from flatland.core.grid.grid_utils import coordinate_to_position
14
15
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
u214892's avatar
u214892 committed
16
from flatland.utils.ordered_set import OrderedSet
17
18


19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
                                        'dist_other_target_encountered '
                                        'dist_other_agent_encountered '
                                        'dist_potential_conflict '
                                        'dist_unusable_switch '
                                        'dist_to_next_branch '
                                        'dist_min_to_target '
                                        'num_agents_same_direction '
                                        'num_agents_opposite_direction '
                                        'num_agents_malfunctioning '
                                        'speed_min_fractional '
                                        'num_agents_ready_to_depart '
                                        'childs')

34
class TreeObsForRailEnv(ObservationBuilder):
u214892's avatar
u214892 committed
35
36
37
38
39
40
    """
    TreeObsForRailEnv object.

    This object returns observation vectors for agents in the RailEnv environment.
    The information is local to each agent and exploits the graph structure of the rail
    network to simplify the representation of the state of the environment for each agent.
41

u214892's avatar
u214892 committed
42
43
    For details about the features in the tree observation see the get() function.
    """
44

45

u214892's avatar
u214892 committed
46
    tree_explored_actions_char = ['L', 'F', 'R', 'B']
47

48
    def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
u214892's avatar
u214892 committed
49
        super().__init__()
50
        self.max_depth = max_depth
Erik Nygren's avatar
Erik Nygren committed
51
        self.observation_dim = 11
52
53
        self.location_has_agent = {}
        self.location_has_agent_direction = {}
54
        self.predictor = predictor
55
        self.location_has_target = None
spiglerg's avatar
spiglerg committed
56

57
    def reset(self):
58
        self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
59

60
    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
61
        """
u214892's avatar
u214892 committed
62
63
        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
        in the `handles` list.
64
65
        """

u214892's avatar
u214892 committed
66
67
        if handles is None:
            handles = []
68
        if self.predictor:
69
            self.max_prediction_depth = 0
70
71
            self.predicted_pos = {}
            self.predicted_dir = {}
72
            self.predictions = self.predictor.get()
73
            if self.predictions:
Erik Nygren's avatar
Erik Nygren committed
74
                for t in range(self.predictor.max_depth + 1):
75
76
77
                    pos_list = []
                    dir_list = []
                    for a in handles:
u214892's avatar
u214892 committed
78
                        if self.predictions[a] is None:
u214892's avatar
u214892 committed
79
                            continue
80
81
82
83
84
                        pos_list.append(self.predictions[a][t][1:3])
                        dir_list.append(self.predictions[a][t][3])
                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
                    self.predicted_dir.update({t: dir_list})
                self.max_prediction_depth = len(self.predicted_pos)
85
86
87
88
89
90
91
92
93
94
95
96
        # Update local lookup table for all agents' positions
        # ignore other agents not in the grid (only status active and done)
        # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
        #                         agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}

        self.location_has_agent = {}
        self.location_has_agent_direction = {}
        self.location_has_agent_speed = {}
        self.location_has_agent_malfunction = {}
        self.location_has_agent_ready_to_depart = {}

        for _agent in self.env.agents:
97
            if not TrainState.off_map_state(_agent.state) and \
98
99
100
101
102
103
104
                _agent.position:
                self.location_has_agent[tuple(_agent.position)] = 1
                self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
                self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
                self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
                    'malfunction']

105
            # [NIMISH] WHAT IS THIS
106
            if TrainState.off_map_state(_agent.state) and \
107
                _agent.initial_position:
108
109
110
111
                    self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
                    self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
                # self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
                #     self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
u214892's avatar
u214892 committed
112
113
114

        observations = super().get_many(handles)

115
116
        return observations

117
    def get(self, handle: int = 0) -> Node:
118
        """
u214892's avatar
u214892 committed
119
        Computes the current observation for agent `handle` in env
120
121
122
123

        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
u214892's avatar
u214892 committed
124
125
        the transitions. The order is::

126
127
            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']

u214892's avatar
u214892 committed
128
129
        Each branch data is organized as::

130
131
132
133
134
135
            [root node information] +
            [recursive branch data from 'left'] +
            [... from 'forward'] +
            [... from 'right] +
            [... from 'back']

136
        Each node information is composed of 9 features:
137

u214892's avatar
u214892 committed
138
139
        #1:
            if own target lies on the explored branch the current distance from the agent in number of cells is stored.
140

u214892's avatar
u214892 committed
141
142
        #2:
            if another agents target is detected the distance in number of cells from the agents current location\
143
            is stored
144

u214892's avatar
u214892 committed
145
146
        #3:
            if another agent is detected the distance in number of cells from current agent position is stored.
147

u214892's avatar
u214892 committed
148
149
150
        #4:
            possible conflict detected
            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \
151
             distance in number of cells from current agent position
152

153
154
            0 = No other agent reserve the same cell at similar time

u214892's avatar
u214892 committed
155
156
        #5:
            if an not usable switch (for agent) is detected we store the distance.
157

u214892's avatar
u214892 committed
158
159
        #6:
            This feature stores the distance in number of cells to the next branching  (current node)
160

u214892's avatar
u214892 committed
161
162
        #7:
            minimum distance from node to the agent's target given the direction of the agent if this path is chosen
163

u214892's avatar
u214892 committed
164
165
166
        #8:
            agent in the same direction
            n = number of agents present same direction \
167
168
                (possible future use: number of other agents in the same direction in this branch)
            0 = no agent present same direction
u214892's avatar
u214892 committed
169

u214892's avatar
u214892 committed
170
171
172
        #9:
            agent in the opposite direction
            n = number of agents present other direction than myself (so conflict) \
u214892's avatar
u214892 committed
173
                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
174
            0 = no agent present other direction than myself
u214892's avatar
u214892 committed
175

u214892's avatar
u214892 committed
176
177
        #10:
            malfunctioning/blokcing agents
178
179
            n = number of time steps the oberved agent remains blocked

u214892's avatar
u214892 committed
180
181
        #11:
            slowest observed speed of an agent in same direction
182
183
184
            1 if no agent is observed

            min_fractional speed otherwise
u214892's avatar
u214892 committed
185
186
        #12:
            number of agents ready to depart but no yet active
187

188
189
190
191
        Missing/padding nodes are filled in with -inf (truncated).
        Missing values in present node are filled in with +inf (truncated).


Erik Nygren's avatar
Erik Nygren committed
192
        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
193
194
195
        In case the target node is reached, the values are [0, 0, 0, 0, 0].
        """

196
197
        if handle > len(self.env.agents):
            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
198
        agent = self.env.agents[handle]  # TODO: handle being treated as index
199
200
201
202
        
        if agent.status == RailAgentStatus.WAITING:
            agent_virtual_position = agent.initial_position
        elif agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
203
            agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
204
        elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
205
            agent_virtual_position = agent.position
u214892's avatar
u214892 committed
206
        elif agent.status == RailAgentStatus.DONE:
u214892's avatar
u214892 committed
207
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
208
209
210
        else:
            return None

u214892's avatar
u214892 committed
211
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
212
        num_transitions = np.count_nonzero(possible_transitions)
Erik Nygren's avatar
Erik Nygren committed
213

Erik Nygren's avatar
Erik Nygren committed
214
        # Here information about the agent itself is stored
u214892's avatar
u214892 committed
215
        distance_map = self.env.distance_map.get()
216

217
218
        # was referring to TreeObsForRailEnv.Node
        root_node_observation = Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
219
220
                                                       dist_other_agent_encountered=0, dist_potential_conflict=0,
                                                       dist_unusable_switch=0, dist_to_next_branch=0,
u214892's avatar
u214892 committed
221
                                                       dist_min_to_target=distance_map[
u214892's avatar
u214892 committed
222
                                                           (handle, *agent_virtual_position,
u214892's avatar
u214892 committed
223
                                                            agent.direction)],
224
225
226
                                                       num_agents_same_direction=0, num_agents_opposite_direction=0,
                                                       num_agents_malfunctioning=agent.malfunction_data['malfunction'],
                                                       speed_min_fractional=agent.speed_data['speed'],
u214892's avatar
u214892 committed
227
                                                       num_agents_ready_to_depart=0,
228
                                                       childs={})
229
        #print("root node type:", type(root_node_observation))
230

u214892's avatar
u214892 committed
231
        visited = OrderedSet()
232

233
234
235
236
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
        orientation = agent.direction
237

238
        if num_transitions == 1:
239
            orientation = np.argmax(possible_transitions)
240

241
242
        for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):

243
            if possible_transitions[branch_direction]:
u214892's avatar
u214892 committed
244
                new_cell = get_new_position(agent_virtual_position, branch_direction)
245

246
                branch_observation, branch_visited = \
u214892's avatar
u214892 committed
247
                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
u214892's avatar
u214892 committed
248
                root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation
249

u214892's avatar
u214892 committed
250
                visited |= branch_visited
251
            else:
252
                # add cells filled with infinity if no transition is possible
u214892's avatar
u214892 committed
253
                root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
254
        self.env.dev_obs_dict[handle] = visited
255

256
        return root_node_observation
257

u214892's avatar
u214892 committed
258
    def _explore_branch(self, handle, position, direction, tot_dist, depth):
259
260
        """
        Utility function to compute tree-based observations.
261
262
        We walk along the branch and collect the information documented in the get() function.
        If there is a branching point a new node is created and each possible branch is explored.
263
        """
264

265
266
        # [Recursive branch opened]
        if depth >= self.max_depth + 1:
267
            return [], []
268
269
270
271
272

        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops
        exploring = True
u214892's avatar
u214892 committed
273
274
275
276
        last_is_switch = False
        last_is_dead_end = False
        last_is_terminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
        last_is_target = False
277

u214892's avatar
u214892 committed
278
        visited = OrderedSet()
Erik Nygren's avatar
Erik Nygren committed
279
        agent = self.env.agents[handle]
280
        time_per_cell = np.reciprocal(agent.speed_data["speed"])
Erik Nygren's avatar
Erik Nygren committed
281
        own_target_encountered = np.inf
282
283
        other_agent_encountered = np.inf
        other_target_encountered = np.inf
284
285
        potential_conflict = np.inf
        unusable_switch = np.inf
286
287
        other_agent_same_direction = 0
        other_agent_opposite_direction = 0
288
289
        malfunctioning_agent = 0
        min_fractional_speed = 1.
290
        num_steps = 1
u214892's avatar
u214892 committed
291
        other_agent_ready_to_depart_encountered = 0
292
293
294
295
296
297
        while exploring:
            # #############################
            # #############################
            # Modify here to compute any useful data required to build the end node's features. This code is called
            # for each cell visited between the previous branching node and the next switch / target / dead-end.
            if position in self.location_has_agent:
298
299
                if tot_dist < other_agent_encountered:
                    other_agent_encountered = tot_dist
300

Erik Nygren's avatar
Erik Nygren committed
301
302
303
304
                # Check if any of the observed agents is malfunctioning, store agent with longest duration left
                if self.location_has_agent_malfunction[position] > malfunctioning_agent:
                    malfunctioning_agent = self.location_has_agent_malfunction[position]

u214892's avatar
u214892 committed
305
306
                other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0)

307
308
                if self.location_has_agent_direction[position] == direction:
                    # Cummulate the number of agents on branch with same direction
309
                    other_agent_same_direction += 1
310

311
312
313
314
                    # Check fractional speed of agents
                    current_fractional_speed = self.location_has_agent_speed[position]
                    if current_fractional_speed < min_fractional_speed:
                        min_fractional_speed = current_fractional_speed
Erik Nygren's avatar
Erik Nygren committed
315

316
317
                else:
                    # If no agent in the same direction was found all agents in that position are other direction
318
                    # Attention this counts to many agents as a few might be going off on a switch.
319
320
321
                    other_agent_opposite_direction += self.location_has_agent[position]

                # Check number of possible transitions for agent and total number of transitions in cell (type)
322
323
324
325
326
327
328
            cell_transitions = self.env.rail.get_transitions(*position, direction)
            transition_bit = bin(self.env.rail.get_full_transitions(*position))
            total_transitions = transition_bit.count("1")
            crossing_found = False
            if int(transition_bit, 2) == int('1000010000100001', 2):
                crossing_found = True

329
            # Register possible future conflict
330
331
            predicted_time = int(tot_dist * time_per_cell)
            if self.predictor and predicted_time < self.max_prediction_depth:
332
                int_position = coordinate_to_position(self.env.width, [position])
333
                if tot_dist < self.max_prediction_depth:
334
335
336

                    pre_step = max(0, predicted_time - 1)
                    post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
337

338
                    # Look for conflicting paths at distance tot_dist
339
340
                    if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
341
                        for ca in conflicting_agent[0]:
342
343
344
                            if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
                                self._reverse_dir(
                                    self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
345
                                potential_conflict = tot_dist
u214892's avatar
u214892 committed
346
                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
347
                                potential_conflict = tot_dist
348
349

                    # Look for conflicting paths at distance num_step-1
350
                    elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
351
                        conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
352
                        for ca in conflicting_agent[0]:
353
354
355
                            if direction != self.predicted_dir[pre_step][ca] \
                                and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
356
                                potential_conflict = tot_dist
u214892's avatar
u214892 committed
357
                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
358
                                potential_conflict = tot_dist
359
360

                    # Look for conflicting paths at distance num_step+1
361
362
363
                    elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                        for ca in conflicting_agent[0]:
364
                            if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
365
366
                                self.predicted_dir[post_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
367
                                potential_conflict = tot_dist
u214892's avatar
u214892 committed
368
                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
369
                                potential_conflict = tot_dist
370

Erik Nygren's avatar
Erik Nygren committed
371
            if position in self.location_has_target and position != agent.target:
372
373
                if tot_dist < other_target_encountered:
                    other_target_encountered = tot_dist
Erik Nygren's avatar
Erik Nygren committed
374

u214892's avatar
u214892 committed
375
376
            if position == agent.target and tot_dist < own_target_encountered:
                own_target_encountered = tot_dist
Erik Nygren's avatar
Erik Nygren committed
377

378
379
380
            # #############################
            # #############################
            if (position[0], position[1], direction) in visited:
u214892's avatar
u214892 committed
381
                last_is_terminal = True
382
383
384
385
386
                break
            visited.add((position[0], position[1], direction))

            # If the target node is encountered, pick that as node. Also, no further branching is possible.
            if np.array_equal(position, self.env.agents[handle].target):
u214892's avatar
u214892 committed
387
                last_is_target = True
388
389
                break

390
            # Check if crossing is found --> Not an unusable switch
391
            if crossing_found:
392
393
                # Treat the crossing as a straight rail cell
                total_transitions = 2
394
            num_transitions = np.count_nonzero(cell_transitions)
395

396
            exploring = False
397

398
            # Detect Switches that can only be used by other agents.
399
            if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
400
401
                unusable_switch = tot_dist

402
403
            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
404
                nbits = total_transitions
405
406
                if nbits == 1:
                    # Dead-end!
407
                    last_is_dead_end = True
408

u214892's avatar
u214892 committed
409
                if not last_is_dead_end:
u214892's avatar
u214892 committed
410
                    # Keep walking through the tree along `direction`
411
                    exploring = True
412
                    # convert one-hot encoding to 0,1,2,3
413
                    direction = np.argmax(cell_transitions)
414
                    position = get_new_position(position, direction)
415
                    num_steps += 1
416
                    tot_dist += 1
417
418
            elif num_transitions > 0:
                # Switch detected
u214892's avatar
u214892 committed
419
                last_is_switch = True
420
421
422
423
424
425
                break

            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
                      position[1], direction)
u214892's avatar
u214892 committed
426
                last_is_terminal = True
427
428
                break

u214892's avatar
u214892 committed
429
        # `position` is either a terminal node or a switch
430
431
432
433

        # #############################
        # #############################
        # Modify here to append new / different features for each visited cell!
434

u214892's avatar
u214892 committed
435
        if last_is_target:
u229589's avatar
u229589 committed
436
437
            dist_to_next_branch = tot_dist
            dist_min_to_target = 0
u214892's avatar
u214892 committed
438
        elif last_is_terminal:
u229589's avatar
u229589 committed
439
440
            dist_to_next_branch = np.inf
            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
441
        else:
u229589's avatar
u229589 committed
442
443
            dist_to_next_branch = tot_dist
            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
444

445
446
        # TreeObsForRailEnv.Node
        node = Node(dist_own_target_encountered=own_target_encountered,
447
448
449
450
451
452
453
454
455
456
                                      dist_other_target_encountered=other_target_encountered,
                                      dist_other_agent_encountered=other_agent_encountered,
                                      dist_potential_conflict=potential_conflict,
                                      dist_unusable_switch=unusable_switch,
                                      dist_to_next_branch=dist_to_next_branch,
                                      dist_min_to_target=dist_min_to_target,
                                      num_agents_same_direction=other_agent_same_direction,
                                      num_agents_opposite_direction=other_agent_opposite_direction,
                                      num_agents_malfunctioning=malfunctioning_agent,
                                      speed_min_fractional=min_fractional_speed,
u214892's avatar
u214892 committed
457
                                      num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
458
459
                                      childs={})

460
461
462
463
464
        # #############################
        # #############################
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # Get the possible transitions
u214892's avatar
u214892 committed
465
        possible_transitions = self.env.rail.get_transitions(*position, direction)
466
        for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
u214892's avatar
u214892 committed
467
468
            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
                                                                 (branch_direction + 2) % 4):
469
470
                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                # it back
471
                new_cell = get_new_position(position, (branch_direction + 2) % 4)
472
473
474
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          (branch_direction + 2) % 4,
u214892's avatar
u214892 committed
475
                                                                          tot_dist + 1,
476
                                                                          depth + 1)
u214892's avatar
u214892 committed
477
                node.childs[self.tree_explored_actions_char[i]] = branch_observation
478
                if len(branch_visited) != 0:
u214892's avatar
u214892 committed
479
                    visited |= branch_visited
u214892's avatar
u214892 committed
480
            elif last_is_switch and possible_transitions[branch_direction]:
481
                new_cell = get_new_position(position, branch_direction)
482
483
484
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          branch_direction,
u214892's avatar
u214892 committed
485
                                                                          tot_dist + 1,
486
                                                                          depth + 1)
u214892's avatar
u214892 committed
487
                node.childs[self.tree_explored_actions_char[i]] = branch_observation
488
                if len(branch_visited) != 0:
u214892's avatar
u214892 committed
489
                    visited |= branch_visited
490
            else:
491
                # no exploring possible, add just cells with infinity
u214892's avatar
u214892 committed
492
                node.childs[self.tree_explored_actions_char[i]] = -np.inf
493

494
495
496
        if depth == self.max_depth:
            node.childs.clear()
        return node, visited
497

498
    def util_print_obs_subtree(self, tree: Node):
499
        """
500
        Utility function to print tree observations returned by this object.
501
        """
502
        self.print_node_features(tree, "root", "")
u214892's avatar
u214892 committed
503
        for direction in self.tree_explored_actions_char:
504
505
            self.print_subtree(tree.childs[direction], direction, "\t")

506
507
508
509
510
511
    @staticmethod
    def print_node_features(node: Node, label, indent):
        print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ",
              node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
              node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
              node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
u214892's avatar
u214892 committed
512
513
              ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional, ", ",
              node.num_agents_ready_to_depart)
514
515
516
517
518

    def print_subtree(self, node, label, indent):
        if node == -np.inf or not node:
            print(indent, "Direction ", label, ": -np.inf")
            return
519

520
521
522
        self.print_node_features(node, label, indent)

        if not node.childs:
523
524
            return

u214892's avatar
u214892 committed
525
        for direction in self.tree_explored_actions_char:
526
            self.print_subtree(node.childs[direction], direction, indent + "\t")
527

u229589's avatar
u229589 committed
528
529
    def set_env(self, env: Environment):
        super().set_env(env)
530
        if self.predictor:
u229589's avatar
u229589 committed
531
            self.predictor.set_env(self.env)
532

533
534
535
    def _reverse_dir(self, direction):
        return int((direction + 2) % 4)

536
537
538
539
540
541

class GlobalObsForRailEnv(ObservationBuilder):
    """
    Gives a global observation of the entire rail environment.
    The observation is composed of the following elements:

u214892's avatar
u214892 committed
542
        - transition map array with dimensions (env.height, env.width, 16),\
543
544
          assuming 16 bits encoding of transitions.

u214892's avatar
u214892 committed
545
        - obs_agents_state: A 3D array (map_height, map_width, 5) with
546
            - first channel containing the agents position and direction
u214892's avatar
u214892 committed
547
            - second channel containing the other agents positions and direction
548
549
            - third channel containing agent/other agent malfunctions
            - fourth channel containing agent/other agent fractional speeds
u214892's avatar
u214892 committed
550
            - fifth channel containing number of other agents ready to depart
551

u214892's avatar
u214892 committed
552
553
        - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
         target and the positions of the other agents targets (flag only, no counter!).
554
555
556
557
558
    """

    def __init__(self):
        super(GlobalObsForRailEnv, self).__init__()

u229589's avatar
u229589 committed
559
560
    def set_env(self, env: Environment):
        super().set_env(env)
561

562
563
564
565
    def reset(self):
        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
        for i in range(self.rail_obs.shape[0]):
            for j in range(self.rail_obs.shape[1]):
u214892's avatar
u214892 committed
566
                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
spiglerg's avatar
spiglerg committed
567
568
                bitlist = [0] * (16 - len(bitlist)) + bitlist
                self.rail_obs[i, j] = np.array(bitlist)
569

570
    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
571

u214892's avatar
u214892 committed
572
        agent = self.env.agents[handle]
573
        if TrainState.off_map_state(agent.state):
u214892's avatar
u214892 committed
574
            agent_virtual_position = agent.initial_position
575
        elif TrainState.on_map_state(agent.state):
u214892's avatar
u214892 committed
576
            agent_virtual_position = agent.position
577
        elif agent.state == TrainState.DONE:
u214892's avatar
u214892 committed
578
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
579
580
581
        else:
            return None

582
        obs_targets = np.zeros((self.env.height, self.env.width, 2))
583
        obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
584

u214892's avatar
u214892 committed
585
        # TODO can we do this more elegantly?
586
587
588
        # for r in range(self.env.height):
        #     for c in range(self.env.width):
        #         obs_agents_state[(r, c)][4] = 0
Erik Nygren's avatar
Erik Nygren committed
589
        obs_agents_state[:, :, 4] = 0
u214892's avatar
u214892 committed
590

u214892's avatar
u214892 committed
591
        obs_agents_state[agent_virtual_position][0] = agent.direction
592
        obs_targets[agent.target][0] = 1
593

594
        for i in range(len(self.env.agents)):
u214892's avatar
u214892 committed
595
            other_agent: EnvAgent = self.env.agents[i]
u214892's avatar
u214892 committed
596
597

            # ignore other agents not in the grid any more
598
            if other_agent.state == TrainState.DONE:
u214892's avatar
u214892 committed
599
                continue
u214892's avatar
u214892 committed
600

u214892's avatar
u214892 committed
601
602
            obs_targets[other_agent.target][1] = 1

u214892's avatar
u214892 committed
603
604
605
606
607
            # second to fourth channel only if in the grid
            if other_agent.position is not None:
                # second channel only for other agents
                if i != handle:
                    obs_agents_state[other_agent.position][1] = other_agent.direction
u214892's avatar
u214892 committed
608
609
                obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
                obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
u214892's avatar
u214892 committed
610
            # fifth channel: all ready to depart on this position
611
            if TrainState.off_map_state(other_agent.state):
u214892's avatar
u214892 committed
612
                obs_agents_state[other_agent.initial_position][4] += 1
613
        return self.rail_obs, obs_agents_state, obs_targets
614

615
616
617

class LocalObsForRailEnv(ObservationBuilder):
    """
618
    !!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!!
619
    Gives a local observation of the rail environment around the agent.
620
621
    The observation is composed of the following elements:

u214892's avatar
u214892 committed
622
623
        - transition map array of the local environment around the given agent, \
          with dimensions (view_height,2*view_width+1, 16), \
624
625
          assuming 16 bits encoding of transitions.

u214892's avatar
u214892 committed
626
        - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \
Erik Nygren's avatar
Erik Nygren committed
627
        if they are in the agent's vision range, its target position, the positions of the other targets.
628

u214892's avatar
u214892 committed
629
        - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \
630
          of the other agents at their position coordinates, if they are in the agent's vision range.
631
632

        - A 4 elements array with one hot encoding of the direction.
633
634
635
636

    Use the parameters view_width and view_height to define the rectangular view of the agent.
    The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
    observation in front of it.
u214892's avatar
u214892 committed
637
638

    .. deprecated:: 2.0.0
639
640
    """

641
    def __init__(self, view_width, view_height, center):
642

643
        super(LocalObsForRailEnv, self).__init__()
644
645
646
647
        self.view_width = view_width
        self.view_height = view_height
        self.center = center
        self.max_padding = max(self.view_width, self.view_height - self.center)
648
649
650
651

    def reset(self):
        # We build the transition map with a view_radius empty cells expansion on each side.
        # This helps to collect the local transition map view when the agent is close to a border.
652
        self.max_padding = max(self.view_width, self.view_height)
653
654
        self.rail_obs = np.zeros((self.env.height,
                                  self.env.width, 16))
655
656
        for i in range(self.env.height):
            for j in range(self.env.width):
u214892's avatar
u214892 committed
657
                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
658
                bitlist = [0] * (16 - len(bitlist)) + bitlist
659
                self.rail_obs[i, j] = np.array(bitlist)
660

661
    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
662
663
664
        agents = self.env.agents
        agent = agents[handle]

665
        # Correct agents position for padding
666
667
        # agent_rel_pos[0] = agent.position[0] + self.max_padding
        # agent_rel_pos[1] = agent.position[1] + self.max_padding
668

669
        # Collect visible cells as set to be plotted
670
671
        visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
        local_rail_obs = None
672

673
        # Add the visible cells to the observed cells
674
        self.env.dev_obs_dict[handle] = set(visited)
675

676
        # Locate observed agents and their coresponding targets
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
        local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
        obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
        obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
        _idx = 0
        for pos in visited:
            curr_rel_coord = rel_coords[_idx]
            local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
            if pos == agent.target:
                obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
            else:
                for tmp_agent in agents:
                    if pos == tmp_agent.target:
                        obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
            if pos != agent.position:
                for tmp_agent in agents:
                    if pos == tmp_agent.position:
                        obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
                            tmp_agent.direction]

            _idx += 1

        direction = np.identity(4)[agent.direction]
699
        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
700

u214892's avatar
u214892 committed
701
702
    def get_many(self, handles: Optional[List[int]] = None) -> Dict[
        int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
703
        """
u214892's avatar
u214892 committed
704
705
        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
        in the `handles` list.
706
        """
707

u214892's avatar
u214892 committed
708
        return super().get_many(handles)
709

710
711
712
713
714
715
716
717
718
719
720
721
722
723
    def field_of_view(self, position, direction, state=None):
        # Compute the local field of view for an agent in the environment
        data_collection = False
        if state is not None:
            temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
            data_collection = True
        if direction == 0:
            origin = (position[0] + self.center, position[1] - self.view_width)
        elif direction == 1:
            origin = (position[0] - self.view_width, position[1] - self.center)
        elif direction == 2:
            origin = (position[0] - self.center, position[1] + self.view_width)
        else:
            origin = (position[0] + self.view_width, position[1] + self.center)
724
725
        visible = list()
        rel_coords = list()
726
727
728
729
        for h in range(self.view_height):
            for w in range(2 * self.view_width + 1):
                if direction == 0:
                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
730
731
                        visible.append((origin[0] - h, origin[1] + w))
                        rel_coords.append((h, w))
732
733
                    # if data_collection:
                    #    temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
734
735
                elif direction == 1:
                    if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
736
737
                        visible.append((origin[0] + w, origin[1] + h))
                        rel_coords.append((h, w))
738
739
                    # if data_collection:
                    #    temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
740
                elif direction == 2:
741
                    if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
742
743
                        visible.append((origin[0] + h, origin[1] - w))
                        rel_coords.append((h, w))
744
745
                    # if data_collection:
                    #    temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
746
                else:
747
                    if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
748
749
                        visible.append((origin[0] - w, origin[1] - h))
                        rel_coords.append((h, w))
750
751
                    # if data_collection:
                    #    temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
752
753
754
        if data_collection:
            return temp_visible_data
        else:
755
            return visible, rel_coords