From 58b93bca389f1346482b3b9c5a39fd53a5454189 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 12 Jul 2019 16:00:42 -0400
Subject: [PATCH] huge refactoring of render code also added new prediction
 rendering editor still broken with this update

---
 README.rst                               |   2 +-
 docs/FAQ.rst                             |   2 +-
 docs/gettingstarted.rst                  |   2 +-
 examples/custom_railmap_example.py       |   2 +-
 examples/demo.py                         |   4 +-
 examples/play_model.py                   |   2 +-
 examples/simple_example_1.py             |   4 +-
 examples/simple_example_2.py             |   4 +-
 examples/simple_example_3.py             |   6 +-
 examples/tkplay.py                       |   4 +-
 flatland/envs/predictions.py             |   5 +
 flatland/envs/rail_env.py                |   1 +
 flatland/utils/editor.py                 | 180 +++----
 flatland/utils/graphics_layer.py         |  14 +-
 flatland/utils/graphics_pil.py           | 269 +++++-----
 flatland/utils/rendertools.py            | 657 +++++++++--------------
 notebooks/Scene_Editor.ipynb             |  48 +-
 tests/test_flatland_envs_observations.py |   8 +-
 tests/test_flatland_envs_predictions.py  |   6 +-
 tests/test_flatland_utils_rendertools.py |  28 +-
 20 files changed, 560 insertions(+), 688 deletions(-)

diff --git a/README.rst b/README.rst
index 23d96aa7..2b817fb0 100644
--- a/README.rst
+++ b/README.rst
@@ -99,7 +99,7 @@ Basic usage of the RailEnv environment used by the Flatland Challenge
         _action = my_controller()
         obs, all_rewards, done, _ = env.step(_action)
         print("Rewards: {}, [done={}]".format( all_rewards, done))
-        env_renderer.renderEnv(show=True, frames=False, show_observations=False)
+        env_renderer.render_env(show=True, frames=False, show_observations=False)
         time.sleep(0.3)
 
 and **ideally** you should see something along the lines of 
diff --git a/docs/FAQ.rst b/docs/FAQ.rst
index 5304e795..45f57727 100644
--- a/docs/FAQ.rst
+++ b/docs/FAQ.rst
@@ -43,5 +43,5 @@ Frequently Asked Questions (FAQs)
     Renders the scene into a image (screenshot)
     .. code-block:: python
 
-    renderer.gl.saveImage("filename.bmp")
+    renderer.gl.save_image("filename.bmp")
 
diff --git a/docs/gettingstarted.rst b/docs/gettingstarted.rst
index 9149d83d..22b9be7d 100644
--- a/docs/gettingstarted.rst
+++ b/docs/gettingstarted.rst
@@ -79,7 +79,7 @@ Environments can be rendered using the utils.rendertools utilities, for example:
 .. code-block:: python
 
     env_renderer = RenderTool(env)
-    env_renderer.renderEnv(show=True)
+    env_renderer.render_env(show=True)
 
 
 Finally, the environment can be run by supplying the environment step function 
diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py
index a76b9c97..515d6c1b 100644
--- a/examples/custom_railmap_example.py
+++ b/examples/custom_railmap_example.py
@@ -37,6 +37,6 @@ env = RailEnv(width=6,
 env.reset()
 
 env_renderer = RenderTool(env)
-env_renderer.renderEnv(show=True)
+env_renderer.render_env(show=True)
 
 input("Press Enter to continue...")
diff --git a/examples/demo.py b/examples/demo.py
index 768f44d3..f150ef85 100644
--- a/examples/demo.py
+++ b/examples/demo.py
@@ -102,7 +102,7 @@ class Demo:
                 action_dict.update({iAgent: action})
 
             # render
-            self.renderer.renderEnv(show=True, show_observations=False)
+            self.renderer.render_env(show=True, show_observations=False)
 
             # environment step (apply the actions to all agents)
             next_obs, all_rewards, done, _ = self.env.step(action_dict)
@@ -111,7 +111,7 @@ class Demo:
                 break
 
             if self.record_frames is not None:
-                self.renderer.gl.saveImage(self.record_frames.format(step))
+                self.renderer.gl.save_image(self.record_frames.format(step))
 
         self.renderer.close_window()
 
diff --git a/examples/play_model.py b/examples/play_model.py
index c44118c2..fdd2bd8d 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -117,7 +117,7 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50):
         for step in range(n_steps):
             oPlayer.step()
             if render:
-                env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
+                env_renderer.render_env(show=True, frames=True, episode=trials, step=step)
                 if delay > 0:
                     time.sleep(delay)
 
diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py
index 633c87b5..daf56a9b 100644
--- a/examples/simple_example_1.py
+++ b/examples/simple_example_1.py
@@ -19,7 +19,7 @@ env = RailEnv(width=6,
 env.reset()
 
 env_renderer = RenderTool(env)
-env_renderer.renderEnv(show=True)
-env_renderer.renderEnv(show=True)
+env_renderer.render_env(show=True)
+env_renderer.render_env(show=True)
 
 input("Press Enter to continue...")
diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py
index 9b5f1938..be6fdef3 100644
--- a/examples/simple_example_2.py
+++ b/examples/simple_example_2.py
@@ -33,7 +33,7 @@ env = RailEnv(width=10,
 env.reset()
 
 env_renderer = RenderTool(env, gl="PIL")
-env_renderer.renderEnv(show=True)
-env_renderer.renderEnv(show=True)
+env_renderer.render_env(show=True)
+env_renderer.render_env(show=True)
 
 input("Press Enter to continue...")
diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index 853d5f5e..3a1c583f 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -22,8 +22,8 @@ for i in range(env.get_num_agents()):
     env.obs_builder.util_print_obs_subtree(tree=obs[i])
 
 env_renderer = RenderTool(env)
-env_renderer.renderEnv(show=True, frames=True)
-env_renderer.renderEnv(show=True, frames=True)
+env_renderer.render_env(show=True, frames=True)
+env_renderer.render_env(show=True, frames=True)
 
 print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
        (turnleft+move, move to front, turnright+move)")
@@ -50,4 +50,4 @@ for step in range(100):
             i = i + 1
         i += 1
 
-    env_renderer.renderEnv(show=True, frames=True)
+    env_renderer.render_env(show=True, frames=True)
diff --git a/examples/tkplay.py b/examples/tkplay.py
index 00a4fc70..225e1138 100644
--- a/examples/tkplay.py
+++ b/examples/tkplay.py
@@ -26,8 +26,8 @@ def tkmain(n_trials=2, n_steps=50, sGL="PIL"):
 
         for step in range(n_steps):
             oPlayer.step()
-            env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step,
-                                   action_dict=oPlayer.action_dict)
+            env_renderer.render_env(show=True, frames=True, episode=trials, step=step,
+                                    action_dict=oPlayer.action_dict)
 
     env_renderer.close_window()
 
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index ca675ce2..1d825ff1 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -122,13 +122,16 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
             _agent_initial_direction = agent.direction
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
             prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
+            visited = set()
             for index in range(1, self.max_depth + 1):
                 # if we're at the target, stop moving...
                 if agent.position == agent.target:
                     prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
+                    visited.add((agent.position[0], agent.position[1], agent.direction))
                     continue
                 if not agent.moving:
                     prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
+                    visited.add((agent.position[0], agent.position[1], agent.direction))
                     continue
                 # Take shortest possible path
                 cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
@@ -159,6 +162,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
 
                 # prediction is ready
                 prediction[index] = [index, *new_position, new_direction, 0]
+                visited.add((new_position[0], new_position[1], new_direction))
+            self.env.dev_pred_dict[agent.handle] = visited
             prediction_dict[agent.handle] = prediction
 
             # cleanup: reset initial position
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 996301a8..a38d1008 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -128,6 +128,7 @@ class RailEnv(Environment):
         self.obs_dict = {}
         self.rewards_dict = {}
         self.dev_obs_dict = {}
+        self.dev_pred_dict = {}
 
         self.agents = [None] * number_of_agents  # live agents
         self.agents_static = [None] * number_of_agents  # static agent information
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index ea4056fa..00e9d85b 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -48,14 +48,14 @@ class View(object):
         self.sGL = sGL
 
     def display(self):
-        self.wOutput.clear_output()
+        self.output.clear_output()
         return self.wMain
 
     def init_canvas(self):
         # update the rendertool with the env
         self.new_env()
-        self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False)
-        img = self.oRT.getImage()
+        self.oRT.render_env(spacing=False, arrows=False, sRailColor="gray", show=False)
+        img = self.oRT.get_image()
         self.wImage = jpy_canvas.Canvas(img)
         self.yxSize = self.wImage.data.shape[:2]
         self.writableData = np.copy(self.wImage.data)  # writable copy of image - wid_img.data is somehow readonly
@@ -67,51 +67,51 @@ class View(object):
 
     def init_widgets(self):
         # Debug checkbox - enable logging in the Output widget
-        self.wDebug = ipywidgets.Checkbox(description="Debug")
-        self.wDebug.observe(self.controller.setDebug, names="value")
+        self.debug = ipywidgets.Checkbox(description="Debug")
+        self.debug.observe(self.controller.set_debug, names="value")
 
         # Separate checkbox for mouse move events - they are very verbose
-        self.wDebug_move = Checkbox(description="Debug mouse move")
-        self.wDebug_move.observe(self.controller.setDebugMove, names="value")
+        self.debug_move = Checkbox(description="Debug mouse move")
+        self.debug_move.observe(self.controller.set_debug_move, names="value")
 
         # This is like a cell widget where loggin goes
-        self.wOutput = Output()
+        self.output = Output()
 
         # Filename textbox
-        self.wFilename = Text(description="Filename")
-        self.wFilename.value = self.model.env_filename
-        self.wFilename.observe(self.controller.setFilename, names="value")
+        self.filename = Text(description="Filename")
+        self.filename.value = self.model.env_filename
+        self.filename.observe(self.controller.set_filename, names="value")
 
         # Size of environment when regenerating
 
-        self.wRegenSizeWidth = IntSlider(value=10, min=5, max=100, step=5, description="Regen Size (Width)",
-                                         tip="Click Regenerate after changing this")
-        self.wRegenSizeWidth.observe(self.controller.setRegenSizeWidth, names="value")
+        self.regen_width = IntSlider(value=10, min=5, max=100, step=5, description="Regen Size (Width)",
+                                     tip="Click Regenerate after changing this")
+        self.regen_width.observe(self.controller.set_regen_width, names="value")
 
-        self.wRegenSizeHeight = IntSlider(value=10, min=5, max=100, step=5, description="Regen Size (Height)",
-                                          tip="Click Regenerate after changing this")
-        self.wRegenSizeHeight.observe(self.controller.setRegenSizeHeight, names="value")
+        self.regen_height = IntSlider(value=10, min=5, max=100, step=5, description="Regen Size (Height)",
+                                      tip="Click Regenerate after changing this")
+        self.regen_height.observe(self.controller.set_regen_height, names="value")
 
         # Number of Agents when regenerating
-        self.wRegenNAgents = IntSlider(value=1, min=0, max=5, step=1, description="# Agents",
-                                       tip="Click regenerate or reset after changing this")
-        self.wRegenMethod = RadioButtons(description="Regen\nMethod", options=["Empty", "Random Cell"])
+        self.regen_n_agents = IntSlider(value=1, min=0, max=5, step=1, description="# Agents",
+                                        tip="Click regenerate or reset after changing this")
+        self.regen_method = RadioButtons(description="Regen\nMethod", options=["Empty", "Random Cell"])
 
-        self.wReplaceAgents = Checkbox(value=True, description="Replace Agents")
+        self.replace_agents = Checkbox(value=True, description="Replace Agents")
 
         self.wTab = Tab()
         tab_contents = ["Regen", "Observation"]
         for i, title in enumerate(tab_contents):
             self.wTab.set_title(i, title)
         self.wTab.children = [
-            VBox([self.wRegenSizeWidth, self.wRegenSizeHeight, self.wRegenNAgents, self.wRegenMethod])
+            VBox([self.regen_width, self.regen_height, self.regen_n_agents, self.regen_method])
         ]
 
         # abbreviated description of buttons and the methods they call
         ldButtons = [
             dict(name="Refresh", method=self.controller.refresh, tip="Redraw only"),
             dict(name="Rotate Agent", method=self.controller.rotate_agent, tip="Rotate selected agent"),
-            dict(name="Restart Agents", method=self.controller.restartAgents,
+            dict(name="Restart Agents", method=self.controller.restart_agents,
                  tip="Move agents back to start positions"),
             dict(name="Random", method=self.controller.reset,
                  tip="Generate a randomized scene, including regen rail + agents"),
@@ -119,7 +119,7 @@ class View(object):
                  tip="Regenerate the rails using the method selected below"),
             dict(name="Load", method=self.controller.load),
             dict(name="Save", method=self.controller.save),
-            dict(name="Save as image", method=self.controller.saveImage)
+            dict(name="Save as image", method=self.controller.save_image)
         ]
 
         self.lwButtons = []
@@ -130,13 +130,13 @@ class View(object):
             self.lwButtons.append(wButton)
 
         self.wVbox_controls = VBox([
-            self.wFilename,
+            self.filename,
             *self.lwButtons,
             self.wTab])
 
         self.wMain = HBox([self.wImage, self.wVbox_controls])
 
-    def drawStroke(self):
+    def draw_stroke(self):
         pass
 
     def new_env(self):
@@ -145,7 +145,7 @@ class View(object):
         self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL)
 
     def redraw(self):
-        with self.wOutput:
+        with self.output:
             self.oRT.set_new_rail()
 
             self.model.env.agents = self.model.env.agents_static
@@ -155,10 +155,10 @@ class View(object):
                 if hasattr(a, 'old_direction') is False:
                     a.old_direction = a.direction
 
-            self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", agents=True,
-                               show=False, iSelectedAgent=self.model.iSelectedAgent,
-                               show_observations=False)
-            img = self.oRT.getImage()
+            self.oRT.render_env(spacing=False, arrows=False, sRailColor="gray", agents=True,
+                                show=False, selected_agent=self.model.selected_agent,
+                                show_observations=False)
+            img = self.oRT.get_image()
 
             self.wImage.data = img
             self.writableData = np.copy(self.wImage.data)
@@ -167,7 +167,7 @@ class View(object):
             self.yxSize = self.wImage.data.shape[:2]
             return img
 
-    def redisplayImage(self):
+    def redisplay_image(self):
         if self.writableData is not None:
             # This updates the image in the browser to be the new edited version
             self.wImage.data = self.writableData
@@ -186,8 +186,8 @@ class View(object):
         return rcCell
 
     def log(self, *args, **kwargs):
-        if self.wOutput:
-            with self.wOutput:
+        if self.output:
+            with self.output:
                 print(*args, **kwargs)
         else:
             print(*args, **kwargs)
@@ -207,7 +207,7 @@ class Controller(object):
         self.qEvents = deque()
         self.drawMode = "Draw"
 
-    def setModel(self, model):
+    def set_model(self, model):
         self.model = model
 
     def on_click(self, wid, event):
@@ -227,26 +227,26 @@ class Controller(object):
             self.model.add_target(rcCell)
             self.lrcStroke = []
         elif bAlt and not bShift and not bCtrl:
-            self.model.clearCell(rcCell)
+            self.model.clear_cell(rcCell)
             self.lrcStroke = []
 
         self.debug("click in cell", rcCell)
         self.model.debug_cell(rcCell)
 
-        if self.model.iSelectedAgent is not None:
+        if self.model.selected_agent is not None:
             self.lrcStroke = []
 
-    def setDebug(self, dEvent):
-        self.model.setDebug(dEvent["new"])
+    def set_debug(self, dEvent):
+        self.model.set_debug(dEvent["new"])
 
-    def setDebugMove(self, dEvent):
+    def set_debug_move(self, dEvent):
         self.model.setDebug_move(dEvent["new"])
 
-    def setDrawMode(self, dEvent):
+    def set_draw_mode(self, dEvent):
         self.drawMode = dEvent["new"]
 
-    def setFilename(self, event):
-        self.model.setFilename(event["new"])
+    def set_filename(self, event):
+        self.model.set_filename(event["new"])
 
     def on_mouse_move(self, wid, event):
         """Mouse motion event handler for drawing.
@@ -285,7 +285,7 @@ class Controller(object):
         else:
             self.lrcStroke = []
 
-        if self.model.iSelectedAgent is not None:
+        if self.model.selected_agent is not None:
             self.lrcStroke = []
             while len(qEvents) > 0:
                 t, x, y = qEvents.popleft()
@@ -307,7 +307,7 @@ class Controller(object):
                     rcCell = self.view.xy_to_rc(x, y)
                     self.editor.drag_path_element(rcCell)
 
-                self.view.redisplayImage()
+                self.view.redisplay_image()
 
         else:
             self.model.mod_path(not event["shiftKey"])
@@ -320,25 +320,25 @@ class Controller(object):
         self.model.clear()
 
     def reset(self, event):
-        self.log("Reset - nAgents:", self.view.wRegenNAgents.value)
+        self.log("Reset - nAgents:", self.view.regen_n_agents.value)
         self.log("Reset - size:", self.model.regen_size_width)
         self.log("Reset - size:", self.model.regen_size_height)
-        self.model.reset(replace_agents=self.view.wReplaceAgents.value,
-                         nAgents=self.view.wRegenNAgents.value)
+        self.model.reset(replace_agents=self.view.replace_agents.value,
+                         nAgents=self.view.regen_n_agents.value)
 
     def rotate_agent(self, event):
-        self.log("Rotate Agent:", self.model.iSelectedAgent)
-        if self.model.iSelectedAgent is not None:
+        self.log("Rotate Agent:", self.model.selected_agent)
+        if self.model.selected_agent is not None:
             for iAgent, agent in enumerate(self.model.env.agents_static):
                 if agent is None:
                     continue
-                if iAgent == self.model.iSelectedAgent:
+                if iAgent == self.model.selected_agent:
                     agent.direction = (agent.direction + 1) % 4
                     agent.old_direction = agent.direction
         self.model.redraw()
 
-    def restartAgents(self, event):
-        self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value)
+    def restart_agents(self, event):
+        self.log("Restart Agents - nAgents:", self.view.regen_n_agents.value)
         if self.model.init_agents_static is not None:
             self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in
                                             self.model.init_agents_static]
@@ -349,15 +349,15 @@ class Controller(object):
         self.refresh(event)
 
     def regenerate(self, event):
-        method = self.view.wRegenMethod.value
-        nAgents = self.view.wRegenNAgents.value
-        self.model.regenerate(method, nAgents)
+        method = self.view.regen_method.value
+        n_agents = self.view.regen_n_agents.value
+        self.model.regenerate(method, n_agents)
 
-    def setRegenSizeWidth(self, event):
-        self.model.setRegenSizeWidth(event["new"])
+    def set_regen_width(self, event):
+        self.model.set_regen_width(event["new"])
 
-    def setRegenSizeHeight(self, event):
-        self.model.setRegenSizeHeight(event["new"])
+    def set_regen_height(self, event):
+        self.model.set_regen_height(event["new"])
 
     def load(self, event):
         self.model.load()
@@ -365,8 +365,8 @@ class Controller(object):
     def save(self, event):
         self.model.save()
 
-    def saveImage(self, event):
-        self.model.saveImage()
+    def save_image(self, event):
+        self.model.save_image()
 
     def step(self, event):
         self.model.step()
@@ -392,13 +392,13 @@ class EditorModel(object):
         self.iTransLast = -1
         self.gRCTrans = array([[-1, 0], [0, 1], [1, 0], [0, -1]])  # NESW in RC
 
-        self.bDebug = False
-        self.bDebug_move = False
+        self.debug_bool = False
+        self.debug_move_bool = False
         self.wid_output = None
         self.drawMode = "Draw"
         self.env_filename = "temp.pkl"
         self.set_env(env)
-        self.iSelectedAgent = None
+        self.selected_agent = None
         self.init_agents_static = None
         self.thread = None
         self.saveImageCnt = 0
@@ -409,15 +409,15 @@ class EditorModel(object):
         """
         self.env = env
 
-    def setDebug(self, bDebug):
-        self.bDebug = bDebug
-        self.log("Set Debug:", self.bDebug)
+    def set_debug(self, bDebug):
+        self.debug_bool = bDebug
+        self.log("Set Debug:", self.debug_bool)
 
-    def setDebugMove(self, bDebug):
-        self.bDebug_move = bDebug
-        self.log("Set DebugMove:", self.bDebug_move)
+    def set_debug_move(self, bDebug):
+        self.debug_move_bool = bDebug
+        self.log("Set DebugMove:", self.debug_move_bool)
 
-    def setDrawMode(self, sDrawMode):
+    def set_draw_mode(self, sDrawMode):
         self.drawMode = sDrawMode
 
     def interpolate_path(self, rcLast, rcCell):
@@ -605,7 +605,7 @@ class EditorModel(object):
 
         self.redraw()
 
-    def clearCell(self, rcCell):
+    def clear_cell(self, rcCell):
         self.debug_cell(rcCell)
         self.env.rail.grid[rcCell[0], rcCell[1]] = 0
         self.redraw()
@@ -614,11 +614,11 @@ class EditorModel(object):
         self.regenerate("complex", nAgents=nAgents)
         self.redraw()
 
-    def restartAgents(self):
+    def restart_agents(self):
         self.env.agents = EnvAgent.list_from_static(self.env.agents_static)
         self.redraw()
 
-    def setFilename(self, filename):
+    def set_filename(self, filename):
         self.env_filename = filename
 
     def load(self):
@@ -650,8 +650,8 @@ class EditorModel(object):
         # reset agents current (current position)
         self.env.agents = temp_store
 
-    def saveImage(self):
-        self.view.oRT.gl.saveImage('frame_{:04d}.bmp'.format(self.saveImageCnt))
+    def save_image(self):
+        self.view.oRT.gl.save_image('frame_{:04d}.bmp'.format(self.saveImageCnt))
         self.saveImageCnt += 1
         self.view.redraw()
 
@@ -681,10 +681,10 @@ class EditorModel(object):
         self.view.new_env()
         self.redraw()
 
-    def setRegenSizeWidth(self, size):
+    def set_regen_width(self, size):
         self.regen_size_width = size
 
-    def setRegenSizeHeight(self, size):
+    def set_regen_height(self, size):
         self.regen_size_height = size
 
     def find_agent_at(self, rcCell):
@@ -702,37 +702,37 @@ class EditorModel(object):
         """
 
         # Has the user clicked on an existing agent?
-        iAgent = self.find_agent_at(rcCell)
+        agent_idx = self.find_agent_at(rcCell)
 
-        if iAgent is None:
+        if agent_idx is None:
             # No
-            if self.iSelectedAgent is None:
+            if self.selected_agent is None:
                 # Create a new agent and select it.
                 agent_static = EnvAgentStatic(position=rcCell, direction=0, target=rcCell, moving=False)
-                self.iSelectedAgent = self.env.add_agent_static(agent_static)
+                self.selected_agent = self.env.add_agent_static(agent_static)
                 self.view.oRT.update_background()
             else:
                 # Move the selected agent to this cell
-                agent_static = self.env.agents_static[self.iSelectedAgent]
+                agent_static = self.env.agents_static[self.selected_agent]
                 agent_static.position = rcCell
                 agent_static.old_position = rcCell
                 self.env.agents = []
         else:
             # Yes
             # Have they clicked on the agent already selected?
-            if self.iSelectedAgent is not None and iAgent == self.iSelectedAgent:
+            if self.selected_agent is not None and agent_idx == self.selected_agent:
                 # Yes - deselect the agent
-                self.iSelectedAgent = None
+                self.selected_agent = None
             else:
                 # No - select the agent
-                self.iSelectedAgent = iAgent
+                self.selected_agent = agent_idx
 
         self.init_agents_static = None
         self.redraw()
 
     def add_target(self, rcCell):
-        if self.iSelectedAgent is not None:
-            self.env.agents_static[self.iSelectedAgent].target = rcCell
+        if self.selected_agent is not None:
+            self.env.agents_static[self.selected_agent].target = rcCell
             self.init_agents_static = None
             self.view.oRT.update_background()
             self.redraw()
@@ -748,7 +748,7 @@ class EditorModel(object):
             self.view.log(*args, **kwargs)
 
     def debug(self, *args, **kwargs):
-        if self.bDebug:
+        if self.debug_bool:
             self.log(*args, **kwargs)
 
     def debug_cell(self, rcCell):
diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py
index 132c6899..de2e8618 100644
--- a/flatland/utils/graphics_layer.py
+++ b/flatland/utils/graphics_layer.py
@@ -33,19 +33,19 @@ class GraphicsLayer(object):
     def clf(self):
         pass
 
-    def beginFrame(self):
+    def begin_frame(self):
         pass
 
     def endFrame(self):
         pass
 
-    def getImage(self):
+    def get_image(self):
         pass
 
-    def saveImage(self, filename):
+    def save_image(self, filename):
         pass
 
-    def adaptColor(self, color, lighten=False):
+    def adapt_color(self, color, lighten=False):
         if type(color) is str:
             if color == "red" or color == "r":
                 color = (255, 0, 0)
@@ -68,7 +68,7 @@ class GraphicsLayer(object):
     def get_cmap(self, *args, **kwargs):
         return plt.get_cmap(*args, **kwargs)
 
-    def setRailAt(self, row, col, binTrans, iTarget=None, isSelected=False, rail_grid=None):
+    def set_rail_at(self, row, col, binTrans, iTarget=None, isSelected=False, rail_grid=None):
         """ Set the rail at cell (row, col) to have transitions binTrans.
             The target argument can contain the index of the agent to indicate
             that agent's target is at that cell, so that a station can be
@@ -76,10 +76,10 @@ class GraphicsLayer(object):
         """
         pass
 
-    def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, isSelected=False):
+    def set_agent_at(self, iAgent, row, col, iDirIn, iDirOut, isSelected=False):
         pass
 
-    def setCellOccupied(self, iAgent, row, col):
+    def set_cell_occupied(self, iAgent, row, col):
         pass
 
     def resize(self, env):
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 40ecab96..21b272e1 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -40,7 +40,7 @@ class PILGL(GraphicsLayer):
     def __init__(self, width, height, jupyter=False):
         self.yxBase = (0, 0)
         self.linewidth = 4
-        self.nAgentColors = 1  # overridden in loadAgent
+        self.n_agent_colors = 1  # overridden in loadAgent
 
         self.width = width
         self.height = height
@@ -84,8 +84,8 @@ class PILGL(GraphicsLayer):
         sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
                   "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
 
-        self.ltAgentColors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")]
-        self.nAgentColors = len(self.ltAgentColors)
+        self.agent_colors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")]
+        self.n_agent_colors = len(self.agent_colors)
 
         self.window_open = False
         self.firstFrame = True
@@ -126,11 +126,11 @@ class PILGL(GraphicsLayer):
         """ convert a hex RGB string like 0091ea to 3-tuple of ints """
         return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2])
 
-    def getAgentColor(self, iAgent):
-        return self.ltAgentColors[iAgent % self.nAgentColors]
+    def get_agent_color(self, iAgent):
+        return self.agent_colors[iAgent % self.n_agent_colors]
 
     def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
-        color = self.adaptColor(color)
+        color = self.adapt_color(color)
         if len(color) == 3:
             color += (opacity,)
         elif len(color) == 4:
@@ -140,13 +140,13 @@ class PILGL(GraphicsLayer):
         self.draws[layer].line(gPoints, fill=color, width=self.linewidth)
 
     def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs):
-        color = self.adaptColor(color)
+        color = self.adapt_color(color)
         r = np.sqrt(s)
         gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
         for x, y in gPoints:
             self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
 
-    def drawImageXY(self, pil_img, xyPixLeftTop, layer=0):
+    def draw_image_xy(self, pil_img, xyPixLeftTop, layer=0):
         if (pil_img.mode == "RGBA"):
             pil_mask = pil_img
         else:
@@ -154,9 +154,9 @@ class PILGL(GraphicsLayer):
 
         self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask)
 
-    def drawImageRC(self, pil_img, rcTopLeft, layer=0):
+    def draw_image_row_col(self, pil_img, rcTopLeft, layer=0):
         xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
-        self.drawImageXY(pil_img, xyPixLeftTop, layer=layer)
+        self.draw_image_xy(pil_img, xyPixLeftTop, layer=layer)
 
     def open_window(self):
         assert self.window_open is False, "Window is already open!"
@@ -178,7 +178,7 @@ class PILGL(GraphicsLayer):
     def prettify2(self, width, height, cell_size):
         pass
 
-    def beginFrame(self):
+    def begin_frame(self):
         # Create a new agent layer
         self.create_layer(iLayer=1, clear=True)
 
@@ -211,14 +211,14 @@ class PILGL(GraphicsLayer):
             img = Image.alpha_composite(img, img2)
         return img
 
-    def getImage(self):
+    def get_image(self):
         """ return a blended / alpha composited image composed of all the layers,
             with layer 0 at the "back".
         """
         img = self.alpha_composite_layers()
         return array(img)
 
-    def saveImage(self, filename):
+    def save_image(self, filename):
         """
         Renders the current scene into a image file
         :param filename: filename where to store the rendering output (supported image format *.bmp , .. , *.png)
@@ -268,15 +268,15 @@ class PILSVG(PILGL):
         self.lwAgents = []
         self.agents_prev = []
 
-        self.loadBuildingSVGs()
-        self.loadScenerySVGs()
-        self.loadRailSVGs()
-        self.loadAgentSVGs()
+        self.load_buildings()
+        self.load_scenery()
+        self.load_rail()
+        self.load_agent()
 
     def is_raster(self):
         return False
 
-    def processEvents(self):
+    def process_events(self):
         time.sleep(0.001)
 
     def clear_rails(self):
@@ -289,7 +289,7 @@ class PILSVG(PILGL):
         self.lwAgents = []
         self.agents_prev = []
 
-    def pilFromSvgFile(self, package, resource):
+    def pil_from_svg_file(self, package, resource):
         bytestring = resource_bytes(package, resource)
         bytesPNG = svg2png(bytestring=bytestring, output_height=self.nPixCell, output_width=self.nPixCell)
         with io.BytesIO(bytesPNG) as fIn:
@@ -298,13 +298,13 @@ class PILSVG(PILGL):
 
         return pil_img
 
-    def pilFromSvgBytes(self, bytesSVG):
+    def pil_from_svg_bytes(self, bytesSVG):
         bytesPNG = svg2png(bytesSVG, output_height=self.nPixCell, output_width=self.nPixCell)
         with io.BytesIO(bytesPNG) as fIn:
             pil_img = Image.open(fIn)
             return pil_img
 
-    def loadBuildingSVGs(self):
+    def load_buildings(self):
         dBuildingFiles = [
             "Buildings/Bank.svg",
             "Buildings/Bar.svg",
@@ -327,16 +327,16 @@ class PILSVG(PILGL):
             "Buildings/Fabrik_I.svg",
         ]
 
-        imgBg = self.pilFromSvgFile('svg', "Background_city.svg")
+        imgBg = self.pil_from_svg_file('svg', "Background_city.svg")
 
         self.dBuildings = []
         for sFile in dBuildingFiles:
-            img = self.pilFromSvgFile('svg', sFile)
+            img = self.pil_from_svg_file('svg', sFile)
             img = Image.alpha_composite(imgBg, img)
             self.dBuildings.append(img)
 
-    def loadScenerySVGs(self):
-        dSceneryFiles = [
+    def load_scenery(self):
+        scenery_files = [
             "Scenery/Laubbaume_A.svg",
             "Scenery/Laubbaume_B.svg",
             "Scenery/Laubbaume_C.svg",
@@ -345,41 +345,41 @@ class PILSVG(PILGL):
             "Scenery/Bergwelt_B.svg"
         ]
 
-        dSceneryFilesDim2 = [
+        scenery_files_d2 = [
             "Scenery/Bergwelt_C_Teil_1_links.svg",
             "Scenery/Bergwelt_C_Teil_2_rechts.svg"
         ]
 
-        dSceneryFilesDim3 = [
+        scenery_files_d3 = [
             "Scenery/Bergwelt_A_Teil_3_rechts.svg",
             "Scenery/Bergwelt_A_Teil_2_mitte.svg",
             "Scenery/Bergwelt_A_Teil_1_links.svg"
         ]
 
-        imgBg = self.pilFromSvgFile('svg', "Background_Light_green.svg")
+        img_back_ground = self.pil_from_svg_file('svg', "Background_Light_green.svg")
 
-        self.dScenery = []
-        for sFile in dSceneryFiles:
-            img = self.pilFromSvgFile('svg', sFile)
-            img = Image.alpha_composite(imgBg, img)
-            self.dScenery.append(img)
+        self.scenery = []
+        for file in scenery_files:
+            img = self.pil_from_svg_file('svg', file)
+            img = Image.alpha_composite(img_back_ground, img)
+            self.scenery.append(img)
 
-        self.dSceneryDim2 = []
-        for sFile in dSceneryFilesDim2:
-            img = self.pilFromSvgFile('svg', sFile)
-            img = Image.alpha_composite(imgBg, img)
-            self.dSceneryDim2.append(img)
+        self.scenery_d2 = []
+        for file in scenery_files_d2:
+            img = self.pil_from_svg_file('svg', file)
+            img = Image.alpha_composite(img_back_ground, img)
+            self.scenery_d2.append(img)
 
-        self.dSceneryDim3 = []
-        for sFile in dSceneryFilesDim3:
-            img = self.pilFromSvgFile('svg', sFile)
-            img = Image.alpha_composite(imgBg, img)
-            self.dSceneryDim3.append(img)
+        self.scenery_d3 = []
+        for file in scenery_files_d3:
+            img = self.pil_from_svg_file('svg', file)
+            img = Image.alpha_composite(img_back_ground, img)
+            self.scenery_d3.append(img)
 
-    def loadRailSVGs(self):
+    def load_rail(self):
         """ Load the rail SVG images, apply rotations, and store as PIL images.
         """
-        dRailFiles = {
+        rail_files = {
             "": "Background_Light_green.svg",
             "WE": "Gleis_Deadend.svg",
             "WW EE NN SS": "Gleis_Diamond_Crossing.svg",
@@ -404,7 +404,7 @@ class PILSVG(PILGL):
             "NE EN SW WS": "Gleis_Kurve_oben_links_unten_rechts.svg"
         }
 
-        dTargetFiles = {
+        target_files = {
             "EW": "Bahnhof_#d50000_Deadend_links.svg",
             "NS": "Bahnhof_#d50000_Deadend_oben.svg",
             "WE": "Bahnhof_#d50000_Deadend_rechts.svg",
@@ -413,114 +413,113 @@ class PILSVG(PILGL):
             "NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"}
 
         # Dict of rail cell images indexed by binary transitions
-        dPilRailFiles = self.loadSVGs(dRailFiles, rotate=True, backgroundImage="Background_rail.svg",
-                                      whitefilter="Background_white_filter.svg")
+        pil_rail_files = self.load_svgs(rail_files, rotate=True, background_image="Background_rail.svg",
+                                        whitefilter="Background_white_filter.svg")
 
         # Load the target files (which have rails and transitions of their own)
         # They are indexed by (binTrans, iAgent), ie a tuple of the binary transition and the agent index
-        dPilTargetFiles = self.loadSVGs(dTargetFiles, rotate=False, agent_colors=self.ltAgentColors,
-                                        backgroundImage="Background_rail.svg",
-                                        whitefilter="Background_white_filter.svg")
+        pil_target_files = self.load_svgs(target_files, rotate=False, agent_colors=self.agent_colors,
+                                          background_image="Background_rail.svg",
+                                          whitefilter="Background_white_filter.svg")
 
         # Load station and recolorize them
-        station = self.pilFromSvgFile("svg", "Bahnhof_#d50000_target.svg")
-        self.ltStationColors = self.recolorImage(station, [0, 0, 0], self.ltAgentColors, False)
+        station = self.pil_from_svg_file("svg", "Bahnhof_#d50000_target.svg")
+        self.station_colors = self.recolor_image(station, [0, 0, 0], self.agent_colors, False)
 
-        cellOccupied = self.pilFromSvgFile("svg", "Cell_occupied.svg")
-        self.ltCellOccupied = self.recolorImage(cellOccupied, [0, 0, 0], self.ltAgentColors, False)
+        cell_occupied = self.pil_from_svg_file("svg", "Cell_occupied.svg")
+        self.cell_occupied = self.recolor_image(cell_occupied, [0, 0, 0], self.agent_colors, False)
 
         # Merge them with the regular rails.
         # https://stackoverflow.com/questions/38987/how-to-merge-two-dictionaries-in-a-single-expression
-        self.dPilRail = {**dPilRailFiles, **dPilTargetFiles}
+        self.pil_rail = {**pil_rail_files, **pil_target_files}
 
-    def loadSVGs(self, dDirFile, rotate=False, agent_colors=False, backgroundImage=None, whitefilter=None):
-        dPil = {}
+    def load_svgs(self, file_directory, rotate=False, agent_colors=False, background_image=None, whitefilter=None):
+        pil = {}
 
         transitions = RailEnvTransitions()
 
-        lDirs = list("NESW")
+        directions = list("NESW")
 
-        for sTrans, sFile in dDirFile.items():
+        for transition, file in file_directory.items():
 
             # Translate the ascii transition description in the format  "NE WS" to the 
             # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
-            lTrans16 = ["0"] * 16
-            for sTran in sTrans.split(" "):
+            transition_16_bit = ["0"] * 16
+            for sTran in transition.split(" "):
                 if len(sTran) == 2:
-                    iDirIn = lDirs.index(sTran[0])
-                    iDirOut = lDirs.index(sTran[1])
-                    iTrans = 4 * iDirIn + iDirOut
-                    lTrans16[iTrans] = "1"
-            sTrans16 = "".join(lTrans16)
-            binTrans = int(sTrans16, 2)
+                    in_direction = directions.index(sTran[0])
+                    out_direction = directions.index(sTran[1])
+                    transition_idx = 4 * in_direction + out_direction
+                    transition_16_bit[transition_idx] = "1"
+            transition_16_bit_string = "".join(transition_16_bit)
+            binary_trans = int(transition_16_bit_string, 2)
 
-            pilRail = self.pilFromSvgFile('svg', sFile)
+            pil_rail = self.pil_from_svg_file('svg', file)
 
-            if backgroundImage is not None:
-                imgBg = self.pilFromSvgFile('svg', backgroundImage)
-                pilRail = Image.alpha_composite(imgBg, pilRail)
+            if background_image is not None:
+                img_bg = self.pil_from_svg_file('svg', background_image)
+                pil_rail = Image.alpha_composite(img_bg, pil_rail)
 
             if whitefilter is not None:
-                imgBg = self.pilFromSvgFile('svg', whitefilter)
-                pilRail = Image.alpha_composite(pilRail, imgBg)
+                img_bg = self.pil_from_svg_file('svg', whitefilter)
+                pil_rail = Image.alpha_composite(pil_rail, img_bg)
 
             if rotate:
                 # For rotations, we also store the base image
-                dPil[binTrans] = pilRail
+                pil[binary_trans] = pil_rail
                 # Rotate both the transition binary and the image and save in the dict
                 for nRot in [90, 180, 270]:
-                    binTrans2 = transitions.rotate_transition(binTrans, nRot)
+                    binary_trans_2 = transitions.rotate_transition(binary_trans, nRot)
 
                     # PIL rotates anticlockwise for positive theta
-                    pilRail2 = pilRail.rotate(-nRot)
-                    dPil[binTrans2] = pilRail2
+                    pil_rail_2 = pil_rail.rotate(-nRot)
+                    pil[binary_trans_2] = pil_rail_2
 
             if agent_colors:
                 # For recoloring, we don't store the base image.
-                a3BaseColor = self.rgb_s2i("d50000")
-                lPils = self.recolorImage(pilRail, a3BaseColor, self.ltAgentColors)
-                for iColor, pilRail2 in enumerate(lPils):
-                    dPil[(binTrans, iColor)] = lPils[iColor]
+                base_color = self.rgb_s2i("d50000")
+                pils = self.recolor_image(pil_rail, base_color, self.agent_colors)
+                for color_idx, pil_rail_2 in enumerate(pils):
+                    pil[(binary_trans, color_idx)] = pils[color_idx]
 
-        return dPil
+        return pil
 
-    def setRailAt(self, row, col, binTrans, iTarget=None, isSelected=False, rail_grid=None):
-        if binTrans in self.dPilRail:
-            pilTrack = self.dPilRail[binTrans]
-            if iTarget is not None:
-                pilTrack = Image.alpha_composite(pilTrack, self.ltStationColors[iTarget % len(self.ltStationColors)])
+    def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None):
+        if binary_trans in self.pil_rail:
+            pil_track = self.pil_rail[binary_trans]
+            if target is not None:
+                pil_track = Image.alpha_composite(pil_track, self.station_colors[target % len(self.station_colors)])
 
-            if binTrans == 0:
+            if binary_trans == 0:
                 if self.background_grid[col][row] <= 4:
                     a = int(self.background_grid[col][row])
                     a = a % len(self.dBuildings)
                     if (col + row + col * row) % 13 > 11:
-                        pilTrack = self.dScenery[a % len(self.dScenery)]
+                        pil_track = self.scenery[a % len(self.scenery)]
                     else:
                         if (col + row + col * row) % 3 == 0:
                             a = (a + (col + row + col * row)) % len(self.dBuildings)
-                        pilTrack = self.dBuildings[a]
+                        pil_track = self.dBuildings[a]
                 elif (self.background_grid[col][row] > 4) or ((col ** 3 + row ** 2 + col * row) % 10 == 0):
                     a = int(self.background_grid[col][row]) - 4
                     a2 = (a + (col + row + col * row + col ** 3 + row ** 4))
                     if a2 % 17 > 11:
                         a = a2
-                    pilTrack = self.dScenery[a % len(self.dScenery)]
+                    pil_track = self.scenery[a % len(self.scenery)]
 
-            self.drawImageRC(pilTrack, (row, col))
+            self.draw_image_row_col(pil_track, (row, col))
         else:
-            print("Illegal rail:", row, col, format(binTrans, "#018b")[2:], binTrans)
+            print("Illegal rail:", row, col, format(binary_trans, "#018b")[2:], binary_trans)
 
-        if iTarget is not None:
-            if isSelected:
-                svgBG = self.pilFromSvgFile("svg", "Selected_Target.svg")
+        if target is not None:
+            if is_selected:
+                svgBG = self.pil_from_svg_file("svg", "Selected_Target.svg")
                 self.clear_layer(3, 0)
-                self.drawImageRC(svgBG, (row, col), layer=3)
+                self.draw_image_row_col(svgBG, (row, col), layer=3)
 
-    def recolorImage(self, pil, a3BaseColor, ltColors, invert=False):
+    def recolor_image(self, pil, a3BaseColor, ltColors, invert=False):
         rgbaImg = array(pil)
-        lPils = []
-
+        pils = []
         for iColor, tnColor in enumerate(ltColors):
             # find the pixels which match the base paint color
             if invert:
@@ -532,67 +531,67 @@ class PILSVG(PILGL):
             # Repaint the base color with the new color
             rgbaImg2[xy_color_mask, 0:3] = tnColor
             pil2 = Image.fromarray(rgbaImg2)
-            lPils.append(pil2)
-        return lPils
+            pils.append(pil2)
+        return pils
 
-    def loadAgentSVGs(self):
+    def load_agent(self):
 
         # Seed initial train/zug files indexed by tuple(iDirIn, iDirOut):
-        dDirsFile = {
+        file_directory = {
             (0, 0): "Zug_Gleis_#0091ea.svg",
             (1, 2): "Zug_1_Weiche_#0091ea.svg",
             (0, 3): "Zug_2_Weiche_#0091ea.svg"
         }
 
         # "paint" color of the train images we load - this is the color we will change.
-        # a3BaseColor = self.rgb_s2i("0091ea") \#  noqa: E800
+        # base_color = self.rgb_s2i("0091ea") \#  noqa: E800
         # temporary workaround for trains / agents renamed with different colour:
-        a3BaseColor = self.rgb_s2i("d50000")
+        base_color = self.rgb_s2i("d50000")
 
-        self.dPilZug = {}
+        self.pil_zug = {}
 
-        for tDirs, sPathSvg in dDirsFile.items():
-            iDirIn, iDirOut = tDirs
+        for directions, path_svg in file_directory.items():
+            in_direction, out_direction = directions
 
-            pilZug = self.pilFromSvgFile("svg", sPathSvg)
+            pil_zug = self.pil_from_svg_file("svg", path_svg)
 
             # Rotate both the directions and the image and save in the dict
-            for iDirRot in range(4):
-                nDegRot = iDirRot * 90
-                iDirIn2 = (iDirIn + iDirRot) % 4
-                iDirOut2 = (iDirOut + iDirRot) % 4
+            for rot_direction in range(4):
+                rotation_degree = rot_direction * 90
+                in_direction_2 = (in_direction + rot_direction) % 4
+                out_direction_2 = (out_direction + rot_direction) % 4
 
                 # PIL rotates anticlockwise for positive theta
-                pilZug2 = pilZug.rotate(-nDegRot)
+                pil_zug_2 = pil_zug.rotate(-rotation_degree)
 
                 # Save colored versions of each rotation / variant
-                lPils = self.recolorImage(pilZug2, a3BaseColor, self.ltAgentColors)
-                for iColor, pilZug3 in enumerate(lPils):
-                    self.dPilZug[(iDirIn2, iDirOut2, iColor)] = lPils[iColor]
-
-    def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, isSelected):
-        delta_dir = (iDirOut - iDirIn) % 4
-        iColor = iAgent % self.nAgentColors
-        # when flipping direction at a dead end, use the "iDirOut" direction.
+                pils = self.recolor_image(pil_zug_2, base_color, self.agent_colors)
+                for color_idx, pil_zug_3 in enumerate(pils):
+                    self.pil_zug[(in_direction_2, out_direction_2, color_idx)] = pils[color_idx]
+
+    def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected):
+        delta_dir = (out_direction - in_direction) % 4
+        color_idx = agent_idx % self.n_agent_colors
+        # when flipping direction at a dead end, use the "out_direction" direction.
         if delta_dir == 2:
-            iDirIn = iDirOut
-        pilZug = self.dPilZug[(iDirIn % 4, iDirOut % 4, iColor)]
-        self.drawImageRC(pilZug, (row, col), layer=1)
+            in_direction = out_direction
+        pil_zug = self.pil_zug[(in_direction % 4, out_direction % 4, color_idx)]
+        self.draw_image_row_col(pil_zug, (row, col), layer=1)
 
-        if isSelected:
-            svgBG = self.pilFromSvgFile("svg", "Selected_Agent.svg")
+        if is_selected:
+            bg_svg = self.pil_from_svg_file("svg", "Selected_Agent.svg")
             self.clear_layer(2, 0)
-            self.drawImageRC(svgBG, (row, col), layer=2)
+            self.draw_image_row_col(bg_svg, (row, col), layer=2)
 
-    def setCellOccupied(self, iAgent, row, col):
-        occIm = self.ltCellOccupied[iAgent % len(self.ltCellOccupied)]
-        self.drawImageRC(occIm, (row, col), 1)
+    def set_cell_occupied(self, agent_idx, row, col):
+        occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)]
+        self.draw_image_row_col(occupied_im, (row, col), 1)
 
 
 def main2():
     gl = PILSVG(10, 10)
     for i in range(10):
-        gl.beginFrame()
+        gl.begin_frame()
         gl.plot([3 + i, 4], [-4 - i, -5], color="r")
         gl.endFrame()
         time.sleep(1)
@@ -602,7 +601,7 @@ def main():
     gl = PILSVG(width=10, height=10)
 
     for i in range(1000):
-        gl.processEvents()
+        gl.process_events()
         time.sleep(0.1)
     time.sleep(1)
 
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 6fa68f02..a15f7d9f 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -25,26 +25,26 @@ class RenderTool(object):
         The lower / rail layer 0 is only redrawn after set_new_rail() has been called.
         Created with a "GraphicsLayer" or gl - now either PIL or PILSVG
     """
-    Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])
+    visit = recordtype("visit", ["rc", "iDir", "iDepth", "prev"])
 
-    lColors = list("brgcmyk")
+    color_list = list("brgcmyk")
     # \delta RC for NESW
-    gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
-    nPixCell = 1  # misnomer...
-    nPixHalf = nPixCell / 2
-    xyHalf = array([nPixHalf, -nPixHalf])
-    grc2xy = array([[0, -nPixCell], [nPixCell, 0]])
-    gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[nPixCell]], [[nPixCell]]])
-    gTheta = np.linspace(0, np.pi / 2, 5)
-    gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]
-
-    def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.ONE_STEP_BEHIND):
+    transitions_row_col = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
+    pix_per_cell = 1  # misnomer...
+    half_pix_per_cell = pix_per_cell / 2
+    x_y_half = array([half_pix_per_cell, -half_pix_per_cell])
+    row_col_to_xy = array([[0, -pix_per_cell], [pix_per_cell, 0]])
+    grid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[pix_per_cell]], [[pix_per_cell]]])
+    theta = np.linspace(0, np.pi / 2, 5)
+    arc = array([np.cos(theta), np.sin(theta)]).T  # from [1,0] to [0,1]
+
+    def __init__(self, env, gl="PILSVG", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND):
         self.env = env
-        self.iFrame = 0
-        self.time1 = time.time()
-        self.lTimes = deque()
+        self.frame_nr = 0
+        self.start_time = time.time()
+        self.times_list = deque()
 
-        self.agentRenderVariant = agentRenderVariant
+        self.agent_render_variant = agent_render_variant
 
         if gl == "PIL":
             self.gl = PILGL(env.width, env.height, jupyter)
@@ -59,12 +59,12 @@ class RenderTool(object):
 
     def update_background(self):
         # create background map
-        dTargets = {}
-        for iAgent, agent in enumerate(self.env.agents_static):
+        targets = {}
+        for agent_idx, agent in enumerate(self.env.agents_static):
             if agent is None:
                 continue
-            dTargets[tuple(agent.target)] = iAgent
-        self.gl.build_background_map(dTargets)
+            targets[tuple(agent.target)] = agent_idx
+        self.gl.build_background_map(targets)
 
     def resize(self):
         self.gl.resize(self.env)
@@ -75,47 +75,31 @@ class RenderTool(object):
         """
         self.new_rail = True
 
-    def plotTreeOnRail(self, lVisits, color="r"):
-        """
-        DEFUNCT
-        Derives and plots a tree of transitions starting at position rcPos
-        in direction iDir.
-        Returns a list of Visits which are the nodes / vertices in the tree.
-        """
-        rt = self.__class__
+    def plot_agents(self, targets=True, selected_agent=None):
+        color_map = self.gl.get_cmap('hsv',
+                                     lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
 
-        for visit in lVisits:
-            # transition for next cell
-            tbTrans = self.env.rail.get_transitions(*visit.rc, visit.iDir)
-            giTrans = np.where(tbTrans)[0]  # RC list of transitions
-            gTransRCAg = rt.gTransRC[giTrans]
-            self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
-
-    def plotAgents(self, targets=True, iSelectedAgent=None):
-        cmap = self.gl.get_cmap('hsv',
-                                lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
-
-        for iAgent, agent in enumerate(self.env.agents_static):
+        for agent_idx, agent in enumerate(self.env.agents_static):
             if agent is None:
                 continue
-            oColor = cmap(iAgent)
-            self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None,
-                           static=True, selected=iAgent == iSelectedAgent)
+            color = color_map(agent_idx)
+            self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None,
+                                   static=True, selected=agent_idx == selected_agent)
 
-        for iAgent, agent in enumerate(self.env.agents):
+        for agent_idx, agent in enumerate(self.env.agents):
             if agent is None:
                 continue
-            oColor = cmap(iAgent)
-            self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None)
+            color = color_map(agent_idx)
+            self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None)
 
-    def getTransRC(self, rcPos, iDir, bgiTrans=False):
+    def get_transition_row_col(self, row_col_pos, direction, bgiTrans=False):
         """
-        Get the available transitions for rcPos in direction iDir,
+        Get the available transitions for row_col_pos in direction direction,
         as row & col deltas.
 
         If bgiTrans is True, return a grid of indices of available transitions.
 
-        eg for a cell rcPos = (4,5), in direction iDir = 0 (N),
+        eg for a cell row_col_pos = (4,5), in direction direction = 0 (N),
         where the available transitions are N and E, returns:
         [[-1,0], [0,1]] ie N=up one row, and E=right one col.
         and if bgiTrans is True, returns a tuple:
@@ -125,217 +109,79 @@ class RenderTool(object):
         )
         """
 
-        tbTrans = self.env.rail.get_transitions(*rcPos, iDir)
-        giTrans = np.where(tbTrans)[0]  # RC list of transitions
+        transitions = self.env.rail.get_transitions(*row_col_pos, direction)
+        transition_list = np.where(transitions)[0]  # RC list of transitions
 
         # HACK: workaround dead-end transitions
-        if len(giTrans) == 0:
-            iDirReverse = (iDir + 2) % 4
-            tbTrans = tuple(int(iDir2 == iDirReverse) for iDir2 in range(4))
-            giTrans = np.where(tbTrans)[0]  # RC list of transitions
+        if len(transition_list) == 0:
+            reverse_direciton = (direction + 2) % 4
+            transitions = tuple(int(tmp_dir == reverse_direciton) for tmp_dir in range(4))
+            transition_list = np.where(transitions)[0]  # RC list of transitions
 
-        gTransRCAg = self.__class__.gTransRC[giTrans]
+        transition_grid = self.__class__.transitions_row_col[transition_list]
 
         if bgiTrans:
-            return gTransRCAg, giTrans
+            return transition_grid, transition_list
         else:
-            return gTransRCAg
+            return transition_grid
 
-    def plotAgent(self, rcPos, iDir, color="r", target=None, static=False, selected=False):
+    def plot_single_agent(self, position_row_col, direction, color="r", target=None, static=False, selected=False):
         """
         Plot a simple agent.
         Assumes a working graphics layer context (cf a MPL figure).
         """
         rt = self.__class__
 
-        rcDir = rt.gTransRC[iDir]  # agent direction in RC
-        xyDir = np.matmul(rcDir, rt.grc2xy)  # agent direction in xy
+        direction_row_col = rt.transitions_row_col[direction]  # agent direction in RC
+        direction_xy = np.matmul(direction_row_col, rt.row_col_to_xy)  # agent direction in xy
 
-        xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
+        xyPos = np.matmul(position_row_col - direction_row_col / 2, rt.row_col_to_xy) + rt.x_y_half
 
         if static:
-            color = self.gl.adaptColor(color, lighten=True)
+            color = self.gl.adapt_color(color, lighten=True)
 
         color = color
 
         self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100)  # agent location
-        xyDirLine = array([xyPos, xyPos + xyDir / 2]).T  # line for agent orient.
-        self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6)
+        xy_dir_line = array([xyPos, xyPos + direction_xy / 2]).T  # line for agent orient.
+        self.gl.plot(*xy_dir_line, color=color, layer=1, lw=5, ms=0, alpha=0.6)
         if selected:
             self._draw_square(xyPos, 1, color)
 
         if target is not None:
-            rcTarget = array(target)
-            xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
-            self._draw_square(xyTarget, 1 / 3, color, layer=1)
+            target_row_col = array(target)
+            target_xy = np.matmul(target_row_col, rt.row_col_to_xy) + rt.x_y_half
+            self._draw_square(target_xy, 1 / 3, color, layer=1)
 
-    def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
+    def plot_transition(self, position_row_col, transition_row_col, color="r", depth=None):
         """
-        plot the transitions in gTransRCAg at position rcPos.
-        gTransRCAg is a 2d numpy array containing a list of RC transitions,
+        plot the transitions in transition_row_col at position position_row_col.
+        transition_row_col is a 2d numpy array containing a list of RC transitions,
         eg [[-1,0], [0,1]] means N, E.
 
         """
 
         rt = self.__class__
-        xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
-        gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy / 2.4)
-        self.gl.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2)
+        position_xy = np.matmul(position_row_col, rt.row_col_to_xy) + rt.x_y_half
+        transition_xy = position_xy + np.matmul(transition_row_col, rt.row_col_to_xy / 2.4)
+        self.gl.scatter(*transition_xy.T, color=color, marker="o", s=50, alpha=0.2)
         if depth is not None:
-            for x, y in gxyTrans:
+            for x, y in transition_xy:
                 self.gl.text(x, y, depth)
 
-    def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True, bPlot=False):
-        """
-        DEFUNCT
-        Generate a tree from the env starting at rcPos, iDir.
-        """
-        rt = self.__class__
-        print(rcPos, iDir)
-        iPos = 0 if bBFS else -1  # BF / DF Search
-
-        iDepth = 0
-        visited = set()
-        lVisits = []
-        stack = [rt.Visit(rcPos, iDir, iDepth, None)]
-        while stack:
-            visit = stack.pop(iPos)
-            rcd = (visit.rc, visit.iDir)
-            if visit.iDepth > nDepth:
-                continue
-            lVisits.append(visit)
-
-            if rcd not in visited:
-                visited.add(rcd)
-
-                gTransRCAg, giTrans = self.getTransRC(visit.rc,
-                                                      visit.iDir,
-                                                      bgiTrans=True)
-                # enqueue the next nodes (ie transitions from this node)
-                for gTransRC2, iTrans in zip(gTransRCAg, giTrans):
-                    visitNext = rt.Visit(tuple(visit.rc + gTransRC2),
-                                         iTrans,
-                                         visit.iDepth + 1,
-                                         visit)
-                    stack.append(visitNext)
-
-                # plot the available transitions from this node
-                if bPlot:
-                    self.plotTrans(
-                        visit.rc, gTransRCAg,
-                        depth=str(visit.iDepth))
-
-        return lVisits
-
-    def plotTree(self, lVisits, xyTarg):
-        '''
-        Plot a vertical tree of transitions.
-        Returns the "visit" to the destination
-        (ie where euclidean distance is near zero) or None if absent.
-        '''
-
-        dPos = {}
-        iPos = 0
-
-        visitDest = None
-
-        for iVisit, visit in enumerate(lVisits):
-
-            if visit.rc in dPos:
-                xLoc = dPos[visit.rc]
-            else:
-                xLoc = dPos[visit.rc] = iPos
-                iPos += 1
-
-            rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
-
-            xLoc = rDist + visit.iDir / 4
-
-            # point labelled with distance
-            self.gl.scatter(xLoc, visit.iDepth, color="k", s=2)
-            self.gl.text(xLoc, visit.iDepth, visit.rc, color="k", rotation=45)
-
-            # if len(dPos)>1:
-            if visit.prev:
-                xLocPrev = dPos[visit.prev.rc]
-
-                rDistPrev = np.linalg.norm(array(visit.prev.rc) -
-                                           array(xyTarg))
-
-                xLocPrev = rDistPrev + visit.prev.iDir / 4
-
-                # line from prev node
-                self.gl.plot([xLocPrev, xLoc],
-                             [visit.iDepth - 1, visit.iDepth],
-                             color="k", alpha=0.5, lw=1)
-
-            if rDist < 0.1:
-                visitDest = visit
-
-        # Walk backwards from destination to origin, plotting in red
-        if visitDest is not None:
-            visit = visitDest
-            xLocPrev = None
-            while visit is not None:
-                rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
-                xLoc = rDist + visit.iDir / 4
-                if xLocPrev is not None:
-                    self.gl.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth + 1],
-                                 color="r", alpha=0.5, lw=2)
-                xLocPrev = xLoc
-                visit = visit.prev
-
-        self.gl.prettify()
-        return visitDest
-
-    def plotPath(self, visitDest):
-        """
-        Given a "final" visit visitDest, plotPath recurses back through the path
-        using the visit.prev field (previous) to get back to the start of the path.
-        The path of transitions is plotted with arrows at 3/4 along the line.
-        The transition is plotted slightly to one side of the rail, so that
-        transitions in opposite directions are separate.
-        Currently, no attempt is made to make the transition arrows coincide
-        at corners, and they are straight only.
-        """
-
-        rt = self.__class__
-        # Walk backwards from destination to origin
-        if visitDest is not None:
-            visit = visitDest
-            xyPrev = None
-            while visit is not None:
-                xy = np.matmul(visit.rc, rt.grc2xy) + rt.xyHalf
-                if xyPrev is not None:
-                    dx, dy = (xyPrev - xy) / 20
-                    xyLine = array([xy, xyPrev]) + array([dy, dx])
-
-                    self.gl.plot(*xyLine.T, color="r", alpha=0.5, lw=1)
-
-                    xyMid = np.sum(xyLine * [[1 / 4], [3 / 4]], axis=0)
-
-                    xyArrow = array([
-                        xyMid + [-dx - dy, +dx - dy],
-                        xyMid,
-                        xyMid + [-dx + dy, -dx - dy]])
-                    self.gl.plot(*xyArrow.T, color="r")
-
-                visit = visit.prev
-                xyPrev = xy
-
-    def drawTrans(self, oFrom, oTo, sColor="gray"):
+    def draw_transition(self, origin, destination, color="gray"):
         self.gl.plot(
-            [oFrom[0], oTo[0]],  # x
-            [oFrom[1], oTo[1]],  # y
-            color=sColor
+            [origin[0], destination[0]],  # x
+            [origin[1], destination[1]],  # y
+            color=color
         )
 
-    def drawTrans2(self,
-                   xyLine, xyCentre,
-                   rotation, bDeadEnd=False,
-                   sColor="gray",
-                   bArrow=True,
-                   spacing=0.1):
+    def draw_transition_2(self,
+                          line, center,
+                          rotation, dead_end=False,
+                          color="gray",
+                          arrow=True,
+                          spacing=0.1):
         """
         gLine is a numpy 2d array of points,
         in the plotting space / coords.
@@ -345,66 +191,66 @@ class RenderTool(object):
         to   x=1, y=0.2
         """
         rt = self.__class__
-        bStraight = rotation in [0, 2]
-        dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2
+        straight = rotation in [0, 2]
+        dx, dy = np.squeeze(np.diff(line, axis=0)) * spacing / 2
 
-        if bStraight:
+        if straight:
 
-            if sColor == "auto":
+            if color == "auto":
                 if dx > 0 or dy > 0:
-                    sColor = "C1"  # N or E
+                    color = "C1"  # N or E
                 else:
-                    sColor = "C2"  # S or W
+                    color = "C2"  # S or W
 
-            if bDeadEnd:
-                xyLine2 = array([
-                    xyLine[1] + [dy, dx],
-                    xyCentre,
-                    xyLine[1] - [dy, dx],
+            if dead_end:
+                line_xy = array([
+                    line[1] + [dy, dx],
+                    center,
+                    line[1] - [dy, dx],
                 ])
-                self.gl.plot(*xyLine2.T, color=sColor)
+                self.gl.plot(*line_xy.T, color=color)
             else:
-                xyLine2 = xyLine + [-dy, dx]
-                self.gl.plot(*xyLine2.T, color=sColor)
+                line_xy = line + [-dy, dx]
+                self.gl.plot(*line_xy.T, color=color)
 
-                if bArrow:
-                    xyMid = np.sum(xyLine2 * [[1 / 4], [3 / 4]], axis=0)
+                if arrow:
+                    middle_xy = np.sum(line_xy * [[1 / 4], [3 / 4]], axis=0)
 
-                    xyArrow = array([
-                        xyMid + [-dx - dy, +dx - dy],
-                        xyMid,
-                        xyMid + [-dx + dy, -dx - dy]])
-                    self.gl.plot(*xyArrow.T, color=sColor)
+                    arrow_xy = array([
+                        middle_xy + [-dx - dy, +dx - dy],
+                        middle_xy,
+                        middle_xy + [-dx + dy, -dx - dy]])
+                    self.gl.plot(*arrow_xy.T, color=color)
 
         else:
 
-            xyMid = np.mean(xyLine, axis=0)
-            dxy = xyMid - xyCentre
-            xyCorner = xyMid + dxy
+            middle_xy = np.mean(line, axis=0)
+            dxy = middle_xy - center
+            corner = middle_xy + dxy
             if rotation == 1:
-                rArcFactor = 1 - spacing
-                sColorAuto = "C1"
+                arc_factor = 1 - spacing
+                color_auto = "C1"
             else:
-                rArcFactor = 1 + spacing
-                sColorAuto = "C2"
-            dxy2 = (xyCentre - xyCorner) * rArcFactor  # for scaling the arc
-
-            if sColor == "auto":
-                sColor = sColorAuto
-
-            self.gl.plot(*(rt.gArc * dxy2 + xyCorner).T, color=sColor)
-
-            if bArrow:
-                dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20
-                iArc = int(len(rt.gArc) / 2)
-                xyMid = xyCorner + rt.gArc[iArc] * dxy2
-                xyArrow = array([
-                    xyMid + [-dx - dy, +dx - dy],
-                    xyMid,
-                    xyMid + [-dx + dy, -dx - dy]])
-                self.gl.plot(*xyArrow.T, color=sColor)
-
-    def renderObs(self, agent_handles, observation_dict):
+                arc_factor = 1 + spacing
+                color_auto = "C2"
+            dxy2 = (center - corner) * arc_factor  # for scaling the arc
+
+            if color == "auto":
+                color = color_auto
+
+            self.gl.plot(*(rt.arc * dxy2 + corner).T, color=color)
+
+            if arrow:
+                dx, dy = np.squeeze(np.diff(line, axis=0)) / 20
+                iArc = int(len(rt.arc) / 2)
+                middle_xy = corner + rt.arc[iArc] * dxy2
+                arrow_xy = array([
+                    middle_xy + [-dx - dy, +dx - dy],
+                    middle_xy,
+                    middle_xy + [-dx + dy, -dx - dy]])
+                self.gl.plot(*arrow_xy.T, color=color)
+
+    def render_observation(self, agent_handles, observation_dict):
         """
         Render the extent of the observation of each agent. All cells that appear in the agent
         observation will be highlighted.
@@ -415,37 +261,54 @@ class RenderTool(object):
         rt = self.__class__
 
         for agent in agent_handles:
-            color = self.gl.getAgentColor(agent)
+            color = self.gl.get_agent_color(agent)
             for visited_cell in observation_dict[agent]:
                 cell_coord = array(visited_cell[:2])
-                cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
+                cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
+                self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
+
+    def render_prediction(self, agent_handles, prediction_dict):
+        """
+        Render the extent of the observation of each agent. All cells that appear in the agent
+        observation will be highlighted.
+        :param agent_handles: List of agent indices to adapt color and get correct observation
+        :param observation_dict: dictionary containing sets of cells of the agent observation
+
+        """
+        rt = self.__class__
+
+        for agent in agent_handles:
+            color = self.gl.get_agent_color(agent)
+            for visited_cell in prediction_dict[agent]:
+                cell_coord = array(visited_cell[:2])
+                cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
                 self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
 
-    def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
+    def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False):
 
         cell_size = 1  # TODO: remove cell_size
         env = self.env
 
         # Draw cells grid
         grid_color = [0.95, 0.95, 0.95]
-        for r in range(env.height + 1):
+        for row in range(env.height + 1):
             self.gl.plot([0, (env.width + 1) * cell_size],
-                         [-r * cell_size, -r * cell_size],
+                         [-row * cell_size, -row * cell_size],
                          color=grid_color, linewidth=2)
-        for c in range(env.width + 1):
-            self.gl.plot([c * cell_size, c * cell_size],
+        for col in range(env.width + 1):
+            self.gl.plot([col * cell_size, col * cell_size],
                          [0, -(env.height + 1) * cell_size],
                          color=grid_color, linewidth=2)
 
         # Draw each cell independently
-        for r in range(env.height):
-            for c in range(env.width):
+        for row in range(env.height):
+            for col in range(env.width):
 
                 # bounding box of the grid cell
-                x0 = cell_size * c  # left
-                x1 = cell_size * (c + 1)  # right
-                y0 = cell_size * -r  # top
-                y1 = cell_size * -(r + 1)  # bottom
+                x0 = cell_size * col  # left
+                x1 = cell_size * (col + 1)  # right
+                y0 = cell_size * -row  # top
+                y1 = cell_size * -(row + 1)  # bottom
 
                 # centres of cell edges
                 coords = [
@@ -456,16 +319,16 @@ class RenderTool(object):
                 ]
 
                 # cell centre
-                xyCentre = array([x0, y1]) + cell_size / 2
+                center_xy = array([x0, y1]) + cell_size / 2
 
                 # cell transition values
-                oCell = env.rail.get_full_transitions(r, c)
+                cell = env.rail.get_full_transitions(row, col)
 
-                bCellValid = env.rail.cell_neighbours_valid((r, c), check_this_cell=True)
+                cell_valid = env.rail.cell_neighbours_valid((row, col), check_this_cell=True)
 
                 # Special Case 7, with a single bit; terminate at center
                 nbits = 0
-                tmp = oCell
+                tmp = cell
 
                 while tmp > 0:
                     nbits += (tmp & 1)
@@ -473,110 +336,114 @@ class RenderTool(object):
 
                 # as above - move the from coord to the centre
                 # it's a dead env.
-                bDeadEnd = nbits == 1
+                is_dead_end = nbits == 1
 
-                if not bCellValid:
-                    self.gl.scatter(*xyCentre, color="r", s=30)
+                if not cell_valid:
+                    self.gl.scatter(*center_xy, color="r", s=30)
 
                 for orientation in range(4):  # ori is where we're heading
                     from_ori = (orientation + 2) % 4  # 0123=NESW -> 2301=SWNE
                     from_xy = coords[from_ori]
 
-                    tMoves = env.rail.get_transitions(r, c, orientation)
+                    moves = env.rail.get_transitions(row, col, orientation)
 
                     for to_ori in range(4):
                         to_xy = coords[to_ori]
                         rotation = (to_ori - from_ori) % 4
 
-                        if (tMoves[to_ori]):  # if we have this transition
+                        if (moves[to_ori]):  # if we have this transition
 
-                            if bDeadEnd:
-                                self.drawTrans2(
-                                    array([from_xy, to_xy]), xyCentre,
-                                    rotation, bDeadEnd=True, spacing=spacing,
-                                    sColor=sRailColor)
+                            if is_dead_end:
+                                self.draw_transition_2(
+                                    array([from_xy, to_xy]), center_xy,
+                                    rotation, dead_end=True, spacing=spacing,
+                                    color=rail_color)
 
                             else:
 
                                 if curves:
-                                    self.drawTrans2(
-                                        array([from_xy, to_xy]), xyCentre,
-                                        rotation, spacing=spacing, bArrow=arrows,
-                                        sColor=sRailColor)
+                                    self.draw_transition_2(
+                                        array([from_xy, to_xy]), center_xy,
+                                        rotation, spacing=spacing, arrow=arrows,
+                                        color=rail_color)
                                 else:
-                                    self.drawTrans(self, from_xy, to_xy, sRailColor)
+                                    self.draw_transition(self, from_xy, to_xy, color=rail_color)
 
                             if False:
                                 print(
-                                    "r,c,ori: ", r, c, orientation,
-                                    "cell:", "{0:b}".format(oCell),
-                                    "moves:", tMoves,
+                                    "r,c,ori: ", row, col, orientation,
+                                    "cell:", "{0:b}".format(cell),
+                                    "moves:", moves,
                                     "from:", from_ori, from_xy,
                                     "to: ", to_ori, to_xy,
-                                    "cen:", *xyCentre,
+                                    "cen:", *center_xy,
                                     "rot:", rotation,
                                 )
 
-    def renderEnv(self,
-                  show=False,  # whether to call matplotlib show() or equivalent after completion
-                  # use false when calling from Jupyter.  (and matplotlib no longer supported!)
-                  curves=True,  # draw turns as curves instead of straight diagonal lines
-                  spacing=False,  # defunct - size of spacing between rails
-                  arrows=False,  # defunct - draw arrows on rail lines
-                  agents=True,  # whether to include agents
-                  show_observations=True,  # whether to include observations
-                  sRailColor="gray",  # color to use in drawing rails (not used with SVG)
-                  frames=False,  # frame counter to show (intended since invocation)
-                  iEpisode=None,  # int episode number to show
-                  iStep=None,  # int step number to show in image
-                  iSelectedAgent=None,  # indicate which agent is "selected" in the editor
-                  action_dict=None):  # defunct - was used to indicate agent intention to turn
+    def render_env(self,
+                   show=False,  # whether to call matplotlib show() or equivalent after completion
+                   # use false when calling from Jupyter.  (and matplotlib no longer supported!)
+                   curves=True,  # draw turns as curves instead of straight diagonal lines
+                   spacing=False,  # defunct - size of spacing between rails
+                   arrows=False,  # defunct - draw arrows on rail lines
+                   agents=True,  # whether to include agents
+                   show_observations=True,  # whether to include observations
+                   show_predictions=True,  # whether to include predictions
+                   sRailColor="gray",  # color to use in drawing rails (not used with SVG)
+                   frames=False,  # frame counter to show (intended since invocation)
+                   episode=None,  # int episode number to show
+                   step=None,  # int step number to show in image
+                   iSelectedAgent=None,  # indicate which agent is "selected" in the editor
+                   action_dict=None):  # defunct - was used to indicate agent intention to turn
         """ Draw the environment using the GraphicsLayer this RenderTool was created with.
             (Use show=False from a Jupyter notebook with %matplotlib inline)
         """
 
         if not self.gl.is_raster():
-            self.renderEnv2(show=show, curves=curves, spacing=spacing,
-                            arrows=arrows, agents=agents, show_observations=show_observations,
-                            sRailColor=sRailColor,
-                            frames=frames, iEpisode=iEpisode, iStep=iStep,
-                            iSelectedAgent=iSelectedAgent, action_dict=action_dict)
+            self.render_env_2(show=show, curves=curves, spacing=spacing,
+                              arrows=arrows, agents=agents, show_observations=show_observations,
+                              show_predictions=show_predictions,
+                              rail_color=sRailColor,
+                              frames=frames, episode=episode, step=step,
+                              selected_agent=iSelectedAgent, action_dict=action_dict)
             return
 
         if type(self.gl) is PILGL:
-            self.gl.beginFrame()
+            self.gl.begin_frame()
 
         env = self.env
 
-        self.renderRail()
+        self.render_rail()
 
         # Draw each agent + its orientation + its target
         if agents:
-            self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
+            self.plot_agents(targets=True, selected_agent=iSelectedAgent)
         if show_observations:
-            self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
+            self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
+        if show_predictions:
+            self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
         # Draw some textual information like fps
-        yText = [-0.3, -0.6, -0.9]
+        text_y = [-0.3, -0.6, -0.9]
         if frames:
-            self.gl.text(0.1, yText[2], "Frame:{:}".format(self.iFrame))
-        self.iFrame += 1
+            self.gl.text(0.1, text_y[2], "Frame:{:}".format(self.frame_nr))
+        self.frame_nr += 1
 
-        if iEpisode is not None:
-            self.gl.text(0.1, yText[1], "Ep:{}".format(iEpisode))
+        if episode is not None:
+            self.gl.text(0.1, text_y[1], "Ep:{}".format(episode))
 
-        if iStep is not None:
-            self.gl.text(0.1, yText[0], "Step:{}".format(iStep))
+        if step is not None:
+            self.gl.text(0.1, text_y[0], "Step:{}".format(step))
 
-        tNow = time.time()
-        self.gl.text(2, yText[2], "elapsed:{:.2f}s".format(tNow - self.time1))
-        self.lTimes.append(tNow)
-        if len(self.lTimes) > 20:
-            self.lTimes.popleft()
-        if len(self.lTimes) > 1:
-            rFps = (len(self.lTimes) - 1) / (self.lTimes[-1] - self.lTimes[0])
-            self.gl.text(2, yText[1], "fps:{:.2f}".format(rFps))
+        time_now = time.time()
+        self.gl.text(2, text_y[2], "elapsed:{:.2f}s".format(time_now - self.start_time))
+        self.times_list.append(time_now)
+        if len(self.times_list) > 20:
+            self.times_list.popleft()
+        if len(self.times_list) > 1:
+            rFps = (len(self.times_list) - 1) / (self.times_list[-1] - self.times_list[0])
+            self.gl.text(2, text_y[1], "fps:{:.2f}".format(rFps))
 
-        self.gl.prettify2(env.width, env.height, self.nPixCell)
+        self.gl.prettify2(env.width, env.height, self.pix_per_cell)
 
         # TODO: for MPL, we don't want to call clf (called by endframe)
         # if not show:
@@ -595,38 +462,13 @@ class RenderTool(object):
         y1 = center[1] + size / 2
         self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
 
-    def getImage(self):
-        return self.gl.getImage()
-
-    def plotTreeObs(self, gObs):
-        nBranchFactor = 4
-
-        gP0 = array([[0, 0, 0]]).T
-        nDepth = 2
-        for i in range(nDepth):
-            nDepthNodes = nBranchFactor ** i
-            rShrinkDepth = 1 / (i + 1)
+    def get_image(self):
+        return self.gl.get_image()
 
-            gX1 = np.linspace(-(nDepthNodes - 1), (nDepthNodes - 1), nDepthNodes) * rShrinkDepth
-            gY1 = np.ones((nDepthNodes)) * i
-            gZ1 = np.zeros((nDepthNodes))
-
-            gP1 = array([gX1, gY1, gZ1])
-            gP01 = np.append(gP0, gP1, axis=1)
-
-            if nDepthNodes > 1:
-                nDepthNodesPrev = nDepthNodes / nBranchFactor
-                giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor)
-                giP1 = np.arange(0, nDepthNodes) + nDepthNodesPrev
-                giLinePoints = np.stack([giP0, giP1]).ravel("F")
-                self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray")
-
-            gP0 = array([gX1, gY1, gZ1])
-
-    def renderEnv2(
+    def render_env_2(
         self, show=False, curves=True, spacing=False, arrows=False, agents=True,
-        show_observations=True, sRailColor="gray",
-        frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
+        show_observations=True, show_predictions=False, rail_color="gray",
+        frames=False, episode=None, step=None, selected_agent=None,
         action_dict=dict()
     ):
         """
@@ -639,45 +481,46 @@ class RenderTool(object):
 
         env = self.env
 
-        self.gl.beginFrame()
+        self.gl.begin_frame()
 
         if self.new_rail:
             self.new_rail = False
             self.gl.clear_rails()
 
             # store the targets
-            dTargets = {}
-            dSelected = {}
-            for iAgent, agent in enumerate(self.env.agents_static):
+            targets = {}
+            selected = {}
+            for agent_idx, agent in enumerate(self.env.agents_static):
                 if agent is None:
                     continue
-                dTargets[tuple(agent.target)] = iAgent
-                dSelected[tuple(agent.target)] = (iAgent == iSelectedAgent)
+                targets[tuple(agent.target)] = agent_idx
+                selected[tuple(agent.target)] = (agent_idx == selected_agent)
 
             # Draw each cell independently
             for r in range(env.height):
                 for c in range(env.width):
-                    binTrans = env.rail.grid[r, c]
-                    if (r, c) in dTargets:
-                        target = dTargets[(r, c)]
-                        isSelected = dSelected[(r, c)]
+                    transitions = env.rail.grid[r, c]
+                    if (r, c) in targets:
+                        target = targets[(r, c)]
+                        is_selected = selected[(r, c)]
                     else:
                         target = None
-                        isSelected = False
+                        is_selected = False
 
-                    self.gl.setRailAt(r, c, binTrans, iTarget=target, isSelected=isSelected, rail_grid=env.rail.grid)
+                    self.gl.set_rail_at(r, c, transitions, target=target, is_selected=is_selected,
+                                        rail_grid=env.rail.grid)
 
-            self.gl.build_background_map(dTargets)
+            self.gl.build_background_map(targets)
 
-        for iAgent, agent in enumerate(self.env.agents):
+        for agent_idx, agent in enumerate(self.env.agents):
 
             if agent is None:
                 continue
 
-            if self.agentRenderVariant == AgentRenderVariant.BOX_ONLY:
-                self.gl.setCellOccupied(iAgent, *(agent.position))
-            elif self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND or \
-                self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:  # noqa: E125
+            if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
+                self.gl.set_cell_occupied(agent_idx, *(agent.position))
+            elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \
+                self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:  # noqa: E125
                 if agent.old_position is not None:
                     position = agent.old_position
                     direction = agent.direction
@@ -687,10 +530,10 @@ class RenderTool(object):
                     direction = agent.direction
                     old_direction = agent.direction
 
-                # setAgentAt uses the agent index for the color
-                if self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
-                    self.gl.setCellOccupied(iAgent, *(agent.position))
-                self.gl.setAgentAt(iAgent, *position, old_direction, direction, iSelectedAgent == iAgent)
+                # set_agent_at uses the agent index for the color
+                if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
+                    self.gl.set_cell_occupied(agent_idx, *(agent.position))
+                self.gl.set_agent_at(agent_idx, *position, old_direction, direction, selected_agent == agent_idx)
             else:
                 position = agent.position
                 direction = agent.direction
@@ -700,23 +543,25 @@ class RenderTool(object):
                     if isValid:
                         direction = possible_directions
 
-                        # setAgentAt uses the agent index for the color
-                        self.gl.setAgentAt(iAgent, *position, agent.direction, direction, iSelectedAgent == iAgent)
+                        # set_agent_at uses the agent index for the color
+                        self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
+                                             selected_agent == agent_idx)
 
-                # setAgentAt uses the agent index for the color
-                if self.agentRenderVariant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
-                    self.gl.setCellOccupied(iAgent, *(agent.position))
-                self.gl.setAgentAt(iAgent, *position, agent.direction, direction, iSelectedAgent == iAgent)
+                # set_agent_at uses the agent index for the color
+                if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
+                    self.gl.set_cell_occupied(agent_idx, *(agent.position))
+                self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, selected_agent == agent_idx)
 
         if show_observations:
-            self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
-
+            self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
+        if show_predictions:
+            self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
         if show:
             self.gl.show()
         for i in range(3):
-            self.gl.processEvents()
+            self.gl.process_events()
 
-        self.iFrame += 1
+        self.frame_nr += 1
         return
 
     def close_window(self):
diff --git a/notebooks/Scene_Editor.ipynb b/notebooks/Scene_Editor.ipynb
index 87b2d38c..bca9a788 100644
--- a/notebooks/Scene_Editor.ipynb
+++ b/notebooks/Scene_Editor.ipynb
@@ -70,7 +70,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "86a323f55bb54ff78169a9c7ca63730f",
+       "model_id": "8285597cada74f51b739ee80c974521b",
        "version_major": 2,
        "version_minor": 0
       },
@@ -80,6 +80,50 @@
      },
      "metadata": {},
      "output_type": "display_data"
+    },
+    {
+     "ename": "AttributeError",
+     "evalue": "'View' object has no attribute 'redisplay_image'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
+      "\u001b[1;32mc:\\users\\u224870\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\flatland_rl-0.2.0-py3.6.egg\\flatland\\utils\\editor.py\u001b[0m in \u001b[0;36mon_mouse_move\u001b[1;34m(self, wid, event)\u001b[0m\n\u001b[0;32m    308\u001b[0m                     \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meditor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdrag_path_element\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrcCell\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    309\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 310\u001b[1;33m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mredisplay_image\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    311\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    312\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mAttributeError\u001b[0m: 'View' object has no attribute 'redisplay_image'"
+     ]
+    },
+    {
+     "ename": "AttributeError",
+     "evalue": "'View' object has no attribute 'redisplay_image'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
+      "\u001b[1;32mc:\\users\\u224870\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\flatland_rl-0.2.0-py3.6.egg\\flatland\\utils\\editor.py\u001b[0m in \u001b[0;36mon_mouse_move\u001b[1;34m(self, wid, event)\u001b[0m\n\u001b[0;32m    308\u001b[0m                     \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meditor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdrag_path_element\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrcCell\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    309\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 310\u001b[1;33m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mredisplay_image\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    311\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    312\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mAttributeError\u001b[0m: 'View' object has no attribute 'redisplay_image'"
+     ]
+    },
+    {
+     "ename": "AttributeError",
+     "evalue": "'View' object has no attribute 'redisplay_image'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
+      "\u001b[1;32mc:\\users\\u224870\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\flatland_rl-0.2.0-py3.6.egg\\flatland\\utils\\editor.py\u001b[0m in \u001b[0;36mon_mouse_move\u001b[1;34m(self, wid, event)\u001b[0m\n\u001b[0;32m    308\u001b[0m                     \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meditor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdrag_path_element\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrcCell\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    309\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 310\u001b[1;33m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mredisplay_image\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    311\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    312\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mAttributeError\u001b[0m: 'View' object has no attribute 'redisplay_image'"
+     ]
+    },
+    {
+     "ename": "AttributeError",
+     "evalue": "'View' object has no attribute 'redisplay_image'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
+      "\u001b[1;32mc:\\users\\u224870\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\flatland_rl-0.2.0-py3.6.egg\\flatland\\utils\\editor.py\u001b[0m in \u001b[0;36mon_mouse_move\u001b[1;34m(self, wid, event)\u001b[0m\n\u001b[0;32m    308\u001b[0m                     \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meditor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdrag_path_element\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrcCell\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    309\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 310\u001b[1;33m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mredisplay_image\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    311\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    312\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+      "\u001b[1;31mAttributeError\u001b[0m: 'View' object has no attribute 'redisplay_image'"
+     ]
     }
    ],
    "source": [
@@ -111,7 +155,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.6.8"
+   "version": "3.6.5"
   },
   "latex_envs": {
    "LaTeX_envs_menu_present": true,
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index c2252619..00bcd0da 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -115,7 +115,7 @@ def test_reward_function_conflict(rendering=False):
 
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
-        renderer.renderEnv(show=True, show_observations=True)
+        renderer.render_env(show=True, show_observations=True)
 
     iteration = 0
     expected_positions = {
@@ -158,7 +158,7 @@ def test_reward_function_conflict(rendering=False):
                                                                                                   agent.position,
                                                                                                   expected_position)
         if rendering:
-            renderer.renderEnv(show=True, show_observations=True)
+            renderer.render_env(show=True, show_observations=True)
 
         iteration += 1
 
@@ -193,7 +193,7 @@ def test_reward_function_waiting(rendering=False):
 
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
-        renderer.renderEnv(show=True, show_observations=True)
+        renderer.render_env(show=True, show_observations=True)
 
     iteration = 0
     expectations = {
@@ -270,7 +270,7 @@ def test_reward_function_waiting(rendering=False):
         rewards = _step_along_shortest_path(env, obs_builder, rail)
 
         if rendering:
-            renderer.renderEnv(show=True, show_observations=True)
+            renderer.render_env(show=True, show_observations=True)
 
         print(env.dones["__all__"])
         for agent in env.agents:
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index eec939e2..dd6d343b 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -37,7 +37,7 @@ def test_dummy_predictor(rendering=False):
 
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
-        renderer.renderEnv(show=True, show_observations=False)
+        renderer.render_env(show=True, show_observations=False)
         input("Continue?")
 
     # test assertions
@@ -130,7 +130,7 @@ def test_shortest_path_predictor(rendering=False):
 
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
-        renderer.renderEnv(show=True, show_observations=False)
+        renderer.render_env(show=True, show_observations=False)
         input("Continue?")
 
     # compute the observations and predictions
@@ -254,7 +254,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
 
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
-        renderer.renderEnv(show=True, show_observations=False)
+        renderer.render_env(show=True, show_observations=False)
         input("Continue?")
 
     # get the trees to test
diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py
index ff7cbd01..ac5d7f41 100644
--- a/tests/test_flatland_utils_rendertools.py
+++ b/tests/test_flatland_utils_rendertools.py
@@ -6,7 +6,6 @@ Tests for `flatland` package.
 
 import sys
 
-import matplotlib.pyplot as plt
 import numpy as np
 from importlib_resources import path
 
@@ -21,7 +20,7 @@ def checkFrozenImage(oRT, sFileImage, resave=False):
     sDirRoot = "."
     sDirImages = sDirRoot + "/images/"
 
-    img_test = oRT.getImage()
+    img_test = oRT.get_image()
 
     if resave:
         np.savez_compressed(sDirImages + sFileImage, img=img_test)
@@ -45,35 +44,14 @@ def test_render_env(save_new_images=False):
                    )
     oEnv.rail.load_transition_map('env_data.tests', "test1.npy")
     oRT = rt.RenderTool(oEnv, gl="PILSVG")
-    oRT.renderEnv(show=False)
+    oRT.render_env(show=False)
 
     checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
 
     oRT = rt.RenderTool(oEnv, gl="PIL")
-    oRT.renderEnv()
+    oRT.render_env()
     checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
 
-    # disable the tree / observation tests until env-agent save/load is available
-    if False:
-        lVisits = oRT.getTreeFromRail(
-            oEnv.agents_position[0],
-            oEnv.agents_direction[0],
-            nDepth=17, bPlot=True)
-
-        checkFrozenImage("env-tree-spatial.png")
-
-        plt.figure(figsize=(8, 8))
-        xyTarg = oRT.env.agents_target[0]
-        visitDest = oRT.plotTree(lVisits, xyTarg)
-
-        checkFrozenImage("env-tree-graph.png")
-
-        plt.figure(figsize=(10, 10))
-        oRT.renderEnv()
-        oRT.plotPath(visitDest)
-
-        checkFrozenImage("env-path.png")
-
 
 def main():
     if len(sys.argv) == 2 and sys.argv[1] == "save":
-- 
GitLab