diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index ac62946dc4546ab1333340ee06d5472f78aab547..6e5dbbb4685c8a10e39af63adc90fbb506fcc653 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -1,5 +1,7 @@
 import numpy as np
 
+from collections import deque
+
 # TODO: add docstrings, pylint, etc...
 
 
@@ -15,15 +17,131 @@ class ObservationBuilder:
 
 
 class TreeObsForRailEnv(ObservationBuilder):
+    def __init__(self, env):
+        self.env = env
+
     def reset(self):
-        # TODO: precompute distances, etc...
-        # raise NotImplementedError()
-        pass
+        self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
+                                                    self.env.height,
+                                                    self.env.width))
+        self.max_dist = np.zeros(self.env.number_of_agents)
+
+        for i in range(self.env.number_of_agents):
+            self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
+
+
+    def _distance_map_walker(self, position, target_nr):
+        # Returns max distance to target, from the farthest away node, while filling in distance_map
+
+        for ori in range(4):
+            self.distance_map[target_nr, position[0], position[1]] = 0
+
+        # Fill in the (up to) 4 neighboring nodes
+        # nodes_queue = []  # list of tuples (row, col, direction, distance);
+        # direction is the direction of movement, meaning that at least a possible orientation
+        # of an agent in cell (row,col) allows a movement in direction `direction'
+        nodes_queue = deque(self._get_and_update_neighbors(position,
+                                                           target_nr, 0, enforce_target_direction=-1))
+
+        # BFS from target `position' to all the reachable nodes in the grid
+        # Stop the search if the target position is re-visited, in any direction
+        visited = set([(position[0], position[1], 0), (position[0], position[1], 1),
+                       (position[0], position[1], 2), (position[0], position[1], 3)])
+
+        max_distance = 0
+
+        while nodes_queue:
+            node = nodes_queue.popleft()
+
+            node_id = (node[0], node[1], node[2])
+
+            #print(node_id, visited, (node_id in visited))
+            #print(nodes_queue)
+
+            if node_id not in visited:
+                visited.add(node_id)
+
+                # From the list of possible neighbors that have at least a path to the
+                # current node, only keep those whose new orientation in the current cell
+                # would allow a transition to direction node[2]
+                valid_neighbors = self._get_and_update_neighbors(
+                    (node[0], node[1]), target_nr, node[3], node[2])
+
+                for n in valid_neighbors:
+                    nodes_queue.append(n)
+
+                if len(valid_neighbors)>0:
+                    max_distance = max(max_distance, node[3]+1)
+
+        return max_distance
+
+
+    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
+        neighbors = []
+
+        for direction in range(4):
+            new_cell = self._new_position(position, (direction+2)%4)
+
+            if new_cell[0]>=0 and new_cell[0]<self.env.height and\
+                new_cell[1]>=0 and new_cell[1]<self.env.width:
+                # Check if the two cells are connected by a valid transition
+                transitionValid = False
+                for orientation in range(4):
+                    moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
+                    if moves[direction]:
+                        transitionValid = True
+                        break
+
+                if not transitionValid:
+                    continue
+
+                # Check if a transition in direction node[2] is possible if an agent
+                # lands in the current cell with orientation `direction'; this only
+                # applies to cells that are not dead-ends!
+                directionMatch = True
+                if enforce_target_direction>=0:
+                    directionMatch = self.env.rail.get_transition(
+                        (new_cell[0], new_cell[1], direction), enforce_target_direction)
+
+                # If transition is found to invalid, check if perhaps it
+                # is a dead-end, in which case the direction of movement is rotated
+                # 180 degrees (moving forward turns the agents and makes it step in the previous cell)
+                if not directionMatch:
+                    # If cell is a dead-end, append previous node with reversed
+                    # orientation!
+                    nbits = 0
+                    tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
+                    while tmp > 0:
+                        nbits += (tmp & 1)
+                        tmp = tmp >> 1
+                    if nbits == 1:
+                        # Dead-end!
+                        # Check if transition is possible in new_cell
+                        # with orientation (direction+2)%4 in direction `direction'
+                        directionMatch = directionMatch or self.env.rail.get_transition(
+                            (new_cell[0], new_cell[1], (direction+2)%4), direction)
+
+                if transitionValid and directionMatch:
+                    new_distance = min(self.distance_map[target_nr,
+                                                         new_cell[0], new_cell[1]], current_distance+1)
+                    neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
+                    self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
+
+        return neighbors
+
+    def _new_position(self, position, movement):
+        if movement == 0:    # NORTH
+            return (position[0]-1, position[1])
+        elif movement == 1:  # EAST
+            return (position[0], position[1] + 1)
+        elif movement == 2:  # SOUTH
+            return (position[0]+1, position[1])
+        elif movement == 3:  # WEST
+            return (position[0], position[1] - 1)
+
 
     def get(self, handle):
         # TODO: compute the observation for agent `handle'
-
-        # raise NotImplementedError()
         return []
 
 
@@ -38,12 +156,235 @@ class GlobalObsForRailEnv(ObservationBuilder):
         - Four 2D arrays containing respectively the position of the given agent,
           the position of its target, the positions of the other agents and of
           their target.
+
+        - A 4 elements array with one of encoding of the direction.
     """
     def __init__(self, env):
         super(GlobalObsForRailEnv, self).__init__(env)
+
+    def reset(self):
         self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
         for i in range(self.rail_obs.shape[0]):
             for j in range(self.rail_obs.shape[1]):
-                self.rail_obs[i, j] = self.env.rail.get_transitions((i, j))
+                self.rail_obs[i, j] = np.array(
+                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
+
+        # self.targets = np.zeros(self.env.height, self.env.width)
+        # for target_pos in self.env.agents_target:
+        #     self.targets[target_pos] += 1
+
+    def get(self, handle):
+        obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
+        agent_pos = self.env.agents_position[handle]
+        obs_agents_targets_pos[0][agent_pos] += 1
+        for i in range(len(self.env.agents_position)):
+            if i != handle:
+                obs_agents_targets_pos[3][self.env.agents_position[i]] += 1
+
+        agent_target_pos = self.env.agents_target[handle]
+        obs_agents_targets_pos[1][agent_target_pos] += 1
+        for i in range(len(self.env.agents_target)):
+            if i != handle:
+                obs_agents_targets_pos[2][self.env.agents_target[i]] += 1
+
+        direction = np.zeros(4)
+        direction[self.env.agents_direction[handle]] = 1
+
+        return self.rail_obs, obs_agents_targets_pos, direction
+
+
+
+
+
+"""
+
+    def get_observation(self, agent):
+        # Get the current observation for an agent
+        current_position = self.internal_position[agent]
+        #target_heading = self._compass(agent).tolist()
+        coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
+        agent_distance = self.distance_map[agent][coordinate][0]
+        # Start tree search
+        if current_position == self.target[agent]:
+            agent_tree = Node(current_position, [-np.inf, -np.inf, -np.inf, -np.inf, -1])
+        else:
+            agent_tree = Node(current_position, [0, 0, 0, 0, agent_distance])
+
+        initial_tree_state = Tree_State(agent, current_position, -1, 0, 0)
+        self._tree_search(initial_tree_state, agent_tree, agent)
+        observation = []
+        distance_data = []
+
+        self._flatten_tree(agent_tree, observation, distance_data,  self.max_depth+1)
+        # This is probably very slow!!!!
+        #max_obs = np.max([i for i in observation if i < np.inf])
+        #if max_obs != 0:
+        #    observation = np.array(observation)/ max_obs
+
+        #print([i for i in distance_data if i >= 0])
+        observation = np.concatenate((observation, distance_data))
+        #observation = np.concatenate((observation, np.identity(5)[int(self.last_action[agent])]))
+        #return np.clip(observation / self.max_dist[agent], -1, 1)
+        return np.clip(observation / 15., -1, 1)
+
+
+
+
+    def _tree_search(self, in_tree_state, parent_node, agent):
+        if in_tree_state.depth >= self.max_depth:
+            return
+        target_distance = np.inf
+        other_target = np.inf
+        other_agent = np.inf
+        coordinate = tuple(np.transpose(self._position_to_coordinate([in_tree_state.position])))
+        curr_target_dist = self.distance_map[agent][coordinate][0]
+        forbidden_action = (in_tree_state.direction + 2) % 4
+        # Update the position
+        failed_move = 0
+        leaf_distance = in_tree_state.distance
+        for child_idx in range(4):
+            if child_idx != forbidden_action or in_tree_state.direction == -1:
+                tree_state = copy.deepcopy(in_tree_state)
+                tree_state.direction = child_idx
+                current_position, invalid_move = self._detect_path(
+                tree_state.position, tree_state.direction)
+                if tree_state.initial_direction == None:
+                    tree_state.initial_direction = child_idx
+                if not invalid_move:
+                    coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
+                    curr_target_dist = self.distance_map[agent][coordinate][0]
+                    #if tree_state.initial_direction == None:
+                    #    tree_state.initial_direction = child_idx
+                    tree_state.position = current_position
+                    tree_state.distance += 1
+
+
+                    # Collect information at the current position
+                    detection_distance = tree_state.distance
+                    if current_position == self.target[tree_state.agent]:
+                        target_distance = detection_distance
+
+                    elif current_position in self.target:
+                        other_target = detection_distance
+
+                    if current_position in self.internal_position:
+                        other_agent = detection_distance
+
+                    tree_state.data[0] = self._min_greater_zero(target_distance, tree_state.data[0])
+                    tree_state.data[1] = self._min_greater_zero(other_target, tree_state.data[1])
+                    tree_state.data[2] = self._min_greater_zero(other_agent, tree_state.data[2])
+                    tree_state.data[3] = tree_state.distance
+                    tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
+
+                    if self._switch_detection(tree_state.position):
+                        tree_state.depth += 1
+                        new_tree_state = copy.deepcopy(tree_state)
+                        new_node = parent_node.insert(tree_state.position,
+                         tree_state.data, tree_state.initial_direction)
+                        new_tree_state.initial_direction = None
+                        new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
+                        self._tree_search(new_tree_state, new_node, agent)
+                    else:
+                        self._tree_search(tree_state, parent_node, agent)
+                else:
+                    failed_move += 1
+            if failed_move == 3 and in_tree_state.direction != -1:
+                tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
+                parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction)
+                return
+        return
+
+    def _flatten_tree(self, node, observation_vector, distance_sensor, depth):
+        if depth <= 0:
+            return
+        if node != None:
+            observation_vector.extend(node.data[:-1])
+            distance_sensor.extend([node.data[-1]])
+        else:
+            observation_vector.extend([-np.inf, -np.inf, -np.inf, -np.inf])
+            distance_sensor.extend([-np.inf])
+        for child_idx in range(4):
+            if node != None:
+                child = node.children[child_idx]
+            else:
+                child = None
+            self._flatten_tree(child, observation_vector, distance_sensor,  depth -1)
+
+
+
+    def _switch_detection(self, position):
+        # Hack to detect switches
+        # This can later directly be derived from the transition matrix
+        paths = 0
+        for i in range(4):
+            _, invalid_move = self._detect_path(position, i)
+            if not invalid_move:
+                paths +=1
+            if paths >= 3:
+                return True
+        return False
+
+
+
+
+    def _min_greater_zero(self, x, y):
+        if x <= 0 and y <= 0:
+            return 0
+        if x < 0:
+            return y
+        if y < 0:
+            return x
+        return min(x, y)
+
+
+
+"""
+
+
+class Tree_State:
+    """
+    Keep track of the current state while building the tree
+    """
+    def __init__(self, agent, position, direction, depth, distance):
+        self.agent = agent
+        self.position = position
+        self.direction = direction
+        self.depth = depth
+        self.initial_direction = None
+        self.distance = distance
+        self.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
+
+class Node():
+    """
+    Define a tree node to get populated during search
+    """
+    def __init__(self, position, data):
+        self.n_children = 4
+        self.children = [None]*self.n_children
+        self.data = data
+        self.position = position
+
+    def insert(self, position, data, child_idx):
+        """
+        Insert new node with data
+
+        @param data node data object to insert
+        """
+        new_node = Node(position, data)
+        self.children[child_idx] = new_node
+        return new_node
+
+    def print_tree(self, i=0, depth = 0):
+        """
+        Print tree content inorder
+        """
+        current_i = i
+        curr_depth = depth+1
+        if i < self.n_children:
+            if self.children[i] != None:
+                self.children[i].print_tree(depth=curr_depth)
+            current_i += 1
+            if self.children[i] != None:
+                self.children[i].print_tree(i, depth=curr_depth)
 
 
diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py
index c3f5dfd56aa62069074753f367535c98831477f6..a89df6cd1faf7aa293531b4689e0d89e5d8e4946 100644
--- a/tests/test_env_observation_builder.py
+++ b/tests/test_env_observation_builder.py
@@ -2,8 +2,11 @@
 # -*- coding: utf-8 -*-
 
 from flatland.core.env_observation_builder import GlobalObsForRailEnv
-from flatland.core.transitions import Grid4Transitions
+# from flatland.core.transitions import Grid4Transitions
+from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
+from flatland.core.env import RailEnv
 import numpy as np
+from flatland.utils.rendertools import *
 
 """Tests for `flatland` package."""
 
@@ -43,18 +46,44 @@ def test_global_obs():
     double_switch_north_horizontal_straight = transitions.rotate_transition(
         double_switch_south_horizontal_straight, 180)
 
-
-
     rail_map = np.array(
         [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
-        [[empty] * 3 + [vertical_straight] + [empty] * 6]*2 +
-        [[horizontal_straight] * 3 + [double_switch_north_horizontal_straight] +
-        [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
-        [horizontal_straight] * 3] +
         [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
-        [[empty] * 3 + [dead_end_from_south] + [empty] * 6], dtype=np.uint16)
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [double_switch_north_horizontal_straight] +
+        [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
+        [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    env = RailEnv(rail, number_of_agents=1)
+
+    env.reset()
+    # env_renderer = RenderTool(env)
+    # env_renderer.renderEnv(show=True)
+
+    global_obs = GlobalObsForRailEnv(env)
+    global_obs.reset()
+    assert(global_obs.rail_obs.shape == rail_map.shape + (16,))
+
+    rail_map_recons = np.zeros_like(rail_map)
+    for i in range(global_obs.rail_obs.shape[0]):
+        for j in range(global_obs.rail_obs.shape[1]):
+            rail_map_recons[i,j] = int(
+                ''.join(global_obs.rail_obs[i, j].astype(int).astype(str)), 2)
+
+    assert(rail_map_recons.all() == rail_map.all())
+
+    obs = global_obs.get(0)
+
+    # If this assertion is wrong, it means that the observation returned
+    # places the agent on an empty cell
+    assert(np.sum(rail_map * obs[1][0]) > 0)
+
 
-    print(rail_map.shape)
 
 test_global_obs()