From a7ddd74b3336b6d934fa4fb5ab76b965558c5268 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 18 Jul 2019 12:07:41 -0400
Subject: [PATCH] added new utility function to normalize tree observation

---
 sequential_agent/run_test.py            |  4 ++--
 torch_training/multi_agent_inference.py | 31 ++++++++-----------------
 torch_training/multi_agent_training.py  | 17 +++-----------
 utils/observation_utils.py              | 10 ++++++++
 4 files changed, 25 insertions(+), 37 deletions(-)

diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py
index 6e9f7c2..970d6aa 100644
--- a/sequential_agent/run_test.py
+++ b/sequential_agent/run_test.py
@@ -37,9 +37,9 @@ tree_depth = 1
 observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
 env_renderer = RenderTool(env, gl="PILSVG", )
 handle = env.get_agent_handles()
-n_trials = 10
+n_trials = 1
 max_steps = 3 * (env.height + env.width)
-record_images = False
+record_images = True
 agent = OrderedAgent()
 action_dict = dict()
 
diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index 003b18a..e399126 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -12,11 +12,11 @@ from importlib_resources import path
 
 import torch_training.Nets
 from torch_training.dueling_double_dqn import Agent
-from utils.observation_utils import norm_obs_clip, split_tree
+from utils.observation_utils import normalize_observation
 
 random.seed(3)
 np.random.seed(2)
-
+"""
 file_name = "./railway/complex_scene.pkl"
 env = RailEnv(width=10,
               height=20,
@@ -27,9 +27,9 @@ y_dim = env.height
 
 """
 
-x_dim = 50 #np.random.randint(8, 20)
-y_dim = 50 #np.random.randint(8, 20)
-n_agents = 20  # np.random.randint(3, 8)
+x_dim = 10  # np.random.randint(8, 20)
+y_dim = 10  # np.random.randint(8, 20)
+n_agents = 5  # np.random.randint(3, 8)
 n_goals = n_agents + np.random.randint(0, 3)
 min_dist = int(0.75 * min(x_dim, y_dim))
 
@@ -41,7 +41,7 @@ env = RailEnv(width=x_dim,
               obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
               number_of_agents=n_agents)
 env.reset(True, True)
-"""
+
 tree_depth = 3
 observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
 env_renderer = RenderTool(env, gl="PILSVG", )
@@ -53,7 +53,7 @@ for i in range(tree_depth + 1):
 state_size = num_features_per_node * nr_nodes
 action_size = 5
 
-n_trials = 1
+n_trials = 10
 observation_radius = 10
 max_steps = int(3 * (env.height + env.width))
 eps = 1.
@@ -73,7 +73,7 @@ agent = Agent(state_size, action_size, "FC", 0)
 with path(torch_training.Nets, "avoid_checkpoint52800.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
-record_images = True
+record_images = False
 frame_step = 0
 
 for trials in range(1, n_trials + 1):
@@ -84,12 +84,7 @@ for trials in range(1, n_trials + 1):
     env_renderer.reset()
 
     for a in range(env.get_num_agents()):
-        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node,
-                                                current_depth=0)
-        data = norm_obs_clip(data, fixed_radius=observation_radius)
-        distance = norm_obs_clip(distance)
-        agent_data = np.clip(agent_data, -1, 1)
-        agent_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+        agent_obs[a] = normalize_observation(obs[a], observation_radius=10)
 
     # Run episode
     for step in range(max_steps):
@@ -108,13 +103,7 @@ for trials in range(1, n_trials + 1):
 
         next_obs, all_rewards, done, _ = env.step(action_dict)
         for a in range(env.get_num_agents()):
-            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
-                                                    num_features_per_node=num_features_per_node,
-                                                    current_depth=0)
-            data = norm_obs_clip(data, fixed_radius=observation_radius)
-            distance = norm_obs_clip(distance)
-            agent_data = np.clip(agent_data, -1, 1)
-            agent_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+            agent_obs[a] = agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
 
         if done['__all__']:
             break
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 476066a..7659b2d 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -17,7 +17,7 @@ from importlib_resources import path
 # Import Torch and utility functions to normalize observation
 import torch_training.Nets
 from torch_training.dueling_double_dqn import Agent
-from utils.observation_utils import norm_obs_clip, split_tree
+from utils.observation_utils import normalize_observation
 
 
 def main(argv):
@@ -131,13 +131,7 @@ def main(argv):
 
         # Build agent specific observations
         for a in range(env.get_num_agents()):
-            data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node,
-                                                    current_depth=0)
-            data = norm_obs_clip(data, fixed_radius=observation_radius)
-            distance = norm_obs_clip(distance)
-            agent_data = np.clip(agent_data, -1, 1)
-            agent_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
-
+            agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], observation_radius=10)
         score = 0
         env_done = 0
 
@@ -155,12 +149,7 @@ def main(argv):
 
             # Build agent specific observations and normalize
             for a in range(env.get_num_agents()):
-                data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
-                                                        num_features_per_node=num_features_per_node, current_depth=0)
-                data = norm_obs_clip(data, fixed_radius=observation_radius)
-                distance = norm_obs_clip(distance)
-                agent_data = np.clip(agent_data, -1, 1)
-                agent_next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+                agent_next_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
 
             # Update replay buffer and train agent
             for a in range(env.get_num_agents()):
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index c5f0d5d..b3dd5ae 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -97,3 +97,13 @@ def split_tree(tree, num_features_per_node, current_depth=0):
             distance_data.extend(tmp_distance_data)
             agent_data.extend(tmp_agent_data)
     return tree_data, distance_data, agent_data
+
+
+def normalize_observation(observation, num_features_per_node=9, observation_radius=0):
+    data, distance, agent_data = split_tree(tree=np.array(observation), num_features_per_node=num_features_per_node,
+                                            current_depth=0)
+    data = norm_obs_clip(data, fixed_radius=observation_radius)
+    distance = norm_obs_clip(distance)
+    agent_data = np.clip(agent_data, -1, 1)
+    normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data))
+    return normalized_obs
-- 
GitLab