diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 535f44b4d422ef54a012c838c41d1ff432527e7f..51086c0cd065386f0bc35aea573d391800457b63 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -30,7 +30,8 @@ class RenderTool(object):
     def __init__(self, env):
         self.env = env
 
-    def plotTreeOnRail(self, rcPos, iDir, nDepth=10):
+    #def plotTreeOnRail(self, rcPos, iDir, nDepth=10, color="r"):
+    def plotTreeOnRail(self, lVisits, color="r"):
         """
         Derives and plots a tree of transitions starting at position rcPos
         in direction iDir.
@@ -50,7 +51,7 @@ class RenderTool(object):
                 self.plotAgent(rcPos, iDir, sColor)
 
                 gTransRCAg = self.getTransRC(rcPos, iDir)
-                self.plotTrans(rcPos, gTransRCAg)
+                self.plotTrans(rcPos, gTransRCAg, color=color)
 
                 if False:
                     # TODO: this was `rcDir' but it was undefined
@@ -62,13 +63,16 @@ class RenderTool(object):
                     giTrans = np.where(tbTrans)[0]  # RC list of transitions
                     gTransRCAg = self.__class__.gTransRC[giTrans]
 
-        # rcPos=(6,4)
-        # iDir=2
-        gTransRCAg = self.getTransRC(rcPos, iDir)
-        self.plotTrans(rcPos, gTransRCAg)
+        for visit in lVisits:
+            # transition for next cell
+            oTrans = self.env.rail[visit.rc]
+            tbTrans = rt.RETrans.get_transitions(oTrans, visit.iDir)
+            giTrans = np.where(tbTrans)[0]  # RC list of transitions
+            gTransRCAg = rt.gTransRC[giTrans]
+            self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
 
-        lVisits = self.getTreeFromRail(rcPos, iDir, nDepth=nDepth)
-        return lVisits
+        #lVisits = self.getTreeFromRail(rcPos, iDir, nDepth=nDepth)
+        #return lVisits
 
     def plotAgents(self):
         rt = self.__class__
@@ -149,12 +153,13 @@ class RenderTool(object):
         xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
         gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy/2.4)
         # print(gxyTrans)
+        #print(gxyTrans, color)
         plt.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2)
         if depth is not None:
             for x, y in gxyTrans:
                 plt.text(x, y, depth)
 
-    def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True):
+    def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True, bPlot=False):
         """
         Generate a tree from the env starting at rcPos, iDir.
         """
@@ -194,7 +199,8 @@ class RenderTool(object):
                     stack.append(visitNext)
 
                 # plot the available transitions from this node
-                self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth))
+                if bPlot:
+                    self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth))
 
         return lVisits
 
diff --git a/images/env-tree-spatial.png b/images/env-tree-spatial.png
index 255feb209f35a315179dba8121ba81abe8d93c5c..b04b617f423b726fc0708bf10a78ed09e44c1a02 100644
Binary files a/images/env-tree-spatial.png and b/images/env-tree-spatial.png differ
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index 39430d670297ac04c7fe9b01bb021889d612a7f0..1bd99c69e100a7a01d6d1e502604fa8647da215a 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -61,10 +61,11 @@ def test_render_env():
 
     plt.figure(figsize=(10,10))
     oRT.renderEnv()
-    lVisits = oRT.plotTreeOnRail(
+    
+    lVisits = oRT.getTreeFromRail(
         oEnv.agents_position[0], 
         oEnv.agents_direction[0], 
-        nDepth=17)
+        nDepth=17, bPlot=True)
 
     checkFrozenImage("env-tree-spatial.png")