From 53c34f901e571f13cb00843e876f2af46f2031bc Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 16 May 2019 18:25:44 +0200
Subject: [PATCH] fixed bugs in get() funciton in observation_builder

---
 examples/training_navigation.py | 16 ++++++++--------
 flatland/envs/observations.py   |  3 ++-
 flatland/envs/rail_env.py       |  4 ++--
 3 files changed, 12 insertions(+), 11 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 4a21ad11..ec19ff20 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -24,18 +24,17 @@ transition_probability = [15,  # empty cell - Case 0
                           1]  # Case 2b (10) - simple switch mirrored
 
 # Example generate a random rail
-
+"""
 env = RailEnv(width=10,
               height=10,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
-              number_of_agents=1)
+              number_of_agents=5)
 """
 env = RailEnv(width=15,
               height=15,
-              rail_generator=complex_rail_generator(nr_start_goal=10, min_dist=5, max_dist=99999, seed=0),
+              rail_generator=complex_rail_generator(nr_start_goal=3, min_dist=5, max_dist=99999, seed=0),
               number_of_agents=3)
 """
-"""
 env = RailEnv(width=20,
               height=20,
               rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
@@ -117,12 +116,14 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs, dev_obs = env.reset()
-    env.dev_obs_dict = dev_obs
+    obs = env.reset()
+
     final_obs = obs.copy()
     final_obs_next = obs.copy()
+
     for a in range(env.get_num_agents()):
         data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
+
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
         obs[a] = np.concatenate((data, distance))
@@ -150,8 +151,7 @@ for trials in range(1, n_trials + 1):
             action_dict.update({a: action})
 
         # Environment step
-        (next_obs, dev_obs), all_rewards, done, _ = env.step(action_dict)
-        env.dev_obs_dict = dev_obs
+        next_obs, all_rewards, done, _ = env.step(action_dict)
         for a in range(env.get_num_agents()):
             data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
                                                         current_depth=0)
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index dba5faf9..b097b5df 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -229,7 +229,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                     num_cells_to_fill_in += pow4
                     pow4 *= 4
                 observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
-        return observation, visited
+        self.env.dev_obs_dict[handle] = visited
+        return observation
 
     def _explore_branch(self, handle, position, direction, root_observation, depth):
         """
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 640eb5c5..118ebf4d 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -318,8 +318,8 @@ class RailEnv(Environment):
         self.debug_obs_dict = {}
         # for handle in self.agents_handles:
         for iAgent in range(self.get_num_agents()):
-            self.obs_dict[iAgent], self.debug_obs_dict[iAgent] = self.obs_builder.get(iAgent)
-        return self.obs_dict, self.debug_obs_dict
+            self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
+        return self.obs_dict
 
     def render(self):
         # TODO:
-- 
GitLab