diff --git a/examples/training_example.py b/examples/training_example.py index 78336b96e2b4e9ed7da1be3253d83c769a9a21f4..218f133efe4e79064ae94b24c61aa2d6f2ce2d09 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -80,7 +80,7 @@ for trials in range(1, n_trials + 1): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) - TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8) + # TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8) # Update replay buffer and train agent for a in range(env.get_num_agents()): agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 885a966b87e23e6294b090fa7c9338ff3893cb60..b5d9c2e0ba6db708c394aa99ca73f1e42bb4cb43 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -178,7 +178,6 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_pos = {} self.predicted_dir = {} self.predictions = self.predictor.get(self.distance_map) - for t in range(len(self.predictions[0])): pos_list = [] dir_list = [] diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 988594e63ad72b895127934258be19c1de85d0b6..366acfebeb8c5ba4f56138c9bab0bcbb904a7b11 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -127,14 +127,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): if np.sum(cell_transitions) == 1: new_direction = np.argmax(cell_transitions) new_position = self._new_position(agent.position, new_direction) - else: + elif np.sum(cell_transitions) > 1: + min_dist = np.inf for direct in range(4): - min_dist = np.inf - target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct] - if target_dist < min_dist: - min_dist = target_dist - new_direction = direct - new_position = self._new_position(agent.position, new_direction) + if cell_transitions[direct] == 1: + target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct] + if target_dist < min_dist: + min_dist = target_dist + new_direction = direct + new_position = self._new_position(agent.position, new_direction) agent.position = new_position agent.direction = new_direction