Skip to content
Snippets Groups Projects
Commit 1a91a738 authored by Erik Nygren's avatar Erik Nygren
Browse files

Minor bugfixes in shortes path predictor

parent 9890abf5
No related branches found
No related tags found
No related merge requests found
...@@ -80,7 +80,7 @@ for trials in range(1, n_trials + 1): ...@@ -80,7 +80,7 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding # Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done # reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8) # TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
......
...@@ -178,7 +178,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -178,7 +178,6 @@ class TreeObsForRailEnv(ObservationBuilder):
self.predicted_pos = {} self.predicted_pos = {}
self.predicted_dir = {} self.predicted_dir = {}
self.predictions = self.predictor.get(self.distance_map) self.predictions = self.predictor.get(self.distance_map)
for t in range(len(self.predictions[0])): for t in range(len(self.predictions[0])):
pos_list = [] pos_list = []
dir_list = [] dir_list = []
......
...@@ -127,14 +127,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -127,14 +127,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
if np.sum(cell_transitions) == 1: if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions) new_direction = np.argmax(cell_transitions)
new_position = self._new_position(agent.position, new_direction) new_position = self._new_position(agent.position, new_direction)
else: elif np.sum(cell_transitions) > 1:
min_dist = np.inf
for direct in range(4): for direct in range(4):
min_dist = np.inf if cell_transitions[direct] == 1:
target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct] target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct]
if target_dist < min_dist: if target_dist < min_dist:
min_dist = target_dist min_dist = target_dist
new_direction = direct new_direction = direct
new_position = self._new_position(agent.position, new_direction) new_position = self._new_position(agent.position, new_direction)
agent.position = new_position agent.position = new_position
agent.direction = new_direction agent.direction = new_direction
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment