diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index bd095dba527dae03e7ccc35b09c19039119ed6c9..4476bef5ac6ed271a860a6e9071f5b803448d75a 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -491,7 +491,7 @@ class TreeObsForRailEnv(ObservationBuilder): child_size = (len(tree) - num_features_per_node) // 4 tree_data = tree[0:num_features_per_node - 1].tolist() distance_data = [tree[num_features_per_node - 3]] - agent_data = tree[-2:] + agent_data = tree[-2:].tolist() for children in range(4): child_tree = tree[(num_features_per_node + children * child_size): (num_features_per_node + (children + 1) * child_size)] @@ -501,7 +501,7 @@ class TreeObsForRailEnv(ObservationBuilder): if len(tmp_tree_data) > 0: tree_data.extend(tmp_tree_data) distance_data.extend(tmp_distance_data) - agent_data.extrend(tmp_agent_data) + agent_data.extend(tmp_agent_data) return tree_data, distance_data, agent_data