diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 931fabc7c9e9b4f42b659490c8b3cbd43cb11596..5d879d26d87d0499edcf5cf9ba3350af8083a8f1 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -178,6 +178,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_pos = {} self.predicted_dir = {} self.predictions = self.predictor.get(self.distance_map) + print(self.predictions) for t in range(len(self.predictions[0])): pos_list = [] dir_list = [] diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index f5f559a1fb83f71ac9ca2ac514bc3d9f6a4b024f..cff4e00e14d73981dd623e796554f46a150c8a80 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -114,12 +114,14 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] for index in range(1, self.max_depth + 1): - action_done = False # if we're at the target, stop moving... if agent.position == agent.target: prediction[index] = [index, agent.target[0], agent.target[1], agent.direction, RailEnvActions.STOP_MOVING] - + continue + if not agent.moving: + prediction[index] = [index, agent.position[0], agent.position[1], agent.direction, + RailEnvActions.STOP_MOVING] continue # Take shortest possible path cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))