diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index aa9da8494b4a7376a3468c16618e5459344a37f3..1403ac7aedf4fceb76abc3eeef34d950adb1f420 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -34,8 +34,11 @@ class SingleAgentNavigationObs(ObservationBuilder): def get(self, handle: int = 0) -> List[int]: agent = self.env.agents[handle] + if agent.position: + possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) + else: + possible_transitions = self.env.rail.get_transitions(*agent.initial_position, agent.direction) - possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Start from the current orientation, and see which transitions are available;