From 9c4beda08efeae698461d00559b9de2ac857761c Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Wed, 24 Apr 2019 11:53:14 +0100
Subject: [PATCH] merged; changed to plotAgent with circle  line; added delay
 and render flag to play

---
 examples/play_model.py        | 39 +++++++++++++++++++--------------
 flatland/utils/graphics_qt.py |  2 +-
 flatland/utils/render_qt.py   | 41 ++++++++++++++++++++++++-----------
 flatland/utils/rendertools.py | 39 ++++++++++++++++++---------------
 4 files changed, 73 insertions(+), 48 deletions(-)

diff --git a/examples/play_model.py b/examples/play_model.py
index d54decd8..4fe40f9a 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -1,17 +1,16 @@
 from flatland.envs.rail_env import RailEnv, random_rail_generator
 # from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import RenderTool
-from flatland.utils.render_qt import QtRailRender
 from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
 import torch
 import random
 import numpy as np
 import matplotlib.pyplot as plt
-import redis
+import time
 
 
-def main():
+def main(render=True, delay=2):
 
     random.seed(1)
     np.random.seed(1)
@@ -32,8 +31,9 @@ def main():
                 height=7,
                 rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
                 number_of_agents=1)
-    env_renderer = RenderTool(env, gl="QT")
-    #env_renderer = QtRailRender(env)
+
+    if render:
+        env_renderer = RenderTool(env, gl="QT")
     plt.figure(figsize=(5,5))
     # fRedis = redis.Redis()
 
@@ -67,6 +67,8 @@ def main():
             idx -= 1
         return None
 
+    iFrame = 0
+    tStart = time.time()
     for trials in range(1, n_trials + 1):
 
         # Reset environment
@@ -102,7 +104,13 @@ def main():
                 agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
                 score += all_rewards[a]
 
-            env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
+            if render:
+                env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
+                if delay > 0:
+                    time.sleep(delay)
+                    
+            iFrame += 1
+
 
             obs = next_obs.copy()
             if done['__all__']:
@@ -116,8 +124,8 @@ def main():
         scores.append(np.mean(scores_window))
         dones_list.append((np.mean(done_window)))
 
-        print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
-            '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+        print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
+                '\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format(
                 env.number_of_agents,
                 trials,
                 np.mean(scores_window),
@@ -125,16 +133,15 @@ def main():
                 eps, action_prob/np.sum(action_prob)),
             end=" ")
         if trials % 100 == 0:
-
-            print(
-                '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+            tNow = time.time()
+            rFps = iFrame / (tNow - tStart)
+            print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + 
+                    '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format(
                     env.number_of_agents,
                     trials,
-                    np.mean(
-                        scores_window),
-                    100 * np.mean(
-                        done_window),
-                    eps, action_prob / np.sum(action_prob)))
+                    np.mean(scores_window),
+                    100 * np.mean(done_window),
+                    eps, rFps, action_prob / np.sum(action_prob)))
             torch.save(agent.qnetwork_local.state_dict(),
                     '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
             action_prob = [1]*4
diff --git a/flatland/utils/graphics_qt.py b/flatland/utils/graphics_qt.py
index 6571f4d7..a4abb578 100644
--- a/flatland/utils/graphics_qt.py
+++ b/flatland/utils/graphics_qt.py
@@ -123,7 +123,7 @@ class QtRenderer(object):
 
     def beginFrame(self):
         self.painter.begin(self.img)
-        self.painter.setRenderHint(QPainter.Antialiasing, False)
+        # self.painter.setRenderHint(QPainter.Antialiasing, False)
 
         # Clear the background
         self.painter.setBrush(QColor(0, 0, 0))
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index 60b8d295..94439b2b 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -2,6 +2,7 @@ from flatland.utils.graphics_qt import QtRenderer
 from numpy import array
 from flatland.utils.graphics_layer import GraphicsLayer
 from matplotlib import pyplot as plt
+import numpy as np
 
 
 class QTGL(GraphicsLayer):
@@ -35,8 +36,7 @@ class QTGL(GraphicsLayer):
         self.qtr.pop()
         self.qtr.endFrame()
 
-    def plot(self, gX, gY, color=None, linewidth=2, **kwargs):
-
+    def adaptColor(self, color):
         if color == "red" or color == "r":
             color = (255, 0, 0)
         elif color == "gray":
@@ -48,20 +48,35 @@ class QTGL(GraphicsLayer):
             color = gcolor[:3] * 255
         else:
             color = self.tColGrid
+        return color
+
+    def plot(self, gX, gY, color=None, linewidth=2, **kwargs):
+        color = self.adaptColor(color)
 
         self.qtr.setLineColor(*color)
         lastx = lasty = None
-        for x, y in zip(gX, gY):
-            if lastx is not None:
-                # print("line", lastx, lasty, x, y)
-                self.qtr.drawLine(
-                    lastx*self.cell_pixels, -lasty*self.cell_pixels,
-                    x*self.cell_pixels, -y*self.cell_pixels)
-            lastx = x
-            lasty = y
-
-    def scatter(self, *args, **kwargs):
-        print("scatter not yet implemented in ", self.__class__)
+
+        if False:
+            for x, y in zip(gX, gY):
+                if lastx is not None:
+                    # print("line", lastx, lasty, x, y)
+                    self.qtr.drawLine(
+                        lastx*self.cell_pixels, -lasty*self.cell_pixels,
+                        x*self.cell_pixels, -y*self.cell_pixels)
+                lastx = x
+                lasty = y
+        else:
+            # print(gX, gY)
+            gPoints = np.stack([array(gX), -array(gY)]).T * self.cell_pixels
+            self.qtr.drawPolyline(gPoints)
+
+    def scatter(self, gX, gY, color=None, marker="o", size=5, *args, **kwargs):
+        color = self.adaptColor(color)
+        self.qtr.setColor(*color)
+        r = np.sqrt(size)
+        gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels
+        for x, y in gPoints:
+            self.qtr.drawCircle(x, y, r)
 
     def text(self, x, y, sText):
         self.qtr.drawText(x*self.cell_pixels, -y*self.cell_pixels, sText)
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index f9ebf532..278a08f6 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -116,8 +116,8 @@ class RenderTool(object):
 
                 self.plotAgent(rcPos, iDir, sColor)
 
-                gTransRCAg = self.getTransRC(rcPos, iDir)
-                self.plotTrans(rcPos, gTransRCAg, color=color)
+                # gTransRCAg = self.getTransRC(rcPos, iDir)
+                # self.plotTrans(rcPos, gTransRCAg, color=color)
 
                 if False:
                     # TODO: this was `rcDir' but it was undefined
@@ -135,20 +135,17 @@ class RenderTool(object):
             self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
 
     def plotAgents(self):
-        rt = self.__class__
-
-        # plt.scatter(*rt.gCentres, s=5, color="r")
-
+        cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1)
         for iAgent in range(self.env.number_of_agents):
-            sColor = rt.lColors[iAgent]
+            oColor = cmap(iAgent)
 
             rcPos = self.env.agents_position[iAgent]
             iDir = self.env.agents_direction[iAgent]  # agent direction index
 
-            self.plotAgent(rcPos, iDir, sColor)
+            self.plotAgent(rcPos, iDir, oColor)
 
-            gTransRCAg = self.getTransRC(rcPos, iDir)
-            self.plotTrans(rcPos, gTransRCAg)
+            # gTransRCAg = self.getTransRC(rcPos, iDir)
+            # self.plotTrans(rcPos, gTransRCAg)
 
     def getTransRC(self, rcPos, iDir, bgiTrans=False):
         """
@@ -189,21 +186,24 @@ class RenderTool(object):
     def plotAgent(self, rcPos, iDir, sColor="r"):
         """
         Plot a simple agent.
-        Assumes a working matplotlib context.
+        Assumes a working graphics layer context (cf a MPL figure).
         """
         rt = self.__class__
-        xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
-        self.gl.scatter(*xyPos, color=sColor)            # agent location
 
         rcDir = rt.gTransRC[iDir]                    # agent direction in RC
         xyDir = np.matmul(rcDir, rt.grc2xy)          # agent direction in xy
-        xyDirLine = array([xyPos, xyPos+xyDir/2]).T  # line for agent orient.
+
+        xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
+        self.gl.scatter(*xyPos, color=sColor, size=10)            # agent location
+
+        xyDirLine = array([xyPos, xyPos + xyDir/2]).T  # line for agent orient.
         self.gl.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6)
 
-        # just mark the next cell we're heading into
-        rcNext = rcPos + rcDir
-        xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf
-        self.gl.scatter(*xyNext, color=sColor)
+        if False:
+            # mark the next cell we're heading into
+            rcNext = rcPos + rcDir
+            xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf
+            self.gl.scatter(*xyNext, color=sColor)
 
     def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
         """
@@ -571,6 +571,9 @@ class RenderTool(object):
         # Draw each agent + its orientation + its target
         if agents:
             cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1)
+            self.plotAgents()
+
+        if False:
             for i in range(env.number_of_agents):
                 self._draw_square((
                                 env.agents_position[i][1] *
-- 
GitLab