From 8062e72b7eec9bab1e71f4c0beb4d9ec9cd67026 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Thu, 23 May 2019 12:19:54 +0100 Subject: [PATCH] moving TK into PILGL renderer to allow regular window moved test_play.py to use player.py rather than contrived tkplay.py trying to improve logic around gl.show() show param to renderEnv --- examples/play_model.py | 10 +++++----- examples/tkplay.py | 3 +-- flatland/envs/agent_utils.py | 7 ++++++- flatland/utils/editor.py | 10 ++++++++-- flatland/utils/graphics_layer.py | 3 +++ flatland/utils/graphics_pil.py | 32 +++++++++++++++++++++++++++++--- flatland/utils/rendertools.py | 12 ++++++++---- tests/test_player.py | 6 +++--- 8 files changed, 63 insertions(+), 20 deletions(-) diff --git a/examples/play_model.py b/examples/play_model.py index 34c6aadf..7d7ed110 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -63,7 +63,8 @@ class Player(object): # Real Agent # action = self.agent.act(np.array(self.obs[handle]), eps=self.eps) # Random actions - action = random.randint(0, 3) + # action = random.randint(0, 3) + action = np.random.choice([0, 1, 2, 3], 1, p=[0.2, 0.1, 0.6, 0.1])[0] # Numpy version uses single random sequence # action = np.random.randint(0, 4, size=1) self.action_prob[action] += 1 @@ -106,7 +107,7 @@ def max_lt(seq, val): return None -def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"): +def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PIL"): random.seed(1) np.random.seed(1) @@ -116,8 +117,7 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"): number_of_agents=5) if render: - # env_renderer = RenderTool(env, gl="QTSVG") - env_renderer = RenderTool(env, gl=sGL) + env_renderer = RenderTool(env, gl=sGL, show=True) oPlayer = Player(env) @@ -159,7 +159,7 @@ def main_old(render=True, delay=0.0): number_of_agents=5) if render: - env_renderer = RenderTool(env, gl="QTSVG") + env_renderer = RenderTool(env, gl="PIL") # env_renderer = RenderTool(env, gl="QT") n_trials = 9999 diff --git a/examples/tkplay.py b/examples/tkplay.py index 95842e3b..05078fad 100644 --- a/examples/tkplay.py +++ b/examples/tkplay.py @@ -9,7 +9,7 @@ from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -def tkmain(n_trials=2): +def tkmain(n_trials=2, n_steps=50): # This creates the main window of an application window = tk.Tk() window.title("Join") @@ -24,7 +24,6 @@ def tkmain(n_trials=2): oPlayer = Player(env) n_trials = 1 - n_steps = 20 delay = 0 for trials in range(1, n_trials + 1): diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 05f81e43..db7f9ae0 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -7,6 +7,11 @@ import numpy as np @attrs class EnvDescription(object): + """ EnvDescription - This is a description of a random env, + based around the rail_generator and stats like size and n_agents. + It mirrors the parameters given to the RailEnv constructor. + Not currently used. + """ n_agents = attrib() height = attrib() width = attrib() @@ -16,7 +21,7 @@ class EnvDescription(object): @attrs class EnvAgentStatic(object): - """ TODO: EnvAgentStatic - To store initial position, direction and target. + """ EnvAgentStatic - Stores initial position, direction and target. This is like static data for the environment - it's where an agent starts, rather than where it is at the moment. The target should also be stored here. diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 7e813d76..32980d98 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -29,11 +29,15 @@ import jpy_canvas class EditorMVC(object): - def __init__(self, env=None, sGL="MPL"): + """ EditorMVC - a class to encompass and assemble the Jupyter Editor Model-View-Controller. + """ + def __init__(self, env=None, sGL="PIL"): + """ Create an Editor MVC assembly around a railenv, or create one if None. + """ if env is None: env = RailEnv(width=10, height=10, - rail_generator=random_rail_generator(), + rail_generator=empty_rail_generator(), number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2)) @@ -47,6 +51,8 @@ class EditorMVC(object): class View(object): + """ The Jupyter Editor View - creates and holds the widgets comprising the Editor. + """ def __init__(self, editor, sGL="MPL"): self.editor = self.model = editor self.sGL = sGL diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index 4cfcc64b..f65d87f0 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -7,6 +7,9 @@ class GraphicsLayer(object): def __init__(self): pass + def open_window(self): + pass + def is_raster(self): return True diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index b66c8dc5..949628fc 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -1,6 +1,7 @@ from flatland.utils.graphics_layer import GraphicsLayer -from PIL import Image, ImageDraw # , ImageFont +from PIL import Image, ImageDraw, ImageTk # , ImageFont +import tkinter as tk from numpy import array import numpy as np @@ -26,6 +27,9 @@ class PILGL(GraphicsLayer): self.tColRail = (0, 0, 0) # black rails self.tColGrid = (230,) * 3 # light grey for grid + self.window_open = False + # self.bShow = show + self.firstFrame = True self.beginFrame() def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): @@ -45,6 +49,13 @@ class PILGL(GraphicsLayer): for x, y in gPoints: self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) + def open_window(self): + assert self.window_open is False, "Window is already open!" + self.window = tk.Tk() + self.window.title("Flatland") + self.window.configure(background='grey') + self.window_open = True + def text(self, *args, **kwargs): pass @@ -59,8 +70,23 @@ class PILGL(GraphicsLayer): self.create_layer(1) def show(self, block=False): - pass - # plt.show(block=block) + img = self.alpha_composite_layers() + + if not self.window_open: + self.open_window() + + tkimg = ImageTk.PhotoImage(img) + + if self.firstFrame: + self.panel = tk.Label(self.window, image=tkimg) + self.panel.pack(side="bottom", fill="both", expand="yes") + else: + # update the image in situ + self.panel.configure(image=tkimg) + self.panel.image = tkimg + + self.window.update() + self.firstFrame = False def pause(self, seconds=0.00001): pass diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 30dccfac..700e3759 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -15,14 +15,15 @@ from flatland.utils.graphics_layer import GraphicsLayer class MPLGL(GraphicsLayer): - def __init__(self, width, height, show=False): + def __init__(self, width, height): self.width = width self.height = height self.yxBase = array([6, 21]) # pixel offset self.nPixCell = 700 / width self.img = None - if show: - plt.figure(figsize=(10, 10)) + + def open_window(self): + plt.figure(figsize=(10, 10)) def plot(self, *args, **kwargs): plt.plot(*args, **kwargs) @@ -126,7 +127,7 @@ class RenderTool(object): # self.gl = MPLGL() if gl == "MPL": - self.gl = MPLGL(env.width, env.height, show=show) + self.gl = MPLGL(env.width, env.height) elif gl == "QT": self.gl = QTGL(env.width, env.height) elif gl == "PIL": @@ -681,6 +682,9 @@ class RenderTool(object): self.gl.show(block=False) # self.gl.endFrame() + if show and type(self.gl) is PILGL: + self.gl.show() + self.gl.pause(0.00001) return diff --git a/tests/test_player.py b/tests/test_player.py index 7b2745f2..a0e580b9 100644 --- a/tests/test_player.py +++ b/tests/test_player.py @@ -1,8 +1,8 @@ -# from examples.play_model import main -from examples.tkplay import tkmain +from examples.play_model import main +# from examples.tkplay import tkmain def test_main(): - tkmain(n_trials=2) + main(render=True, n_steps=20, n_trials=2, sGL="PIL") -- GitLab