Commit 1a91a738 authored by Erik Nygren's avatar Erik Nygren
Browse files

Minor bugfixes in shortes path predictor

parent 9890abf5
Pipeline #1088 failed with stages
in 9 minutes and 21 seconds
......@@ -80,7 +80,7 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
# TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
......
......@@ -178,7 +178,6 @@ class TreeObsForRailEnv(ObservationBuilder):
self.predicted_pos = {}
self.predicted_dir = {}
self.predictions = self.predictor.get(self.distance_map)
for t in range(len(self.predictions[0])):
pos_list = []
dir_list = []
......
......@@ -127,14 +127,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions)
new_position = self._new_position(agent.position, new_direction)
else:
elif np.sum(cell_transitions) > 1:
min_dist = np.inf
for direct in range(4):
min_dist = np.inf
target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
if target_dist < min_dist:
min_dist = target_dist
new_direction = direct
new_position = self._new_position(agent.position, new_direction)
if cell_transitions[direct] == 1:
target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
if target_dist < min_dist:
min_dist = target_dist
new_direction = direct
new_position = self._new_position(agent.position, new_direction)
agent.position = new_position
agent.direction = new_direction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment