From fcc1ee6afcca9fbfc5c430fc40b2124621703222 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sat, 5 Oct 2019 09:13:16 -0400
Subject: [PATCH] minor bugfixes

---
 torch_training/training_navigation.py | 19 ++++++-------------
 utils/observation_utils.py            |  2 +-
 2 files changed, 7 insertions(+), 14 deletions(-)

diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index bd221ae..ad512c6 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -59,19 +59,12 @@ def main(argv):
 
     env = RailEnv(width=x_dim,
                   height=y_dim,
-                  rail_generator=sparse_rail_generator(num_cities=5,
+                  rail_generator=sparse_rail_generator(max_num_cities=3,
                                                        # Number of cities in map (where train stations are)
-                                                       num_intersections=4,
-                                                       # Number of intersections (no start / target)
-                                                       num_trainstations=10,  # Number of possible start/targets on map
-                                                       min_node_dist=3,  # Minimal distance of nodes
-                                                       node_radius=2,  # Proximity of stations to city center
-                                                       num_neighb=3,
-                                                       # Number of connections to other cities/intersections
-                                                       seed=15,  # Random seed
-                                                       grid_mode=True,
-                                                       enhance_intersection=False
-                                                       ),
+                                                       seed=1,  # Random seed
+                                                       grid_mode=False,
+                                                       max_rails_between_cities=2,
+                                                       max_rails_in_city=2),
                   schedule_generator=sparse_schedule_generator(speed_ration_map),
                   number_of_agents=n_agents,
                   stochastic_data=stochastic_data,  # Malfunction data generator
@@ -129,7 +122,7 @@ def main(argv):
 
         # Build agent specific observations
         for a in range(env.get_num_agents()):
-            agent_obs[a] = normalize_observation(obs[a], observation_radius=10)
+            agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
             agent_obs_buffer[a] = agent_obs[a].copy()
 
         # Reset score and done
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index f9661f5..e9eb3ed 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -104,7 +104,7 @@ def split_tree_into_feature_groups(tree: TreeObsForRailEnv.Node, max_tree_depth:
     """
     data, distance, agent_data = _split_node_into_feature_groups(tree)
 
-    for direction in TreeObsForRailEnv.tree_explorted_actions_char:
+    for direction in TreeObsForRailEnv.tree_explored_actions_char:
         sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, max_tree_depth)
         data = np.concatenate((data, sub_data))
         distance = np.concatenate((distance, sub_distance))
-- 
GitLab