diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 4aabd8b8b7363087f79f1da643815c6067afb340..671e9263edbfd0efbac895e42ed3457d99508618 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -9,5 +9,7 @@ before_script:
 
 tests:
     script:
+        - apt update
+        - apt install -y libgl1-mesa-glx xvfb
         - pip install tox
-        - tox
+        - xvfb-run -s "-screen 0 800x600x24" tox
diff --git a/examples/play_model.py b/examples/play_model.py
index 713d831e4936b0803f7d69f4eae8253db0685ef3..d54decd8950a7e0ef8fa67f987f458b6b0fed005 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -1,12 +1,14 @@
 from flatland.envs.rail_env import RailEnv, random_rail_generator
 # from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import RenderTool
+from flatland.utils.render_qt import QtRailRender
 from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
 import torch
 import random
 import numpy as np
 import matplotlib.pyplot as plt
+import redis
 
 
 def main():
@@ -30,8 +32,11 @@ def main():
                 height=7,
                 rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
                 number_of_agents=1)
-    env_renderer = RenderTool(env)
+    env_renderer = RenderTool(env, gl="QT")
+    #env_renderer = QtRailRender(env)
     plt.figure(figsize=(5,5))
+    # fRedis = redis.Redis()
+
     handle = env.get_agent_handles()
 
     state_size = 105
@@ -115,10 +120,8 @@ def main():
             '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
                 env.number_of_agents,
                 trials,
-                np.mean(
-                    scores_window),
-                100 * np.mean(
-                    done_window),
+                np.mean(scores_window),
+                100 * np.mean(done_window),
                 eps, action_prob/np.sum(action_prob)),
             end=" ")
         if trials % 100 == 0:
diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b199a3b2e4d71261a402254ccb60ec453e30469c
--- /dev/null
+++ b/flatland/utils/graphics_layer.py
@@ -0,0 +1,33 @@
+
+
+class GraphicsLayer(object):
+    def __init__(self):
+        pass
+
+    def plot(self, *args, **kwargs):
+        pass
+
+    def scatter(self, *args, **kwargs):
+        pass
+
+    def text(self, *args, **kwargs):
+        pass
+
+    def prettify(self, *args, **kwargs):
+        pass
+
+    def show(self, block=False):
+        pass
+    
+    def pause(self, seconds=0.00001):
+        pass
+
+    def clf(self):
+        pass
+    
+    def beginFrame(self):
+        pass
+    
+    def endFrame(self):
+        pass
+
diff --git a/flatland/utils/graphics_qt.py b/flatland/utils/graphics_qt.py
new file mode 100644
index 0000000000000000000000000000000000000000..6571f4d7ab03aced3e7846c73ae8439c36e6af42
--- /dev/null
+++ b/flatland/utils/graphics_qt.py
@@ -0,0 +1,226 @@
+
+import numpy as np
+from PyQt5.QtCore import Qt
+from PyQt5.QtGui import QImage, QPixmap, QPainter, QColor, QPolygon
+from PyQt5.QtCore import QPoint, QRect  # QSize
+from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QTextEdit
+from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QLabel, QFrame
+import os
+
+
+class Window(QMainWindow):
+    """
+    Simple application window to render the environment into
+    """
+
+    def __init__(self):
+        super().__init__()
+
+        self.setWindowTitle('MiniGrid Gym Environment')
+
+        # Image label to display the rendering
+        self.imgLabel = QLabel()
+        self.imgLabel.setFrameStyle(QFrame.Panel | QFrame.Sunken)
+
+        if False:
+            # Text box for the mission
+            self.missionBox = QTextEdit()
+            self.missionBox.setReadOnly(True)
+            self.missionBox.setMinimumSize(400, 100)
+
+        # Center the image
+        hbox = QHBoxLayout()
+        hbox.addStretch(1)
+        hbox.addWidget(self.imgLabel)
+        hbox.addStretch(1)
+
+        # Arrange widgets vertically
+        vbox = QVBoxLayout()
+        vbox.addLayout(hbox)
+        # vbox.addWidget(self.missionBox)
+
+        # Create a main widget for the window
+        self.mainWidget = QWidget(self)
+        self.setCentralWidget(self.mainWidget)
+        self.mainWidget.setLayout(vbox)
+
+        # Show the application window
+        self.show()
+        self.setFocus()
+
+        self.closed = False
+
+        # Callback for keyboard events
+        self.keyDownCb = None
+
+    def closeEvent(self, event):
+        self.closed = True
+
+    def setPixmap(self, pixmap):
+        self.imgLabel.setPixmap(pixmap)
+
+    def setText(self, text):
+        # self.missionBox.setPlainText(text)
+        pass
+
+    def setKeyDownCb(self, callback):
+        self.keyDownCb = callback
+
+    def keyPressEvent(self, e):
+        if self.keyDownCb is None:
+            return
+
+        keyName = None
+        if e.key() == Qt.Key_Left:
+            keyName = 'LEFT'
+        elif e.key() == Qt.Key_Right:
+            keyName = 'RIGHT'
+        elif e.key() == Qt.Key_Up:
+            keyName = 'UP'
+        elif e.key() == Qt.Key_Down:
+            keyName = 'DOWN'
+        elif e.key() == Qt.Key_Space:
+            keyName = 'SPACE'
+        elif e.key() == Qt.Key_Return:
+            keyName = 'RETURN'
+        elif e.key() == Qt.Key_Alt:
+            keyName = 'ALT'
+        elif e.key() == Qt.Key_Control:
+            keyName = 'CTRL'
+        elif e.key() == Qt.Key_PageUp:
+            keyName = 'PAGE_UP'
+        elif e.key() == Qt.Key_PageDown:
+            keyName = 'PAGE_DOWN'
+        elif e.key() == Qt.Key_Backspace:
+            keyName = 'BACKSPACE'
+        elif e.key() == Qt.Key_Escape:
+            keyName = 'ESCAPE'
+
+        if keyName is None:
+            return
+        self.keyDownCb(keyName)
+
+
+class QtRenderer(object):
+    def __init__(self, width, height, ownWindow=False):
+        self.width = width
+        self.height = height
+
+        self.img = QImage(width, height, QImage.Format_RGB888)
+        self.painter = QPainter()
+
+        self.window = None
+        if ownWindow:
+            self.app = QApplication([])
+            self.window = Window()
+        self.iFrame = 0  # for movie capture
+
+    def close(self):
+        """
+        Deallocate resources used
+        """
+        pass
+
+    def beginFrame(self):
+        self.painter.begin(self.img)
+        self.painter.setRenderHint(QPainter.Antialiasing, False)
+
+        # Clear the background
+        self.painter.setBrush(QColor(0, 0, 0))
+        self.painter.drawRect(0, 0, self.width - 1, self.height - 1)
+
+    def endFrame(self):
+        self.painter.end()
+
+        if self.window:
+            if self.window.closed:
+                self.window = None
+            else:
+                self.window.setPixmap(self.getPixmap())
+                self.app.processEvents()
+
+    def getPixmap(self):
+        return QPixmap.fromImage(self.img)
+
+    def getArray(self):
+        """
+        Get a numpy array of RGB pixel values.
+        The size argument should be (3,w,h)
+        """
+
+        width = self.width
+        height = self.height
+        shape = (width, height, 3)
+
+        numBytes = self.width * self.height * 3
+        buf = self.img.bits().asstring(numBytes)
+        output = np.frombuffer(buf, dtype='uint8')
+        output = output.reshape(shape)
+
+        return output
+
+    def push(self):
+        self.painter.save()
+
+    def pop(self):
+        self.painter.restore()
+
+    def rotate(self, degrees):
+        self.painter.rotate(degrees)
+
+    def translate(self, x, y):
+        self.painter.translate(x, y)
+
+    def scale(self, x, y):
+        self.painter.scale(x, y)
+
+    def setLineColor(self, r, g, b, a=255):
+        self.painter.setPen(QColor(r, g, b, a))
+
+    def setColor(self, r, g, b, a=255):
+        self.painter.setBrush(QColor(r, g, b, a))
+
+    def setLineWidth(self, width):
+        pen = self.painter.pen()
+        pen.setWidthF(width)
+        self.painter.setPen(pen)
+
+    def drawLine(self, x0, y0, x1, y1):
+        self.painter.drawLine(x0, y0, x1, y1)
+
+    def drawCircle(self, x, y, r):
+        center = QPoint(x, y)
+        self.painter.drawEllipse(center, r, r)
+
+    def drawPolygon(self, points):
+        """Takes a list of points (tuples) as input"""
+        points = map(lambda p: QPoint(p[0], p[1]), points)
+        self.painter.drawPolygon(QPolygon(points))
+
+    def drawRect(self, x, y, w, h):
+        self.painter.drawRect(x, y, w, h)
+
+    def drawPolyline(self, points):
+        """Takes a list of points (tuples) as input"""
+        points = map(lambda p: QPoint(p[0], p[1]), points)
+        self.painter.drawPolyline(QPolygon(points))
+
+    def fillRect(self, x, y, width, height, r, g, b, a=255):
+        self.painter.fillRect(QRect(x, y, width, height), QColor(r, g, b, a))
+
+    def drawText(self, x, y, sText):
+        self.painter.drawText(x, y, sText)
+
+    def takeSnapshot(self, sDir="./movie"):
+        oWidget = self.window.mainWidget
+        oPixmap = oWidget.grab()
+        
+        if not os.path.isdir(sDir):
+            os.mkdir(sDir)
+        
+        nRunIn = 30
+        if self.iFrame > nRunIn:
+            sfImage = "%s/frame%05d.jpg" % (sDir, self.iFrame - nRunIn)
+            oPixmap.save(sfImage, "jpg")
+        self.iFrame += 1
+        
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
new file mode 100644
index 0000000000000000000000000000000000000000..60b8d2952d23289cafe7648f4adc1ab5527bac5f
--- /dev/null
+++ b/flatland/utils/render_qt.py
@@ -0,0 +1,96 @@
+from flatland.utils.graphics_qt import QtRenderer
+from numpy import array
+from flatland.utils.graphics_layer import GraphicsLayer
+from matplotlib import pyplot as plt
+
+
+class QTGL(GraphicsLayer):
+    def __init__(self, width, height):
+        self.cell_pixels = 50
+        self.tile_size = self.cell_pixels
+
+        self.width = width
+        self.height = height
+
+        # Total grid size at native scale
+        self.widthPx = self.width * self.cell_pixels
+        self.heightPx = self.height * self.cell_pixels
+        self.qtr = QtRenderer(self.widthPx, self.heightPx, ownWindow=True)
+
+        self.qtr.beginFrame()
+        self.qtr.push()
+
+        # This comment comes from minigrid.  Not sure if it's still true. Jeremy.
+        # Internally, we draw at the "large" full-grid resolution, but we
+        # use the renderer to scale back to the desired size
+        self.qtr.scale(self.tile_size / self.cell_pixels, self.tile_size / self.cell_pixels)
+
+        self.tColBg = (255, 255, 255)     # white background
+        # self.tColBg = (220, 120, 40)    # background color
+        self.tColRail = (0, 0, 0)         # black rails
+        self.tColGrid = (230,) * 3        # light grey for grid
+
+        # Draw the background of the in-world cells
+        self.qtr.fillRect(0, 0, self.widthPx, self.heightPx, *self.tColBg)
+        self.qtr.pop()
+        self.qtr.endFrame()
+
+    def plot(self, gX, gY, color=None, linewidth=2, **kwargs):
+
+        if color == "red" or color == "r":
+            color = (255, 0, 0)
+        elif color == "gray":
+            color = (128, 128, 128)
+        elif type(color) is list:
+            color = array(color) * 255
+        elif type(color) is tuple:
+            gcolor = array(color)
+            color = gcolor[:3] * 255
+        else:
+            color = self.tColGrid
+
+        self.qtr.setLineColor(*color)
+        lastx = lasty = None
+        for x, y in zip(gX, gY):
+            if lastx is not None:
+                # print("line", lastx, lasty, x, y)
+                self.qtr.drawLine(
+                    lastx*self.cell_pixels, -lasty*self.cell_pixels,
+                    x*self.cell_pixels, -y*self.cell_pixels)
+            lastx = x
+            lasty = y
+
+    def scatter(self, *args, **kwargs):
+        print("scatter not yet implemented in ", self.__class__)
+
+    def text(self, x, y, sText):
+        self.qtr.drawText(x*self.cell_pixels, -y*self.cell_pixels, sText)
+    
+    def prettify(self, *args, **kwargs):
+        pass
+
+    def prettify2(self, width, height, cell_size):
+        pass
+    
+    def show(self, block=False):
+        pass
+
+    def pause(self, seconds=0.00001):
+        pass
+    
+    def clf(self):
+        pass
+
+    def get_cmap(self, *args, **kwargs):
+        return plt.get_cmap(*args, **kwargs)
+
+    def beginFrame(self):
+        self.qtr.beginFrame()
+        self.qtr.push()
+        self.qtr.fillRect(0, 0, self.widthPx, self.heightPx, *self.tColBg)
+    
+    def endFrame(self):
+        self.qtr.pop()
+        self.qtr.endFrame()
+
+
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index d97f71637ca769579f1e2bdfe12aeafd8d2db5cd..f9ebf5328556d40fbe739557c3ec7b73907bf90d 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -6,10 +6,67 @@ import xarray as xr
 import matplotlib.pyplot as plt
 import time
 from collections import deque
+from flatland.utils.render_qt import QTGL
+from flatland.utils.graphics_layer import GraphicsLayer
 
 # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
 
 
+class MPLGL(GraphicsLayer):
+    def __init__(self):
+        pass
+
+    def plot(self, *args, **kwargs):
+        plt.plot(*args, **kwargs)
+
+    def scatter(self, *args, **kwargs):
+        plt.scatter(*args, **kwargs)
+
+    def text(self, *args, **kwargs):
+        plt.text(*args, **kwargs)
+    
+    def prettify(self, *args, **kwargs):
+        ax = plt.gca()
+        plt.xticks(range(int(ax.get_xlim()[1])+1))
+        plt.yticks(range(int(ax.get_ylim()[1])+1))
+        plt.grid()
+        plt.xlabel("Euclidean distance")
+        plt.ylabel("Tree / Transition Depth")
+
+    def prettify2(self, width, height, cell_size):
+        plt.xlim([0, width * cell_size])
+        plt.ylim([-height * cell_size, 0])
+
+        gTicks = (np.arange(0, height) + 0.5) * cell_size
+        gLabels = np.arange(0, height)
+        plt.xticks(gTicks, gLabels)
+
+        gTicks = np.arange(-height * cell_size, 0) + cell_size/2
+        gLabels = np.arange(height-1, -1, -1)
+        plt.yticks(gTicks, gLabels)
+
+        plt.xlim([0, width * cell_size])
+        plt.ylim([-height * cell_size, 0])
+    
+    def show(self, block=False):
+        plt.show(block=block)
+
+    def pause(self, seconds=0.00001):
+        plt.pause(seconds)
+    
+    def clf(self):
+        plt.clf()
+    
+    def get_cmap(self, *args, **kwargs):
+        return plt.get_cmap(*args, **kwargs)
+
+    def beginFrame(self):
+        pass
+    
+    def endFrame(self):
+        pass
+
+
 class RenderTool(object):
     Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])
 
@@ -31,11 +88,14 @@ class RenderTool(object):
     gTheta = np.linspace(0, np.pi/2, 10)
     gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]
 
-    def __init__(self, env):
+    def __init__(self, env, gl="MPL"):
         self.env = env
         self.iFrame = 0
         self.time1 = time.time()
         self.lTimes = deque()
+        # self.gl = MPLGL()
+
+        self.gl = MPLGL() if gl == "MPL" else QTGL(env.width, env.height)
 
     def plotTreeOnRail(self, lVisits, color="r"):
         """
@@ -133,17 +193,17 @@ class RenderTool(object):
         """
         rt = self.__class__
         xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
-        plt.scatter(*xyPos, color=sColor)            # agent location
+        self.gl.scatter(*xyPos, color=sColor)            # agent location
 
         rcDir = rt.gTransRC[iDir]                    # agent direction in RC
         xyDir = np.matmul(rcDir, rt.grc2xy)          # agent direction in xy
         xyDirLine = array([xyPos, xyPos+xyDir/2]).T  # line for agent orient.
-        plt.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6)
+        self.gl.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6)
 
         # just mark the next cell we're heading into
         rcNext = rcPos + rcDir
         xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf
-        plt.scatter(*xyNext, color=sColor)
+        self.gl.scatter(*xyNext, color=sColor)
 
     def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
         """
@@ -156,10 +216,10 @@ class RenderTool(object):
         rt = self.__class__
         xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
         gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy/2.4)
-        plt.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2)
+        self.gl.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2)
         if depth is not None:
             for x, y in gxyTrans:
-                plt.text(x, y, depth)
+                self.gl.text(x, y, depth)
 
     def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True, bPlot=False):
         """
@@ -234,9 +294,9 @@ class RenderTool(object):
             xLoc = rDist + visit.iDir / 4
 
             # point labelled with distance
-            plt.scatter(xLoc, visit.iDepth,  color="k", s=2)
+            self.gl.scatter(xLoc, visit.iDepth,  color="k", s=2)
             # plt.text(xLoc, visit.iDepth, sDist, color="k", rotation=45)
-            plt.text(xLoc, visit.iDepth, visit.rc, color="k", rotation=45)
+            self.gl.text(xLoc, visit.iDepth, visit.rc, color="k", rotation=45)
 
             # if len(dPos)>1:
             if visit.prev:
@@ -251,7 +311,7 @@ class RenderTool(object):
                 xLocPrev = rDistPrev + visit.prev.iDir / 4
 
                 # line from prev node
-                plt.plot([xLocPrev, xLoc],
+                self.gl.plot([xLocPrev, xLoc],
                          [visit.iDepth-1, visit.iDepth],
                          color="k", alpha=0.5, lw=1)
 
@@ -266,19 +326,14 @@ class RenderTool(object):
                 rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
                 xLoc = rDist + visit.iDir / 4
                 if xLocPrev is not None:
-                    plt.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth+1],
+                    self.gl.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth+1],
                              color="r", alpha=0.5, lw=2)
                 xLocPrev = xLoc
                 visit = visit.prev
             # prev = prev.prev
 
-        # plt.xticks(range(7)); plt.yticks(range(11))
-        ax = plt.gca()
-        plt.xticks(range(int(ax.get_xlim()[1])+1))
-        plt.yticks(range(int(ax.get_ylim()[1])+1))
-        plt.grid()
-        plt.xlabel("Euclidean distance")
-        plt.ylabel("Tree / Transition Depth")
+        # self.gl.xticks(range(7)); self.gl.yticks(range(11))
+        self.gl.prettify()
         return visitDest
 
     def plotPath(self, visitDest):
@@ -303,7 +358,7 @@ class RenderTool(object):
                     dx, dy = (xyPrev - xy) / 20
                     xyLine = array([xy, xyPrev]) + array([dy, dx])
 
-                    plt.plot(*xyLine.T, color="r", alpha=0.5, lw=1)
+                    self.gl.plot(*xyLine.T, color="r", alpha=0.5, lw=1)
 
                     xyMid = np.sum(xyLine * [[1/4], [3/4]], axis=0)
 
@@ -312,7 +367,7 @@ class RenderTool(object):
                         xyMid,
                         xyMid + [-dx+dy, -dx-dy]
                         ])
-                    plt.plot(*xyArrow.T, color="r")
+                    self.gl.plot(*xyArrow.T, color="r")
 
                 visit = visit.prev
                 xyPrev = xy
@@ -350,10 +405,10 @@ class RenderTool(object):
                     xyCentre,
                     xyLine[1] - [dy, dx],
                 ])
-                plt.plot(*xyLine2.T, color=sColor)
+                self.gl.plot(*xyLine2.T, color=sColor)
             else:
                 xyLine2 = xyLine + [-dy, dx]
-                plt.plot(*xyLine2.T, color=sColor)
+                self.gl.plot(*xyLine2.T, color=sColor)
 
                 if bArrow:
                     xyMid = np.sum(xyLine2 * [[1/4], [3/4]], axis=0)
@@ -363,7 +418,7 @@ class RenderTool(object):
                         xyMid,
                         xyMid + [-dx+dy, -dx-dy]
                         ])
-                    plt.plot(*xyArrow.T, color=sColor)
+                    self.gl.plot(*xyArrow.T, color=sColor)
 
         else:
 
@@ -381,7 +436,7 @@ class RenderTool(object):
             if sColor == "auto":
                 sColor = sColorAuto
 
-            plt.plot(*(rt.gArc * dxy2 + xyCorner).T, color=sColor)
+            self.gl.plot(*(rt.gArc * dxy2 + xyCorner).T, color=sColor)
 
             if bArrow:
                 dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20
@@ -392,7 +447,7 @@ class RenderTool(object):
                     xyMid,
                     xyMid + [-dx+dy, -dx-dy]
                     ])
-                plt.plot(*xyArrow.T, color=sColor)
+                self.gl.plot(*xyArrow.T, color=sColor)
 
     def renderEnv(
             self, show=False, curves=True, spacing=False,
@@ -409,12 +464,13 @@ class RenderTool(object):
         # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
         # so for now I've changed it to 1 (from 10)
         cell_size = 1
-        plt.clf()
+        self.gl.beginFrame()
+        self.gl.clf()
         # if oFigure is None:
-        #    oFigure = plt.figure()
+        #    oFigure = self.gl.figure()
 
         def drawTrans(oFrom, oTo, sColor="gray"):
-            plt.plot(
+            self.gl.plot(
                 [oFrom[0], oTo[0]],  # x
                 [oFrom[1], oTo[1]],  # y
                 color=sColor
@@ -425,11 +481,11 @@ class RenderTool(object):
         # Draw cells grid
         grid_color = [0.95, 0.95, 0.95]
         for r in range(env.height+1):
-            plt.plot([0, (env.width+1)*cell_size],
+            self.gl.plot([0, (env.width+1)*cell_size],
                      [-r*cell_size, -r*cell_size],
                      color=grid_color)
         for c in range(env.width+1):
-            plt.plot([c*cell_size, c*cell_size],
+            self.gl.plot([c*cell_size, c*cell_size],
                      [0, -(env.height+1)*cell_size],
                      color=grid_color)
 
@@ -514,7 +570,7 @@ class RenderTool(object):
 
         # Draw each agent + its orientation + its target
         if agents:
-            cmap = plt.get_cmap('hsv', lut=env.number_of_agents+1)
+            cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1)
             for i in range(env.number_of_agents):
                 self._draw_square((
                                 env.agents_position[i][1] *
@@ -537,54 +593,46 @@ class RenderTool(object):
                     new_position[0] + env.agents_position[i][0]) / 2 * cell_size,
                     (new_position[1] + env.agents_position[i][1]) / 2 * cell_size)
 
-                plt.plot(
+                self.gl.plot(
                     [env.agents_position[i][1] * cell_size+cell_size/2, new_position[1]+cell_size/2],
                     [-env.agents_position[i][0] * cell_size-cell_size/2, -new_position[0]-cell_size/2],
                     color=cmap(i),
                     linewidth=2.0)
 
         # Draw some textual information like fps
-        yText = [0.1, 0.4, 0.7]
+        yText = [-0.3, -0.6, -0.9]
         if frames:
-            plt.text(0.1, yText[2], "Frame:{:}".format(self.iFrame))
+            self.gl.text(0.1, yText[2], "Frame:{:}".format(self.iFrame))
         self.iFrame += 1
         
         if iEpisode is not None:
-            plt.text(0.1, yText[1], "Ep:{}".format(iEpisode))
+            self.gl.text(0.1, yText[1], "Ep:{}".format(iEpisode))
 
         if iStep is not None:
-            plt.text(0.1, yText[0], "Step:{}".format(iStep))
+            self.gl.text(0.1, yText[0], "Step:{}".format(iStep))
 
         tNow = time.time()
-        plt.text(2, yText[2], "elapsed:{:.2f}s".format(tNow - self.time1))
+        self.gl.text(2, yText[2], "elapsed:{:.2f}s".format(tNow - self.time1))
         self.lTimes.append(tNow)
         if len(self.lTimes) > 20:
             self.lTimes.popleft()
         if len(self.lTimes) > 1:
             rFps = (len(self.lTimes) - 1) / (self.lTimes[-1] - self.lTimes[0])
-            plt.text(2, yText[1], "fps:{:.2f}".format(rFps))
+            self.gl.text(2, yText[1], "fps:{:.2f}".format(rFps))
 
-        plt.xlim([0, env.width * cell_size])
-        plt.ylim([-env.height * cell_size, 0])
+        self.gl.prettify2(env.width, env.height, self.nPixCell)
 
-        gTicks = (np.arange(0, env.height) + 0.5) * cell_size
-        gLabels = np.arange(0, env.height)
-        plt.xticks(gTicks, gLabels)
-
-        gTicks = np.arange(-env.height * cell_size, 0) + cell_size/2
-        gLabels = np.arange(env.height-1, -1, -1)
-        plt.yticks(gTicks, gLabels)
+        self.gl.endFrame()
 
-        plt.xlim([0, env.width * cell_size])
-        plt.ylim([-env.height * cell_size, 0])
         if show:
-            plt.show(block=False)
-            plt.pause(0.00001)
-            return
+            self.gl.show(block=False)
+            self.gl.pause(0.00001)
+
+        return
 
     def _draw_square(self, center, size, color):
         x0 = center[0]-size/2
         x1 = center[0]+size/2
         y0 = center[1]-size/2
         y1 = center[1]+size/2
-        plt.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color)
+        self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color)
diff --git a/requirements_dev.txt b/requirements_dev.txt
index 4cb4edd4f4dd0f41a32b434e4e8ca13d5d6199c8..0bc267e913ae035b96984a0688e169daacc5ddd3 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -15,4 +15,5 @@ numpy==1.16.2
 recordtype==1.3
 xarray==0.11.3
 matplotlib==3.0.2
+PyQt5==5.12