From da1b4bad334596b7fbb9d2f10db27e9b298fdb36 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 16 May 2019 17:49:38 +0200
Subject: [PATCH] added initial tree rendering updated tests after failing

---
 examples/training_navigation.py | 10 ++++++----
 flatland/envs/rail_env.py       |  2 +-
 flatland/utils/rendertools.py   | 18 ++++++++++++++++++
 tests/test_rendertools.py       |  6 +++---
 4 files changed, 28 insertions(+), 8 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index a7920c7d..4a21ad11 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -24,7 +24,7 @@ transition_probability = [15,  # empty cell - Case 0
                           1]  # Case 2b (10) - simple switch mirrored
 
 # Example generate a random rail
-"""
+
 env = RailEnv(width=10,
               height=10,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
@@ -35,6 +35,7 @@ env = RailEnv(width=15,
               rail_generator=complex_rail_generator(nr_start_goal=10, min_dist=5, max_dist=99999, seed=0),
               number_of_agents=3)
 """
+"""
 env = RailEnv(width=20,
               height=20,
               rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
@@ -116,7 +117,8 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs, _ = env.reset()
+    obs, dev_obs = env.reset()
+    env.dev_obs_dict = dev_obs
     final_obs = obs.copy()
     final_obs_next = obs.copy()
     for a in range(env.get_num_agents()):
@@ -148,8 +150,8 @@ for trials in range(1, n_trials + 1):
             action_dict.update({a: action})
 
         # Environment step
-        (next_obs,_), all_rewards, done, _ = env.step(action_dict)
-
+        (next_obs, dev_obs), all_rewards, done, _ = env.step(action_dict)
+        env.dev_obs_dict = dev_obs
         for a in range(env.get_num_agents()):
             data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
                                                         current_depth=0)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 749e5e01..640eb5c5 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -98,7 +98,7 @@ class RailEnv(Environment):
 
         self.obs_dict = {}
         self.rewards_dict = {}
-
+        self.dev_obs_dict = {}
         # self.agents_handles = list(range(self.number_of_agents))
 
         # self.agents_position = []
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 4921def4..1bd67c22 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -471,6 +471,23 @@ class RenderTool(object):
                     xyMid,
                     xyMid + [-dx + dy, -dx - dy]])
                 self.gl.plot(*xyArrow.T, color=sColor)
+    def renderObs(self, agent_handles, observation_list):
+        """
+
+        :param agent_handles:
+        :param observation_list:
+        :return:
+        """
+        rt = self.__class__
+
+        cmap = self.gl.get_cmap('hsv',lut=max(len(self.env.agents),len(self.env.agents_static)+1))
+
+        for agent in agent_handles:
+            color = cmap(agent)
+            for visited_cell in observation_list[agent]:
+                cell_coord = array(visited_cell[:2])
+                cell_coord_trans = np.matmul(cell_coord,rt.grc2xy)+rt.xyHalf
+                self._draw_square(cell_coord_trans,1 / 3, color)
 
     def renderEnv(
             self, show=False, curves=True, spacing=False,
@@ -612,6 +629,7 @@ class RenderTool(object):
         if agents:
             self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
 
+        self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
         # Draw some textual information like fps
         yText = [-0.3, -0.6, -0.9]
         if frames:
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index f6defb2d..245f2f32 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -6,10 +6,10 @@ Tests for `flatland` package.
 
 from flatland.envs.rail_env import RailEnv, random_rail_generator
 import numpy as np
-<<<<<<< HEAD
-=======
+#<<<<<<< HEAD
+#=======
 # import os
->>>>>>> dc2fa1ee0244b15c76d89ab768c5e1bbd2716147
+#>>>>>>> dc2fa1ee0244b15c76d89ab768c5e1bbd2716147
 import sys
 
 import matplotlib.pyplot as plt
-- 
GitLab