From 1a91a7383d8b7358b2ef0d080d7f62ca9a51af13 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Fri, 14 Jun 2019 00:04:06 +0200 Subject: [PATCH] Minor bugfixes in shortes path predictor --- examples/training_example.py | 2 +- flatland/envs/observations.py | 1 - flatland/envs/predictions.py | 15 ++++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/training_example.py b/examples/training_example.py index 78336b9..218f133 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -80,7 +80,7 @@ for trials in range(1, n_trials + 1): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) - TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8) + # TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8) # Update replay buffer and train agent for a in range(env.get_num_agents()): agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 885a966..b5d9c2e 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -178,7 +178,6 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_pos = {} self.predicted_dir = {} self.predictions = self.predictor.get(self.distance_map) - for t in range(len(self.predictions[0])): pos_list = [] dir_list = [] diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 988594e..366acfe 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -127,14 +127,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): if np.sum(cell_transitions) == 1: new_direction = np.argmax(cell_transitions) new_position = self._new_position(agent.position, new_direction) - else: + elif np.sum(cell_transitions) > 1: + min_dist = np.inf for direct in range(4): - min_dist = np.inf - target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct] - if target_dist < min_dist: - min_dist = target_dist - new_direction = direct - new_position = self._new_position(agent.position, new_direction) + if cell_transitions[direct] == 1: + target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct] + if target_dist < min_dist: + min_dist = target_dist + new_direction = direct + new_position = self._new_position(agent.position, new_direction) agent.position = new_position agent.direction = new_direction -- GitLab