diff --git a/examples/demo.py b/examples/demo.py
index 86e986f6fd0587b3d5781175c6c59565db3d4c71..6e9f05389ce116c443090e7ee2e11867d0888fe3 100644
--- a/examples/demo.py
+++ b/examples/demo.py
@@ -137,6 +137,8 @@ class Demo:
         # Reset environment
         _ = self.env.reset(False, False)
 
+        time.sleep(0.0001)  # to satisfy lint...
+
         for step in range(max_nbr_of_steps):
 
             # time.sleep(.1)
@@ -171,6 +173,7 @@ class Demo:
 
             if done['__all__']:
                 break
+        self.renderer.close_window()
 
 
 if True:
diff --git a/examples/tkplay.py b/examples/tkplay.py
index a46dcbc6f3aca2b84a9b35f33c339c3f03291c32..c17ea519014ff304f654036f111ac953f328a118 100644
--- a/examples/tkplay.py
+++ b/examples/tkplay.py
@@ -26,7 +26,7 @@ def tkmain(n_trials=2, n_steps=50, sGL="PIL"):
             env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step,
                                    action_dict=oPlayer.action_dict)
 
-    env_renderer.gl.close_window()            
+    env_renderer.close_window()            
 
 
 if __name__ == "__main__":
diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py
index a1de818fe57ed2f91a3b025f652234c578dc11e8..1b2ff7ea0c01240b4293e7386e8f0c358c4e2c96 100644
--- a/flatland/utils/graphics_layer.py
+++ b/flatland/utils/graphics_layer.py
@@ -66,7 +66,12 @@ class GraphicsLayer(object):
     def get_cmap(self, *args, **kwargs):
         return plt.get_cmap(*args, **kwargs)
 
-    def setRailAt(self, row, col, binTrans):
+    def setRailAt(self, row, col, binTrans, target=None):
+        """ Set the rail at cell (row, col) to have transitions binTrans.
+            The target argument can contain the index of the agent to indicate
+            that agent's target is at that cell, so that a station can be
+            rendered in the static rail layer.
+        """
         pass
 
     def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut):
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 63a4be28fa8f5fc1bc1f19e29d14ae74e766f587..65235bf410f41770b3fbcb93265cc7b02ed92fa2 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -33,12 +33,22 @@ class PILGL(GraphicsLayer):
         self.tColRail = (0, 0, 0)         # black rails
         self.tColGrid = (230,) * 3        # light grey for grid
 
+        sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
+            "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
+            
+        self.ltAgentColors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")]
+        self.nAgentColors = len(self.ltAgentColors)
+
         self.window_open = False
         # self.bShow = show
         self.firstFrame = True
         self.create_layers()
         # self.beginFrame()
 
+    def rgb_s2i(self, sRGB):
+        """ convert a hex RGB string like 0091ea to 3-tuple of ints """
+        return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2])
+
     def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
         color = self.adaptColor(color)
         if len(color) == 3:
@@ -246,10 +256,17 @@ class PILSVG(PILGL):
             "EE WW": "Bahnhof_#d50000_Gleis_horizontal.svg",
             "NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"}
 
+        # Dict of rail cell images indexed by binary transitions
         self.dPilRail = self.loadSVGs(dRailFiles, rotate=True)
-        self.dPilTarget = self.loadSVGs(dTargetFiles, rotate=False)
 
-    def loadSVGs(self, dDirFile, rotate=False):
+        # Load the target files (which have rails and transitions of their own)
+        # They are indexed by (binTrans, iAgent), ie a tuple of the binary transition and the agent index
+        dPilRail2 = self.loadSVGs(dTargetFiles, rotate=False, agent_colors=self.ltAgentColors)
+        # Merge them with the regular rails.
+        # https://stackoverflow.com/questions/38987/how-to-merge-two-dictionaries-in-a-single-expression
+        self.dPilRail = {**self.dPilRail, **dPilRail2}
+        
+    def loadSVGs(self, dDirFile, rotate=False, agent_colors=False):
         dPil = {}
 
         transitions = RailEnvTransitions()
@@ -280,9 +297,10 @@ class PILSVG(PILGL):
             #    svg = svg.merge(svgBG)
 
             pilRail = self.pilFromSvgFile(sPathSvg)
-            dPil[binTrans] = pilRail
-
+            
             if rotate:
+                # For rotations, we also store the base image
+                dPil[binTrans] = pilRail
                 # Rotate both the transition binary and the image and save in the dict
                 for nRot in [90, 180, 270]:
                     binTrans2 = transitions.rotate_transition(binTrans, nRot)
@@ -290,25 +308,44 @@ class PILSVG(PILGL):
                     # PIL rotates anticlockwise for positive theta
                     pilRail2 = pilRail.rotate(-nRot)
                     dPil[binTrans2] = pilRail2
+            
+            if agent_colors:
+                # For recoloring, we don't store the base image.
+                a3BaseColor = self.rgb_s2i("d50000")
+                lPils = self.recolorImage(pilRail, a3BaseColor, self.ltAgentColors)
+                for iColor, pilRail2 in enumerate(lPils):
+                    dPil[(binTrans, iColor)] = lPils[iColor]
+
         return dPil
 
-    def setRailAt(self, row, col, binTrans, target=None):
-        if target is None:
+    def setRailAt(self, row, col, binTrans, iTarget=None):
+        if iTarget is None:
             if binTrans in self.dPilRail:
                 pilTrack = self.dPilRail[binTrans]
                 self.drawImageRC(pilTrack, (row, col))
             else:
                 print("Illegal rail:", row, col, format(binTrans, "#018b")[2:])
         else:
-            if binTrans in self.dPilTarget:
-                pilTrack = self.dPilTarget[binTrans]
+            if (binTrans, iTarget) in self.dPilRail:
+                pilTrack = self.dPilRail[(binTrans, iTarget)]
                 self.drawImageRC(pilTrack, (row, col))
             else:
                 print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:])
 
-    def rgb_s2i(self, sRGB):
-        """ convert a hex RGB string like 0091ea to 3-tuple of ints """
-        return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2])
+    def recolorImage(self, pil, a3BaseColor, ltColors):
+        rgbaImg = array(pil)
+        lPils = []
+
+        for iColor, tnColor in enumerate(ltColors):
+            # find the pixels which match the base paint color
+            xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor == 0, axis=2)
+            rgbaImg2 = np.copy(rgbaImg)
+
+            # Repaint the base color with the new color
+            rgbaImg2[xy_color_mask, 0:3] = tnColor
+            pil2 = Image.fromarray(rgbaImg2)
+            lPils.append(pil2)
+        return lPils
 
     def loadAgentSVGs(self):
 
@@ -319,13 +356,8 @@ class PILSVG(PILGL):
             (0, 3): "svg/Zug_2_Weiche_#0091ea.svg"
             }
 
-        sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
-            "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
-        lColors = sColors.split("#")
-        self.nAgentColors = len(lColors)
-
         # "paint" color of the train images we load
-        a_base_color = self.rgb_s2i("0091ea")
+        a3BaseColor = self.rgb_s2i("0091ea")
 
         self.dPilZug = {}
 
@@ -342,16 +374,13 @@ class PILSVG(PILGL):
 
                 # PIL rotates anticlockwise for positive theta
                 pilZug2 = pilZug.rotate(-nDegRot)
-                rgbaZug2 = array(pilZug2)
 
-                for iColor, sColor in enumerate(lColors):
-                    tnNewColor = self.rgb_s2i(sColor)
-                    xy_color_mask = np.all(rgbaZug2[:, :, 0:3] - a_base_color == 0, axis=2)
-                    rgbaZug3 = np.copy(rgbaZug2)
-                    rgbaZug3[xy_color_mask, 0:3] = tnNewColor
-                    self.dPilZug[(iDirIn2, iDirOut2, iColor)] = Image.fromarray(rgbaZug3)
+                # Save colored versions of each rotation / variant
+                lPils = self.recolorImage(pilZug2, a3BaseColor, self.ltAgentColors)
+                for iColor, pilZug3 in enumerate(lPils):
+                    self.dPilZug[(iDirIn2, iDirOut2, iColor)] = lPils[iColor]
 
-    def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, color=None):
+    def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut):
         delta_dir = (iDirOut - iDirIn) % 4
         iColor = iAgent % self.nAgentColors
         # when flipping direction at a dead end, use the "iDirOut" direction.
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index 233f07bca474204ea31aec5d75b530911071bfad..ff8ebd12a966c76342dec8298fab2755b3734223 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -157,7 +157,7 @@ class QTSVG(GraphicsLayer):
         self.lwAgents = []
         self.agents_prev = []
 
-    def setRailAt(self, row, col, binTrans):
+    def setRailAt(self, row, col, binTrans, target=None):
         if binTrans in self.track.dSvg:
             sSVG = self.track.dSvg[binTrans].to_string()
             svgWidget = create_QtSvgWidget_from_svg_string(sSVG)
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index f2200b97b03577e4dcdb2f34575b0330ff0693fe..37d81519a104856ac3081a2f6d5ae9ebd03fd44a 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -767,7 +767,7 @@ class RenderTool(object):
                         target = dTargets[(r, c)]
                     else:
                         target = None
-                    self.gl.setRailAt(r, c, binTrans)
+                    self.gl.setRailAt(r, c, binTrans, iTarget=target)
 
         for iAgent, agent in enumerate(self.env.agents):
             if agent is None:
@@ -782,8 +782,9 @@ class RenderTool(object):
                 direction = agent.direction
                 old_direction = agent.direction
 
-            cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
-            self.gl.setAgentAt(iAgent, *position, old_direction, direction,color=cmap(iAgent))
+            # setAgentAt uses the agent index for the color
+            # cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
+            self.gl.setAgentAt(iAgent, *position, old_direction, direction)  # ,color=cmap(iAgent))
 
         if show:
             self.gl.show()
@@ -792,3 +793,6 @@ class RenderTool(object):
 
         self.iFrame += 1
         return
+
+    def close_window(self):
+        self.gl.close_window()
diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py
index 0a6b895dcdaa5e9cd4d99cfa5861ecc187dad905..e7b2cebb606089f5937d37c8e4d6381ada9fc211 100644
--- a/flatland/utils/svg.py
+++ b/flatland/utils/svg.py
@@ -60,7 +60,7 @@ class SVG(object):
         sNewStyles = "\n"
         for sKey, sValue in self.dStyles.items():
             if sKey == style_name:
-                sValue = "fill:#" + "".join([ ('{:#04x}'.format(int(255.0*col))[2:4]) for col in color[0:3]]) + ";"
+                sValue = "fill:#" + "".join([('{:#04x}'.format(int(255.0*col))[2:4]) for col in color[0:3]]) + ";"
             sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n"
             sNewStyles += sNewStyle