diff --git a/examples/training_example.py b/examples/training_example.py
index 67e0c740c2b15e25c65dde8036831f3ef062a831..a05f7c727ac9a56453951de453ea184cab1ea4fc 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -14,11 +14,11 @@ np.random.seed(1)
 
 TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
-env = RailEnv(width=20,
-              height=20,
+env = RailEnv(width=50,
+              height=50,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
               obs_builder_object=LocalGridObs,
-              number_of_agents=2)
+              number_of_agents=5)
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index d0afc71b3d09d42cf374ef61deaba6008d7fc420..19eb9ab9ee18e8f5f7a389bfe83368cae8187fd0 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -704,7 +704,7 @@ class LocalObsForRailEnv(ObservationBuilder):
         - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively,
         if they are in the agent's vision range, its target position, the positions of the other targets.
 
-        - A 3D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
+        - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
           of the other agents at their position coordinates, if they are in the agent's vision range.
 
         - A 4 elements array with one hot encoding of the direction.
@@ -724,43 +724,52 @@ class LocalObsForRailEnv(ObservationBuilder):
         # We build the transition map with a view_radius empty cells expansion on each side.
         # This helps to collect the local transition map view when the agent is close to a border.
         self.max_padding = max(self.view_width, self.view_height)
-        self.rail_obs = np.zeros((self.env.height + 2 * self.max_padding,
-                                  self.env.width + 2 * self.max_padding, 16))
+        self.rail_obs = np.zeros((self.env.height,
+                                  self.env.width, 16))
         for i in range(self.env.height):
             for j in range(self.env.width):
                 bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
-                self.rail_obs[i + self.view_height, j + self.view_width] = np.array(bitlist)
+                self.rail_obs[i, j] = np.array(bitlist)
 
     def get(self, handle):
         agents = self.env.agents
         agent = agents[handle]
-        agent_rel_pos = [0, 0]
 
         # Correct agents position for padding
-        agent_rel_pos[0] = agent.position[0] + self.max_padding
-        agent_rel_pos[1] = agent.position[1] + self.max_padding
-
-        # Collect the rail information in the local field of view
-        local_rail_obs = self.field_of_view(agent_rel_pos, agent.direction, state=self.rail_obs)
+        # agent_rel_pos[0] = agent.position[0] + self.max_padding
+        # agent_rel_pos[1] = agent.position[1] + self.max_padding
 
         # Collect visible cells as set to be plotted
-        visited, rel_coords = self.field_of_view(agent.position, agent.direction)
+        visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
+        local_rail_obs = None
 
         # Add the visible cells to the observed cells
         self.env.dev_obs_dict[handle] = set(visited)
 
         # Locate observed agents and their coresponding targets
-        obs_map_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 2))
-        obs_other_agents_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 4))
-
-        for i in range(len(agents)):
-            temp_agent = agents[i]
-            if temp_agent.target in visited:
-                location = np.where(visited == temp_agent.target)
-                print("I see my target", location, handle)
-        direction = self._get_one_hot_for_agent_direction(agent)
-
+        local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
+        obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
+        obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
+        _idx = 0
+        for pos in visited:
+            curr_rel_coord = rel_coords[_idx]
+            local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
+            if pos == agent.target:
+                obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
+            else:
+                for tmp_agent in agents:
+                    if pos == tmp_agent.target:
+                        obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
+            if pos != agent.position:
+                for tmp_agent in agents:
+                    if pos == tmp_agent.position:
+                        obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
+                            tmp_agent.direction]
+
+            _idx += 1
+
+        direction = np.identity(4)[agent.direction]
         return local_rail_obs, obs_map_state, obs_other_agents_state, direction
 
     def get_many(self, handles=None):
@@ -796,26 +805,26 @@ class LocalObsForRailEnv(ObservationBuilder):
                     if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
                         visible.append((origin[0] - h, origin[1] + w))
                         rel_coords.append((h, w))
-                    if data_collection:
-                        temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
                 elif direction == 1:
                     if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
                         visible.append((origin[0] + w, origin[1] + h))
                         rel_coords.append((h, w))
-                    if data_collection:
-                        temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
                 elif direction == 2:
-                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
+                    if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
                         visible.append((origin[0] + h, origin[1] - w))
                         rel_coords.append((h, w))
-                    if data_collection:
-                        temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
                 else:
-                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
+                    if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
                         visible.append((origin[0] - w, origin[1] - h))
                         rel_coords.append((h, w))
-                    if data_collection:
-                        temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
         if data_collection:
             return temp_visible_data
         else: