diff --git a/Makefile b/Makefile
index e9c25bbdfed174ea0ebc4570aed4949c53b31c48..69ad1b42fd51ef9ec9420f5473dc8acef5468572 100644
--- a/Makefile
+++ b/Makefile
@@ -54,13 +54,14 @@ lint: ## check style with flake8
 	flake8 flatland tests examples
 
 test: ## run tests quickly with the default Python
+	echo "$$DISPLAY"
 	py.test
 
 test-all: ## run tests on every Python version with tox
 	tox
 
 coverage: ## check code coverage quickly with the default Python
-	coverage run --source flatland -m pytest
+	xvfb-run -a coverage run --source flatland -m pytest
 	coverage report -m
 	coverage html
 	$(BROWSER) htmlcov/index.html
diff --git a/examples/play_model.py b/examples/play_model.py
index 174568177a4a886cfe38e53125d0f73f2dae52de..34c6aadfeefd44771fd335e2957e1fbd0b2f740f 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -1,11 +1,11 @@
+# import torch
 import random
 import time
+# from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
 
 import numpy as np
-import torch
 
-from flatland.baselines.dueling_double_dqn import Agent
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
@@ -28,10 +28,12 @@ class Player(object):
         self.scores = []
         self.dones_list = []
         self.action_prob = [0] * 4
-        self.agent = Agent(self.state_size, self.action_size, "FC", 0)
+
+        # Removing refs to a real agent for now.
+        # self.agent = Agent(self.state_size, self.action_size, "FC", 0)
         # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
-        self.agent.qnetwork_local.load_state_dict(torch.load(
-            '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
+        # self.agent.qnetwork_local.load_state_dict(torch.load(
+        #    '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
 
         self.iFrame = 0
         self.tStart = time.time()
@@ -49,12 +51,21 @@ class Player(object):
         self.score = 0
         self.env_done = 0
 
+    def reset(self):
+        self.obs = self.env.reset()
+        return self.obs
+
     def step(self):
         env = self.env
 
         # Pass the (stored) observation to the agent network and retrieve the action
         for handle in env.get_agent_handles():
-            action = self.agent.act(np.array(self.obs[handle]), eps=self.eps)
+            # Real Agent
+            # action = self.agent.act(np.array(self.obs[handle]), eps=self.eps)
+            # Random actions
+            action = random.randint(0, 3)
+            # Numpy version uses single random sequence
+            # action = np.random.randint(0, 4, size=1)
             self.action_prob[action] += 1
             self.action_dict.update({handle: action})
 
@@ -67,11 +78,12 @@ class Player(object):
             next_obs[handle] = np.clip(np.array(next_obs[handle]) / norm, -1, 1)
 
         # Update replay buffer and train agent
-        for handle in self.env.get_agent_handles():
-            self.agent.step(self.obs[handle], self.action_dict[handle],
-                            all_rewards[handle], next_obs[handle], done[handle],
-                            train=False)
-            self.score += all_rewards[handle]
+        if False:
+            for handle in self.env.get_agent_handles():
+                self.agent.step(self.obs[handle], self.action_dict[handle],
+                                all_rewards[handle], next_obs[handle], done[handle],
+                                train=False)
+                self.score += all_rewards[handle]
 
         self.iFrame += 1
 
@@ -94,7 +106,50 @@ def max_lt(seq, val):
     return None
 
 
-def main(render=True, delay=0.0):
+def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"):
+    random.seed(1)
+    np.random.seed(1)
+
+    # Example generate a random rail
+    env = RailEnv(width=15, height=15,
+                  rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
+                  number_of_agents=5)
+
+    if render:
+        # env_renderer = RenderTool(env, gl="QTSVG")
+        env_renderer = RenderTool(env, gl=sGL)
+
+    oPlayer = Player(env)
+
+    for trials in range(1, n_trials + 1):
+
+        # Reset environment
+        oPlayer.reset()
+        env_renderer.set_new_rail()
+
+        # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+
+        # score = 0
+        # env_done = 0
+
+        # Run episode
+        for step in range(n_steps):
+            oPlayer.step()
+            if render:
+                env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step,
+                                       action_dict=oPlayer.action_dict)
+                # time.sleep(10)
+                if delay > 0:
+                    time.sleep(delay)
+
+
+def main_old(render=True, delay=0.0):
+    ''' DEPRECATED main which drives agent directly
+        Please use the new main() which creates a Player object which is also used by the Editor.
+        Please fix any bugs in main() and Player rather than here.
+        Will delete this one shortly.
+    '''
+
     random.seed(1)
     np.random.seed(1)
 
@@ -107,8 +162,6 @@ def main(render=True, delay=0.0):
         env_renderer = RenderTool(env, gl="QTSVG")
         # env_renderer = RenderTool(env, gl="QT")
 
-    state_size = 105
-    action_size = 4
     n_trials = 9999
     eps = 1.
     eps_end = 0.005
@@ -119,8 +172,11 @@ def main(render=True, delay=0.0):
     scores = []
     dones_list = []
     action_prob = [0] * 4
-    agent = Agent(state_size, action_size, "FC", 0)
 
+    # Real Agent
+    # state_size = 105
+    # action_size = 4
+    # agent = Agent(state_size, action_size, "FC", 0)
     # agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
 
     def max_lt(seq, val):
@@ -161,7 +217,7 @@ def main(render=True, delay=0.0):
             # print(step)
             # Action
             for a in range(env.get_num_agents()):
-                action = agent.act(np.array(obs[a]), eps=eps)
+                action = random.randint(0, 3)  # agent.act(np.array(obs[a]), eps=eps)
                 action_prob[action] += 1
                 action_dict.update({a: action})
 
@@ -174,13 +230,16 @@ def main(render=True, delay=0.0):
 
             # Environment step
             next_obs, all_rewards, done, _ = env.step(action_dict)
+
             for a in range(env.get_num_agents()):
                 norm = max(1, max_lt(next_obs[a], np.inf))
                 next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
+
             # Update replay buffer and train agent
-            for a in range(env.get_num_agents()):
-                agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
-                score += all_rewards[a]
+            # only needed for "real" agent
+            # for a in range(env.get_num_agents()):
+            #    agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
+            #    score += all_rewards[a]
 
             obs = next_obs.copy()
             if done['__all__']:
@@ -212,8 +271,8 @@ def main(render=True, delay=0.0):
                 np.mean(scores_window),
                 100 * np.mean(done_window),
                 eps, rFps, action_prob / np.sum(action_prob)))
-            torch.save(agent.qnetwork_local.state_dict(),
-                       '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
+            # torch.save(agent.qnetwork_local.state_dict(),
+            #         '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
             action_prob = [1] * 4
 
 
diff --git a/examples/tkplay.py b/examples/tkplay.py
new file mode 100644
index 0000000000000000000000000000000000000000..95842e3b430000169093d27c3c9de02ebe037de9
--- /dev/null
+++ b/examples/tkplay.py
@@ -0,0 +1,60 @@
+import time
+import tkinter as tk
+
+from PIL import ImageTk, Image
+
+from examples.play_model import Player
+from flatland.envs.generators import complex_rail_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool
+
+
+def tkmain(n_trials=2):
+    # This creates the main window of an application
+    window = tk.Tk()
+    window.title("Join")
+    window.configure(background='grey')
+
+    # Example generate a random rail
+    env = RailEnv(width=15, height=15,
+                  rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
+                  number_of_agents=5)
+
+    env_renderer = RenderTool(env, gl="PIL")
+
+    oPlayer = Player(env)
+    n_trials = 1
+    n_steps = 20
+    delay = 0
+    for trials in range(1, n_trials + 1):
+
+        # Reset environment8
+        oPlayer.reset()
+        env_renderer.set_new_rail()
+
+        first = True
+
+        for step in range(n_steps):
+            oPlayer.step()
+            env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step,
+                                   action_dict=oPlayer.action_dict)
+            img = env_renderer.getImage()
+            img = Image.fromarray(img)
+            tkimg = ImageTk.PhotoImage(img)
+
+            if first:
+                panel = tk.Label(window, image=tkimg)
+                panel.pack(side="bottom", fill="both", expand="yes")
+            else:
+                # update the image in situ
+                panel.configure(image=tkimg)
+                panel.image = tkimg
+
+            window.update()
+            if delay > 0:
+                time.sleep(delay)
+            first = False
+
+
+if __name__ == "__main__":
+    tkmain()
diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py
index 4cfcc64bffb82f91a0f36822188db297bc1ed37e..40bf319da737c8bb49e6c228c98b70604734ddc4 100644
--- a/flatland/utils/graphics_layer.py
+++ b/flatland/utils/graphics_layer.py
@@ -51,7 +51,7 @@ class GraphicsLayer(object):
         elif type(color) is tuple:
             if type(color[0]) is not int:
                 gcolor = array(color)
-                color = tuple((gcolor[:3] * 255).astype(int))
+                color = tuple((gcolor[:4] * 255).astype(int))
         else:
             color = self.tColGrid
 
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 41516fd94737556a4b8abbc7ccfce0fd503a3d6e..b66c8dc55f38c321d038306f933de66493a6e6b3 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -18,28 +18,32 @@ class PILGL(GraphicsLayer):
         # Total grid size at native scale
         self.widthPx = self.width * self.nPixCell + self.linewidth
         self.heightPx = self.height * self.nPixCell + self.linewidth
-        self.beginFrame()
+        self.layers = []
+        self.draws = []
 
         self.tColBg = (255, 255, 255)     # white background
         # self.tColBg = (220, 120, 40)    # background color
         self.tColRail = (0, 0, 0)         # black rails
         self.tColGrid = (230,) * 3        # light grey for grid
 
-    def plot(self, gX, gY, color=None, linewidth=3, **kwargs):
-        color = self.adaptColor(color)
+        self.beginFrame()
 
-        # print(gX, gY)
+    def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
+        color = self.adaptColor(color)
+        if len(color) == 3:
+            color += (opacity,)
+        elif len(color) == 4:
+            color = color[:3] + (opacity,)
         gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell
         gPoints = list(gPoints.ravel())
-        # print(gPoints, color)
-        self.draw.line(gPoints, fill=color, width=self.linewidth)
+        self.draws[layer].line(gPoints, fill=color, width=self.linewidth)
 
-    def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs):
+    def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs):
         color = self.adaptColor(color)
         r = np.sqrt(s)
         gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
         for x, y in gPoints:
-            self.draw.rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
+            self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
 
     def text(self, *args, **kwargs):
         pass
@@ -51,8 +55,8 @@ class PILGL(GraphicsLayer):
         pass
 
     def beginFrame(self):
-        self.img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, 255))
-        self.draw = ImageDraw.Draw(self.img)
+        self.create_layer(0)
+        self.create_layer(1)
 
     def show(self, block=False):
         pass
@@ -62,5 +66,35 @@ class PILGL(GraphicsLayer):
         pass
         # plt.pause(seconds)
 
+    def alpha_composite_layers(self):
+        img = self.layers[0]
+        for img2 in self.layers[1:]:
+            img = Image.alpha_composite(img, img2)
+        return img
+
     def getImage(self):
-        return array(self.img)
+        """ return a blended / alpha composited image composed of all the layers,
+            with layer 0 at the "back".
+        """
+        img = self.alpha_composite_layers()
+        return array(img)
+
+    def create_image(self, opacity=255):
+        img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity))
+        return img
+
+    def create_layer(self, iLayer=0):
+        if len(self.layers) <= iLayer:
+            for i in range(len(self.layers), iLayer+1):
+                if i == 0:
+                    opacity = 255  # "bottom" layer is opaque (for rails)
+                else:
+                    opacity = 0   # subsequent layers are transparent
+                img = self.create_image(opacity)
+                self.layers.append(img)
+                self.draws.append(ImageDraw.Draw(img))
+        else:
+            opacity = 0 if iLayer > 0 else 255
+            self.layers[iLayer] = img = self.create_image(opacity)
+            self.draws[iLayer] = ImageDraw.Draw(img)
+
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 34f3e9fa6857e86f4d99d211784d983a2e2a1e75..26ec39d1f426691e7d4fe5a8b4b6aec3bcc7b1fd 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -15,12 +15,14 @@ from flatland.utils.graphics_layer import GraphicsLayer
 
 
 class MPLGL(GraphicsLayer):
-    def __init__(self, width, height):
+    def __init__(self, width, height, show=False):
         self.width = width
         self.height = height
         self.yxBase = array([6, 21])  # pixel offset
         self.nPixCell = 700 / width
         self.img = None
+        if show:
+            plt.figure(figsize=(10, 10))
 
     def plot(self, *args, **kwargs):
         plt.plot(*args, **kwargs)
@@ -70,6 +72,7 @@ class MPLGL(GraphicsLayer):
     def beginFrame(self):
         self.img = None
         plt.figure(figsize=(10, 10))
+        plt.clf()
         pass
 
     def endFrame(self):
@@ -115,7 +118,7 @@ class RenderTool(object):
     gTheta = np.linspace(0, np.pi / 2, 5)
     gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]
 
-    def __init__(self, env, gl="MPL"):
+    def __init__(self, env, gl="MPL", show=False):
         self.env = env
         self.iFrame = 0
         self.time1 = time.time()
@@ -123,7 +126,7 @@ class RenderTool(object):
         # self.gl = MPLGL()
 
         if gl == "MPL":
-            self.gl = MPLGL(env.width, env.height)
+            self.gl = MPLGL(env.width, env.height, show=show)
         elif gl == "QT":
             self.gl = QTGL(env.width, env.height)
         elif gl == "PIL":
@@ -219,17 +222,19 @@ class RenderTool(object):
         if static:
             color = self.gl.adaptColor(color, lighten=True)
 
+        color = color
+
         # print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
-        self.gl.scatter(*xyPos, color=color, marker="o", s=100)  # agent location
+        self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100)  # agent location
         xyDirLine = array([xyPos, xyPos + xyDir / 2]).T  # line for agent orient.
-        self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
+        self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6)
         if selected:
             self._draw_square(xyPos, 1, color)
 
         if target is not None:
             rcTarget = array(target)
             xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
-            self._draw_square(xyTarget, 1 / 3, color)
+            self._draw_square(xyTarget, 1 / 3, color, layer=1)
 
     def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
         """
@@ -397,6 +402,13 @@ class RenderTool(object):
                 visit = visit.prev
                 xyPrev = xy
 
+    def drawTrans(self, oFrom, oTo, sColor="gray"):
+        self.gl.plot(
+            [oFrom[0], oTo[0]],  # x
+            [oFrom[1], oTo[1]],  # y
+            color=sColor
+        )
+
     def drawTrans2(
         self,
             xyLine, xyCentre,
@@ -474,8 +486,8 @@ class RenderTool(object):
 
     def renderObs(self, agent_handles, observation_dict):
         """
-        Render the extent of the observation of each agent. All cells that appear in the agent obsrevation will be
-        highlighted.
+        Render the extent of the observation of each agent. All cells that appear in the agent
+        observation will be highlighted.
         :param agent_handles: List of agent indices to adapt color and get correct observation
         :param observation_dict: dictionary containing sets of cells of the agent observation
 
@@ -489,47 +501,13 @@ class RenderTool(object):
             for visited_cell in observation_dict[agent]:
                 cell_coord = array(visited_cell[:2])
                 cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
-                self._draw_square(cell_coord_trans, 1 / 3, color)
+                self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100)
 
-    def renderEnv(
-        self, show=False, curves=True, spacing=False,
-            arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False, iEpisode=None, iStep=None,
-            iSelectedAgent=None, action_dict=None):
-        """
-        Draw the environment using matplotlib.
-        Draw into the figure if provided.
-
-        Call pyplot.show() if show==True.
-        (Use show=False from a Jupyter notebook with %matplotlib inline)
-        """
-
-        if not self.gl.is_raster():
-            self.renderEnv2(show, curves, spacing,
-                            arrows, agents, sRailColor,
-                            frames, iEpisode, iStep,
-                            iSelectedAgent, action_dict)
-            return
-
-        # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
-        # so for now I've changed it to 1 (from 10)
-        cell_size = 1
-        self.gl.beginFrame()
-
-        # self.gl.clf()
-        # if oFigure is None:
-        #    oFigure = self.gl.figure()
-
-        def drawTrans(oFrom, oTo, sColor="gray"):
-            self.gl.plot(
-                [oFrom[0], oTo[0]],  # x
-                [oFrom[1], oTo[1]],  # y
-                color=sColor
-            )
+    def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
 
+        cell_size = 1  # TODO: remove cell_size
         env = self.env
 
-        # t1 = time.time()
-
         # Draw cells grid
         grid_color = [0.95, 0.95, 0.95]
         for r in range(env.height + 1):
@@ -613,7 +591,7 @@ class RenderTool(object):
                                         rotation, spacing=spacing, bArrow=arrows,
                                         sColor=sRailColor)
                                 else:
-                                    drawTrans(from_xy, to_xy, sRailColor)
+                                    self.drawTrans(self, from_xy, to_xy, sRailColor)
 
                             if False:
                                 print(
@@ -626,6 +604,42 @@ class RenderTool(object):
                                     "rot:", rotation,
                                 )
 
+    def renderEnv(
+        self, show=False, curves=True, spacing=False,
+            arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False,
+            iEpisode=None, iStep=None,
+            iSelectedAgent=None, action_dict=None):
+        """
+        Draw the environment using matplotlib.
+        Draw into the figure if provided.
+
+        Call pyplot.show() if show==True.
+        (Use show=False from a Jupyter notebook with %matplotlib inline)
+        """
+
+        if not self.gl.is_raster():
+            self.renderEnv2(show, curves, spacing,
+                            arrows, agents, sRailColor,
+                            frames, iEpisode, iStep,
+                            iSelectedAgent, action_dict)
+            return
+
+        if type(self.gl) in (QTGL, PILGL):
+            self.gl.beginFrame()
+
+        if type(self.gl) is MPLGL:
+            # self.gl.clf()
+            self.gl.beginFrame()
+            pass
+
+        # self.gl.clf()
+        # if oFigure is None:
+        #    oFigure = self.gl.figure()
+
+        env = self.env
+
+        self.renderRail()
+
         # Draw each agent + its orientation + its target
         if agents:
             self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
@@ -657,23 +671,26 @@ class RenderTool(object):
         # TODO: for MPL, we don't want to call clf (called by endframe)
         # for QT, we need to call endFrame()
         # if not show:
-        self.gl.endFrame()
+        if type(self.gl) is QTGL:
+            self.gl.endFrame()
+            if show:
+                self.gl.show(block=False)
 
-        # t2 = time.time()
-        # print(t2 - t1, "seconds")
+        if type(self.gl) is MPLGL:
+            if show:
+                self.gl.show(block=False)
+            # self.gl.endFrame()
 
-        if show:
-            self.gl.show(block=False)
-            self.gl.pause(0.00001)
+        self.gl.pause(0.00001)
 
         return
 
-    def _draw_square(self, center, size, color):
+    def _draw_square(self, center, size, color, opacity=255, layer=0):
         x0 = center[0] - size / 2
         x1 = center[0] + size / 2
         y0 = center[1] - size / 2
         y1 = center[1] + size / 2
-        self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color)
+        self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
 
     def getImage(self):
         return self.gl.getImage()
diff --git a/images/basic-env.npz b/images/basic-env.npz
index 356da5d70146b3b8081dd99c0fe5e6bd70646e53..8ffaf023e1116b0c92702212ddb04c71b82f0655 100644
Binary files a/images/basic-env.npz and b/images/basic-env.npz differ
diff --git a/tests/test_player.py b/tests/test_player.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b2745f2e372ca80cd2fb5cf9dcaa3db96fb910a
--- /dev/null
+++ b/tests/test_player.py
@@ -0,0 +1,8 @@
+
+# from examples.play_model import main
+from examples.tkplay import tkmain
+
+
+def test_main():
+    tkmain(n_trials=2)
+
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index c7841df54022d0c6ea24e209f6442342514153bc..8204a305328df746a772d034f3c763c848cceb93 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -46,8 +46,8 @@ def test_render_env(save_new_images=False):
                    )
     sfTestEnv = "env-data/tests/test1.npy"
     oEnv.rail.load_transition_map(sfTestEnv)
-    oRT = rt.RenderTool(oEnv)
-    oRT.renderEnv()
+    oRT = rt.RenderTool(oEnv, gl="PIL", show=False)
+    oRT.renderEnv(show=False)
 
     checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
 
diff --git a/tox.ini b/tox.ini
index 71edb7b5fbe2652b00bf48fc63a35d523c791a7a..6dd011aadeb2e7ba802ff692278aa763fb665f10 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
 [tox]
-envlist = py36, py37, flake8, docs, coverage
+envlist = py36, py37, flake8, docs, coverage, xvfb-run, sh
 
 [travis]
 python =
@@ -8,7 +8,7 @@ python =
 
 [flake8]
 max-line-length = 120
-ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505 
+ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505
 
 [testenv:flake8]
 basepython = python
@@ -23,12 +23,15 @@ commands = make docs
 [testenv:coverage]
 basepython = python
 whitelist_externals = make
-commands = 
+commands =
     pip install -U pip
     pip install -r requirements_dev.txt
     make coverage
 
 [testenv]
+whitelist_externals = xvfb-run
+                      sh
+                      pip
 setenv =
     PYTHONPATH = {toxinidir}
 deps =
@@ -39,6 +42,7 @@ deps =
 commands =
     pip install -U pip
     pip install -r requirements_dev.txt
-    py.test --basetemp={envtmpdir}
+    sh -c 'echo DISPLAY: $DISPLAY'
+    xvfb-run -a py.test --basetemp={envtmpdir}