diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py
index 084d0a22b8c3ae39a2503e849ed346f50a7c8aa9..ee75a615cda4d26a06810d2b7d109fe5691d5ac4 100644
--- a/flatland/baselines/dueling_double_dqn.py
+++ b/flatland/baselines/dueling_double_dqn.py
@@ -22,12 +22,12 @@ device = torch.device("cpu")
 print(device)
 
 
-class Agent():
+class Agent:
     """Interacts with and learns from the environment."""
 
     def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
         """Initialize an Agent object.
-        
+
         Params
         ======
             state_size (int): dimension of each state
@@ -78,7 +78,7 @@ class Agent():
 
     def act(self, state, eps=0.):
         """Returns actions for given state as per current policy.
-        
+
         Params
         ======
             state (array_like): current state
@@ -140,7 +140,7 @@ class Agent():
         ======
             local_model (PyTorch model): weights will be copied from
             target_model (PyTorch model): weights will be copied to
-            tau (float): interpolation parameter 
+            tau (float): interpolation parameter
         """
         for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
             target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
diff --git a/flatland/baselines/model.py b/flatland/baselines/model.py
index 3b52e9f5ed691aeb3be01e69f04f92d615829b0b..7a5b3d613342a4fba8e2c8f1f45df21381e12684 100644
--- a/flatland/baselines/model.py
+++ b/flatland/baselines/model.py
@@ -1,4 +1,3 @@
-import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index ce0a28ae9f984249233e9086f78debd3c1cbc54e..b0730d8a4cd5e7d8567e243765b258a111070296 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -9,6 +9,7 @@ case of multi-agent environments.
 """
 
 import numpy as np
+
 from collections import deque
 
 # TODO: add docstrings, pylint, etc...
@@ -103,6 +104,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             node = nodes_queue.popleft()
 
             node_id = (node[0], node[1], node[2])
+
             if node_id not in visited:
                 visited.add(node_id)
 
@@ -125,6 +127,53 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
         neighbors = []
 
+        for direction in range(4):
+            new_cell = self._new_position(position, (direction+2) % 4)
+
+            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
+                # Check if the two cells are connected by a valid transition
+                transitionValid = False
+                for orientation in range(4):
+                    moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
+                    if moves[direction]:
+                        transitionValid = True
+                        break
+
+                if not transitionValid:
+                    continue
+
+                # Check if a transition in direction node[2] is possible if an agent
+                # lands in the current cell with orientation `direction'; this only
+                # applies to cells that are not dead-ends!
+                directionMatch = True
+                if enforce_target_direction >= 0:
+                    directionMatch = self.env.rail.get_transition(
+                        (new_cell[0], new_cell[1], direction), enforce_target_direction)
+
+                # If transition is found to invalid, check if perhaps it
+                # is a dead-end, in which case the direction of movement is rotated
+                # 180 degrees (moving forward turns the agents and makes it step in the previous cell)
+                if not directionMatch:
+                    # If cell is a dead-end, append previous node with reversed
+                    # orientation!
+                    nbits = 0
+                    tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
+                    while tmp > 0:
+                        nbits += (tmp & 1)
+                        tmp = tmp >> 1
+                    if nbits == 1:
+                        # Dead-end!
+                        # Check if transition is possible in new_cell
+                        # with orientation (direction+2)%4 in direction `direction'
+                        directionMatch = directionMatch or self.env.rail.get_transition(
+                            (new_cell[0], new_cell[1], (direction+2) % 4), direction)
+
+                if transitionValid and directionMatch:
+                    new_distance = min(self.distance_map[target_nr,
+                                                         new_cell[0], new_cell[1]], current_distance+1)
+                    neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
+                    self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
+
         possible_directions = [0, 1, 2, 3]
         if enforce_target_direction >= 0:
             # The agent must land into the current cell with orientation `enforce_target_direction'.
@@ -195,6 +244,10 @@ class TreeObsForRailEnv(ObservationBuilder):
         the transitions. The order is:
             [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
 
+
+
+
+
         Each branch data is organized as:
             [root node information] +
             [recursive branch data from 'left'] +
@@ -410,3 +463,51 @@ class TreeObsForRailEnv(ObservationBuilder):
                                         num_features_per_node,
                                         prompt=prompt_[children],
                                         current_depth=current_depth+1)
+
+
+class GlobalObsForRailEnv(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width, 16),
+          assuming 16 bits encoding of transitions.
+
+        - Four 2D arrays containing respectively the position of the given agent,
+          the position of its target, the positions of the other agents and of
+          their target.
+
+        - A 4 elements array with one of encoding of the direction.
+    """
+    def __init__(self):
+        super(GlobalObsForRailEnv, self).__init__()
+
+    def reset(self):
+        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                self.rail_obs[i, j] = np.array(
+                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
+
+        # self.targets = np.zeros(self.env.height, self.env.width)
+        # for target_pos in self.env.agents_target:
+        #     self.targets[target_pos] += 1
+
+    def get(self, handle):
+        obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
+        agent_pos = self.env.agents_position[handle]
+        obs_agents_targets_pos[0][agent_pos] += 1
+        for i in range(len(self.env.agents_position)):
+            if i != handle:
+                obs_agents_targets_pos[3][self.env.agents_position[i]] += 1
+
+        agent_target_pos = self.env.agents_target[handle]
+        obs_agents_targets_pos[1][agent_target_pos] += 1
+        for i in range(len(self.env.agents_target)):
+            if i != handle:
+                obs_agents_targets_pos[2][self.env.agents_target[i]] += 1
+
+        direction = np.zeros(4)
+        direction[self.env.agents_direction[handle]] = 1
+
+        return self.rail_obs, obs_agents_targets_pos, direction
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 24f365820b06698931b46afb6266deca03d6834b..9f008f00c261f76e71771732aa02c3c3071f9542 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -406,6 +406,7 @@ class RenderTool(object):
         plt.clf()
         # if oFigure is None:
         #    oFigure = plt.figure()
+
         def drawTrans(oFrom, oTo, sColor="gray"):
             plt.plot(
                 [oFrom[0], oTo[0]],  # x
@@ -554,8 +555,6 @@ class RenderTool(object):
             plt.pause(0.00001)
             return
 
-
-
     def _draw_square(self, center, size, color):
         x0 = center[0]-size/2
         x1 = center[0]+size/2
diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..44dbfc6d8f14a7293e148f125b47eed66c9ca08d
--- /dev/null
+++ b/tests/test_env_observation_builder.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from flatland.core.env_observation_builder import GlobalObsForRailEnv
+from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
+from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
+from flatland.utils.rendertools import *
+
+"""Tests for `flatland` package."""
+
+
+def test_global_obs():
+    # We instantiate a very simple rail network on a 7x10 grid:
+    #        |
+    #        |
+    #        |
+    # _ _ _ /_\ _ _  _  _ _ _
+    #               \ /
+    #                |
+    #                |
+    #                |
+
+    cells = [int('0000000000000000', 2),  # empty cell - Case 0
+             int('1000000000100000', 2),  # Case 1 - straight
+             int('1001001000100000', 2),  # Case 2 - simple switch
+             int('1000010000100001', 2),  # Case 3 - diamond drossing
+             int('1001011000100001', 2),  # Case 4 - single slip switch
+             int('1100110000110011', 2),  # Case 5 - double slip switch
+             int('0101001000000010', 2),  # Case 6 - symmetrical switch
+             int('0010000000000000', 2)]  # Case 7 - dead end
+
+    transitions = Grid4Transitions([])
+    empty = cells[0]
+
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+
+    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
+    double_switch_north_horizontal_straight = transitions.rotate_transition(
+        double_switch_south_horizontal_straight, 180)
+
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [double_switch_north_horizontal_straight] +
+        [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
+        [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=1,
+                  obs_builder_object=GlobalObsForRailEnv())
+
+    global_obs = env.reset()
+    # env_renderer = RenderTool(env)
+    # env_renderer.renderEnv(show=True)
+
+    # global_obs.reset()
+    assert(global_obs[0][0].shape == rail_map.shape + (16,))
+
+    rail_map_recons = np.zeros_like(rail_map)
+    for i in range(global_obs[0][0].shape[0]):
+        for j in range(global_obs[0][0].shape[1]):
+            rail_map_recons[i, j] = int(
+                ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)
+
+    assert(rail_map_recons.all() == rail_map.all())
+
+    # If this assertion is wrong, it means that the observation returned
+    # places the agent on an empty cell
+    assert(np.sum(rail_map * global_obs[0][1][0]) > 0)
+
+
+
+test_global_obs()
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/test_environments.py b/tests/test_environments.py
index 4f31fe6196846ff843aee3a445d1347d8975799d..ea8748b8aa4b50a1371a013be98f3b42d0d01228 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -5,6 +5,7 @@ import numpy as np
 from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
 from flatland.core.transitions import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
+from flatland.core.env_observation_builder import GlobalObsForRailEnv
 
 """Tests for `flatland` package."""
 
@@ -49,7 +50,9 @@ def test_rail_environment_single_agent():
     rail_env = RailEnv(width=3,
                        height=3,
                        rail_generator=rail_from_GridTransitionMap_generator(rail),
-                       number_of_agents=1)
+                       number_of_agents=1,
+                       obs_builder_object=GlobalObsForRailEnv())
+
     for _ in range(200):
         _ = rail_env.reset()
 
@@ -124,7 +127,8 @@ def test_dead_end():
     rail_env = RailEnv(width=rail_map.shape[1],
                        height=rail_map.shape[0],
                        rail_generator=rail_from_GridTransitionMap_generator(rail),
-                       number_of_agents=1)
+                       number_of_agents=1,
+                       obs_builder_object=GlobalObsForRailEnv())
 
     def check_consistency(rail_env):
         # We run step to check that trains do not move anymore
@@ -147,14 +151,14 @@ def test_dead_end():
 
     # We try the configuration in the 4 directions:
     rail_env.reset()
-    rail_env.agents_target[0] = [0, 0]
-    rail_env.agents_position[0] = [0, 2]
+    rail_env.agents_target[0] = (0, 0)
+    rail_env.agents_position[0] = (0, 2)
     rail_env.agents_direction[0] = 1
     check_consistency(rail_env)
 
     rail_env.reset()
-    rail_env.agents_target[0] = [0, 4]
-    rail_env.agents_position[0] = [0, 2]
+    rail_env.agents_target[0] = (0, 4)
+    rail_env.agents_position[0] = (0, 2)
     rail_env.agents_direction[0] = 3
     check_consistency(rail_env)
 
@@ -173,16 +177,17 @@ def test_dead_end():
     rail_env = RailEnv(width=rail_map.shape[1],
                        height=rail_map.shape[0],
                        rail_generator=rail_from_GridTransitionMap_generator(rail),
-                       number_of_agents=1)
+                       number_of_agents=1,
+                       obs_builder_object=GlobalObsForRailEnv())
 
     rail_env.reset()
-    rail_env.agents_target[0] = [0, 0]
-    rail_env.agents_position[0] = [2, 0]
+    rail_env.agents_target[0] = (0, 0)
+    rail_env.agents_position[0] = (2, 0)
     rail_env.agents_direction[0] = 2
     check_consistency(rail_env)
 
     rail_env.reset()
-    rail_env.agents_target[0] = [4, 0]
-    rail_env.agents_position[0] = [2, 0]
+    rail_env.agents_target[0] = (4, 0)
+    rail_env.agents_position[0] = (2, 0)
     rail_env.agents_direction[0] = 0
     check_consistency(rail_env)
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index 0bdc47bace289656181b01b1f44344e4322363a0..e45b7d1815365afda98f699d628a7e6f51c92395 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -11,6 +11,7 @@ import os
 import matplotlib.pyplot as plt
 
 import flatland.utils.rendertools as rt
+from flatland.core.env_observation_builder import GlobalObsForRailEnv
 
 
 def checkFrozenImage(sFileImage):
@@ -36,7 +37,10 @@ def checkFrozenImage(sFileImage):
 def test_render_env():
     # random.seed(100)
     np.random.seed(100)
-    oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator(), number_of_agents=2)
+    oEnv = RailEnv(width=10, height=10,
+                   rail_generator=random_rail_generator(),
+                   number_of_agents=2,
+                   obs_builder_object=GlobalObsForRailEnv())
     oEnv.reset()
     oRT = rt.RenderTool(oEnv)
     plt.figure(figsize=(10, 10))