diff --git a/examples/play_model.py b/examples/play_model.py
index 5ed6feb7b9b3cc5a5ab6251657864fad7ee96a02..777f2d34254b9dba4c7d0a82e207397a61885d6a 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -100,7 +100,7 @@ def max_lt(seq, val):
     return None
 
 
-def main(render=True, delay=0.0, n_trials=3, n_steps=50):
+def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"):
     random.seed(1)
     np.random.seed(1)
 
@@ -111,7 +111,7 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50):
 
     if render:
         # env_renderer = RenderTool(env, gl="QTSVG")
-        env_renderer = RenderTool(env, gl="QT")
+        env_renderer = RenderTool(env, gl=sGL)
 
     oPlayer = Player(env)
 
diff --git a/examples/tkplay.py b/examples/tkplay.py
new file mode 100644
index 0000000000000000000000000000000000000000..b337253814ebb55994ac5981cb09a5bae008317b
--- /dev/null
+++ b/examples/tkplay.py
@@ -0,0 +1,59 @@
+
+import tkinter as tk
+from PIL import ImageTk, Image
+from examples.play_model import Player
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.generators import complex_rail_generator
+from flatland.utils.rendertools import RenderTool
+import time
+
+
+def tkmain(n_trials=2):
+    # This creates the main window of an application
+    window = tk.Tk()
+    window.title("Join")
+    window.configure(background='grey')
+
+    # Example generate a random rail
+    env = RailEnv(width=15, height=15,
+                  rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
+                  number_of_agents=5)
+
+    env_renderer = RenderTool(env, gl="PIL")
+
+    oPlayer = Player(env)
+    n_trials = 1
+    n_steps = 20
+    delay = 0
+    for trials in range(1, n_trials + 1):
+
+        # Reset environment8
+        oPlayer.reset()
+        env_renderer.set_new_rail()
+
+        first = True
+
+        for step in range(n_steps):
+            oPlayer.step()
+            env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step,
+                action_dict=oPlayer.action_dict)
+            img = env_renderer.getImage()
+            img = Image.fromarray(img)
+            tkimg = ImageTk.PhotoImage(img)
+
+            if first:
+                panel = tk.Label(window, image = tkimg)
+                panel.pack(side = "bottom", fill = "both", expand = "yes")
+            else:
+                # update the image in situ
+                panel.configure(image=tkimg)
+                panel.image = tkimg
+
+            window.update()
+            if delay > 0:
+                time.sleep(delay)
+            first = False
+
+
+if __name__ == "__main__":
+    tkmain()
\ No newline at end of file
diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py
index 4cfcc64bffb82f91a0f36822188db297bc1ed37e..40bf319da737c8bb49e6c228c98b70604734ddc4 100644
--- a/flatland/utils/graphics_layer.py
+++ b/flatland/utils/graphics_layer.py
@@ -51,7 +51,7 @@ class GraphicsLayer(object):
         elif type(color) is tuple:
             if type(color[0]) is not int:
                 gcolor = array(color)
-                color = tuple((gcolor[:3] * 255).astype(int))
+                color = tuple((gcolor[:4] * 255).astype(int))
         else:
             color = self.tColGrid
 
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 41516fd94737556a4b8abbc7ccfce0fd503a3d6e..ceda3fd9f4c7d16b8a516dca8af43ea3122d51c7 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -18,28 +18,32 @@ class PILGL(GraphicsLayer):
         # Total grid size at native scale
         self.widthPx = self.width * self.nPixCell + self.linewidth
         self.heightPx = self.height * self.nPixCell + self.linewidth
-        self.beginFrame()
+        self.layers = []
+        self.draws = []
 
         self.tColBg = (255, 255, 255)     # white background
         # self.tColBg = (220, 120, 40)    # background color
         self.tColRail = (0, 0, 0)         # black rails
         self.tColGrid = (230,) * 3        # light grey for grid
 
-    def plot(self, gX, gY, color=None, linewidth=3, **kwargs):
-        color = self.adaptColor(color)
+        self.beginFrame()
 
-        # print(gX, gY)
+    def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
+        color = self.adaptColor(color)
+        if len(color) == 3:
+            color += (opacity,)
+        elif len(color) == 4:
+            color = color[:3] + (opacity,)
         gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell
         gPoints = list(gPoints.ravel())
-        # print(gPoints, color)
-        self.draw.line(gPoints, fill=color, width=self.linewidth)
+        self.draws[layer].line(gPoints, fill=color, width=self.linewidth)
 
-    def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs):
+    def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs):
         color = self.adaptColor(color)
         r = np.sqrt(s)
         gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
         for x, y in gPoints:
-            self.draw.rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
+            self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
 
     def text(self, *args, **kwargs):
         pass
@@ -51,8 +55,8 @@ class PILGL(GraphicsLayer):
         pass
 
     def beginFrame(self):
-        self.img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, 255))
-        self.draw = ImageDraw.Draw(self.img)
+        self.create_layer(0)
+        self.create_layer(1)
 
     def show(self, block=False):
         pass
@@ -62,5 +66,35 @@ class PILGL(GraphicsLayer):
         pass
         # plt.pause(seconds)
 
+    def alpha_composite_layers(self):
+        img = self.layers[0]
+        for img2 in self.layers[1:]:
+            img = Image.alpha_composite(img, img2)
+        return img
+
     def getImage(self):
-        return array(self.img)
+        """ return a blended / alpha composited image composed of all the layers,
+            with layer 0 at the "back".
+        """
+        img = self.alpha_composite_layers()
+        return array(img)
+
+    def create_image(self, opacity=255):
+        img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity))
+        return img
+
+    def create_layer(self, iLayer=0):
+        if len(self.layers) <= iLayer:
+            for i in range(len(self.layers), iLayer+1):
+                if i==0:
+                    opacity = 255  # "bottom" layer is opaque (for rails)
+                else:
+                    opacity = 0   # subsequent layers are transparent
+                img = self.create_image(opacity)
+                self.layers.append(img)
+                self.draws.append(ImageDraw.Draw(img))
+        else:
+            opacity = 0 if iLayer > 0 else 255
+            self.layers[iLayer] = img = self.create_image(opacity)
+            self.draws[iLayer] = ImageDraw.Draw(img)
+
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index e9a94ec1fd18635b69d5f4bc47ce1d4f43daffcd..4def953a02b1154ff36827d0d4be5e023b6d1146 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -15,12 +15,14 @@ from flatland.utils.graphics_layer import GraphicsLayer
 
 
 class MPLGL(GraphicsLayer):
-    def __init__(self, width, height):
+    def __init__(self, width, height, show=False):
         self.width = width
         self.height = height
         self.yxBase = array([6, 21])  # pixel offset
         self.nPixCell = 700 / width
         self.img = None
+        if show:
+            plt.figure(figsize=(10, 10))
 
     def plot(self, *args, **kwargs):
         plt.plot(*args, **kwargs)
@@ -70,6 +72,7 @@ class MPLGL(GraphicsLayer):
     def beginFrame(self):
         self.img = None
         plt.figure(figsize=(10, 10))
+        plt.clf()
         pass
 
     def endFrame(self):
@@ -115,7 +118,7 @@ class RenderTool(object):
     gTheta = np.linspace(0, np.pi / 2, 5)
     gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]
 
-    def __init__(self, env, gl="MPL"):
+    def __init__(self, env, gl="MPL", show=False):
         self.env = env
         self.iFrame = 0
         self.time1 = time.time()
@@ -123,7 +126,7 @@ class RenderTool(object):
         # self.gl = MPLGL()
 
         if gl == "MPL":
-            self.gl = MPLGL(env.width, env.height)
+            self.gl = MPLGL(env.width, env.height, show=show)
         elif gl == "QT":
             self.gl = QTGL(env.width, env.height)
         elif gl == "PIL":
@@ -219,17 +222,19 @@ class RenderTool(object):
         if static:
             color = self.gl.adaptColor(color, lighten=True)
 
+        color = color
+
         # print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
-        self.gl.scatter(*xyPos, color=color, marker="o", s=100)  # agent location
+        self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100)  # agent location
         xyDirLine = array([xyPos, xyPos + xyDir / 2]).T  # line for agent orient.
-        self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
+        self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6)
         if selected:
             self._draw_square(xyPos, 1, color)
 
         if target is not None:
             rcTarget = array(target)
             xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
-            self._draw_square(xyTarget, 1 / 3, color)
+            self._draw_square(xyTarget, 1 / 3, color, layer=1)
 
     def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
         """
@@ -397,6 +402,13 @@ class RenderTool(object):
                 visit = visit.prev
                 xyPrev = xy
 
+    def drawTrans(self, oFrom, oTo, sColor="gray"):
+        self.gl.plot(
+            [oFrom[0], oTo[0]],  # x
+            [oFrom[1], oTo[1]],  # y
+            color=sColor
+        )
+
     def drawTrans2(
         self,
             xyLine, xyCentre,
@@ -489,47 +501,14 @@ class RenderTool(object):
             for visited_cell in observation_dict[agent]:
                 cell_coord = array(visited_cell[:2])
                 cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
-                self._draw_square(cell_coord_trans, 1 / 3, color)
-
-    def renderEnv(
-        self, show=False, curves=True, spacing=False,
-            arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False, iEpisode=None, iStep=None,
-            iSelectedAgent=None, action_dict=None):
-        """
-        Draw the environment using matplotlib.
-        Draw into the figure if provided.
-
-        Call pyplot.show() if show==True.
-        (Use show=False from a Jupyter notebook with %matplotlib inline)
-        """
-
-        if not self.gl.is_raster():
-            self.renderEnv2(show, curves, spacing,
-                            arrows, agents, sRailColor,
-                            frames, iEpisode, iStep,
-                            iSelectedAgent, action_dict)
-            return
-
-        # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
-        # so for now I've changed it to 1 (from 10)
-        cell_size = 1
-        self.gl.beginFrame()
+                self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100)
 
-        # self.gl.clf()
-        # if oFigure is None:
-        #    oFigure = self.gl.figure()
 
-        def drawTrans(oFrom, oTo, sColor="gray"):
-            self.gl.plot(
-                [oFrom[0], oTo[0]],  # x
-                [oFrom[1], oTo[1]],  # y
-                color=sColor
-            )
+    def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
 
+        cell_size = 1  # TODO: remove cell_size
         env = self.env
 
-        # t1 = time.time()
-
         # Draw cells grid
         grid_color = [0.95, 0.95, 0.95]
         for r in range(env.height + 1):
@@ -613,7 +592,7 @@ class RenderTool(object):
                                         rotation, spacing=spacing, bArrow=arrows,
                                         sColor=sRailColor)
                                 else:
-                                    drawTrans(from_xy, to_xy, sRailColor)
+                                    self.drawTrans(self, from_xy, to_xy, sRailColor)
 
                             if False:
                                 print(
@@ -626,6 +605,54 @@ class RenderTool(object):
                                     "rot:", rotation,
                                 )
 
+
+
+    def renderEnv(
+        self, show=False, curves=True, spacing=False,
+            arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False,
+            iEpisode=None, iStep=None,
+            iSelectedAgent=None, action_dict=None):
+        """
+        Draw the environment using matplotlib.
+        Draw into the figure if provided.
+
+        Call pyplot.show() if show==True.
+        (Use show=False from a Jupyter notebook with %matplotlib inline)
+        """
+
+        if not self.gl.is_raster():
+            self.renderEnv2(show, curves, spacing,
+                            arrows, agents, sRailColor,
+                            frames, iEpisode, iStep,
+                            iSelectedAgent, action_dict)
+            return
+
+        # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
+        # so for now I've changed it to 1 (from 10)
+        cell_size = 1
+
+        if type(self.gl) in (QTGL, PILGL):
+            self.gl.beginFrame()
+
+        if type(self.gl) is MPLGL:
+            #self.gl.clf()
+            # plt.clf()
+            self.gl.beginFrame()
+            pass
+
+        # self.gl.clf()
+        # if oFigure is None:
+        #    oFigure = self.gl.figure()
+
+
+
+        env = self.env
+
+        # t1 = time.time()
+
+
+        self.renderRail()
+
         # Draw each agent + its orientation + its target
         if agents:
             self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
@@ -657,23 +684,26 @@ class RenderTool(object):
         # TODO: for MPL, we don't want to call clf (called by endframe)
         # for QT, we need to call endFrame()
         # if not show:
-        self.gl.endFrame()
+        if type(self.gl) is QTGL:
+            self.gl.endFrame()
+            if show:
+                self.gl.show(block=False)
 
-        # t2 = time.time()
-        # print(t2 - t1, "seconds")
+        if type(self.gl) is MPLGL:
+            if show:
+                self.gl.show(block=False)
+            #self.gl.endFrame()
 
-        if show:
-            self.gl.show(block=False)
-            self.gl.pause(0.00001)
+        self.gl.pause(0.00001)
 
         return
 
-    def _draw_square(self, center, size, color):
+    def _draw_square(self, center, size, color, opacity=255, layer=0):
         x0 = center[0] - size / 2
         x1 = center[0] + size / 2
         y0 = center[1] - size / 2
         y1 = center[1] + size / 2
-        self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color)
+        self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
 
     def getImage(self):
         return self.gl.getImage()
diff --git a/images/basic-env.npz b/images/basic-env.npz
index 356da5d70146b3b8081dd99c0fe5e6bd70646e53..8ffaf023e1116b0c92702212ddb04c71b82f0655 100644
Binary files a/images/basic-env.npz and b/images/basic-env.npz differ
diff --git a/tests/test_player.py b/tests/test_player.py
index 82f6f67f0fe29af9f7abad433e7c5ddb9df1012f..7b2745f2e372ca80cd2fb5cf9dcaa3db96fb910a 100644
--- a/tests/test_player.py
+++ b/tests/test_player.py
@@ -1,7 +1,8 @@
 
-from examples.play_model import main
+# from examples.play_model import main
+from examples.tkplay import tkmain
 
 
 def test_main():
-    main(n_trials=2)
+    tkmain(n_trials=2)
 
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index 3259ed387f1a09b7a6a5e73fe7976cc78f02f32f..7d3ddf63226b55217edd28bf25aed0307377433b 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -46,8 +46,8 @@ def test_render_env(save_new_images=False):
                    )
     sfTestEnv = "env-data/tests/test1.npy"
     oEnv.rail.load_transition_map(sfTestEnv)
-    oRT = rt.RenderTool(oEnv)
-    oRT.renderEnv()
+    oRT = rt.RenderTool(oEnv, gl="PIL", show=False)
+    oRT.renderEnv(show=False)
 
     checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)