diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 97e3e5622ff633a35159f1d1fe1b2249e1803463..b72216a0a85b54714b0a10503e077b06c6a019ac 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -37,6 +37,19 @@ class PILGL(GraphicsLayer):
     # https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist
     window = tk.Tk()
 
+    RAIL_LAYER = 0
+    AGENT_LAYER = 1
+    PREDICTION_PATH_LAYER = 2
+    SELECTED_AGENT_LAYER = 3
+    SELECTED_TARGET_LAYER = 4
+
+    def create_layers(self, clear=True):
+        self.create_layer(PILGL.RAIL_LAYER, clear=clear)  # rail / background (scene)
+        self.create_layer(PILGL.AGENT_LAYER, clear=clear)  # agents
+        self.create_layer(PILGL.PREDICTION_PATH_LAYER, clear=clear)  # drawing layer for agent's prediction path
+        self.create_layer(PILGL.SELECTED_AGENT_LAYER, clear=clear)  # drawing layer for selected agent
+        self.create_layer(PILGL.SELECTED_TARGET_LAYER, clear=clear)  # drawing layer for selected agent's target
+
     def __init__(self, width, height, jupyter=False):
         self.yxBase = (0, 0)
         self.linewidth = 4
@@ -129,7 +142,7 @@ class PILGL(GraphicsLayer):
     def get_agent_color(self, iAgent):
         return self.agent_colors[iAgent % self.n_agent_colors]
 
-    def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
+    def plot(self, gX, gY, color=None, linewidth=3, layer=RAIL_LAYER, opacity=255, **kwargs):
         color = self.adapt_color(color)
         if len(color) == 3:
             color += (opacity,)
@@ -139,14 +152,14 @@ class PILGL(GraphicsLayer):
         gPoints = list(gPoints.ravel())
         self.draws[layer].line(gPoints, fill=color, width=self.linewidth)
 
-    def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs):
+    def scatter(self, gX, gY, color=None, marker="o", s=50, layer=RAIL_LAYER, opacity=255, *args, **kwargs):
         color = self.adapt_color(color)
         r = np.sqrt(s)
         gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
         for x, y in gPoints:
             self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
 
-    def draw_image_xy(self, pil_img, xyPixLeftTop, layer=0):
+    def draw_image_xy(self, pil_img, xyPixLeftTop, layer=RAIL_LAYER, ):
         if (pil_img.mode == "RGBA"):
             pil_mask = pil_img
         else:
@@ -154,7 +167,7 @@ class PILGL(GraphicsLayer):
 
         self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask)
 
-    def draw_image_row_col(self, pil_img, rcTopLeft, layer=0):
+    def draw_image_row_col(self, pil_img, rcTopLeft, layer=RAIL_LAYER, ):
         xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
         self.draw_image_xy(pil_img, xyPixLeftTop, layer=layer)
 
@@ -180,7 +193,7 @@ class PILGL(GraphicsLayer):
 
     def begin_frame(self):
         # Create a new agent layer
-        self.create_layer(iLayer=1, clear=True)
+        self.create_layer(iLayer=PILGL.AGENT_LAYER, clear=True)
 
     def show(self, block=False):
         img = self.alpha_composite_layers()
@@ -254,11 +267,6 @@ class PILGL(GraphicsLayer):
             if clear:
                 self.clear_layer(iLayer)
 
-    def create_layers(self, clear=True):
-        self.create_layer(0, clear=clear)  # rail / background (scene)
-        self.create_layer(1, clear=clear)  # agents
-        self.create_layer(2, clear=clear)  # drawing layer for selected agent
-        self.create_layer(3, clear=clear)  # drawing layer for selected agent's target
 
 
 class PILSVG(PILGL):
@@ -485,23 +493,21 @@ class PILSVG(PILGL):
 
         return pil
 
-    def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None,
-                    agent_rail_color=None, blend_factor=0.5):
+
+    def clear_set_predicion_path_layer(self):
+        self.clear_layer(PILGL.PREDICTION_PATH_LAYER,0)
+
+    def set_predicion_path_at(self, row, col, binary_trans, agent_rail_color):
+        colored_rail = self.recolor_image(self.pil_rail_org[binary_trans],
+                                          [61, 61, 61], [agent_rail_color],
+                                          False)[0]
+        # pil_track = Image.blend(pil_track,colored_rail,blend_factor)
+        # pil_track = colored_rail#Image.alpha_composite(pil_track, colored_rail)
+        self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER)
+
+    def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None):
         if binary_trans in self.pil_rail:
             pil_track = self.pil_rail[binary_trans]
-            if agent_rail_color is not None:
-                colored_rail = self.recolor_image(self.pil_rail_org[binary_trans],
-                                                  [61, 61, 61], [agent_rail_color],
-                                                  False)[0]
-                rcTopLeft1 = (row, col)
-                rcTopLeft2 = (row + 1, col + 1)
-                rcTopLeft1 = tuple((array(rcTopLeft1) * self.nPixCell)[[1, 0]])
-                rcTopLeft2 = tuple((array(rcTopLeft2) * self.nPixCell)[[1, 0]])
-                pil_track = Image.blend(
-                    self.layers[0].crop((rcTopLeft1[0], rcTopLeft1[1], rcTopLeft2[0], rcTopLeft2[1])),
-                    colored_rail,
-                    blend_factor)
-
             if target is not None:
                 pil_track = Image.alpha_composite(pil_track, self.station_colors[target % len(self.station_colors)])
 
@@ -522,15 +528,15 @@ class PILSVG(PILGL):
                         a = a2
                     pil_track = self.scenery[a % len(self.scenery)]
 
-            self.draw_image_row_col(pil_track, (row, col))
+            self.draw_image_row_col(pil_track, (row, col), layer=PILGL.RAIL_LAYER)
         else:
             print("Illegal rail:", row, col, format(binary_trans, "#018b")[2:], binary_trans)
 
         if target is not None:
             if is_selected:
                 svgBG = self.pil_from_svg_file("svg", "Selected_Target.svg")
-                self.clear_layer(3, 0)
-                self.draw_image_row_col(svgBG, (row, col), layer=3)
+                self.clear_layer(PILGL.SELECTED_TARGET_LAYER, 0)
+                self.draw_image_row_col(svgBG, (row, col), layer=PILGL.SELECTED_TARGET_LAYER)
 
     def recolor_image(self, pil, a3BaseColor, ltColors, invert=False):
         rgbaImg = array(pil)
@@ -591,12 +597,12 @@ class PILSVG(PILGL):
         if delta_dir == 2:
             in_direction = out_direction
         pil_zug = self.pil_zug[(in_direction % 4, out_direction % 4, color_idx)]
-        self.draw_image_row_col(pil_zug, (row, col), layer=1)
+        self.draw_image_row_col(pil_zug, (row, col), layer=PILGL.AGENT_LAYER)
 
         if is_selected:
             bg_svg = self.pil_from_svg_file("svg", "Selected_Agent.svg")
-            self.clear_layer(2, 0)
-            self.draw_image_row_col(bg_svg, (row, col), layer=2)
+            self.clear_layer(PILGL.SELECTED_AGENT_LAYER, 0)
+            self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER)
 
     def set_cell_occupied(self, agent_idx, row, col):
         occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)]
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 9f1b59e10aa36987e06e89d5b1e8145e55bfead1..14c49add9fe03ba90f2691b9f00de32cb0863330 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -292,19 +292,21 @@ class RenderTool(object):
 
         """
         rt = self.__class__
+        if type(self.gl) is PILSVG:
+            self.gl.clear_set_predicion_path_layer()
+
         for agent in agent_handles:
             color = self.gl.get_agent_color(agent)
             for visited_cell in prediction_dict[agent]:
                 cell_coord = array(visited_cell[:2])
-                cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
                 if type(self.gl) is PILSVG:
                     # TODO : Track highlighting (Adrian)
                     r = cell_coord[0]
                     c = cell_coord[1]
                     transitions = self.env.rail.grid[r, c]
-                    self.gl.set_rail_at(r, c, transitions, target=None, is_selected=False, rail_grid=self.env.rail.grid,
-                                        agent_rail_color=color)
+                    self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color)
                 else:
+                    cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
                     self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
 
     def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False):
diff --git a/notebooks/Simple_Rendering_Demo.ipynb b/notebooks/Simple_Rendering_Demo.ipynb
index ad76f054cd3d83a2cf1c239a6656c8d95f9a2720..2084ee46cbb7de2839ec71bce98c7a5bc46c2469 100644
--- a/notebooks/Simple_Rendering_Demo.ipynb
+++ b/notebooks/Simple_Rendering_Demo.ipynb
@@ -9,18 +9,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "The autoreload extension is already loaded. To reload it, use:\n",
-      "  %reload_ext autoreload\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "%load_ext autoreload\n",
     "%autoreload 2"
@@ -28,7 +19,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -40,7 +31,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -50,34 +41,22 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
     "import flatland.core.env\n",
     "import flatland.utils.rendertools as rt\n",
     "from flatland.envs.rail_env import RailEnv, random_rail_generator\n",
-    "from flatland.envs.observations import TreeObsForRailEnv"
+    "from flatland.envs.observations import TreeObsForRailEnv\n",
+    "from flatland.envs.predictions import ShortestPathPredictorForRailEnv"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/html": [
-       "<style>.container { width:90% !important; }</style>"
-      ],
-      "text/plain": [
-       "<IPython.core.display.HTML object>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
     "from IPython.core.display import display, HTML\n",
     "display(HTML(\"<style>.container { width:90% !important; }</style>\"))"
@@ -92,7 +71,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -102,7 +81,7 @@
     "                  height=10,\n",
     "                  rail_generator=fnMethod,\n",
     "                  number_of_agents=nAgents,\n",
-    "                  obs_builder_object=TreeObsForRailEnv(max_depth=2))"
+    "                  obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))"
    ]
   },
   {
@@ -114,71 +93,40 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 31,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
     "oRT = rt.RenderTool(env,gl=\"PILSVG\")\n",
-    "env.dev_pred_dict = env.dev_obs_dict\n",
     "oRT.render_env(show_observations=False,show_predictions=True)\n",
     "img = oRT.get_image()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 32,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "6a9401eacf31417d97119674cf249bc7",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Canvas()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
     "jpy_canvas.Canvas(img)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 29,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
     "oRT = rt.RenderTool(env,gl=\"PIL\")\n",
-    "oRT.render_env()\n",
+    "oRT.render_env(show_observations=False,show_predictions=True)\n",
     "img = oRT.get_image()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 30,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "cedca754ea334df9b4678c1a29a4788e",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Canvas()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
     "jpy_canvas.Canvas(img)"
    ]