diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 2887610b06d5455bd7612d3c6f0ad386e086c6be..991399002909144adc9d84beb3a959e2dd011b5b 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -399,7 +399,12 @@ class TreeObsForRailEnv(ObservationBuilder): break cell_transitions = self.env.rail.get_transitions(*position, direction) - total_transitions = bin(self.env.rail.get_full_transitions(*position)).count("1") + transition_bit = bin(self.env.rail.get_full_transitions(*position)) + total_transitions = transition_bit.count("1") + + # Check if crossing is found --> Not an unusable switch + if int(transition_bit, 2) == int('1000010000100001', 2): + total_transitions = 1 num_transitions = np.count_nonzero(cell_transitions) exploring = False