observations.py 36.4 KB
Newer Older
1 2 3
"""
Collection of environment-specific ObservationBuilder.
"""
4 5
import collections
from typing import Optional, List, Dict, Tuple
6

u214892's avatar
u214892 committed
7 8
import numpy as np

9
from flatland.core.env import Environment
10
from flatland.core.env_observation_builder import ObservationBuilder
11
from flatland.core.env_prediction_builder import PredictionBuilder
12
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
13
from flatland.core.grid.grid_utils import coordinate_to_position
14 15
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
u214892's avatar
u214892 committed
16
from flatland.utils.ordered_set import OrderedSet
17 18


19

20 21 22 23 24 25 26 27 28 29 30 31 32 33
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
                                        'dist_other_target_encountered '
                                        'dist_other_agent_encountered '
                                        'dist_potential_conflict '
                                        'dist_unusable_switch '
                                        'dist_to_next_branch '
                                        'dist_min_to_target '
                                        'num_agents_same_direction '
                                        'num_agents_opposite_direction '
                                        'num_agents_malfunctioning '
                                        'speed_min_fractional '
                                        'num_agents_ready_to_depart '
                                        'childs')

34
class TreeObsForRailEnv(ObservationBuilder):
u214892's avatar
u214892 committed
35 36 37 38 39 40
    """
    TreeObsForRailEnv object.

    This object returns observation vectors for agents in the RailEnv environment.
    The information is local to each agent and exploits the graph structure of the rail
    network to simplify the representation of the state of the environment for each agent.
41

u214892's avatar
u214892 committed
42 43
    For details about the features in the tree observation see the get() function.
    """
44

45

u214892's avatar
u214892 committed
46
    tree_explored_actions_char = ['L', 'F', 'R', 'B']
47

48
    def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
u214892's avatar
u214892 committed
49
        super().__init__()
50
        self.max_depth = max_depth
Erik Nygren's avatar
Erik Nygren committed
51
        self.observation_dim = 11
52 53
        self.location_has_agent = {}
        self.location_has_agent_direction = {}
54
        self.predictor = predictor
55
        self.location_has_target = None
spiglerg's avatar
spiglerg committed
56

57
    def reset(self):
58
        self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
59

60
    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
61
        """
u214892's avatar
u214892 committed
62 63
        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
        in the `handles` list.
64 65
        """

u214892's avatar
u214892 committed
66 67
        if handles is None:
            handles = []
68
        if self.predictor:
69
            self.max_prediction_depth = 0
70 71
            self.predicted_pos = {}
            self.predicted_dir = {}
72
            self.predictions = self.predictor.get()
73
            if self.predictions:
Erik Nygren's avatar
Erik Nygren committed
74
                for t in range(self.predictor.max_depth + 1):
75 76 77
                    pos_list = []
                    dir_list = []
                    for a in handles:
u214892's avatar
u214892 committed
78
                        if self.predictions[a] is None:
u214892's avatar
u214892 committed
79
                            continue
80 81 82 83 84
                        pos_list.append(self.predictions[a][t][1:3])
                        dir_list.append(self.predictions[a][t][3])
                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
                    self.predicted_dir.update({t: dir_list})
                self.max_prediction_depth = len(self.predicted_pos)
85 86 87 88 89 90 91 92 93 94 95 96
        # Update local lookup table for all agents' positions
        # ignore other agents not in the grid (only status active and done)
        # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
        #                         agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}

        self.location_has_agent = {}
        self.location_has_agent_direction = {}
        self.location_has_agent_speed = {}
        self.location_has_agent_malfunction = {}
        self.location_has_agent_ready_to_depart = {}

        for _agent in self.env.agents:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
97
            if not _agent.state.is_off_map_state() and \
98 99 100
                _agent.position:
                self.location_has_agent[tuple(_agent.position)] = 1
                self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
101
                self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed
102 103
                self.location_has_agent_malfunction[tuple(_agent.position)] = \
                        _agent.malfunction_handler.malfunction_down_counter
104

105
            # [NIMISH] WHAT IS THIS
Dipam Chakraborty's avatar
Dipam Chakraborty committed
106
            if _agent.state.is_off_map_state() and \
107
                _agent.initial_position:
108 109 110 111
                    self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
                    self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
                # self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
                #     self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
u214892's avatar
u214892 committed
112 113 114

        observations = super().get_many(handles)

115 116
        return observations

117
    def get(self, handle: int = 0) -> Node:
118
        """
u214892's avatar
u214892 committed
119
        Computes the current observation for agent `handle` in env
120 121 122 123

        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
u214892's avatar
u214892 committed
124 125
        the transitions. The order is::

126 127
            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']

u214892's avatar
u214892 committed
128 129
        Each branch data is organized as::

130 131 132 133 134 135
            [root node information] +
            [recursive branch data from 'left'] +
            [... from 'forward'] +
            [... from 'right] +
            [... from 'back']

136
        Each node information is composed of 9 features:
137

u214892's avatar
u214892 committed
138 139
        #1:
            if own target lies on the explored branch the current distance from the agent in number of cells is stored.
140

u214892's avatar
u214892 committed
141 142
        #2:
            if another agents target is detected the distance in number of cells from the agents current location\
143
            is stored
144

u214892's avatar
u214892 committed
145 146
        #3:
            if another agent is detected the distance in number of cells from current agent position is stored.
147

u214892's avatar
u214892 committed
148 149 150
        #4:
            possible conflict detected
            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \
151
             distance in number of cells from current agent position
152

153 154
            0 = No other agent reserve the same cell at similar time

u214892's avatar
u214892 committed
155 156
        #5:
            if an not usable switch (for agent) is detected we store the distance.
157

u214892's avatar
u214892 committed
158 159
        #6:
            This feature stores the distance in number of cells to the next branching  (current node)
160

u214892's avatar
u214892 committed
161 162
        #7:
            minimum distance from node to the agent's target given the direction of the agent if this path is chosen
163

u214892's avatar
u214892 committed
164 165 166
        #8:
            agent in the same direction
            n = number of agents present same direction \
167 168
                (possible future use: number of other agents in the same direction in this branch)
            0 = no agent present same direction
u214892's avatar
u214892 committed
169

u214892's avatar
u214892 committed
170 171 172
        #9:
            agent in the opposite direction
            n = number of agents present other direction than myself (so conflict) \
u214892's avatar
u214892 committed
173
                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
174
            0 = no agent present other direction than myself
u214892's avatar
u214892 committed
175

u214892's avatar
u214892 committed
176 177
        #10:
            malfunctioning/blokcing agents
178 179
            n = number of time steps the oberved agent remains blocked

u214892's avatar
u214892 committed
180 181
        #11:
            slowest observed speed of an agent in same direction
182 183 184
            1 if no agent is observed

            min_fractional speed otherwise
u214892's avatar
u214892 committed
185 186
        #12:
            number of agents ready to depart but no yet active
187

188 189 190 191
        Missing/padding nodes are filled in with -inf (truncated).
        Missing values in present node are filled in with +inf (truncated).


Erik Nygren's avatar
Erik Nygren committed
192
        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
193 194 195
        In case the target node is reached, the values are [0, 0, 0, 0, 0].
        """

196 197
        if handle > len(self.env.agents):
            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
198
        agent = self.env.agents[handle]  # TODO: handle being treated as index
199 200

        if agent.state.is_off_map_state():
u214892's avatar
u214892 committed
201
            agent_virtual_position = agent.initial_position
202
        elif agent.state.is_on_map_state():
u214892's avatar
u214892 committed
203
            agent_virtual_position = agent.position
204
        elif agent.state == TrainState.DONE:
u214892's avatar
u214892 committed
205
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
206 207 208
        else:
            return None

u214892's avatar
u214892 committed
209
        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
210
        num_transitions = np.count_nonzero(possible_transitions)
Erik Nygren's avatar
Erik Nygren committed
211

Erik Nygren's avatar
Erik Nygren committed
212
        # Here information about the agent itself is stored
u214892's avatar
u214892 committed
213
        distance_map = self.env.distance_map.get()
214

215 216
        # was referring to TreeObsForRailEnv.Node
        root_node_observation = Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
217 218
                                                       dist_other_agent_encountered=0, dist_potential_conflict=0,
                                                       dist_unusable_switch=0, dist_to_next_branch=0,
u214892's avatar
u214892 committed
219
                                                       dist_min_to_target=distance_map[
u214892's avatar
u214892 committed
220
                                                           (handle, *agent_virtual_position,
u214892's avatar
u214892 committed
221
                                                            agent.direction)],
222
                                                       num_agents_same_direction=0, num_agents_opposite_direction=0,
223
                                                       num_agents_malfunctioning=agent.malfunction_handler.malfunction_down_counter,
224
                                                       speed_min_fractional=agent.speed_counter.speed,
u214892's avatar
u214892 committed
225
                                                       num_agents_ready_to_depart=0,
226
                                                       childs={})
227
        #print("root node type:", type(root_node_observation))
228

u214892's avatar
u214892 committed
229
        visited = OrderedSet()
230

231 232 233 234
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
        orientation = agent.direction
235

236
        if num_transitions == 1:
237
            orientation = np.argmax(possible_transitions)
238

239 240
        for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):

241
            if possible_transitions[branch_direction]:
u214892's avatar
u214892 committed
242
                new_cell = get_new_position(agent_virtual_position, branch_direction)
243

244
                branch_observation, branch_visited = \
u214892's avatar
u214892 committed
245
                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
u214892's avatar
u214892 committed
246
                root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation
247

u214892's avatar
u214892 committed
248
                visited |= branch_visited
249
            else:
250
                # add cells filled with infinity if no transition is possible
u214892's avatar
u214892 committed
251
                root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
252
        self.env.dev_obs_dict[handle] = visited
253

254
        return root_node_observation
255

u214892's avatar
u214892 committed
256
    def _explore_branch(self, handle, position, direction, tot_dist, depth):
257 258
        """
        Utility function to compute tree-based observations.
259 260
        We walk along the branch and collect the information documented in the get() function.
        If there is a branching point a new node is created and each possible branch is explored.
261
        """
262

263 264
        # [Recursive branch opened]
        if depth >= self.max_depth + 1:
265
            return [], []
266 267 268 269 270

        # Continue along direction until next switch or
        # until no transitions are possible along the current direction (i.e., dead-ends)
        # We treat dead-ends as nodes, instead of going back, to avoid loops
        exploring = True
u214892's avatar
u214892 committed
271 272 273 274
        last_is_switch = False
        last_is_dead_end = False
        last_is_terminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
        last_is_target = False
275

u214892's avatar
u214892 committed
276
        visited = OrderedSet()
Erik Nygren's avatar
Erik Nygren committed
277
        agent = self.env.agents[handle]
278
        time_per_cell = np.reciprocal(agent.speed_counter.speed)
Erik Nygren's avatar
Erik Nygren committed
279
        own_target_encountered = np.inf
280 281
        other_agent_encountered = np.inf
        other_target_encountered = np.inf
282 283
        potential_conflict = np.inf
        unusable_switch = np.inf
284 285
        other_agent_same_direction = 0
        other_agent_opposite_direction = 0
286 287
        malfunctioning_agent = 0
        min_fractional_speed = 1.
288
        num_steps = 1
u214892's avatar
u214892 committed
289
        other_agent_ready_to_depart_encountered = 0
290 291 292 293 294 295
        while exploring:
            # #############################
            # #############################
            # Modify here to compute any useful data required to build the end node's features. This code is called
            # for each cell visited between the previous branching node and the next switch / target / dead-end.
            if position in self.location_has_agent:
296 297
                if tot_dist < other_agent_encountered:
                    other_agent_encountered = tot_dist
298

Erik Nygren's avatar
Erik Nygren committed
299 300 301 302
                # Check if any of the observed agents is malfunctioning, store agent with longest duration left
                if self.location_has_agent_malfunction[position] > malfunctioning_agent:
                    malfunctioning_agent = self.location_has_agent_malfunction[position]

u214892's avatar
u214892 committed
303 304
                other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0)

305 306
                if self.location_has_agent_direction[position] == direction:
                    # Cummulate the number of agents on branch with same direction
307
                    other_agent_same_direction += 1
308

309 310 311 312
                    # Check fractional speed of agents
                    current_fractional_speed = self.location_has_agent_speed[position]
                    if current_fractional_speed < min_fractional_speed:
                        min_fractional_speed = current_fractional_speed
Erik Nygren's avatar
Erik Nygren committed
313

314 315
                else:
                    # If no agent in the same direction was found all agents in that position are other direction
316
                    # Attention this counts to many agents as a few might be going off on a switch.
317 318 319
                    other_agent_opposite_direction += self.location_has_agent[position]

                # Check number of possible transitions for agent and total number of transitions in cell (type)
320 321 322 323 324 325 326
            cell_transitions = self.env.rail.get_transitions(*position, direction)
            transition_bit = bin(self.env.rail.get_full_transitions(*position))
            total_transitions = transition_bit.count("1")
            crossing_found = False
            if int(transition_bit, 2) == int('1000010000100001', 2):
                crossing_found = True

327
            # Register possible future conflict
328 329
            predicted_time = int(tot_dist * time_per_cell)
            if self.predictor and predicted_time < self.max_prediction_depth:
330
                int_position = coordinate_to_position(self.env.width, [position])
331
                if tot_dist < self.max_prediction_depth:
332 333 334

                    pre_step = max(0, predicted_time - 1)
                    post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
335

336
                    # Look for conflicting paths at distance tot_dist
337 338
                    if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
339
                        for ca in conflicting_agent[0]:
340 341 342
                            if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
                                self._reverse_dir(
                                    self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
343
                                potential_conflict = tot_dist
344
                            if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
345
                                potential_conflict = tot_dist
346 347

                    # Look for conflicting paths at distance num_step-1
348
                    elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
349
                        conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
350
                        for ca in conflicting_agent[0]:
351 352 353
                            if direction != self.predicted_dir[pre_step][ca] \
                                and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
354
                                potential_conflict = tot_dist
355
                            if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
356
                                potential_conflict = tot_dist
357 358

                    # Look for conflicting paths at distance num_step+1
359 360 361
                    elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
                        conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                        for ca in conflicting_agent[0]:
362
                            if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
363 364
                                self.predicted_dir[post_step][ca])] == 1 \
                                and tot_dist < potential_conflict:  # noqa: E125
365
                                potential_conflict = tot_dist
366
                            if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
367
                                potential_conflict = tot_dist
368

Erik Nygren's avatar
Erik Nygren committed
369
            if position in self.location_has_target and position != agent.target:
370 371
                if tot_dist < other_target_encountered:
                    other_target_encountered = tot_dist
Erik Nygren's avatar
Erik Nygren committed
372

u214892's avatar
u214892 committed
373 374
            if position == agent.target and tot_dist < own_target_encountered:
                own_target_encountered = tot_dist
Erik Nygren's avatar
Erik Nygren committed
375

376 377 378
            # #############################
            # #############################
            if (position[0], position[1], direction) in visited:
u214892's avatar
u214892 committed
379
                last_is_terminal = True
380 381 382 383 384
                break
            visited.add((position[0], position[1], direction))

            # If the target node is encountered, pick that as node. Also, no further branching is possible.
            if np.array_equal(position, self.env.agents[handle].target):
u214892's avatar
u214892 committed
385
                last_is_target = True
386 387
                break

388
            # Check if crossing is found --> Not an unusable switch
389
            if crossing_found:
390 391
                # Treat the crossing as a straight rail cell
                total_transitions = 2
392
            num_transitions = np.count_nonzero(cell_transitions)
393

394
            exploring = False
395

396
            # Detect Switches that can only be used by other agents.
397
            if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
398 399
                unusable_switch = tot_dist

400 401
            if num_transitions == 1:
                # Check if dead-end, or if we can go forward along direction
402
                nbits = total_transitions
403 404
                if nbits == 1:
                    # Dead-end!
405
                    last_is_dead_end = True
406

u214892's avatar
u214892 committed
407
                if not last_is_dead_end:
u214892's avatar
u214892 committed
408
                    # Keep walking through the tree along `direction`
409
                    exploring = True
410
                    # convert one-hot encoding to 0,1,2,3
411
                    direction = np.argmax(cell_transitions)
412
                    position = get_new_position(position, direction)
413
                    num_steps += 1
414
                    tot_dist += 1
415 416
            elif num_transitions > 0:
                # Switch detected
u214892's avatar
u214892 committed
417
                last_is_switch = True
418 419 420 421 422 423
                break

            elif num_transitions == 0:
                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
                      position[1], direction)
u214892's avatar
u214892 committed
424
                last_is_terminal = True
425 426
                break

u214892's avatar
u214892 committed
427
        # `position` is either a terminal node or a switch
428 429 430 431

        # #############################
        # #############################
        # Modify here to append new / different features for each visited cell!
432

u214892's avatar
u214892 committed
433
        if last_is_target:
u229589's avatar
u229589 committed
434 435
            dist_to_next_branch = tot_dist
            dist_min_to_target = 0
u214892's avatar
u214892 committed
436
        elif last_is_terminal:
u229589's avatar
u229589 committed
437 438
            dist_to_next_branch = np.inf
            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
439
        else:
u229589's avatar
u229589 committed
440 441
            dist_to_next_branch = tot_dist
            dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
442

443 444
        # TreeObsForRailEnv.Node
        node = Node(dist_own_target_encountered=own_target_encountered,
445 446 447 448 449 450 451 452 453 454
                                      dist_other_target_encountered=other_target_encountered,
                                      dist_other_agent_encountered=other_agent_encountered,
                                      dist_potential_conflict=potential_conflict,
                                      dist_unusable_switch=unusable_switch,
                                      dist_to_next_branch=dist_to_next_branch,
                                      dist_min_to_target=dist_min_to_target,
                                      num_agents_same_direction=other_agent_same_direction,
                                      num_agents_opposite_direction=other_agent_opposite_direction,
                                      num_agents_malfunctioning=malfunctioning_agent,
                                      speed_min_fractional=min_fractional_speed,
u214892's avatar
u214892 committed
455
                                      num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
456 457
                                      childs={})

458 459 460 461 462
        # #############################
        # #############################
        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right, back], relative to the current orientation
        # Get the possible transitions
u214892's avatar
u214892 committed
463
        possible_transitions = self.env.rail.get_transitions(*position, direction)
464
        for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
u214892's avatar
u214892 committed
465 466
            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
                                                                 (branch_direction + 2) % 4):
467 468
                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                # it back
469
                new_cell = get_new_position(position, (branch_direction + 2) % 4)
470 471 472
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          (branch_direction + 2) % 4,
u214892's avatar
u214892 committed
473
                                                                          tot_dist + 1,
474
                                                                          depth + 1)
u214892's avatar
u214892 committed
475
                node.childs[self.tree_explored_actions_char[i]] = branch_observation
476
                if len(branch_visited) != 0:
u214892's avatar
u214892 committed
477
                    visited |= branch_visited
u214892's avatar
u214892 committed
478
            elif last_is_switch and possible_transitions[branch_direction]:
479
                new_cell = get_new_position(position, branch_direction)
480 481 482
                branch_observation, branch_visited = self._explore_branch(handle,
                                                                          new_cell,
                                                                          branch_direction,
u214892's avatar
u214892 committed
483
                                                                          tot_dist + 1,
484
                                                                          depth + 1)
u214892's avatar
u214892 committed
485
                node.childs[self.tree_explored_actions_char[i]] = branch_observation
486
                if len(branch_visited) != 0:
u214892's avatar
u214892 committed
487
                    visited |= branch_visited
488
            else:
489
                # no exploring possible, add just cells with infinity
u214892's avatar
u214892 committed
490
                node.childs[self.tree_explored_actions_char[i]] = -np.inf
491

492 493 494
        if depth == self.max_depth:
            node.childs.clear()
        return node, visited
495

496
    def util_print_obs_subtree(self, tree: Node):
497
        """
498
        Utility function to print tree observations returned by this object.
499
        """
500
        self.print_node_features(tree, "root", "")
u214892's avatar
u214892 committed
501
        for direction in self.tree_explored_actions_char:
502 503
            self.print_subtree(tree.childs[direction], direction, "\t")

504 505 506 507 508 509
    @staticmethod
    def print_node_features(node: Node, label, indent):
        print(indent, "Direction ", label, ": ", node.dist_own_target_encountered, ", ",
              node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
              node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
              node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
u214892's avatar
u214892 committed
510 511
              ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional, ", ",
              node.num_agents_ready_to_depart)
512 513 514 515 516

    def print_subtree(self, node, label, indent):
        if node == -np.inf or not node:
            print(indent, "Direction ", label, ": -np.inf")
            return
517

518 519 520
        self.print_node_features(node, label, indent)

        if not node.childs:
521 522
            return

u214892's avatar
u214892 committed
523
        for direction in self.tree_explored_actions_char:
524
            self.print_subtree(node.childs[direction], direction, indent + "\t")
525

u229589's avatar
u229589 committed
526 527
    def set_env(self, env: Environment):
        super().set_env(env)
528
        if self.predictor:
u229589's avatar
u229589 committed
529
            self.predictor.set_env(self.env)
530

531 532 533
    def _reverse_dir(self, direction):
        return int((direction + 2) % 4)

534 535 536 537 538 539

class GlobalObsForRailEnv(ObservationBuilder):
    """
    Gives a global observation of the entire rail environment.
    The observation is composed of the following elements:

u214892's avatar
u214892 committed
540
        - transition map array with dimensions (env.height, env.width, 16),\
541 542
          assuming 16 bits encoding of transitions.

u214892's avatar
u214892 committed
543
        - obs_agents_state: A 3D array (map_height, map_width, 5) with
544
            - first channel containing the agents position and direction
u214892's avatar
u214892 committed
545
            - second channel containing the other agents positions and direction
546 547
            - third channel containing agent/other agent malfunctions
            - fourth channel containing agent/other agent fractional speeds
u214892's avatar
u214892 committed
548
            - fifth channel containing number of other agents ready to depart
549

u214892's avatar
u214892 committed
550 551
        - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
         target and the positions of the other agents targets (flag only, no counter!).
552 553 554 555 556
    """

    def __init__(self):
        super(GlobalObsForRailEnv, self).__init__()

u229589's avatar
u229589 committed
557 558
    def set_env(self, env: Environment):
        super().set_env(env)
559

560 561 562 563
    def reset(self):
        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
        for i in range(self.rail_obs.shape[0]):
            for j in range(self.rail_obs.shape[1]):
u214892's avatar
u214892 committed
564
                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
spiglerg's avatar
spiglerg committed
565 566
                bitlist = [0] * (16 - len(bitlist)) + bitlist
                self.rail_obs[i, j] = np.array(bitlist)
567

568
    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
569

u214892's avatar
u214892 committed
570
        agent = self.env.agents[handle]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
571
        if agent.state.is_off_map_state():
u214892's avatar
u214892 committed
572
            agent_virtual_position = agent.initial_position
Dipam Chakraborty's avatar
Dipam Chakraborty committed
573
        elif agent.state.is_on_map_state():
u214892's avatar
u214892 committed
574
            agent_virtual_position = agent.position
575
        elif agent.state == TrainState.DONE:
u214892's avatar
u214892 committed
576
            agent_virtual_position = agent.target
u214892's avatar
u214892 committed
577 578 579
        else:
            return None

580
        obs_targets = np.zeros((self.env.height, self.env.width, 2))
581
        obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
582

u214892's avatar
u214892 committed
583
        # TODO can we do this more elegantly?
584 585 586
        # for r in range(self.env.height):
        #     for c in range(self.env.width):
        #         obs_agents_state[(r, c)][4] = 0
Erik Nygren's avatar
Erik Nygren committed
587
        obs_agents_state[:, :, 4] = 0
u214892's avatar
u214892 committed
588

u214892's avatar
u214892 committed
589
        obs_agents_state[agent_virtual_position][0] = agent.direction
590
        obs_targets[agent.target][0] = 1
591

592
        for i in range(len(self.env.agents)):
u214892's avatar
u214892 committed
593
            other_agent: EnvAgent = self.env.agents[i]
u214892's avatar
u214892 committed
594 595

            # ignore other agents not in the grid any more
596
            if other_agent.state == TrainState.DONE:
u214892's avatar
u214892 committed
597
                continue
u214892's avatar
u214892 committed
598

u214892's avatar
u214892 committed
599 600
            obs_targets[other_agent.target][1] = 1

u214892's avatar
u214892 committed
601 602 603 604 605
            # second to fourth channel only if in the grid
            if other_agent.position is not None:
                # second channel only for other agents
                if i != handle:
                    obs_agents_state[other_agent.position][1] = other_agent.direction
606
                obs_agents_state[other_agent.position][2] = other_agent.malfunction_handler.malfunction_down_counter
607
                obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed
u214892's avatar
u214892 committed
608
            # fifth channel: all ready to depart on this position
Dipam Chakraborty's avatar
Dipam Chakraborty committed
609
            if other_agent.state.is_off_map_state():
u214892's avatar
u214892 committed
610
                obs_agents_state[other_agent.initial_position][4] += 1
611
        return self.rail_obs, obs_agents_state, obs_targets
612

613 614 615

class LocalObsForRailEnv(ObservationBuilder):
    """
616
    !!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!!
617
    Gives a local observation of the rail environment around the agent.
618 619
    The observation is composed of the following elements:

u214892's avatar
u214892 committed
620 621
        - transition map array of the local environment around the given agent, \
          with dimensions (view_height,2*view_width+1, 16), \
622 623
          assuming 16 bits encoding of transitions.

u214892's avatar
u214892 committed
624
        - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \
Erik Nygren's avatar
Erik Nygren committed
625
        if they are in the agent's vision range, its target position, the positions of the other targets.
626

u214892's avatar
u214892 committed
627
        - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \
628
          of the other agents at their position coordinates, if they are in the agent's vision range.
629 630

        - A 4 elements array with one hot encoding of the direction.
631 632 633 634

    Use the parameters view_width and view_height to define the rectangular view of the agent.
    The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
    observation in front of it.
u214892's avatar
u214892 committed
635 636

    .. deprecated:: 2.0.0
637 638
    """

639
    def __init__(self, view_width, view_height, center):
640

641
        super(LocalObsForRailEnv, self).__init__()
642 643 644 645
        self.view_width = view_width
        self.view_height = view_height
        self.center = center
        self.max_padding = max(self.view_width, self.view_height - self.center)
646 647 648 649

    def reset(self):
        # We build the transition map with a view_radius empty cells expansion on each side.
        # This helps to collect the local transition map view when the agent is close to a border.
650
        self.max_padding = max(self.view_width, self.view_height)
651 652
        self.rail_obs = np.zeros((self.env.height,
                                  self.env.width, 16))
653 654
        for i in range(self.env.height):
            for j in range(self.env.width):
u214892's avatar
u214892 committed
655
                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
656
                bitlist = [0] * (16 - len(bitlist)) + bitlist
657
                self.rail_obs[i, j] = np.array(bitlist)
658

659
    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
660 661 662
        agents = self.env.agents
        agent = agents[handle]

663
        # Correct agents position for padding
664 665
        # agent_rel_pos[0] = agent.position[0] + self.max_padding
        # agent_rel_pos[1] = agent.position[1] + self.max_padding
666

667
        # Collect visible cells as set to be plotted
668 669
        visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
        local_rail_obs = None
670

671
        # Add the visible cells to the observed cells
672
        self.env.dev_obs_dict[handle] = set(visited)
673

674
        # Locate observed agents and their coresponding targets
675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696
        local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
        obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
        obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
        _idx = 0
        for pos in visited:
            curr_rel_coord = rel_coords[_idx]
            local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
            if pos == agent.target:
                obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
            else:
                for tmp_agent in agents:
                    if pos == tmp_agent.target:
                        obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
            if pos != agent.position:
                for tmp_agent in agents:
                    if pos == tmp_agent.position:
                        obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
                            tmp_agent.direction]

            _idx += 1

        direction = np.identity(4)[agent.direction]
697
        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
698

u214892's avatar
u214892 committed
699 700
    def get_many(self, handles: Optional[List[int]] = None) -> Dict[
        int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
701
        """
u214892's avatar
u214892 committed
702 703
        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
        in the `handles` list.
704
        """
705

u214892's avatar
u214892 committed
706
        return super().get_many(handles)
707

708 709 710 711 712 713 714 715 716 717 718 719 720 721
    def field_of_view(self, position, direction, state=None):