From ed720ea23ae4edafc523b10221b9912f4cba89b8 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 16 May 2019 19:06:20 +0200
Subject: [PATCH] minor changes to rendering of observation

---
 examples/training_navigation.py | 14 ++++++++------
 flatland/envs/generators.py     |  2 +-
 flatland/utils/rendertools.py   | 11 ++++++-----
 3 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index ec19ff2..cabb655 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -5,7 +5,7 @@ from flatland.utils.rendertools import *
 from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
 import torch, random
-
+import time
 random.seed(1)
 np.random.seed(1)
 
@@ -25,15 +25,16 @@ transition_probability = [15,  # empty cell - Case 0
 
 # Example generate a random rail
 """
-env = RailEnv(width=10,
-              height=10,
+env = RailEnv(width=20,
+              height=20,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
-              number_of_agents=5)
+              number_of_agents=1)
 """
 env = RailEnv(width=15,
               height=15,
-              rail_generator=complex_rail_generator(nr_start_goal=3, min_dist=5, max_dist=99999, seed=0),
+              rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=30, min_dist=5, max_dist=99999, seed=0),
               number_of_agents=3)
+
 """
 env = RailEnv(width=20,
               height=20,
@@ -139,7 +140,8 @@ for trials in range(1, n_trials + 1):
     # Run episode
     for step in range(100):
         if demo:
-            env_renderer.renderEnv(show=True)
+            env_renderer.renderEnv(show=True, obsrender=True)
+            time.sleep(2)
         # print(step)
         # Action
         for a in range(env.get_num_agents()):
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 04e9a8f..c1578a8 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -9,7 +9,7 @@ from flatland.envs.env_utils import distance_on_rail, connect_rail, get_directio
 from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail
 
 
-def complex_rail_generator(nr_start_goal=1, nr_extra=10, min_dist=2, max_dist=99999, seed=0):
+def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0):
     """
     Parameters
     -------
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 09a9f10..34f3e9f 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -472,12 +472,13 @@ class RenderTool(object):
                     xyMid + [-dx + dy, -dx - dy]])
                 self.gl.plot(*xyArrow.T, color=sColor)
 
-    def renderObs(self, agent_handles, observation_list):
+    def renderObs(self, agent_handles, observation_dict):
         """
+        Render the extent of the observation of each agent. All cells that appear in the agent obsrevation will be
+        highlighted.
+        :param agent_handles: List of agent indices to adapt color and get correct observation
+        :param observation_dict: dictionary containing sets of cells of the agent observation
 
-        :param agent_handles:
-        :param observation_list:
-        :return:
         """
         rt = self.__class__
 
@@ -485,7 +486,7 @@ class RenderTool(object):
 
         for agent in agent_handles:
             color = cmap(agent)
-            for visited_cell in observation_list[agent]:
+            for visited_cell in observation_dict[agent]:
                 cell_coord = array(visited_cell[:2])
                 cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
                 self._draw_square(cell_coord_trans, 1 / 3, color)
-- 
GitLab