diff --git a/examples/play_model.py b/examples/play_model.py
index 7d7ed1104e7689e4ff49d993476a6f1fc6d4b6e8..08dadeb1d3d1e64b012c00e99594bcefd2905d1c 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -107,7 +107,7 @@ def max_lt(seq, val):
     return None
 
 
-def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PIL"):
+def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PILSVG"):
     random.seed(1)
     np.random.seed(1)
 
@@ -277,4 +277,4 @@ def main_old(render=True, delay=0.0):
 
 
 if __name__ == "__main__":
-    main(render=True, delay=0)
+    main(render=True, delay=0.5)
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 949628fcd093c404ffef8d313dbeea4379d86708..25ee27e56420f55308ca9a7047737219337833af 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -4,6 +4,12 @@ from PIL import Image, ImageDraw, ImageTk   # , ImageFont
 import tkinter as tk
 from numpy import array
 import numpy as np
+# from flatland.utils.svg import Track, Zug
+import time
+import io
+from cairosvg import svg2png
+from flatland.core.transitions import RailEnvTransitions
+# from copy import copy
 
 
 class PILGL(GraphicsLayer):
@@ -11,6 +17,7 @@ class PILGL(GraphicsLayer):
         self.nPixCell = 60
         self.yxBase = (0, 0)
         self.linewidth = 4
+        self.nAgentColors = 1  # overridden in loadAgent
         # self.tile_size = self.nPixCell
 
         self.width = width
@@ -30,6 +37,7 @@ class PILGL(GraphicsLayer):
         self.window_open = False
         # self.bShow = show
         self.firstFrame = True
+        self.create_layers()
         self.beginFrame()
 
     def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
@@ -49,6 +57,20 @@ class PILGL(GraphicsLayer):
         for x, y in gPoints:
             self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
 
+    def drawImageXY(self, pil_img, xyPixLeftTop, layer=0):
+        # self.layers[layer].alpha_composite(pil_img, offset=xyPixLeftTop)
+        if (pil_img.mode == "RGBA"): 
+            pil_mask = pil_img
+        else:
+            pil_mask = None
+            # print(pil_img, pil_img.mode, xyPixLeftTop, layer)
+        
+        self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask)
+
+    def drawImageRC(self, pil_img, rcTopLeft, layer=0):
+        xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
+        self.drawImageXY(pil_img, xyPixLeftTop, layer=layer)
+
     def open_window(self):
         assert self.window_open is False, "Window is already open!"
         self.window = tk.Tk()
@@ -66,8 +88,8 @@ class PILGL(GraphicsLayer):
         pass
 
     def beginFrame(self):
-        self.create_layer(0)
-        self.create_layer(1)
+        # Create a new agent layer
+        self.create_layer(iLayer=1, clear=True)
 
     def show(self, block=False):
         img = self.alpha_composite_layers()
@@ -78,6 +100,7 @@ class PILGL(GraphicsLayer):
         tkimg = ImageTk.PhotoImage(img)
         
         if self.firstFrame:
+            # Do TK actions for a new panel (not sure what they really do)
             self.panel = tk.Label(self.window, image=tkimg)
             self.panel.pack(side="bottom", fill="both", expand="yes")
         else:
@@ -109,7 +132,8 @@ class PILGL(GraphicsLayer):
         img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity))
         return img
 
-    def create_layer(self, iLayer=0):
+    def create_layer(self, iLayer=0, clear=True):
+        # If we don't have the layers already, create them
         if len(self.layers) <= iLayer:
             for i in range(len(self.layers), iLayer+1):
                 if i == 0:
@@ -120,7 +144,216 @@ class PILGL(GraphicsLayer):
                 self.layers.append(img)
                 self.draws.append(ImageDraw.Draw(img))
         else:
-            opacity = 0 if iLayer > 0 else 255
-            self.layers[iLayer] = img = self.create_image(opacity)
-            self.draws[iLayer] = ImageDraw.Draw(img)
+            # We do already have this iLayer.  Clear it if requested.
+            if clear:
+                opacity = 0 if iLayer > 0 else 255
+                self.layers[iLayer] = img = self.create_image(opacity)
+                # We also need to maintain a Draw object for each layer
+                self.draws[iLayer] = ImageDraw.Draw(img)
+
+    def create_layers(self, clear=True):        
+        self.create_layer(0, clear=clear)
+        self.create_layer(1, clear=clear)
+
+
+class PILSVG(PILGL):
+    def __init__(self, width, height):
+        print(self, type(self))
+        oSuper = super()
+        print(oSuper, type(oSuper))
+        oSuper.__init__(width, height)
+
+        # self.track = self.track = Track()
+        # self.lwTrack = []
+        # self.zug = Zug()
+
+        self.lwAgents = []
+        self.agents_prev = []
+
+        self.loadRailSVGs()
+        self.loadAgentSVGs()
+
+    def is_raster(self):
+        return False
+
+    def processEvents(self):
+        # self.app.processEvents()
+        time.sleep(0.001)
+
+    def clear_rails(self):
+        print("Clear rails")
+        self.create_layers()
+        self.clear_agents()
+
+    def clear_agents(self):
+        # print("Clear Agents: ", len(self.lwAgents))
+        for wAgent in self.lwAgents:
+            self.layout.removeWidget(wAgent)
+        self.lwAgents = []
+        self.agents_prev = []
+
+    def pilFromSvgFile(self, sfPath):
+        with open(sfPath, "r") as fIn:
+            bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell)
+        
+        with io.BytesIO(bytesPNG) as fIn:
+            pil_img = Image.open(fIn)
+            pil_img.load()
+            # print(pil_img.mode)
+        
+        return pil_img
+
+    def pilFromSvgBytes(self, bytesSVG):
+        bytesPNG = svg2png(bytesSVG, output_height=self.nPixCell, output_width=self.nPixCell)
+        with io.BytesIO(bytesPNG) as fIn:
+            pil_img = Image.open(fIn)
+            return pil_img
+
+    def loadRailSVGs(self):
+        """ Load the rail SVG images, apply rotations, and store as PIL images.
+        """
+        dFiles = {
+            "": "Background_#91D1DD.svg",
+            "WE": "Gleis_Deadend.svg",
+            "WW EE NN SS": "Gleis_Diamond_Crossing.svg",
+            "WW EE": "Gleis_horizontal.svg",
+            "EN SW": "Gleis_Kurve_oben_links.svg",
+            "WN SE": "Gleis_Kurve_oben_rechts.svg",
+            "ES NW": "Gleis_Kurve_unten_links.svg",
+            "NE WS": "Gleis_Kurve_unten_rechts.svg",
+            "NN SS": "Gleis_vertikal.svg",
+            "NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg",
+            "EE WW EN SW": "Weiche_horizontal_oben_links.svg",
+            "EE WW SE WN": "Weiche_horizontal_oben_rechts.svg",
+            "EE WW ES NW": "Weiche_horizontal_unten_links.svg",
+            "EE WW NE WS": "Weiche_horizontal_unten_rechts.svg",
+            "NN SS EE WW NW ES": "Weiche_Single_Slip.svg",
+            "NE NW ES WS": "Weiche_Symetrical.svg",
+            "NN SS EN SW": "Weiche_vertikal_oben_links.svg",
+            "NN SS SE WN": "Weiche_vertikal_oben_rechts.svg",
+            "NN SS NW ES": "Weiche_vertikal_unten_links.svg",
+            "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg"}
+
+        self.dPil = {}
+
+        transitions = RailEnvTransitions()
+
+        lDirs = list("NESW")
+
+        # svgBG = SVG("./svg/Background_#91D1DD.svg")
+
+        for sTrans, sFile in dFiles.items():
+            sPathSvg = "./svg/" + sFile
+
+            # Translate the ascii transition descption in the format  "NE WS" to the 
+            # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
+            lTrans16 = ["0"] * 16
+            for sTran in sTrans.split(" "):
+                if len(sTran) == 2:
+                    iDirIn = lDirs.index(sTran[0])
+                    iDirOut = lDirs.index(sTran[1])
+                    iTrans = 4 * iDirIn + iDirOut
+                    lTrans16[iTrans] = "1"
+            sTrans16 = "".join(lTrans16)
+            binTrans = int(sTrans16, 2)
+            print(sTrans, sTrans16, sFile)
+
+            # Merge the transition svg image with the background colour.
+            # This is a shortcut / hack and will need re-working.
+            # if binTrans > 0:
+            #    svg = svg.merge(svgBG)
+
+            pilRail = self.pilFromSvgFile(sPathSvg)
+            self.dPil[binTrans] = pilRail
+
+            # Rotate both the transition binary and the image and save in the dict
+            for nRot in [90, 180, 270]:
+                binTrans2 = transitions.rotate_transition(binTrans, nRot)
+
+                # PIL rotates anticlockwise for positive theta
+                pilRail2 = pilRail.rotate(-nRot)
+                self.dPil[binTrans2] = pilRail2
+
+    def setRailAt(self, row, col, binTrans):
+        if binTrans in self.dPil:
+            pilTrack = self.dPil[binTrans]
+            self.drawImageRC(pilTrack, (row, col))
+        else:
+            print("Illegal rail:", row, col, format(binTrans, "#018b")[2:])
+
+    def rgb_s2i(self, sRGB):
+        """ convert a hex RGB string like 0091ea to 3-tuple of ints """
+        return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2])
+
+    def loadAgentSVGs(self):
+
+        # Seed initial train/zug files indexed by tuple(iDirIn, iDirOut):
+        dDirsFile = {
+            (0, 0): "svg/Zug_Gleis_#0091ea.svg",
+            (1, 2): "svg/Zug_1_Weiche_#0091ea.svg",
+            (0, 3): "svg/Zug_2_Weiche_#0091ea.svg"
+            }
+
+        sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
+            "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
+        lColors = sColors.split("#")
+        self.nAgentColors = len(lColors)
+
+        # "paint" color of the train images we load
+        a_base_color = self.rgb_s2i("0091ea")
+
+        self.dPilZug = {}
+
+        for tDirs, sPathSvg in dDirsFile.items():
+            iDirIn, iDirOut = tDirs
+            
+            pilZug = self.pilFromSvgFile(sPathSvg)
+
+            # Rotate both the directions and the image and save in the dict
+            for iDirRot in range(4):
+                nDegRot = iDirRot * 90
+                iDirIn2 = (iDirIn + iDirRot) % 4
+                iDirOut2 = (iDirOut + iDirRot) % 4
+
+                # PIL rotates anticlockwise for positive theta
+                pilZug2 = pilZug.rotate(-nDegRot)
+                rgbaZug2 = array(pilZug2)
+
+                for iColor, sColor in enumerate(lColors):
+                    tnNewColor = self.rgb_s2i(sColor)
+                    xy_color_mask = np.all(rgbaZug2[:, :, 0:3] - a_base_color == 0, axis=2)
+                    rgbaZug3 = np.copy(rgbaZug2)
+                    rgbaZug3[xy_color_mask, 0:3] = tnNewColor
+                    self.dPilZug[(iDirIn2, iDirOut2, iColor)] = Image.fromarray(rgbaZug3)
+
+    def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, color=None):
+        delta_dir = (iDirOut - iDirIn) % 4
+        iColor = iAgent % self.nAgentColors
+        # when flipping direction at a dead end, use the "iDirOut" direction.
+        if delta_dir == 2:
+            iDirIn = iDirOut
+        pilZug = self.dPilZug[(iDirIn % 4, iDirOut % 4, iColor)]
+        self.drawImageRC(pilZug, (row, col), layer=1)
+
+
+def main2():
+    gl = PILSVG(10, 10)
+    for i in range(10):
+        gl.beginFrame()
+        gl.plot([3 + i, 4], [-4 - i, -5], color="r")
+        gl.endFrame()
+        time.sleep(1)
+
+
+def main():
+    gl = PILSVG(width=10, height=10)
+
+    for i in range(1000):
+        gl.processEvents()
+        time.sleep(0.1)
+    time.sleep(1)
+
+
+if __name__ == "__main__":
+    main()
 
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index 73b8ca77a33042bf181097d4b1a0a1afcb48b56e..a8ccc780165fe4bcbf5908969f9284906b2571fe 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -16,10 +16,11 @@ def transform_string_svg(sSVG):
     bySVG = bytearray(sSVG, encoding='utf-8')
     return bySVG
 
+
 def create_QtSvgWidget_from_svg_string(sSVG):
     svgWidget = QtSvg.QSvgWidget()
     ret = svgWidget.renderer().load(transform_string_svg(sSVG))
-    if ret == False:
+    if ret is False:
         print("create_QtSvgWidget_from_svg_string : failed to parse:", sSVG)
     return svgWidget
 
@@ -132,26 +133,7 @@ class QTSVG(GraphicsLayer):
         self.lwAgents = []
         self.agents_prev = []
 
-        svgWidget = None
-
-        iArt = 0
-        iCol = 0
-        iRow = 0
-        nCols = 10
-
-        if False:
-            for binTrans in self.track.dSvg.keys():
-                sSVG = self.track.dSvg[binTrans].to_string()
-                self.layout.addWidget(create_QtSvgWidget_from_svg_string(sSVG), iRow, iCol)
-
-                iArt += 1
-                iRow = int(iArt / nCols)
-                iCol = iArt % nCols
-
-            svgWidget2 = QtSvg.QSvgWidget()
-            svgWidget2.renderer().load(bySVG)
-
-            self.layout.addWidget(svgWidget2, 0, 0)
+        # svgWidget = None
 
     def is_raster(self):
         return False
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 1aa1748f6b61129509584c980a73086b87300d4e..9b99d1ecffb7bf0974bc70a63ed89ff687e7434e 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -3,14 +3,12 @@ from collections import deque
 
 # import xarray as xr
 import matplotlib.pyplot as plt
-import numpy as np
-from numpy import array
-from recordtype import recordtype
-
-from flatland.utils.graphics_layer import GraphicsLayer
-from flatland.utils.graphics_pil import PILGL
 from flatland.utils.render_qt import QTGL, QTSVG
-
+from flatland.utils.graphics_pil import PILGL, PILSVG
+from flatland.utils.graphics_layer import GraphicsLayer
+import recordtype
+from numpy import array
+import numpy as np
 
 # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
 
@@ -133,6 +131,8 @@ class RenderTool(object):
             self.gl = QTGL(env.width, env.height)
         elif gl == "PIL":
             self.gl = PILGL(env.width, env.height)
+        elif gl == "PILSVG":
+            self.gl = PILSVG(env.width, env.height)
         elif gl == "QTSVG":
             self.gl = QTSVG(env.width, env.height)
 
@@ -618,11 +618,11 @@ class RenderTool(object):
         """
 
         if not self.gl.is_raster():
-            self.renderEnv2(show, curves, spacing,
-                            arrows, agents, renderobs,sRailColor,
-                            frames, iEpisode, iStep,
-                            iSelectedAgent, action_dict)
-
+            self.renderEnv2(show=show, curves=curves, spacing=spacing,
+                            arrows=arrows, agents=agents, show_observations=show_observations,
+                            sRailColor=sRailColor,
+                            frames=frames, iEpisode=iEpisode, iStep=iStep,
+                            iSelectedAgent=iSelectedAgent, action_dict=action_dict)
             return
 
         if type(self.gl) in (QTGL, PILGL):
@@ -728,9 +728,11 @@ class RenderTool(object):
 
             gP0 = array([gX1, gY1, gZ1])
 
-    def renderEnv2(self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True,
-                   sRailColor="gray", frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
-                   action_dict=dict()):
+    def renderEnv2(
+        self, show=False, curves=True, spacing=False, arrows=False, agents=True, 
+            show_observations=True, sRailColor="gray",
+            frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
+            action_dict=dict()):
         """
         Draw the environment using matplotlib.
         Draw into the figure if provided.
@@ -741,6 +743,8 @@ class RenderTool(object):
 
         env = self.env
 
+        self.gl.beginFrame()
+
         if self.new_rail:
             self.new_rail = False
             self.gl.clear_rails()
@@ -766,7 +770,6 @@ class RenderTool(object):
                 iAction = action_dict[iAgent]
                 new_direction, action_isValid = self.env.check_action(agent, iAction)
 
-
             # ** TODO ***
             # why should we only update if the action is valid ?
             if True:
@@ -779,7 +782,8 @@ class RenderTool(object):
             else:
                 self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction, color=oColor)
 
-        self.gl.show()
+        if show:
+            self.gl.show()
         for i in range(3):
             self.gl.processEvents()
 
diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py
index fb8b987cae83f4b8a3696e0e9496659045709c17..89730af2c41a442e65d8c5976f24b1ea84468330 100644
--- a/flatland/utils/svg.py
+++ b/flatland/utils/svg.py
@@ -104,6 +104,13 @@ class Zug(object):
 
 
 class Track(object):
+    """ Class to load and hold SVG track images.
+        Creates a mapping between
+        - cell entry and exit directions (ie transitions), and
+        - specific images provided by the SBB graphic artist.
+        The directions and images are also rotated by 90, 180 & 270 degrees.
+        (There is some redundancy in this process, given the images provided)
+    """
     def __init__(self):
         dFiles = {
             "": "Background_#9CCB89.svg",
@@ -138,6 +145,8 @@ class Track(object):
         for sTrans, sFile in dFiles.items():
             svg = SVG("./svg/" + sFile)
 
+            # Translate the ascii transition descption in the format  "NE WS" to the 
+            # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
             lTrans16 = ["0"] * 16
             for sTran in sTrans.split(" "):
                 if len(sTran) == 2:
@@ -149,11 +158,14 @@ class Track(object):
             binTrans = int(sTrans16, 2)
             print(sTrans, sTrans16, sFile)
 
+            # Merge the transition svg image with the background colour.
+            # This is a shortcut / hack and will need re-working.
             if binTrans > 0:
                 svg = svg.merge(svgBG)
 
             self.dSvg[binTrans] = svg
 
+            # Rotate both the transition binary and the image and save in the dict
             for nRot in [90, 180, 270]:
                 binTrans2 = transitions.rotate_transition(binTrans, nRot)
                 svg2 = svg.copy()
diff --git a/images/basic-env.npz b/images/basic-env.npz
index 8ffaf023e1116b0c92702212ddb04c71b82f0655..e645113154b9e953575a12dcbadf4f9f3195b4ad 100644
Binary files a/images/basic-env.npz and b/images/basic-env.npz differ
diff --git a/tests/test_player.py b/tests/test_player.py
index a0e580b92ee04f41ec1bab7c4e99da1339767c96..668afb96a502ce5c25fa1e6d6fe9383a1facbf87 100644
--- a/tests/test_player.py
+++ b/tests/test_player.py
@@ -4,5 +4,9 @@ from examples.play_model import main
 
 
 def test_main():
-    main(render=True, n_steps=20, n_trials=2, sGL="PIL")
+    main(render=True, n_steps=20, n_trials=2, sGL="PILSVG")
+    # main(render=True, n_steps=20, n_trials=2, sGL="PIL")
 
+
+if __name__ == "__main__":
+    test_main()
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index 8204a305328df746a772d034f3c763c848cceb93..1bfce0323322638447ed11191e7d5d9bbea565b5 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -12,6 +12,7 @@ import numpy as np
 import flatland.utils.rendertools as rt
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv, random_rail_generator
+from flatland.envs.generators import empty_rail_generator
 
 
 def checkFrozenImage(oRT, sFileImage, resave=False):
@@ -39,14 +40,15 @@ def test_render_env(save_new_images=False):
     # random.seed(100)
     np.random.seed(100)
     oEnv = RailEnv(width=10, height=10,
-                   rail_generator=random_rail_generator(),
+                   # rail_generator=random_rail_generator(),
+                   rail_generator=empty_rail_generator(),
                    number_of_agents=0,
                    # obs_builder_object=GlobalObsForRailEnv())
                    obs_builder_object=TreeObsForRailEnv(max_depth=2)
                    )
     sfTestEnv = "env-data/tests/test1.npy"
     oEnv.rail.load_transition_map(sfTestEnv)
-    oRT = rt.RenderTool(oEnv, gl="PIL", show=False)
+    oRT = rt.RenderTool(oEnv, gl="PILSVG", show=False)
     oRT.renderEnv(show=False)
 
     checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
@@ -82,6 +84,7 @@ def main():
         test_render_env(save_new_images=True)
     else:
         print("Run 'python test_rendertools.py save' to regenerate images")
+        test_render_env()
 
 
 if __name__ == "__main__":