From 21dba5dafe233949a4fa972f4626e9c6ab5fe31e Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Wed, 3 Jul 2019 12:56:39 -0400
Subject: [PATCH] added utils repo, moved normalization into utils

---
 torch_training/training_navigation.py | 55 ++---------------
 utils/__init__.py                     |  0
 utils/observation_utils.py            | 85 +++++++++++++++++++++++++++
 3 files changed, 91 insertions(+), 49 deletions(-)
 create mode 100644 utils/__init__.py
 create mode 100644 utils/observation_utils.py

diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index d085a8b..a6ee613 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -5,12 +5,13 @@ import matplotlib.pyplot as plt
 import numpy as np
 import torch
 from dueling_double_dqn import Agent
-from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
+from utils.observation_utils import norm_obs_clip, split_tree
+
 random.seed(1)
 np.random.seed(1)
 
@@ -39,7 +40,6 @@ env = RailEnv(width=15,
               height=15,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
               number_of_agents=1)
-
 """
 env = RailEnv(width=10,
               height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
@@ -54,6 +54,7 @@ env = RailEnv(width=20,
               number_of_agents=15)
 file_load = False
 env.reset(True, True)
+
 """
 env_renderer = RenderTool(env, gl="PILSVG",)
 handle = env.get_agent_handles()
@@ -81,51 +82,7 @@ agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pt
 demo = True
 record_images = True
 
-def max_lt(seq, val):
-    """
-    Return greatest item in seq for which item < val applies.
-    None is returned if seq was empty or all items in seq were >= val.
-    """
-    max = 0
-    idx = len(seq) - 1
-    while idx >= 0:
-        if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
-            max = seq[idx]
-        idx -= 1
-    return max
-
-
-def min_lt(seq, val):
-    """
-    Return smallest item in seq for which item > val applies.
-    None is returned if seq was empty or all items in seq were >= val.
-    """
-    min = np.inf
-    idx = len(seq) - 1
-    while idx >= 0:
-        if seq[idx] >= val and seq[idx] < min:
-            min = seq[idx]
-        idx -= 1
-    return min
-
-
-def norm_obs_clip(obs, clip_min=-1, clip_max=1):
-    """
-    This function returns the difference between min and max value of an observation
-    :param obs: Observation that should be normalized
-    :param clip_min: min value where observation will be clipped
-    :param clip_max: max value where observation will be clipped
-    :return: returnes normalized and clipped observatoin
-    """
-    max_obs = max(1, max_lt(obs, 1000))
-    min_obs = min(max_obs, min_lt(obs, 0))
-
-    if max_obs == min_obs:
-        return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
-    norm = np.abs(max_obs - min_obs)
-    if norm == 0:
-        norm = 1.
-    return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
+
 
 
 for trials in range(1, n_trials + 1):
@@ -141,8 +98,8 @@ for trials in range(1, n_trials + 1):
     final_obs_next = obs.copy()
     for a in range(env.get_num_agents()):
         print(a)
-        data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=8,
-                                                                current_depth=0)
+        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=8,
+                                                current_depth=0)
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
         agent_data = np.clip(agent_data, -1, 1)
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
new file mode 100644
index 0000000..75e6b0a
--- /dev/null
+++ b/utils/observation_utils.py
@@ -0,0 +1,85 @@
+import numpy as np
+
+
+def max_lt(seq, val):
+    """
+    Return greatest item in seq for which item < val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    max = 0
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
+            max = seq[idx]
+        idx -= 1
+    return max
+
+
+def min_lt(seq, val):
+    """
+    Return smallest item in seq for which item > val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    min = np.inf
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] >= val and seq[idx] < min:
+            min = seq[idx]
+        idx -= 1
+    return min
+
+
+def norm_obs_clip(obs, clip_min=-1, clip_max=1):
+    """
+    This function returns the difference between min and max value of an observation
+    :param obs: Observation that should be normalized
+    :param clip_min: min value where observation will be clipped
+    :param clip_max: max value where observation will be clipped
+    :return: returnes normalized and clipped observatoin
+    """
+    max_obs = max(1, max_lt(obs, 1000))
+    min_obs = min(max_obs, min_lt(obs, 0))
+
+    if max_obs == min_obs:
+        return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
+    norm = np.abs(max_obs - min_obs)
+    if norm == 0:
+        norm = 1.
+    return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
+
+
+def split_tree(tree, num_features_per_node=8, current_depth=0):
+    """
+
+    :param tree:
+    :param num_features_per_node:
+    :param prompt:
+    :param current_depth:
+    :return:
+    """
+
+    if len(tree) < num_features_per_node:
+        return [], [], []
+
+    depth = 0
+    tmp = len(tree) / num_features_per_node - 1
+    pow4 = 4
+    while tmp > 0:
+        tmp -= pow4
+        depth += 1
+        pow4 *= 4
+    child_size = (len(tree) - num_features_per_node) // 4
+    tree_data = tree[:4].tolist()
+    distance_data = [tree[4]]
+    agent_data = tree[5:num_features_per_node].tolist()
+    for children in range(4):
+        child_tree = tree[(num_features_per_node + children * child_size):
+                          (num_features_per_node + (children + 1) * child_size)]
+        tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree,
+                                                                      num_features_per_node,
+                                                                      current_depth=current_depth + 1)
+        if len(tmp_tree_data) > 0:
+            tree_data.extend(tmp_tree_data)
+            distance_data.extend(tmp_distance_data)
+            agent_data.extend(tmp_agent_data)
+    return tree_data, distance_data, agent_data
-- 
GitLab