diff --git a/examples/play_model.py b/examples/play_model.py
index 6a67397ea4ba8d8906ce62b1a8d21327c247a3e0..68530e6f694a3a43aa2f76307f83300df9ac7f6e 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -10,6 +10,82 @@ import matplotlib.pyplot as plt
 import time
 
 
+
+class Player(object):
+    def __init__(self, env):
+        self.env = env
+        self.handle = env.get_agent_handles()
+
+        self.state_size = 105
+        self.action_size = 4
+        self.n_trials = 9999
+        self.eps = 1.
+        self.eps_end = 0.005
+        self.eps_decay = 0.998
+        self.action_dict = dict()
+        self.scores_window = deque(maxlen=100)
+        self.done_window = deque(maxlen=100)
+        self.scores = []
+        self.dones_list = []
+        self.action_prob = [0]*4
+        self.agent = Agent(self.state_size, self.action_size, "FC", 0)
+        self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
+
+        self.iFrame = 0
+        self.tStart = time.time()
+        
+        # Reset environment
+        self.obs = self.env.reset()
+        for a in range(self.env.number_of_agents):
+            norm = max(1, max_lt(self.obs[a], np.inf))
+            self.obs[a] = np.clip(np.array(self.obs[a]) / norm, -1, 1)
+
+        # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+
+        self.score = 0
+        self.env_done = 0
+
+    def step(self):
+        env = self.env
+        for a in range(env.number_of_agents):
+            action = self.agent.act(np.array(self.obs[a]), eps=self.eps)
+            self.action_prob[action] += 1
+            self.action_dict.update({a: action})
+
+        # Environment step
+        next_obs, all_rewards, done, _ = self.env.step(self.action_dict)
+
+        for a in range(env.number_of_agents):
+            norm = max(1, max_lt(next_obs[a], np.inf))
+            next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
+
+        # Update replay buffer and train agent
+        for a in range(self.env.number_of_agents):
+            self.agent.step(self.obs[a], self.action_dict[a], all_rewards[a], next_obs[a], done[a])
+            self.score += all_rewards[a]
+
+        self.iFrame += 1
+
+        self.obs = next_obs.copy()
+        if done['__all__']:
+            self.env_done = 1
+
+
+def max_lt(seq, val):
+    """
+    Return greatest item in seq for which item < val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+
+    idx = len(seq)-1
+    while idx >= 0:
+        if seq[idx] < val and seq[idx] >= 0:
+            return seq[idx]
+        idx -= 1
+    return None
+
+
+
 def main(render=True, delay=0.0):
 
     random.seed(1)
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index 0a0bbc4b535e9921127ec1d1b37bcbeb767e1798..935e40d02cbc8011de57e181af93018623b86466 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -15,6 +15,19 @@ from flatland.envs.rail_env import RailEnv, random_rail_generator
 # from flatland.core.transitions import RailEnvTransitions
 from flatland.core.env_observation_builder import TreeObsForRailEnv
 import flatland.utils.rendertools as rt
+from examples.play_model import Player
+
+
+class View(object):
+    def __init__(self, editor):
+        self.editor = editor
+        self.oRT = rt.RenderTool(self.editor.env)
+        plt.figure(figsize=(10,10))
+        self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False)
+        img = self.oRT.getImage()
+        plt.clf()
+        import jpy_canvas
+        self.wid_img = jpy_canvas.Canvas(img)
 
 
 class JupEditor(object):
@@ -39,6 +52,8 @@ class JupEditor(object):
         self.drawMode = "Draw"
         self.env_filename = "temp.npy"
         self.set_env(env)
+        self.iAgent = None
+        self.player = None
 
     def set_env(self, env):
         self.env = env
@@ -56,6 +71,28 @@ class JupEditor(object):
     def setDrawMode(self, dEvent):
         self.drawMode = dEvent["new"]
 
+    def on_click(self, wid, event):
+        x = event['canvasX']
+        y = event['canvasY']
+        rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int)
+
+        if self.drawMode == "Origin":
+            self.iAgent = len(self.env.agents_position)
+            self.env.agents_position.append(rcCell)
+            self.env.agents_handles.append(max(self.env.agents_handles + [-1]) + 1)
+            self.env.agents_direction.append(0)
+            self.env.agents_target.append(rcCell) # set the target to the origin initially
+            self.env.number_of_agents = self.iAgent + 1
+            self.drawMode = "Destination"
+
+        elif self.drawMode == "Destination" and self.iAgent is not None:
+            self.env.agents_target[self.iAgent] = rcCell
+            self.drawMode = "Origin"
+        
+        self.log("agent", self.drawMode, self.iAgent, rcCell)
+
+        self.redraw()
+
     def event_handler(self, wid, event):
         """Mouse motion event handler
         """
@@ -150,9 +187,6 @@ class JupEditor(object):
             # This updates the image in the browser to be the new edited version
             self.wid_img.data = writableData
     
-    def on_click(self, event):
-        pass
-
     def redraw(self, hide_stdout=True, update=True):
 
         if hide_stdout:
@@ -161,7 +195,8 @@ class JupEditor(object):
             stdout_dest = sys.stdout
 
         # TODO: bit of a hack - can we suppress the console messages from MPL at source?
-        with redirect_stdout(stdout_dest):
+        #with redirect_stdout(stdout_dest):
+        with self.wid_output:
             plt.figure(figsize=(10, 10))
             self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False)
             img = self.oRT.getImage()
@@ -178,6 +213,13 @@ class JupEditor(object):
     
     def clear(self, event):
         self.env.rail.grid[:, :] = 0
+        self.env.number_of_agents = 0
+        self.env.agents_position = []
+        self.env.agents_direction = []
+        self.env.agents_handles = []
+        self.env.agents_target = []
+        self.player = None
+
         self.redraw_event(event)
 
     def setFilename(self, filename):
@@ -201,15 +243,22 @@ class JupEditor(object):
         self.env = RailEnv(width=self.regen_size,
               height=self.regen_size,
               rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
-              number_of_agents=0,
+              number_of_agents=self.env.number_of_agents,
               obs_builder_object=TreeObsForRailEnv(max_depth=2))
         self.env.reset()
         self.set_env(self.env)
+        self.player = Player(self.env)
         self.redraw()
         
     def setRegenSize_event(self, event):
         self.regen_size = event["new"]
-    
+
+    def step_event(self, event=None):
+        if self.player is None:
+            self.player = Player(self.env)
+        self.player.step()
+        self.redraw()
+
     def fix_env(self):
         self.env.width = self.env.rail.width
         self.env.height = self.env.rail.height
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 1444645e8cc03ff587babc6ad66d08f8e9bd9504..6c5a17555beb95ce80d023d0c10376650f06ec3b 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -21,7 +21,6 @@ class MPLGL(GraphicsLayer):
         plt.plot(*args, **kwargs)
 
     def scatter(self, *args, **kwargs):
-        print(args, kwargs)
         plt.scatter(*args, **kwargs)
 
     def text(self, *args, **kwargs):
@@ -209,7 +208,7 @@ class RenderTool(object):
         xyDir = np.matmul(rcDir, rt.grc2xy)          # agent direction in xy
 
         xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
-        self.gl.scatter(*xyPos, color=color, s=40)            # agent location
+        self.gl.scatter(*xyPos, color=color, marker="o", s=100)            # agent location
 
         xyDirLine = array([xyPos, xyPos + xyDir/2]).T  # line for agent orient.
         self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
@@ -219,12 +218,6 @@ class RenderTool(object):
             xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
             self._draw_square(xyTarget, 1/3, color)
 
-        if False:
-            # mark the next cell we're heading into
-            rcNext = rcPos + rcDir
-            xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf
-            self.gl.scatter(*xyNext, color=color)
-
     def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
         """
         plot the transitions in gTransRCAg at position rcPos.