diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 1bd67c22ad4ba05bd6badfc991e8839a85fe9e30..d64601c7ee32d0f611aa98d1577eb785eea870f3 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -10,6 +10,7 @@ from flatland.utils.render_qt import QTGL, QTSVG
 from flatland.utils.graphics_pil import PILGL
 from flatland.utils.graphics_layer import GraphicsLayer
 
+
 # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
 
 
@@ -100,12 +101,12 @@ class RenderTool(object):
     lColors = list("brgcmyk")
     # \delta RC for NESW
     gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
-    nPixCell = 1   # misnomer...
+    nPixCell = 1  # misnomer...
     nPixHalf = nPixCell / 2
     xyHalf = array([nPixHalf, -nPixHalf])
     grc2xy = array([[0, -nPixCell], [nPixCell, 0]])
     gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * \
-        array([[[nPixCell]], [[nPixCell]]])
+            array([[[nPixCell]], [[nPixCell]]])
     # xyPixHalf = xr.DataArray([nPixHalf, -nPixHalf],
     #                         dims="xy",
     #                         coords={"xy": ["x", "y"]})
@@ -130,7 +131,7 @@ class RenderTool(object):
             self.gl = PILGL(env.width, env.height)
         elif gl == "QTSVG":
             self.gl = QTSVG(env.width, env.height)
-        
+
         self.new_rail = True
 
     def set_new_rail(self):
@@ -153,14 +154,14 @@ class RenderTool(object):
 
     def plotAgents(self, targets=True, iSelectedAgent=None):
         cmap = self.gl.get_cmap('hsv',
-            lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
+                                lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
 
         for iAgent, agent in enumerate(self.env.agents_static):
             if agent is None:
                 continue
             oColor = cmap(iAgent)
             self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None,
-                static=True, selected=iAgent == iSelectedAgent)
+                           static=True, selected=iAgent == iSelectedAgent)
 
         for iAgent, agent in enumerate(self.env.agents):
             if agent is None:
@@ -211,8 +212,8 @@ class RenderTool(object):
         """
         rt = self.__class__
 
-        rcDir = rt.gTransRC[iDir]                    # agent direction in RC
-        xyDir = np.matmul(rcDir, rt.grc2xy)          # agent direction in xy
+        rcDir = rt.gTransRC[iDir]  # agent direction in RC
+        xyDir = np.matmul(rcDir, rt.grc2xy)  # agent direction in xy
 
         xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
 
@@ -220,7 +221,7 @@ class RenderTool(object):
             color = self.gl.adaptColor(color, lighten=True)
 
         # print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
-        self.gl.scatter(*xyPos, color=color, marker="o", s=100)            # agent location
+        self.gl.scatter(*xyPos, color=color, marker="o", s=100)  # agent location
         xyDirLine = array([xyPos, xyPos + xyDir / 2]).T  # line for agent orient.
         self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
         if selected:
@@ -398,12 +399,12 @@ class RenderTool(object):
                 xyPrev = xy
 
     def drawTrans2(
-            self,
-            xyLine, xyCentre,
-            rotation, bDeadEnd=False,
-            sColor="gray",
-            bArrow=True,
-            spacing=0.1):
+        self,
+        xyLine, xyCentre,
+        rotation, bDeadEnd=False,
+        sColor="gray",
+        bArrow=True,
+        spacing=0.1):
         """
         gLine is a numpy 2d array of points,
         in the plotting space / coords.
@@ -420,9 +421,9 @@ class RenderTool(object):
 
             if sColor == "auto":
                 if dx > 0 or dy > 0:
-                    sColor = "C1"   # N or E
+                    sColor = "C1"  # N or E
                 else:
-                    sColor = "C2"   # S or W
+                    sColor = "C2"  # S or W
 
             if bDeadEnd:
                 xyLine2 = array([
@@ -471,6 +472,7 @@ class RenderTool(object):
                     xyMid,
                     xyMid + [-dx + dy, -dx - dy]])
                 self.gl.plot(*xyArrow.T, color=sColor)
+
     def renderObs(self, agent_handles, observation_list):
         """
 
@@ -480,21 +482,21 @@ class RenderTool(object):
         """
         rt = self.__class__
 
-        cmap = self.gl.get_cmap('hsv',lut=max(len(self.env.agents),len(self.env.agents_static)+1))
+        cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
 
         for agent in agent_handles:
             color = cmap(agent)
             for visited_cell in observation_list[agent]:
                 cell_coord = array(visited_cell[:2])
-                cell_coord_trans = np.matmul(cell_coord,rt.grc2xy)+rt.xyHalf
-                self._draw_square(cell_coord_trans,1 / 3, color)
+                cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
+                self._draw_square(cell_coord_trans, 1 / 3, color)
 
     def renderEnv(
-            self, show=False, curves=True, spacing=False,
-            arrows=False, agents=True, sRailColor="gray",
-            frames=False, iEpisode=None, iStep=None,
-            iSelectedAgent=None,
-            action_dict=None):
+        self, show=False, curves=True, spacing=False,
+        arrows=False, agents=True, obsrender=True, sRailColor="gray",
+        frames=False, iEpisode=None, iStep=None,
+        iSelectedAgent=None,
+        action_dict=None):
         """
         Draw the environment using matplotlib.
         Draw into the figure if provided.
@@ -505,15 +507,16 @@ class RenderTool(object):
 
         if not self.gl.is_raster():
             self.renderEnv2(show, curves, spacing,
-            arrows, agents, sRailColor,
-            frames, iEpisode, iStep,
-            iSelectedAgent, action_dict)
+                            arrows, agents, sRailColor,
+                            frames, iEpisode, iStep,
+                            iSelectedAgent, action_dict)
             return
 
         # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
         # so for now I've changed it to 1 (from 10)
         cell_size = 1
         self.gl.beginFrame()
+
         # self.gl.clf()
         # if oFigure is None:
         #    oFigure = self.gl.figure()
@@ -545,9 +548,9 @@ class RenderTool(object):
             for c in range(env.width):
 
                 # bounding box of the grid cell
-                x0 = cell_size * c       # left
-                x1 = cell_size * (c + 1)   # right
-                y0 = cell_size * -r      # top
+                x0 = cell_size * c  # left
+                x1 = cell_size * (c + 1)  # right
+                y0 = cell_size * -r  # top
                 y1 = cell_size * -(r + 1)  # bottom
 
                 # centres of cell edges
@@ -555,7 +558,7 @@ class RenderTool(object):
                     ((x0 + x1) / 2.0, y0),  # N middle top
                     (x1, (y0 + y1) / 2.0),  # E middle right
                     ((x0 + x1) / 2.0, y1),  # S middle bottom
-                    (x0, (y0 + y1) / 2.0)   # W middle left
+                    (x0, (y0 + y1) / 2.0)  # W middle left
                 ]
 
                 # cell centre
@@ -628,8 +631,8 @@ class RenderTool(object):
         # Draw each agent + its orientation + its target
         if agents:
             self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
-
-        self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
+        if obsrender:
+            self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
         # Draw some textual information like fps
         yText = [-0.3, -0.6, -0.9]
         if frames:
@@ -683,18 +686,18 @@ class RenderTool(object):
         gP0 = array([[0, 0, 0]]).T
         nDepth = 2
         for i in range(nDepth):
-            nDepthNodes = nBranchFactor**i
+            nDepthNodes = nBranchFactor ** i
             # rScale = nBranchFactor ** (nDepth - i)
-            rShrinkDepth = 1/(i+1)
+            rShrinkDepth = 1 / (i + 1)
             # gX1 = np.linspace(-nDepthNodes / 2, nDepthNodes / 2, nDepthNodes) * rShrinkDepth
-            
-            gX1 = np.linspace(-(nDepthNodes-1), (nDepthNodes-1), nDepthNodes) * rShrinkDepth
+
+            gX1 = np.linspace(-(nDepthNodes - 1), (nDepthNodes - 1), nDepthNodes) * rShrinkDepth
             gY1 = np.ones((nDepthNodes)) * i
             gZ1 = np.zeros((nDepthNodes))
-            
+
             gP1 = array([gX1, gY1, gZ1])
             gP01 = np.append(gP0, gP1, axis=1)
-            
+
             if nDepthNodes > 1:
                 nDepthNodesPrev = nDepthNodes / nBranchFactor
                 giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor)
@@ -705,13 +708,13 @@ class RenderTool(object):
                 self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray")
 
             gP0 = array([gX1, gY1, gZ1])
-    
+
     def renderEnv2(
-            self, show=False, curves=True, spacing=False,
-            arrows=False, agents=True, sRailColor="gray",
-            frames=False, iEpisode=None, iStep=None,
-            iSelectedAgent=None,
-            action_dict=dict()):
+        self, show=False, curves=True, spacing=False,
+        arrows=False, agents=True, sRailColor="gray",
+        frames=False, iEpisode=None, iStep=None,
+        iSelectedAgent=None,
+        action_dict=dict()):
         """
         Draw the environment using matplotlib.
         Draw into the figure if provided.
@@ -728,12 +731,11 @@ class RenderTool(object):
             # Draw each cell independently
             for r in range(env.height):
                 for c in range(env.width):
-
                     binTrans = env.rail.grid[r, c]
                     self.gl.setRailAt(r, c, binTrans)
 
         cmap = self.gl.get_cmap('hsv',
-            lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
+                                lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
 
         for iAgent, agent in enumerate(self.env.agents):
             if agent is None:
@@ -747,14 +749,14 @@ class RenderTool(object):
             if iAgent in action_dict:
                 iAction = action_dict[iAgent]
                 new_direction, action_isValid = self.env.check_action(agent, iAction)
-            
+
             if action_isValid:
                 self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction, color=oColor)
             else:
                 pass
                 # print("invalid action - agent ", iAgent, " bend ", agent.direction, new_direction)
                 # self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction)
-                
+
         self.gl.show()
         for i in range(3):
             self.gl.processEvents()