From da35282adcd7fdb8dcc735f41d50c5cd935c9e30 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Wed, 8 May 2019 08:27:29 +0100
Subject: [PATCH] added Pillow to requirements Added PIL graphics - 5-10 times
 faster than MPL move adaptColor into GraphicsLayer change scatter size arg to
 s like MPL move clf, get_cmap up into GraphicsLayer added boolean "train" arg
 to ddqn step() to make running faster in player / editor add plotTreeObs -
 not finished. improved pixel scaling in editor (I hope) Made the agent
 background steps update a progress bar.

---
 examples/play_model.py                   | 10 +++-
 flatland/baselines/dueling_double_dqn.py |  5 +-
 flatland/utils/editor.py                 | 25 ++++----
 flatland/utils/graphics_layer.py         | 20 +++++++
 flatland/utils/graphics_pil.py           | 66 +++++++++++++++++++++
 flatland/utils/render_qt.py              | 26 +-------
 flatland/utils/rendertools.py            | 75 +++++++++++++++++++-----
 requirements_dev.txt                     |  1 +
 8 files changed, 175 insertions(+), 53 deletions(-)
 create mode 100644 flatland/utils/graphics_pil.py

diff --git a/examples/play_model.py b/examples/play_model.py
index e6e81c9..a61954e 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -28,7 +28,8 @@ class Player(object):
         self.dones_list = []
         self.action_prob = [0]*4
         self.agent = Agent(self.state_size, self.action_size, "FC", 0)
-        self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
+        # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
+        self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
 
         self.iFrame = 0
         self.tStart = time.time()
@@ -48,12 +49,15 @@ class Player(object):
 
     def step(self):
         env = self.env
+
+        # Pass the (stored) observation to the agent network and retrieve the action
         for a in range(env.number_of_agents):
             action = self.agent.act(np.array(self.obs[a]), eps=self.eps)
             self.action_prob[action] += 1
             self.action_dict.update({a: action})
 
-        # Environment step
+        # Environment step - pass the agent actions to the environment,
+        # retrieve the response - observations, rewards, dones
         next_obs, all_rewards, done, _ = self.env.step(self.action_dict)
 
         for a in range(env.number_of_agents):
@@ -62,7 +66,7 @@ class Player(object):
 
         # Update replay buffer and train agent
         for a in range(self.env.number_of_agents):
-            self.agent.step(self.obs[a], self.action_dict[a], all_rewards[a], next_obs[a], done[a])
+            self.agent.step(self.obs[a], self.action_dict[a], all_rewards[a], next_obs[a], done[a], train=False)
             self.score += all_rewards[a]
 
         self.iFrame += 1
diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py
index ee75a61..66fe3a3 100644
--- a/flatland/baselines/dueling_double_dqn.py
+++ b/flatland/baselines/dueling_double_dqn.py
@@ -64,7 +64,7 @@ class Agent:
         if os.path.exists(filename + ".target"):
             self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
 
-    def step(self, state, action, reward, next_state, done):
+    def step(self, state, action, reward, next_state, done, train=True):
         # Save experience in replay memory
         self.memory.add(state, action, reward, next_state, done)
 
@@ -74,7 +74,8 @@ class Agent:
             # If enough samples are available in memory, get random subset and learn
             if len(self.memory) > BATCH_SIZE:
                 experiences = self.memory.sample()
-                self.learn(experiences, GAMMA)
+                if train:
+                    self.learn(experiences, GAMMA)
 
     def act(self, state, eps=0.):
         """Returns actions for given state as per current policy.
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index 543b793..f97cb6d 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -27,34 +27,34 @@ import jpy_canvas
 
 
 class EditorMVC(object):
-    def __init__(self, env=None):
-
+    def __init__(self, env=None, sGL="MPL"):
         if env is None:
             env = RailEnv(width=10,
                           height=10,
-                          rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
+                          rail_generator=random_rail_generator(),
                           number_of_agents=0,
                           obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
         env.reset()
 
         self.editor = EditorModel(env)
-        self.editor.view = self.view = View(self.editor)
+        self.editor.view = self.view = View(self.editor, sGL=sGL)
         self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view)
         self.view.init_canvas()
         self.view.init_widgets()   # has to be done after controller
 
 
 class View(object):
-    def __init__(self, editor):
+    def __init__(self, editor, sGL="MPL"):
         self.editor = self.model = editor
+        self.sGL = sGL
 
     def display(self):
         self.wOutput.clear_output()
         return self.wMain
 
     def init_canvas(self):
-        self.oRT = rt.RenderTool(self.editor.env)
+        self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL)
         plt.figure(figsize=(10, 10))
         self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False)
         img = self.oRT.getImage()
@@ -66,8 +66,10 @@ class View(object):
         self.wImage.register_click(self.controller.on_click)
 
         # TODO: These are currently estimated values
-        self.yxBase = array([6, 21])  # pixel offset
-        self.nPixCell = 700 / self.model.env.rail.width  # 35
+        # self.yxBase = array([6, 21])  # pixel offset
+        # self.nPixCell = 700 / self.model.env.rail.width  # 35
+        self.yxBase = self.oRT.gl.yxBase
+        self.nPixCell = self.oRT.gl.nPixCell
 
     def init_widgets(self):
         # Radiobutton for drawmode - TODO: replace with shift/ctrl/alt keys
@@ -151,7 +153,7 @@ class View(object):
     def drag_path_element(self, x, y):
         # Draw a black square on the in-memory copy of the image
         if x > 10 and x < self.yxSize[1] and y > 10 and y < self.yxSize[0]:
-            self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :] = 0
+            self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :3] = 0
 
     def xy_to_rc(self, x, y):
         rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int)
@@ -549,17 +551,18 @@ class EditorModel(object):
 
     def start_run(self):
         if self.thread is None:
-            self.thread = threading.Thread(target=self.bg_updater, args=())
+            self.thread = threading.Thread(target=self.bg_updater, args=(self.view.wProg_steps,))
             self.thread.start()
         else:
             self.log("thread already present")
 
-    def bg_updater(self):
+    def bg_updater(self, wProg_steps):
         try:
             for i in range(20):
                 # self.log("step ", i)
                 self.step()
                 time.sleep(0.2)
+                wProg_steps.value = i+1   # indicate progress on bar
         finally:
             self.thread = None
 
diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py
index aa9257b..6268e84 100644
--- a/flatland/utils/graphics_layer.py
+++ b/flatland/utils/graphics_layer.py
@@ -1,4 +1,7 @@
 
+import matplotlib.pyplot as plt
+from numpy import array
+
 
 class GraphicsLayer(object):
     def __init__(self):
@@ -33,3 +36,20 @@ class GraphicsLayer(object):
 
     def getImage(self):
         pass
+
+    def adaptColor(self, color):
+        if color == "red" or color == "r":
+            color = (255, 0, 0)
+        elif color == "gray":
+            color = (128, 128, 128)
+        elif type(color) is list:
+            color = tuple((array(color) * 255).astype(int))
+        elif type(color) is tuple:
+            gcolor = array(color)
+            color = tuple((gcolor[:3] * 255).astype(int))
+        else:
+            color = self.tColGrid
+        return color
+
+    def get_cmap(self, *args, **kwargs):
+        return plt.get_cmap(*args, **kwargs)
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
new file mode 100644
index 0000000..01cc5f0
--- /dev/null
+++ b/flatland/utils/graphics_pil.py
@@ -0,0 +1,66 @@
+
+from flatland.utils.graphics_layer import GraphicsLayer
+from PIL import Image, ImageDraw   # , ImageFont
+from numpy import array
+import numpy as np
+
+
+class PILGL(GraphicsLayer):
+    def __init__(self, width, height, nPixCell=60):
+        self.nPixCell = 60
+        self.yxBase = (0, 0)
+        self.linewidth = 4
+        # self.tile_size = self.nPixCell
+
+        self.width = width
+        self.height = height
+
+        # Total grid size at native scale
+        self.widthPx = self.width * self.nPixCell + self.linewidth
+        self.heightPx = self.height * self.nPixCell + self.linewidth
+        self.beginFrame()
+
+        self.tColBg = (255, 255, 255)     # white background
+        # self.tColBg = (220, 120, 40)    # background color
+        self.tColRail = (0, 0, 0)         # black rails
+        self.tColGrid = (230,) * 3        # light grey for grid
+
+    def plot(self, gX, gY, color=None, linewidth=3, **kwargs):
+        color = self.adaptColor(color)
+
+        # print(gX, gY)
+        gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell
+        gPoints = list(gPoints.ravel())
+        # print(gPoints, color)
+        self.draw.line(gPoints, fill=color, width=self.linewidth)
+
+    def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs):
+        color = self.adaptColor(color)
+        r = np.sqrt(s)
+        gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
+        for x, y in gPoints:
+            self.draw.rectangle([(x-r, y-r), (x+r, y+r)], fill=color, outline=color)
+
+    def text(self, *args, **kwargs):
+        pass
+
+    def prettify(self, *args, **kwargs):
+        pass
+
+    def prettify2(self, width, height, cell_size):
+        pass
+
+    def beginFrame(self):
+        self.img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, 255))
+        self.draw = ImageDraw.Draw(self.img)
+
+    def show(self, block=False):
+        pass
+        # plt.show(block=block)
+
+    def pause(self, seconds=0.00001):
+        pass
+        # plt.pause(seconds)
+
+    def getImage(self):
+        return array(self.img)
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index 0804c9d..29fdec5 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -1,7 +1,7 @@
 from flatland.utils.graphics_qt import QtRenderer
 from numpy import array
 from flatland.utils.graphics_layer import GraphicsLayer
-from matplotlib import pyplot as plt
+# from matplotlib import pyplot as plt
 import numpy as np
 
 
@@ -36,20 +36,6 @@ class QTGL(GraphicsLayer):
         self.qtr.pop()
         self.qtr.endFrame()
 
-    def adaptColor(self, color):
-        if color == "red" or color == "r":
-            color = (255, 0, 0)
-        elif color == "gray":
-            color = (128, 128, 128)
-        elif type(color) is list:
-            color = array(color) * 255
-        elif type(color) is tuple:
-            gcolor = array(color)
-            color = gcolor[:3] * 255
-        else:
-            color = self.tColGrid
-        return color
-
     def plot(self, gX, gY, color=None, linewidth=2, **kwargs):
         color = self.adaptColor(color)
 
@@ -70,11 +56,11 @@ class QTGL(GraphicsLayer):
             gPoints = np.stack([array(gX), -array(gY)]).T * self.cell_pixels
             self.qtr.drawPolyline(gPoints)
 
-    def scatter(self, gX, gY, color=None, marker="o", size=5, *args, **kwargs):
+    def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs):
         color = self.adaptColor(color)
         self.qtr.setColor(*color)
         self.qtr.setLineColor(*color)
-        r = np.sqrt(size)
+        r = np.sqrt(s)
         gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels
         for x, y in gPoints:
             self.qtr.drawCircle(x, y, r)
@@ -94,12 +80,6 @@ class QTGL(GraphicsLayer):
     def pause(self, seconds=0.00001):
         pass
 
-    def clf(self):
-        pass
-
-    def get_cmap(self, *args, **kwargs):
-        return plt.get_cmap(*args, **kwargs)
-
     def beginFrame(self):
         self.qtr.beginFrame()
         self.qtr.push()
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index e84ff46..c54b2fe 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -2,19 +2,23 @@ from recordtype import recordtype
 
 import numpy as np
 from numpy import array
-import xarray as xr
+# import xarray as xr
 import matplotlib.pyplot as plt
 import time
 from collections import deque
 from flatland.utils.render_qt import QTGL
+from flatland.utils.graphics_pil import PILGL
 from flatland.utils.graphics_layer import GraphicsLayer
 
-
 # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
 
 
 class MPLGL(GraphicsLayer):
-    def __init__(self):
+    def __init__(self, width, height):
+        self.width = width
+        self.height = height
+        self.yxBase = array([6, 21])  # pixel offset
+        self.nPixCell = 700 / width
         pass
 
     def plot(self, *args, **kwargs):
@@ -62,9 +66,12 @@ class MPLGL(GraphicsLayer):
         return plt.get_cmap(*args, **kwargs)
 
     def beginFrame(self):
+        # plt.figure(figsize=(10, 10))
         pass
 
     def endFrame(self):
+        # plt.clf()
+        # plt.close()
         pass
 
     def getImage(self):
@@ -83,19 +90,19 @@ class RenderTool(object):
     lColors = list("brgcmyk")
     # \delta RC for NESW
     gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
-    nPixCell = 1
+    nPixCell = 1   # misnomer...
     nPixHalf = nPixCell / 2
     xyHalf = array([nPixHalf, -nPixHalf])
     grc2xy = array([[0, -nPixCell], [nPixCell, 0]])
     gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * \
         array([[[nPixCell]], [[nPixCell]]])
-    xyPixHalf = xr.DataArray([nPixHalf, -nPixHalf],
-                             dims="xy",
-                             coords={"xy": ["x", "y"]})
-    gCentres = xr.DataArray(gGrid,
-                            dims=["xy", "p1", "p2"],
-                            coords={"xy": ["x", "y"]}) + xyPixHalf
-    gTheta = np.linspace(0, np.pi / 2, 10)
+    # xyPixHalf = xr.DataArray([nPixHalf, -nPixHalf],
+    #                         dims="xy",
+    #                         coords={"xy": ["x", "y"]})
+    # gCentres = xr.DataArray(gGrid,
+    #                        dims=["xy", "p1", "p2"],
+    #                        coords={"xy": ["x", "y"]}) + xyPixHalf
+    gTheta = np.linspace(0, np.pi / 2, 5)
     gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]
 
     def __init__(self, env, gl="MPL"):
@@ -105,7 +112,12 @@ class RenderTool(object):
         self.lTimes = deque()
         # self.gl = MPLGL()
 
-        self.gl = MPLGL() if gl == "MPL" else QTGL(env.width, env.height)
+        if gl == "MPL":
+            self.gl = MPLGL(env.width, env.height)
+        elif gl == "QT":
+            self.gl = QTGL(env.width, env.height)
+        elif gl == "PIL":
+            self.gl = PILGL(env.width, env.height)
 
     def plotTreeOnRail(self, lVisits, color="r"):
         """
@@ -489,16 +501,18 @@ class RenderTool(object):
 
         env = self.env
 
+        t1 = time.time()
+
         # Draw cells grid
         grid_color = [0.95, 0.95, 0.95]
         for r in range(env.height + 1):
             self.gl.plot([0, (env.width + 1) * cell_size],
                          [-r * cell_size, -r * cell_size],
-                         color=grid_color)
+                         color=grid_color, linewidth=2)
         for c in range(env.width + 1):
             self.gl.plot([c * cell_size, c * cell_size],
                          [0, -(env.height + 1) * cell_size],
-                         color=grid_color)
+                         color=grid_color, linewidth=2)
 
         # Draw each cell independently
         for r in range(env.height):
@@ -644,6 +658,9 @@ class RenderTool(object):
 
         self.gl.endFrame()
 
+        t2 = time.time()
+        print(t2 - t1, "seconds")
+
         if show:
             self.gl.show(block=False)
             self.gl.pause(0.00001)
@@ -659,3 +676,33 @@ class RenderTool(object):
 
     def getImage(self):
         return self.gl.getImage()
+
+    def plotTreeObs(self, gObs):
+        nBranchFactor = 4
+
+        gP0 = array([[0, 0, 0]]).T
+        nDepth = 2
+        for i in range(nDepth):
+            nDepthNodes = nBranchFactor**i
+            # rScale = nBranchFactor ** (nDepth - i)
+            rShrinkDepth = 1/(i+1)
+            # gX1 = np.linspace(-nDepthNodes / 2, nDepthNodes / 2, nDepthNodes) * rShrinkDepth
+            
+            gX1 = np.linspace(-(nDepthNodes-1), (nDepthNodes-1), nDepthNodes) * rShrinkDepth
+            gY1 = np.ones((nDepthNodes)) * i
+            gZ1 = np.zeros((nDepthNodes))
+            
+            gP1 = array([gX1, gY1, gZ1])
+            gP01 = np.append(gP0, gP1, axis=1)
+            
+            if nDepthNodes > 1:
+                nDepthNodesPrev = nDepthNodes / nBranchFactor
+                giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor)
+                giP1 = np.arange(0, nDepthNodes) + nDepthNodesPrev
+                giLinePoints = np.stack([giP0, giP1]).ravel("F")
+                # print(gP01[:,:10])
+                print(giLinePoints)
+                self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray")
+
+            gP0 = array([gX1, gY1, gZ1])
+            
\ No newline at end of file
diff --git a/requirements_dev.txt b/requirements_dev.txt
index 0bc267e..40a6b7f 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -16,4 +16,5 @@ recordtype==1.3
 xarray==0.11.3
 matplotlib==3.0.2
 PyQt5==5.12
+Pillow==5.4.1
 
-- 
GitLab