diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index 821384d492545127b0c5dfcb3de8e8efb3ddbb34..4d3501784b3bc07cb88ccbe057e388b135183a28 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -69,6 +69,31 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p return rail_trans.is_valid(new_trans) +def position_to_coordinate(width, position): + """ + + :param width: + :param position: + :return: + """ + coords = () + for p in position: + coords = coords + ((int(p) % width, int(p) // width),) # changed x_dim to y_dim + return coords + + +def coordinate_to_position(width, coords): + """ + + :param width: + :param coords: + :return: + """ + position = [] + for t in coords: + position.append((t[1] * width + t[0])) + return np.array(position) + class AStarNode(): """A node class for A* Pathfinding""" diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 541f8ad592d1481afb8eb6da2eb7b887aacae419..8775b9d899523e4305e42822e2d86cfa027d7d11 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -173,10 +173,13 @@ class TreeObsForRailEnv(ObservationBuilder): in the `handles' list. """ - self.predictions = [] + if self.predictor: - for a in range(len(handles)): - self.predictions.append(self.predictor.get(a)) + self.predictions = self.predictor.get() + pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0) + pred_pos = list(map(list, zip(*pred_pos))) + pred_dir = [x[:, 2] for x in list(self.predictions.values())] + observations = {} for h in handles: observations[h] = self.get(h)