diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index f85afee4b625e59374c6cce266bf55b21e7fdb84..b30c2b1f5ddab079c9b6c41e35f03c69ed4162c3 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -7,13 +7,14 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and
 + Get() is called whenever an observation has to be computed, potentially for each agent independently in
 case of multi-agent environments.
 """
+import numpy as np
 
 
 class ObservationBuilder:
     """
     ObservationBuilder base class.
 
-    Derived objects must implement and `observation_space' attribute as a tuple with the dimensuions of the returned
+    Derived objects must implement and `observation_space' attribute as a tuple with the dimensions of the returned
     observations.
     """
 
@@ -45,3 +46,9 @@ class ObservationBuilder:
             An observation structure, specific to the corresponding environment.
         """
         raise NotImplementedError()
+
+    def _get_one_hot_for_agent_direction(self, agent):
+        """Retuns the agent's direction to one-hot encoding."""
+        direction = np.zeros(4)
+        direction[agent.direction] = 1
+        return direction
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index f5f70aa1c54484560f5e58d7a9b901f2607fe790..22ebf264183528762fb9b4c0c6bea07bfc726423 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -1,9 +1,10 @@
 """
 Collection of environment-specific ObservationBuilder.
 """
-import numpy as np
 from collections import deque
 
+import numpy as np
+
 from flatland.core.env_observation_builder import ObservationBuilder
 
 
@@ -22,10 +23,10 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Compute the size of the returned observation vector
         size = 0
         pow4 = 1
-        for i in range(self.max_depth+1):
+        for i in range(self.max_depth + 1):
             size += pow4
             pow4 *= 4
-        self.observation_space = [size * 5]
+        self.observation_space = [size * 6]
 
     def reset(self):
         agents = self.env.agents
@@ -186,6 +187,10 @@ class TreeObsForRailEnv(ObservationBuilder):
         #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
             branch.
 
+        #6: agent direction
+
+
+
         Missing/padding nodes are filled in with -inf (truncated).
         Missing values in present node are filled in with +inf (truncated).
 
@@ -202,13 +207,10 @@ class TreeObsForRailEnv(ObservationBuilder):
         if handle > len(self.env.agents):
             print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
         agent = self.env.agents[handle]  # TODO: handle being treated as index
-        # position = self.env.agents_position[handle]
-        # orientation = self.env.agents_direction[handle]
         possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
         num_transitions = np.count_nonzero(possible_transitions)
         # Root node - current position
-        # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
-        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
+        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], agent.direction]
         root_observation = observation[:]
         visited = set()
         # Start from the current orientation, and see which transitions are available;
@@ -337,40 +339,49 @@ class TreeObsForRailEnv(ObservationBuilder):
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
                            root_observation[3] + num_steps,
-                           0]
+                           0,
+                           direction]
 
         elif last_isTerminal:
             observation = [0,
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
                            np.inf,
-                           np.inf]
+                           np.inf,
+                           direction]
         else:
             observation = [0,
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
                            root_observation[3] + num_steps,
-                           self.distance_map[handle, position[0], position[1], direction]]
+                           self.distance_map[handle, position[0], position[1], direction],
+                           direction]
         """
         if last_isTarget:
             observation = [0,
                            other_target_encountered,
                            other_agent_encountered,
                            root_observation[3] + num_steps,
-                           0]
+                           0,
+                           direction
+                           ]
 
         elif last_isTerminal:
             observation = [0,
                            other_target_encountered,
                            other_agent_encountered,
                            np.inf,
-                           np.inf]
+                           np.inf,
+                           direction
+                           ]
         else:
             observation = [0,
                            other_target_encountered,
                            other_agent_encountered,
                            root_observation[3] + num_steps,
-                           self.distance_map[handle, position[0], position[1], direction]]
+                           self.distance_map[handle, position[0], position[1], direction],
+                           direction
+                           ]
         # #############################
         # #############################
 
@@ -409,7 +420,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 for i in range(self.max_depth - depth):
                     num_cells_to_fill_in += pow4
                     pow4 *= 4
-                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
 
         return observation, visited
 
@@ -532,7 +543,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
                 obs_agents_state[agent2.position][4 + agent2.direction] = 1
                 obs_targets[agent2.target][1] += 1
 
-        return self.rail_obs, obs_agents_state, obs_targets
+        direction = self._get_one_hot_for_agent_direction(agent)
+
+        return self.rail_obs, obs_agents_state, obs_targets, direction
 
 
 class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
@@ -542,13 +555,15 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
 
         - transition map array with dimensions (env.height, env.width, 16),
           assuming 16 bits encoding of transitions, flipped in the direction of the agent
-          (the agent is always heding north on the flipped view).
+          (the agent is always heading north on the flipped view).
 
         - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
          target and the positions of the other agents targets, also flipped depending on the agent's direction.
 
         - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
           agents at their position coordinates, and the last channel containing the position of the given agent.
+
+        - A 4 elements array with one hot encoding of the direction.
     """
 
     def __init__(self):
@@ -603,7 +618,9 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
                 obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
                 obs_targets[agent2.target][1] += 1
 
-        return rail_obs, obs_agents_state, obs_targets
+        direction = self._get_one_hot_for_agent_direction(agent)
+
+        return rail_obs, obs_agents_state, obs_targets, direction
 
 
 class LocalObsForRailEnv(ObservationBuilder):
@@ -635,8 +652,8 @@ class LocalObsForRailEnv(ObservationBuilder):
         # We build the transition map with a view_radius empty cells expansion on each side.
         # This helps to collect the local transition map view when the agent is close to a border.
 
-        self.rail_obs = np.zeros((self.env.height + 2*self.view_radius,
-                                  self.env.width + 2*self.view_radius, 16))
+        self.rail_obs = np.zeros((self.env.height + 2 * self.view_radius,
+                                  self.env.width + 2 * self.view_radius, 16))
         for i in range(self.env.height):
             for j in range(self.env.width):
                 bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
@@ -654,12 +671,12 @@ class LocalObsForRailEnv(ObservationBuilder):
         # top_offset = max(0, agent.position[0] - 1 - self.view_radius)
         # bottom_offset = min(0, agent.position[0] + 1 + self.view_radius)
 
-        local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0]+2*self.view_radius + 1,
-                                       agent.position[1]:agent.position[1]+2*self.view_radius + 1]
+        local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0] + 2 * self.view_radius + 1,
+                         agent.position[1]:agent.position[1] + 2 * self.view_radius + 1]
 
-        obs_map_state = np.zeros((2*self.view_radius + 1, 2*self.view_radius + 1, 2))
+        obs_map_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 2))
 
-        obs_other_agents_state = np.zeros((2*self.view_radius + 1, 2*self.view_radius + 1, 4))
+        obs_other_agents_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 4))
 
         def relative_pos(pos):
             return [agent.position[0] - pos[0], agent.position[1] - pos[1]]
@@ -684,15 +701,11 @@ class LocalObsForRailEnv(ObservationBuilder):
                 if is_in(target_rel_pos_2):
                     obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1
 
-        direction = np.zeros(4)
-        direction[agent.direction] = 1
+        direction = self._get_one_hot_for_agent_direction(agent)
 
         return local_rail_obs, obs_map_state, obs_other_agents_state, direction
 
-
 # class LocalObsForRailEnvImproved(ObservationBuilder):
 #     """
 #     Returns a local observation around the given agent
 #     """
-
-