From c361c45bf229b92071e5202e2283471a7169c7dc Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Wed, 24 Apr 2019 13:21:46 +0100
Subject: [PATCH] plot targets.  play with larger grid 15x15, 5 agents.
 delay=0.

---
 examples/play_model.py        | 12 ++++++------
 flatland/utils/render_qt.py   |  1 +
 flatland/utils/rendertools.py | 23 ++++++++++++++++-------
 3 files changed, 23 insertions(+), 13 deletions(-)

diff --git a/examples/play_model.py b/examples/play_model.py
index 4fe40f9..6a67397 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
 import time
 
 
-def main(render=True, delay=2):
+def main(render=True, delay=0.0):
 
     random.seed(1)
     np.random.seed(1)
@@ -27,10 +27,10 @@ def main(render=True, delay=2):
                             0.0]  # Case 7 - dead end
 
     # Example generate a random rail
-    env = RailEnv(width=7,
-                height=7,
+    env = RailEnv(width=15,
+                height=15,
                 rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
-                number_of_agents=1)
+                number_of_agents=5)
 
     if render:
         env_renderer = RenderTool(env, gl="QT")
@@ -52,7 +52,7 @@ def main(render=True, delay=2):
     dones_list = []
     action_prob = [0]*4
     agent = Agent(state_size, action_size, "FC", 0)
-    agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
+    # agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
 
     def max_lt(seq, val):
         """
@@ -108,7 +108,7 @@ def main(render=True, delay=2):
                 env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
                 if delay > 0:
                     time.sleep(delay)
-                    
+
             iFrame += 1
 
 
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index 94439b2..e2334d3 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -73,6 +73,7 @@ class QTGL(GraphicsLayer):
     def scatter(self, gX, gY, color=None, marker="o", size=5, *args, **kwargs):
         color = self.adaptColor(color)
         self.qtr.setColor(*color)
+        self.qtr.setLineColor(*color)
         r = np.sqrt(size)
         gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels
         for x, y in gPoints:
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 278a08f..985f2b6 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -134,7 +134,7 @@ class RenderTool(object):
             gTransRCAg = rt.gTransRC[giTrans]
             self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
 
-    def plotAgents(self):
+    def plotAgents(self, targets=True):
         cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1)
         for iAgent in range(self.env.number_of_agents):
             oColor = cmap(iAgent)
@@ -142,7 +142,11 @@ class RenderTool(object):
             rcPos = self.env.agents_position[iAgent]
             iDir = self.env.agents_direction[iAgent]  # agent direction index
 
-            self.plotAgent(rcPos, iDir, oColor)
+            if targets:
+                target = self.env.agents_target[iAgent]
+            else:
+                target = None
+            self.plotAgent(rcPos, iDir, oColor, target=target)
 
             # gTransRCAg = self.getTransRC(rcPos, iDir)
             # self.plotTrans(rcPos, gTransRCAg)
@@ -183,7 +187,7 @@ class RenderTool(object):
         else:
             return gTransRCAg
 
-    def plotAgent(self, rcPos, iDir, sColor="r"):
+    def plotAgent(self, rcPos, iDir, color="r", target=None):
         """
         Plot a simple agent.
         Assumes a working graphics layer context (cf a MPL figure).
@@ -194,16 +198,21 @@ class RenderTool(object):
         xyDir = np.matmul(rcDir, rt.grc2xy)          # agent direction in xy
 
         xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
-        self.gl.scatter(*xyPos, color=sColor, size=10)            # agent location
+        self.gl.scatter(*xyPos, color=color, size=40)            # agent location
 
         xyDirLine = array([xyPos, xyPos + xyDir/2]).T  # line for agent orient.
-        self.gl.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6)
+        self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
+
+        if target is not None:
+            rcTarget = array(target)
+            xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
+            self._draw_square(xyTarget, 1/3, color)
 
         if False:
             # mark the next cell we're heading into
             rcNext = rcPos + rcDir
             xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf
-            self.gl.scatter(*xyNext, color=sColor)
+            self.gl.scatter(*xyNext, color=color)
 
     def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
         """
@@ -571,7 +580,7 @@ class RenderTool(object):
         # Draw each agent + its orientation + its target
         if agents:
             cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1)
-            self.plotAgents()
+            self.plotAgents(targets=True)
 
         if False:
             for i in range(env.number_of_agents):
-- 
GitLab