diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 991399002909144adc9d84beb3a959e2dd011b5b..1280117558fe6dc7b058e5747af2235c981acf8d 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -404,7 +404,8 @@ class TreeObsForRailEnv(ObservationBuilder):
 
             # Check if crossing is found --> Not an unusable switch
             if int(transition_bit, 2) == int('1000010000100001', 2):
-                total_transitions = 1
+                # Treat the crossing as a straight rail cell
+                total_transitions = 2
             num_transitions = np.count_nonzero(cell_transitions)
 
             exploring = False