From 61b289feefba9f0c8a4aa1aed57d3ac1bfadc9a1 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sat, 5 Oct 2019 09:31:00 -0400
Subject: [PATCH] updated single agent navigation to work with new env

---
 torch_training/training_navigation.py | 7 +++----
 utils/observation_utils.py            | 2 +-
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index ad512c6..3a61d1f 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -38,8 +38,7 @@ def main(argv):
     x_dim = 20
     y_dim = 20
     n_agents = 1
-    n_goals = 5
-    min_dist = 5
+
 
     # Use a the malfunction generator to break agents from time to time
     stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
@@ -149,7 +148,7 @@ def main(argv):
 
             # Build agent specific observations and normalize
             for a in range(env.get_num_agents()):
-                agent_next_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
+                agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
                 cummulated_reward[a] += all_rewards[a]
 
             # Update replay buffer and train agent
@@ -186,7 +185,7 @@ def main(argv):
         for _idx in range(env.get_num_agents()):
             if done[_idx] == 1:
                 tasks_finished += 1
-        done_window.append(tasks_finished / env.get_num_agents())
+        done_window.append(tasks_finished / max(1, env.get_num_agents()))
         scores_window.append(score / max_steps)  # save most recent score
         scores.append(np.mean(scores_window))
         dones_list.append((np.mean(done_window)))
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index e9eb3ed..ddb0374 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -89,7 +89,7 @@ def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tre
     if not node.childs:
         return data, distance, agent_data
 
-    for direction in TreeObsForRailEnv.tree_explorted_actions_char:
+    for direction in TreeObsForRailEnv.tree_explored_actions_char:
         sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth)
         data = np.concatenate((data, sub_data))
         distance = np.concatenate((distance, sub_distance))
-- 
GitLab