diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index edd9cae7d9c632a0564d31607c1d661bdb42641c..a6fbae6d0d271f47e98d08262c7fbc2801b7142d 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -82,8 +82,7 @@ class TreeObsForRailEnv(ObservationBuilder): """ # Returns max distance to target, from the farthest away node, while filling in distance_map - for ori in range(4): - self.distance_map[target_nr, position[0], position[1], ori] = 0 + self.distance_map[target_nr, position[0], position[1], :] = 0 # Fill in the (up to) 4 neighboring nodes # nodes_queue = [] # list of tuples (row, col, direction, distance); @@ -237,14 +236,18 @@ class TreeObsForRailEnv(ObservationBuilder): position = self.env.agents_position[handle] orientation = self.env.agents_direction[handle] possible_transitions = self.env.rail.get_transitions((position[0], position[1], orientation)) - + num_transitions = np.count_nonzero(possible_transitions) # Root node - current position observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] root_observation = observation[:] # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation - # TODO: Adjust this to the novel movement dynamics --> Only Forward present when one transition is possible. + # If only one transition is possible, the tree is oriented with this transition as the forward branch. + # TODO: Test if this works as desired! + if num_transitions == 1: + orientation == np.argmax(possible_transitions) + for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: new_cell = self._new_position(position, branch_direction)