diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index e8df61e5d5ccad6631a1fc17ecfa746a1525cba6..db22a8f4afc19d06efd714cd2e881db0d9c530fe 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -207,8 +207,10 @@ class FastTreeObs(ObservationBuilder): possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) if agents_on_switch: orientation = new_direction - if fast_count_nonzero(possible_transitions) == 1: + possible_transitions_nonzero = fast_count_nonzero(possible_transitions) + if possible_transitions_nonzero == 1: orientation = fast_argmax(possible_transitions) + for dir_loop, branch_direction in enumerate( [(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]): # branch the exploration path and aggregate the found information