-class Agent():
+class Agent:
     """Interacts with and learns from the environment."""
     def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
         """Initialize an Agent object.
             state_size (int): dimension of each state
     def act(self, state, eps=0.):
         """Returns actions for given state as per current policy.
             state (array_like): current state
             local_model (PyTorch model): weights will be copied from
             target_model (PyTorch model): weights will be copied to
-            tau (float): interpolation parameter 
+            tau (float): interpolation parameter
         for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
             target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
-import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import numpy as np
 from collections import deque
 # TODO: add docstrings, pylint, etc...
             node = nodes_queue.popleft()
             node_id = (node[0], node[1], node[2])
             if node_id not in visited:
         neighbors = []
+        for direction in range(4):
+            new_cell = self._new_position(position, (direction+2) % 4)
+            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
+                # Check if the two cells are connected by a valid transition
+                transitionValid = False
+                for orientation in range(4):
+                    moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
+                    if moves[direction]:
+                        transitionValid = True
+                        break
+                if not transitionValid:
+                    continue
+                # Check if a transition in direction node[2] is possible if an agent
+                # lands in the current cell with orientation `direction'; this only
+                # applies to cells that are not dead-ends!
+                directionMatch = True
+                if enforce_target_direction >= 0:
+                    directionMatch = self.env.rail.get_transition(
+                        (new_cell[0], new_cell[1], direction), enforce_target_direction)
+                # If transition is found to invalid, check if perhaps it
+                # is a dead-end, in which case the direction of movement is rotated
+                # 180 degrees (moving forward turns the agents and makes it step in the previous cell)
+                if not directionMatch:
+                    # If cell is a dead-end, append previous node with reversed
+                    # orientation!
+                    nbits = 0
+                    tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
+                    while tmp > 0:
+                        nbits += (tmp & 1)
+                        tmp = tmp >> 1
+                    if nbits == 1:
+                        # Dead-end!
+                        # Check if transition is possible in new_cell
+                        # with orientation (direction+2)%4 in direction `direction'
+                        directionMatch = directionMatch or self.env.rail.get_transition(
+                            (new_cell[0], new_cell[1], (direction+2) % 4), direction)
+                if transitionValid and directionMatch:
+                    new_distance = min(self.distance_map[target_nr,
+                                                         new_cell[0], new_cell[1]], current_distance+1)
+                    neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
+                    self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
         possible_directions = [0, 1, 2, 3]
         if enforce_target_direction >= 0:
             # The agent must land into the current cell with orientation `enforce_target_direction'.
         the transitions. The order is:
             [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
         Each branch data is organized as:
             [root node information] +
             [recursive branch data from 'left'] +
@@ -410,3 +463,51 @@ class TreeObsForRailEnv(ObservationBuilder):
+class GlobalObsForRailEnv(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+        - transition map array with dimensions (env.height, env.width, 16),
+          assuming 16 bits encoding of transitions.
+        - Four 2D arrays containing respectively the position of the given agent,
+          the position of its target, the positions of the other agents and of
+          their target.
+        - A 4 elements array with one of encoding of the direction.
+    """
+    def __init__(self):
+        super(GlobalObsForRailEnv, self).__init__()
+    def reset(self):
+        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                self.rail_obs[i, j] = np.array(
+                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
+        # self.targets = np.zeros(self.env.height, self.env.width)
+        # for target_pos in self.env.agents_target:
+        #     self.targets[target_pos] += 1
+    def get(self, handle):
+        obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
+        agent_pos = self.env.agents_position[handle]
+        obs_agents_targets_pos[0][agent_pos] += 1
+        for i in range(len(self.env.agents_position)):
+            if i != handle:
+                obs_agents_targets_pos[3][self.env.agents_position[i]] += 1
+        agent_target_pos = self.env.agents_target[handle]
+        obs_agents_targets_pos[1][agent_target_pos] += 1
+        for i in range(len(self.env.agents_target)):
+            if i != handle:
+                obs_agents_targets_pos[2][self.env.agents_target[i]] += 1
+        direction = np.zeros(4)
+        direction[self.env.agents_direction[handle]] = 1
+        return self.rail_obs, obs_agents_targets_pos, direction
         # if oFigure is None:
         #    oFigure = plt.figure()
         def drawTrans(oFrom, oTo, sColor="gray"):
                 [oFrom[0], oTo[0]],  # x
@@ -554,8 +555,6 @@ class RenderTool(object):
     def _draw_square(self, center, size, color):
         x0 = center[0]-size/2
         x1 = center[0]+size/2
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from flatland.core.env_observation_builder import GlobalObsForRailEnv
+from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
+from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
+from flatland.utils.rendertools import *
+"""Tests for `flatland` package."""
+def test_global_obs():
+    # We instantiate a very simple rail network on a 7x10 grid:
+    #        |
+    #        |
+    #        |
+    # _ _ _ /_\ _ _  _  _ _ _
+    #               \ /
+    #                |
+    #                |
+    #                |
+    cells = [int('0000000000000000', 2),  # empty cell - Case 0
+             int('1000000000100000', 2),  # Case 1 - straight
+             int('1001001000100000', 2),  # Case 2 - simple switch
+             int('1000010000100001', 2),  # Case 3 - diamond drossing
+             int('1001011000100001', 2),  # Case 4 - single slip switch
+             int('1100110000110011', 2),  # Case 5 - double slip switch
+             int('0101001000000010', 2),  # Case 6 - symmetrical switch
+             int('0010000000000000', 2)]  # Case 7 - dead end
+    transitions = Grid4Transitions([])
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
+    double_switch_north_horizontal_straight = transitions.rotate_transition(
+        double_switch_south_horizontal_straight, 180)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [double_switch_north_horizontal_straight] +
+        [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
+        [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=1,
+                  obs_builder_object=GlobalObsForRailEnv())
+    global_obs = env.reset()
+    # env_renderer = RenderTool(env)
+    # env_renderer.renderEnv(show=True)
+    # global_obs.reset()
+    assert(global_obs[0][0].shape == rail_map.shape + (16,))
+    rail_map_recons = np.zeros_like(rail_map)
+    for i in range(global_obs[0][0].shape[0]):
+        for j in range(global_obs[0][0].shape[1]):
+            rail_map_recons[i, j] = int(
+                ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)
+    assert(rail_map_recons.all() == rail_map.all())
+    # If this assertion is wrong, it means that the observation returned
+    # places the agent on an empty cell
+    assert(np.sum(rail_map * global_obs[0][1][0]) > 0)
 from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
 from flatland.core.transitions import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
+from flatland.core.env_observation_builder import GlobalObsForRailEnv
 """Tests for `flatland` package."""
     rail_env = RailEnv(width=3,
-                       number_of_agents=1)
+                       number_of_agents=1,
+                       obs_builder_object=GlobalObsForRailEnv())
     for _ in range(200):
         _ = rail_env.reset()
     rail_env = RailEnv(width=rail_map.shape[1],
-                       number_of_agents=1)
+                       number_of_agents=1,
+                       obs_builder_object=GlobalObsForRailEnv())
     def check_consistency(rail_env):
         # We run step to check that trains do not move anymore
     # We try the configuration in the 4 directions:
-    rail_env.agents_target[0] = [0, 0]
-    rail_env.agents_position[0] = [0, 2]
+    rail_env.agents_target[0] = (0, 0)
+    rail_env.agents_position[0] = (0, 2)
     rail_env.agents_direction[0] = 1
-    rail_env.agents_target[0] = [0, 4]
-    rail_env.agents_position[0] = [0, 2]
+    rail_env.agents_target[0] = (0, 4)
+    rail_env.agents_position[0] = (0, 2)
     rail_env.agents_direction[0] = 3
     rail_env = RailEnv(width=rail_map.shape[1],
-                       number_of_agents=1)
+                       number_of_agents=1,
+                       obs_builder_object=GlobalObsForRailEnv())
-    rail_env.agents_target[0] = [0, 0]
-    rail_env.agents_position[0] = [2, 0]
+    rail_env.agents_target[0] = (0, 0)
+    rail_env.agents_position[0] = (2, 0)
     rail_env.agents_direction[0] = 2
-    rail_env.agents_target[0] = [4, 0]
-    rail_env.agents_position[0] = [2, 0]
+    rail_env.agents_target[0] = (4, 0)
+    rail_env.agents_position[0] = (2, 0)
     rail_env.agents_direction[0] = 0
 import matplotlib.pyplot as plt
 import flatland.utils.rendertools as rt
+from flatland.core.env_observation_builder import GlobalObsForRailEnv
 def checkFrozenImage(sFileImage):
 def test_render_env():
     # random.seed(100)
-    oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator(), number_of_agents=2)
+    oEnv = RailEnv(width=10, height=10,
+                   rail_generator=random_rail_generator(),
+                   number_of_agents=2,
+                   obs_builder_object=GlobalObsForRailEnv())
     oRT = rt.RenderTool(oEnv)
     plt.figure(figsize=(10, 10))