From 28e339bb2848b38881f08f1accdc2751161f3e6b Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Wed, 3 Jul 2019 16:17:43 -0400
Subject: [PATCH] taking new observation features into account

---
 torch_training/training_navigation.py | 27 +++++++++++++--------------
 utils/observation_utils.py            |  9 ++++-----
 2 files changed, 17 insertions(+), 19 deletions(-)

diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 0e5ad18..dc01f7f 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
-
+from flatland.envs.generators import complex_rail_generator
 from utils.observation_utils import norm_obs_clip, split_tree
 
 random.seed(1)
@@ -40,26 +40,26 @@ env = RailEnv(width=15,
               height=15,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
               number_of_agents=1)
-"""
+
 env = RailEnv(width=10,
               height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
 env.load("./railway/complex_scene.pkl")
 file_load = True
 """
 
-env = RailEnv(width=20,
-              height=20,
-              rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
+env = RailEnv(width=10,
+              height=10,
+              rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
               obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-              number_of_agents=15)
+              number_of_agents=3)
 file_load = False
 env.reset(True, True)
-
+"""
 """
 env_renderer = RenderTool(env, gl="PILSVG",)
 handle = env.get_agent_handles()
-
-state_size = 168 * 2
+features_per_node = 9
+state_size = features_per_node*21 * 2
 action_size = 5
 n_trials = 15000
 max_steps = int(3 * (env.height + env.width))
@@ -77,9 +77,9 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
+#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
 
-demo = True
+demo = False
 record_images = False
 
 
@@ -97,8 +97,7 @@ for trials in range(1, n_trials + 1):
     final_obs = obs.copy()
     final_obs_next = obs.copy()
     for a in range(env.get_num_agents()):
-        print(a)
-        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=8,
+        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
                                                 current_depth=0)
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
@@ -136,7 +135,7 @@ for trials in range(1, n_trials + 1):
 
         next_obs, all_rewards, done, _ = env.step(action_dict)
         for a in range(env.get_num_agents()):
-            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=8,
+            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node,
                                                     current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index 63adfff..0c97b18 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -48,7 +48,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
     return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
 
 
-def split_tree(tree, num_features_per_node=8, current_depth=0):
+def split_tree(tree, num_features_per_node=9, current_depth=0):
     """
     Splits the tree observation into different sub groups that need the same normalization.
     This is necessary because the tree observation includes two different distance:
@@ -80,10 +80,9 @@ def split_tree(tree, num_features_per_node=8, current_depth=0):
     Here we split the node features into the different classes of distances and binary values.
     Pay close attention to this part if you modify any of the features in the tree observation.
     """
-    tree_data = tree[:4].tolist()
-    distance_data = [tree[4]]
-    agent_data = tree[5:num_features_per_node].tolist()
-
+    tree_data = tree[:6].tolist()
+    distance_data = [tree[6]]
+    agent_data = tree[7:num_features_per_node].tolist()
     # Split each child of the current node and continue to next depth level
     for children in range(4):
         child_tree = tree[(num_features_per_node + children * child_size):
-- 
GitLab