diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 8590922131d57cd76ba4a638b10e3e7e169a984a..83584a302f0e4064cef187171db291b637126242 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -295,21 +295,16 @@ class TreeObsForRailEnv(ObservationBuilder):
 
                 if self.location_has_agent_direction[position] == direction:
                     # Cummulate the number of agents on branch with same direction
-                    other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0)
+                    other_agent_same_direction += 1
 
                     # Check fractional speed of agents
                     current_fractional_speed = self.location_has_agent_speed[position]
                     if current_fractional_speed < min_fractional_speed:
                         min_fractional_speed = current_fractional_speed
 
-                    # Other direction agents
-                    # TODO: Test that this behavior is as expected
-                    other_agent_opposite_direction += \
-                        self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction),
-                                                                                                  0)
-
                 else:
                     # If no agent in the same direction was found all agents in that position are other direction
+                    # Attention this counts to many agents as a few might be going off on a switch.
                     other_agent_opposite_direction += self.location_has_agent[position]
 
                 # Check number of possible transitions for agent and total number of transitions in cell (type)