Skip to content
Snippets Groups Projects
Commit 61b289fe authored by Erik Nygren's avatar Erik Nygren
Browse files

updated single agent navigation to work with new env

parent fcc1ee6a
No related branches found
No related tags found
No related merge requests found
...@@ -38,8 +38,7 @@ def main(argv): ...@@ -38,8 +38,7 @@ def main(argv):
x_dim = 20 x_dim = 20
y_dim = 20 y_dim = 20
n_agents = 1 n_agents = 1
n_goals = 5
min_dist = 5
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents
...@@ -149,7 +148,7 @@ def main(argv): ...@@ -149,7 +148,7 @@ def main(argv):
# Build agent specific observations and normalize # Build agent specific observations and normalize
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_next_obs[a] = normalize_observation(next_obs[a], observation_radius=10) agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
cummulated_reward[a] += all_rewards[a] cummulated_reward[a] += all_rewards[a]
# Update replay buffer and train agent # Update replay buffer and train agent
...@@ -186,7 +185,7 @@ def main(argv): ...@@ -186,7 +185,7 @@ def main(argv):
for _idx in range(env.get_num_agents()): for _idx in range(env.get_num_agents()):
if done[_idx] == 1: if done[_idx] == 1:
tasks_finished += 1 tasks_finished += 1
done_window.append(tasks_finished / env.get_num_agents()) done_window.append(tasks_finished / max(1, env.get_num_agents()))
scores_window.append(score / max_steps) # save most recent score scores_window.append(score / max_steps) # save most recent score
scores.append(np.mean(scores_window)) scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window))) dones_list.append((np.mean(done_window)))
......
...@@ -89,7 +89,7 @@ def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tre ...@@ -89,7 +89,7 @@ def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tre
if not node.childs: if not node.childs:
return data, distance, agent_data return data, distance, agent_data
for direction in TreeObsForRailEnv.tree_explorted_actions_char: for direction in TreeObsForRailEnv.tree_explored_actions_char:
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth) sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth)
data = np.concatenate((data, sub_data)) data = np.concatenate((data, sub_data))
distance = np.concatenate((distance, sub_distance)) distance = np.concatenate((distance, sub_distance))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment