diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index dc5741b95a39442065328667ccd9c53cda4c408e..53630476874072c613bede5daf1ac76bdab33625 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -437,7 +437,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth - depth): num_cells_to_fill_in += pow4 pow4 *= 4 - observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in + observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in return observation, visited @@ -496,8 +496,8 @@ class TreeObsForRailEnv(ObservationBuilder): child_tree = tree[(num_features_per_node + children * child_size): (num_features_per_node + (children + 1) * child_size)] tmp_tree_data, tmp_distance_data, tmp_agent_data = self.split_tree(child_tree, - num_features_per_node, - current_depth=current_depth + 1) + num_features_per_node, + current_depth=current_depth + 1) if len(tmp_tree_data) > 0: tree_data.extend(tmp_tree_data) distance_data.extend(tmp_distance_data)