diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 00fde0246ed937dcf97f44dd63067060b7c54350..5ec756f98c2229110e2bfc4cc146f83522eee3bc 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -145,6 +145,13 @@ class RenderTool(object):
         plt.scatter(*xyNext, color=sColor)
 
     def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
+        """
+        plot the transitions in gTransRCAg at position rcPos.
+        gTransRCAg is a 2d numpy array containing a list of RC transitions,
+        eg [[-1,0], [0,1]] means N, E.
+
+        """
+
         rt = self.__class__
         xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
         gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy/2.4)
@@ -299,7 +306,7 @@ class RenderTool(object):
                 visit = visit.prev
                 xyPrev = xy
 
-    def renderEnv(self, show=False):
+    def renderEnv(self, show=False, curves=True):
         """
         Draw the environment using matplotlib.
         Draw into the figure if provided.
@@ -307,11 +314,17 @@ class RenderTool(object):
         Call pyplot.show() if show==True.
         (Use show=False from a Jupyter notebook with %matplotlib inline)
         """
+
+        # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
+        # so for now I've changed it to 1 (from 10)
         cell_size = 1
 
         # if oFigure is None:
         #    oFigure = plt.figure()
 
+        gTheta = np.linspace(0, np.pi/2, 10)
+        gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]
+
         def drawTrans(oFrom, oTo, sColor="gray"):
             plt.plot(
                 [oFrom[0], oTo[0]],  # x
@@ -319,6 +332,26 @@ class RenderTool(object):
                 color=sColor
             )
 
+        def drawTrans2(xyLine, xyCentre, rotation, sColor="gray"):
+            """
+            gLine is a numpy 2d array of points,
+            in the plotting space / coords.
+            eg:
+            [[0,.5],[1,0.2]] means a line
+            from x=0, y=0.5
+            to   x=1, y=0.2
+            """
+            xyMid = np.mean(xyLine, axis=0)
+            dxy = xyMid - xyCentre
+            xyCorner = xyMid + dxy
+            dxy2 = xyCentre - xyCorner
+
+            bStraight = rotation in [0, 2]
+            if bStraight:
+                plt.plot(*xyLine.T, color=sColor)
+            else:
+                plt.plot(*(gArc * dxy2 + xyCorner).T, color=sColor)
+
         RETrans = RailEnvTransitions()
         env = self.env
 
@@ -338,10 +371,10 @@ class RenderTool(object):
             for c in range(env.width):
                 trans_ = env.rail[r][c]
 
-                x0 = cell_size * c
-                x1 = cell_size * (c+1)
-                y0 = cell_size * -r
-                y1 = cell_size * -(r+1)
+                x0 = cell_size * c       # left
+                x1 = cell_size * (c+1)   # right
+                y0 = cell_size * -r      # top
+                y1 = cell_size * -(r+1)  # bottom
 
                 coords = [
                     ((x0+x1)/2.0, y0),  # N middle top
@@ -350,6 +383,8 @@ class RenderTool(object):
                     (x0, (y0+y1)/2.0)   # W middle left
                 ]
 
+                xyCentre = array([x0, y1]) + cell_size / 2
+
                 oCell = env.rail[r, c]
 
                 for orientation in range(4):  # ori is where we're heading
@@ -378,15 +413,23 @@ class RenderTool(object):
                         # to_ori = (orientation + 2) % 4
                         for to_ori in range(4):
                             to_xy = coords[to_ori]
+                            rotation = (to_ori - from_ori) % 4
 
-                            if False:
-                                print("r,c,ori: ", r, c, orientation,
-                                      "cell:", "{0:b}".format(oCell),
-                                      "moves:", tMoves,
-                                      "from:", from_ori, from_xy,
-                                      "to: ", to_ori, to_xy)
                             if (tMoves[to_ori]):
-                                drawTrans(from_xy, to_xy)
+                                if curves:
+                                    drawTrans2(array([from_xy, to_xy]), xyCentre, rotation)
+                                else:
+                                    drawTrans(from_xy, to_xy)
+                                if False:
+                                    print(
+                                        "r,c,ori: ", r, c, orientation,
+                                        "cell:", "{0:b}".format(oCell),
+                                        "moves:", tMoves,
+                                        "from:", from_ori, from_xy,
+                                        "to: ", to_ori, to_xy,
+                                        "cen:", *xyCentre,
+                                        "rot:", rotation,
+                                    )
 
         # Draw each agent + its orientation + its target
         cmap = plt.get_cmap('hsv', lut=env.number_of_agents+1)
diff --git a/images/basic-env.png b/images/basic-env.png
index 75a145fef094ea38e8148ed1075ae18fffb8f858..0b21b26887236b4e7d7e82b34bfa074ec9d05c38 100644
Binary files a/images/basic-env.png and b/images/basic-env.png differ
diff --git a/images/env-path.png b/images/env-path.png
index d748e12c4c9f548553db70a61d2147db9d6c6da6..5f49e744754237889dd7331213d3084cb19b1555 100644
Binary files a/images/env-path.png and b/images/env-path.png differ
diff --git a/images/env-tree-spatial.png b/images/env-tree-spatial.png
index b04b617f423b726fc0708bf10a78ed09e44c1a02..06f2054027c5da517c8eff0b2c142dd183a7e3fb 100644
Binary files a/images/env-tree-spatial.png and b/images/env-tree-spatial.png differ