From 61b289feefba9f0c8a4aa1aed57d3ac1bfadc9a1 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 5 Oct 2019 09:31:00 -0400 Subject: [PATCH] updated single agent navigation to work with new env --- torch_training/training_navigation.py | 7 +++---- utils/observation_utils.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index ad512c6..3a61d1f 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -38,8 +38,7 @@ def main(argv): x_dim = 20 y_dim = 20 n_agents = 1 - n_goals = 5 - min_dist = 5 + # Use a the malfunction generator to break agents from time to time stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents @@ -149,7 +148,7 @@ def main(argv): # Build agent specific observations and normalize for a in range(env.get_num_agents()): - agent_next_obs[a] = normalize_observation(next_obs[a], observation_radius=10) + agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) cummulated_reward[a] += all_rewards[a] # Update replay buffer and train agent @@ -186,7 +185,7 @@ def main(argv): for _idx in range(env.get_num_agents()): if done[_idx] == 1: tasks_finished += 1 - done_window.append(tasks_finished / env.get_num_agents()) + done_window.append(tasks_finished / max(1, env.get_num_agents())) scores_window.append(score / max_steps) # save most recent score scores.append(np.mean(scores_window)) dones_list.append((np.mean(done_window))) diff --git a/utils/observation_utils.py b/utils/observation_utils.py index e9eb3ed..ddb0374 100644 --- a/utils/observation_utils.py +++ b/utils/observation_utils.py @@ -89,7 +89,7 @@ def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tre if not node.childs: return data, distance, agent_data - for direction in TreeObsForRailEnv.tree_explorted_actions_char: + for direction in TreeObsForRailEnv.tree_explored_actions_char: sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth) data = np.concatenate((data, sub_data)) distance = np.concatenate((distance, sub_distance)) -- GitLab