diff --git a/examples/play_model.py b/examples/play_model.py index 34c6aadfeefd44771fd335e2957e1fbd0b2f740f..7d7ed1104e7689e4ff49d993476a6f1fc6d4b6e8 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -63,7 +63,8 @@ class Player(object): # Real Agent # action = self.agent.act(np.array(self.obs[handle]), eps=self.eps) # Random actions - action = random.randint(0, 3) + # action = random.randint(0, 3) + action = np.random.choice([0, 1, 2, 3], 1, p=[0.2, 0.1, 0.6, 0.1])[0] # Numpy version uses single random sequence # action = np.random.randint(0, 4, size=1) self.action_prob[action] += 1 @@ -106,7 +107,7 @@ def max_lt(seq, val): return None -def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"): +def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PIL"): random.seed(1) np.random.seed(1) @@ -116,8 +117,7 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"): number_of_agents=5) if render: - # env_renderer = RenderTool(env, gl="QTSVG") - env_renderer = RenderTool(env, gl=sGL) + env_renderer = RenderTool(env, gl=sGL, show=True) oPlayer = Player(env) @@ -159,7 +159,7 @@ def main_old(render=True, delay=0.0): number_of_agents=5) if render: - env_renderer = RenderTool(env, gl="QTSVG") + env_renderer = RenderTool(env, gl="PIL") # env_renderer = RenderTool(env, gl="QT") n_trials = 9999 diff --git a/examples/tkplay.py b/examples/tkplay.py index 95842e3b430000169093d27c3c9de02ebe037de9..05078fadc225602dccc30bd5159b50a1ac7a8713 100644 --- a/examples/tkplay.py +++ b/examples/tkplay.py @@ -9,7 +9,7 @@ from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -def tkmain(n_trials=2): +def tkmain(n_trials=2, n_steps=50): # This creates the main window of an application window = tk.Tk() window.title("Join") @@ -24,7 +24,6 @@ def tkmain(n_trials=2): oPlayer = Player(env) n_trials = 1 - n_steps = 20 delay = 0 for trials in range(1, n_trials + 1): diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 05f81e43be3f33bfdfc81911d6cf6272bfba2d7e..db7f9ae05483f4b4879e24f93716a384690f6432 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -7,6 +7,11 @@ import numpy as np @attrs class EnvDescription(object): + """ EnvDescription - This is a description of a random env, + based around the rail_generator and stats like size and n_agents. + It mirrors the parameters given to the RailEnv constructor. + Not currently used. + """ n_agents = attrib() height = attrib() width = attrib() @@ -16,7 +21,7 @@ class EnvDescription(object): @attrs class EnvAgentStatic(object): - """ TODO: EnvAgentStatic - To store initial position, direction and target. + """ EnvAgentStatic - Stores initial position, direction and target. This is like static data for the environment - it's where an agent starts, rather than where it is at the moment. The target should also be stored here. diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 7e813d763d7f26a96a1e1ca1f4c3d1bceef68ee3..32980d98856f31aef9ec1c53af311618452ae32e 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -29,11 +29,15 @@ import jpy_canvas class EditorMVC(object): - def __init__(self, env=None, sGL="MPL"): + """ EditorMVC - a class to encompass and assemble the Jupyter Editor Model-View-Controller. + """ + def __init__(self, env=None, sGL="PIL"): + """ Create an Editor MVC assembly around a railenv, or create one if None. + """ if env is None: env = RailEnv(width=10, height=10, - rail_generator=random_rail_generator(), + rail_generator=empty_rail_generator(), number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2)) @@ -47,6 +51,8 @@ class EditorMVC(object): class View(object): + """ The Jupyter Editor View - creates and holds the widgets comprising the Editor. + """ def __init__(self, editor, sGL="MPL"): self.editor = self.model = editor self.sGL = sGL diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index 4cfcc64bffb82f91a0f36822188db297bc1ed37e..f65d87f0b25d0ef22e14f3893896d6c80d1b4080 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -7,6 +7,9 @@ class GraphicsLayer(object): def __init__(self): pass + def open_window(self): + pass + def is_raster(self): return True diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index b66c8dc55f38c321d038306f933de66493a6e6b3..949628fcd093c404ffef8d313dbeea4379d86708 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -1,6 +1,7 @@ from flatland.utils.graphics_layer import GraphicsLayer -from PIL import Image, ImageDraw # , ImageFont +from PIL import Image, ImageDraw, ImageTk # , ImageFont +import tkinter as tk from numpy import array import numpy as np @@ -26,6 +27,9 @@ class PILGL(GraphicsLayer): self.tColRail = (0, 0, 0) # black rails self.tColGrid = (230,) * 3 # light grey for grid + self.window_open = False + # self.bShow = show + self.firstFrame = True self.beginFrame() def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): @@ -45,6 +49,13 @@ class PILGL(GraphicsLayer): for x, y in gPoints: self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) + def open_window(self): + assert self.window_open is False, "Window is already open!" + self.window = tk.Tk() + self.window.title("Flatland") + self.window.configure(background='grey') + self.window_open = True + def text(self, *args, **kwargs): pass @@ -59,8 +70,23 @@ class PILGL(GraphicsLayer): self.create_layer(1) def show(self, block=False): - pass - # plt.show(block=block) + img = self.alpha_composite_layers() + + if not self.window_open: + self.open_window() + + tkimg = ImageTk.PhotoImage(img) + + if self.firstFrame: + self.panel = tk.Label(self.window, image=tkimg) + self.panel.pack(side="bottom", fill="both", expand="yes") + else: + # update the image in situ + self.panel.configure(image=tkimg) + self.panel.image = tkimg + + self.window.update() + self.firstFrame = False def pause(self, seconds=0.00001): pass diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 30dccfaca0f0d8506c9fed7bf34601275148bd7b..700e375946e910c05b6234a4c0bfb9d1bc67f4cd 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -15,14 +15,15 @@ from flatland.utils.graphics_layer import GraphicsLayer class MPLGL(GraphicsLayer): - def __init__(self, width, height, show=False): + def __init__(self, width, height): self.width = width self.height = height self.yxBase = array([6, 21]) # pixel offset self.nPixCell = 700 / width self.img = None - if show: - plt.figure(figsize=(10, 10)) + + def open_window(self): + plt.figure(figsize=(10, 10)) def plot(self, *args, **kwargs): plt.plot(*args, **kwargs) @@ -126,7 +127,7 @@ class RenderTool(object): # self.gl = MPLGL() if gl == "MPL": - self.gl = MPLGL(env.width, env.height, show=show) + self.gl = MPLGL(env.width, env.height) elif gl == "QT": self.gl = QTGL(env.width, env.height) elif gl == "PIL": @@ -681,6 +682,9 @@ class RenderTool(object): self.gl.show(block=False) # self.gl.endFrame() + if show and type(self.gl) is PILGL: + self.gl.show() + self.gl.pause(0.00001) return diff --git a/tests/test_player.py b/tests/test_player.py index 7b2745f2e372ca80cd2fb5cf9dcaa3db96fb910a..a0e580b92ee04f41ec1bab7c4e99da1339767c96 100644 --- a/tests/test_player.py +++ b/tests/test_player.py @@ -1,8 +1,8 @@ -# from examples.play_model import main -from examples.tkplay import tkmain +from examples.play_model import main +# from examples.tkplay import tkmain def test_main(): - tkmain(n_trials=2) + main(render=True, n_steps=20, n_trials=2, sGL="PIL")