diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 7b35cb212f691287e0fff411d0a019fc8acd4059..60dc1adbf473382aad4ed0dde3ab9a13e64e7785 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -8,16 +8,7 @@ import torch,random
 random.seed(1)
 np.random.seed(1)
 
-"""
-transition_probability = [1.0,  # empty cell - Case 0
-                          3.0,  # Case 1 - straight
-                          1.0,  # Case 2 - simple switch
-                          3.0,  # Case 3 - diamond drossing
-                          2.0,  # Case 4 - single slip
-                          1.0,  # Case 5 - double slip
-                          1.0,  # Case 6 - symmetrical
-                          1.0]  # Case 7 - dead end
-"""
+
 # Example generate a rail given a manual specification,
 # a map of tuples (cell_type, rotation)
 transition_probability = [1.0,  # empty cell - Case 0
@@ -63,6 +54,7 @@ for trials in range(1, n_trials + 1):
 
     # Run episode
     for step in range(100):
+        #env_renderer.renderEnv(show=True)
 
         # Action
         for a in range(env.number_of_agents):
@@ -73,7 +65,6 @@ for trials in range(1, n_trials + 1):
         next_obs, all_rewards, done, _ = env.step(action_dict)
 
 
-
         # Update replay buffer and train agent
         for a in range(env.number_of_agents):
             agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
@@ -81,7 +72,7 @@ for trials in range(1, n_trials + 1):
 
         obs = next_obs.copy()
 
-        if all(done):
+        if done['__all__']:
             env_done = 1
             break
     # Epsioln decay
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 5365800b6cc3ae81ab1c5e7936c96b703c933d79..24f365820b06698931b46afb6266deca03d6834b 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -403,10 +403,9 @@ class RenderTool(object):
         # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
         # so for now I've changed it to 1 (from 10)
         cell_size = 1
-
+        plt.clf()
         # if oFigure is None:
         #    oFigure = plt.figure()
-
         def drawTrans(oFrom, oTo, sColor="gray"):
             plt.plot(
                 [oFrom[0], oTo[0]],  # x
@@ -551,7 +550,11 @@ class RenderTool(object):
         plt.xlim([0, env.width * cell_size])
         plt.ylim([-env.height * cell_size, 0])
         if show:
-            plt.show()
+            plt.show(block=False)
+            plt.pause(0.00001)
+            return
+
+
 
     def _draw_square(self, center, size, color):
         x0 = center[0]-size/2