From 8062e72b7eec9bab1e71f4c0beb4d9ec9cd67026 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Thu, 23 May 2019 12:19:54 +0100
Subject: [PATCH] moving TK into PILGL renderer to allow regular window moved
 test_play.py to use player.py rather than contrived tkplay.py trying to
 improve logic around gl.show() show param to renderEnv

---
 examples/play_model.py           | 10 +++++-----
 examples/tkplay.py               |  3 +--
 flatland/envs/agent_utils.py     |  7 ++++++-
 flatland/utils/editor.py         | 10 ++++++++--
 flatland/utils/graphics_layer.py |  3 +++
 flatland/utils/graphics_pil.py   | 32 +++++++++++++++++++++++++++++---
 flatland/utils/rendertools.py    | 12 ++++++++----
 tests/test_player.py             |  6 +++---
 8 files changed, 63 insertions(+), 20 deletions(-)

diff --git a/examples/play_model.py b/examples/play_model.py
index 34c6aadf..7d7ed110 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -63,7 +63,8 @@ class Player(object):
             # Real Agent
             # action = self.agent.act(np.array(self.obs[handle]), eps=self.eps)
             # Random actions
-            action = random.randint(0, 3)
+            # action = random.randint(0, 3)
+            action = np.random.choice([0, 1, 2, 3], 1, p=[0.2, 0.1, 0.6, 0.1])[0]
             # Numpy version uses single random sequence
             # action = np.random.randint(0, 4, size=1)
             self.action_prob[action] += 1
@@ -106,7 +107,7 @@ def max_lt(seq, val):
     return None
 
 
-def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"):
+def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PIL"):
     random.seed(1)
     np.random.seed(1)
 
@@ -116,8 +117,7 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"):
                   number_of_agents=5)
 
     if render:
-        # env_renderer = RenderTool(env, gl="QTSVG")
-        env_renderer = RenderTool(env, gl=sGL)
+        env_renderer = RenderTool(env, gl=sGL, show=True)
 
     oPlayer = Player(env)
 
@@ -159,7 +159,7 @@ def main_old(render=True, delay=0.0):
                   number_of_agents=5)
 
     if render:
-        env_renderer = RenderTool(env, gl="QTSVG")
+        env_renderer = RenderTool(env, gl="PIL")
         # env_renderer = RenderTool(env, gl="QT")
 
     n_trials = 9999
diff --git a/examples/tkplay.py b/examples/tkplay.py
index 95842e3b..05078fad 100644
--- a/examples/tkplay.py
+++ b/examples/tkplay.py
@@ -9,7 +9,7 @@ from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
 
-def tkmain(n_trials=2):
+def tkmain(n_trials=2, n_steps=50):
     # This creates the main window of an application
     window = tk.Tk()
     window.title("Join")
@@ -24,7 +24,6 @@ def tkmain(n_trials=2):
 
     oPlayer = Player(env)
     n_trials = 1
-    n_steps = 20
     delay = 0
     for trials in range(1, n_trials + 1):
 
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 05f81e43..db7f9ae0 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -7,6 +7,11 @@ import numpy as np
 
 @attrs
 class EnvDescription(object):
+    """ EnvDescription - This is a description of a random env,
+        based around the rail_generator and stats like size and n_agents.
+        It mirrors the parameters given to the RailEnv constructor.
+        Not currently used.
+    """
     n_agents = attrib()
     height = attrib()
     width = attrib()
@@ -16,7 +21,7 @@ class EnvDescription(object):
 
 @attrs
 class EnvAgentStatic(object):
-    """ TODO: EnvAgentStatic - To store initial position, direction and target.
+    """ EnvAgentStatic - Stores initial position, direction and target.
         This is like static data for the environment - it's where an agent starts,
         rather than where it is at the moment.
         The target should also be stored here.
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index 7e813d76..32980d98 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -29,11 +29,15 @@ import jpy_canvas
 
 
 class EditorMVC(object):
-    def __init__(self, env=None, sGL="MPL"):
+    """ EditorMVC - a class to encompass and assemble the Jupyter Editor Model-View-Controller.
+    """
+    def __init__(self, env=None, sGL="PIL"):
+        """ Create an Editor MVC assembly around a railenv, or create one if None.
+        """
         if env is None:
             env = RailEnv(width=10,
                           height=10,
-                          rail_generator=random_rail_generator(),
+                          rail_generator=empty_rail_generator(),
                           number_of_agents=0,
                           obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
@@ -47,6 +51,8 @@ class EditorMVC(object):
 
 
 class View(object):
+    """ The Jupyter Editor View - creates and holds the widgets comprising the Editor.
+    """
     def __init__(self, editor, sGL="MPL"):
         self.editor = self.model = editor
         self.sGL = sGL
diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py
index 4cfcc64b..f65d87f0 100644
--- a/flatland/utils/graphics_layer.py
+++ b/flatland/utils/graphics_layer.py
@@ -7,6 +7,9 @@ class GraphicsLayer(object):
     def __init__(self):
         pass
 
+    def open_window(self):
+        pass
+
     def is_raster(self):
         return True
 
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index b66c8dc5..949628fc 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -1,6 +1,7 @@
 
 from flatland.utils.graphics_layer import GraphicsLayer
-from PIL import Image, ImageDraw   # , ImageFont
+from PIL import Image, ImageDraw, ImageTk   # , ImageFont
+import tkinter as tk
 from numpy import array
 import numpy as np
 
@@ -26,6 +27,9 @@ class PILGL(GraphicsLayer):
         self.tColRail = (0, 0, 0)         # black rails
         self.tColGrid = (230,) * 3        # light grey for grid
 
+        self.window_open = False
+        # self.bShow = show
+        self.firstFrame = True
         self.beginFrame()
 
     def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
@@ -45,6 +49,13 @@ class PILGL(GraphicsLayer):
         for x, y in gPoints:
             self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
 
+    def open_window(self):
+        assert self.window_open is False, "Window is already open!"
+        self.window = tk.Tk()
+        self.window.title("Flatland")
+        self.window.configure(background='grey')
+        self.window_open = True
+
     def text(self, *args, **kwargs):
         pass
 
@@ -59,8 +70,23 @@ class PILGL(GraphicsLayer):
         self.create_layer(1)
 
     def show(self, block=False):
-        pass
-        # plt.show(block=block)
+        img = self.alpha_composite_layers()
+        
+        if not self.window_open:
+            self.open_window()
+        
+        tkimg = ImageTk.PhotoImage(img)
+        
+        if self.firstFrame:
+            self.panel = tk.Label(self.window, image=tkimg)
+            self.panel.pack(side="bottom", fill="both", expand="yes")
+        else:
+            # update the image in situ
+            self.panel.configure(image=tkimg)
+            self.panel.image = tkimg
+
+        self.window.update()
+        self.firstFrame = False
 
     def pause(self, seconds=0.00001):
         pass
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 30dccfac..700e3759 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -15,14 +15,15 @@ from flatland.utils.graphics_layer import GraphicsLayer
 
 
 class MPLGL(GraphicsLayer):
-    def __init__(self, width, height, show=False):
+    def __init__(self, width, height):
         self.width = width
         self.height = height
         self.yxBase = array([6, 21])  # pixel offset
         self.nPixCell = 700 / width
         self.img = None
-        if show:
-            plt.figure(figsize=(10, 10))
+
+    def open_window(self):
+        plt.figure(figsize=(10, 10))
 
     def plot(self, *args, **kwargs):
         plt.plot(*args, **kwargs)
@@ -126,7 +127,7 @@ class RenderTool(object):
         # self.gl = MPLGL()
 
         if gl == "MPL":
-            self.gl = MPLGL(env.width, env.height, show=show)
+            self.gl = MPLGL(env.width, env.height)
         elif gl == "QT":
             self.gl = QTGL(env.width, env.height)
         elif gl == "PIL":
@@ -681,6 +682,9 @@ class RenderTool(object):
                 self.gl.show(block=False)
             # self.gl.endFrame()
 
+        if show and type(self.gl) is PILGL:
+            self.gl.show()
+
         self.gl.pause(0.00001)
 
         return
diff --git a/tests/test_player.py b/tests/test_player.py
index 7b2745f2..a0e580b9 100644
--- a/tests/test_player.py
+++ b/tests/test_player.py
@@ -1,8 +1,8 @@
 
-# from examples.play_model import main
-from examples.tkplay import tkmain
+from examples.play_model import main
+# from examples.tkplay import tkmain
 
 
 def test_main():
-    tkmain(n_trials=2)
+    main(render=True, n_steps=20, n_trials=2, sGL="PIL")
 
-- 
GitLab