diff --git a/examples/play_model.py b/examples/play_model.py index 5ed6feb7b9b3cc5a5ab6251657864fad7ee96a02..777f2d34254b9dba4c7d0a82e207397a61885d6a 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -100,7 +100,7 @@ def max_lt(seq, val): return None -def main(render=True, delay=0.0, n_trials=3, n_steps=50): +def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"): random.seed(1) np.random.seed(1) @@ -111,7 +111,7 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50): if render: # env_renderer = RenderTool(env, gl="QTSVG") - env_renderer = RenderTool(env, gl="QT") + env_renderer = RenderTool(env, gl=sGL) oPlayer = Player(env) diff --git a/examples/tkplay.py b/examples/tkplay.py new file mode 100644 index 0000000000000000000000000000000000000000..b337253814ebb55994ac5981cb09a5bae008317b --- /dev/null +++ b/examples/tkplay.py @@ -0,0 +1,59 @@ + +import tkinter as tk +from PIL import ImageTk, Image +from examples.play_model import Player +from flatland.envs.rail_env import RailEnv +from flatland.envs.generators import complex_rail_generator +from flatland.utils.rendertools import RenderTool +import time + + +def tkmain(n_trials=2): + # This creates the main window of an application + window = tk.Tk() + window.title("Join") + window.configure(background='grey') + + # Example generate a random rail + env = RailEnv(width=15, height=15, + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), + number_of_agents=5) + + env_renderer = RenderTool(env, gl="PIL") + + oPlayer = Player(env) + n_trials = 1 + n_steps = 20 + delay = 0 + for trials in range(1, n_trials + 1): + + # Reset environment8 + oPlayer.reset() + env_renderer.set_new_rail() + + first = True + + for step in range(n_steps): + oPlayer.step() + env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, + action_dict=oPlayer.action_dict) + img = env_renderer.getImage() + img = Image.fromarray(img) + tkimg = ImageTk.PhotoImage(img) + + if first: + panel = tk.Label(window, image = tkimg) + panel.pack(side = "bottom", fill = "both", expand = "yes") + else: + # update the image in situ + panel.configure(image=tkimg) + panel.image = tkimg + + window.update() + if delay > 0: + time.sleep(delay) + first = False + + +if __name__ == "__main__": + tkmain() \ No newline at end of file diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index 4cfcc64bffb82f91a0f36822188db297bc1ed37e..40bf319da737c8bb49e6c228c98b70604734ddc4 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -51,7 +51,7 @@ class GraphicsLayer(object): elif type(color) is tuple: if type(color[0]) is not int: gcolor = array(color) - color = tuple((gcolor[:3] * 255).astype(int)) + color = tuple((gcolor[:4] * 255).astype(int)) else: color = self.tColGrid diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 41516fd94737556a4b8abbc7ccfce0fd503a3d6e..ceda3fd9f4c7d16b8a516dca8af43ea3122d51c7 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -18,28 +18,32 @@ class PILGL(GraphicsLayer): # Total grid size at native scale self.widthPx = self.width * self.nPixCell + self.linewidth self.heightPx = self.height * self.nPixCell + self.linewidth - self.beginFrame() + self.layers = [] + self.draws = [] self.tColBg = (255, 255, 255) # white background # self.tColBg = (220, 120, 40) # background color self.tColRail = (0, 0, 0) # black rails self.tColGrid = (230,) * 3 # light grey for grid - def plot(self, gX, gY, color=None, linewidth=3, **kwargs): - color = self.adaptColor(color) + self.beginFrame() - # print(gX, gY) + def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): + color = self.adaptColor(color) + if len(color) == 3: + color += (opacity,) + elif len(color) == 4: + color = color[:3] + (opacity,) gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell gPoints = list(gPoints.ravel()) - # print(gPoints, color) - self.draw.line(gPoints, fill=color, width=self.linewidth) + self.draws[layer].line(gPoints, fill=color, width=self.linewidth) - def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs): + def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs): color = self.adaptColor(color) r = np.sqrt(s) gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell for x, y in gPoints: - self.draw.rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) + self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) def text(self, *args, **kwargs): pass @@ -51,8 +55,8 @@ class PILGL(GraphicsLayer): pass def beginFrame(self): - self.img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, 255)) - self.draw = ImageDraw.Draw(self.img) + self.create_layer(0) + self.create_layer(1) def show(self, block=False): pass @@ -62,5 +66,35 @@ class PILGL(GraphicsLayer): pass # plt.pause(seconds) + def alpha_composite_layers(self): + img = self.layers[0] + for img2 in self.layers[1:]: + img = Image.alpha_composite(img, img2) + return img + def getImage(self): - return array(self.img) + """ return a blended / alpha composited image composed of all the layers, + with layer 0 at the "back". + """ + img = self.alpha_composite_layers() + return array(img) + + def create_image(self, opacity=255): + img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity)) + return img + + def create_layer(self, iLayer=0): + if len(self.layers) <= iLayer: + for i in range(len(self.layers), iLayer+1): + if i==0: + opacity = 255 # "bottom" layer is opaque (for rails) + else: + opacity = 0 # subsequent layers are transparent + img = self.create_image(opacity) + self.layers.append(img) + self.draws.append(ImageDraw.Draw(img)) + else: + opacity = 0 if iLayer > 0 else 255 + self.layers[iLayer] = img = self.create_image(opacity) + self.draws[iLayer] = ImageDraw.Draw(img) + diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index e9a94ec1fd18635b69d5f4bc47ce1d4f43daffcd..4def953a02b1154ff36827d0d4be5e023b6d1146 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -15,12 +15,14 @@ from flatland.utils.graphics_layer import GraphicsLayer class MPLGL(GraphicsLayer): - def __init__(self, width, height): + def __init__(self, width, height, show=False): self.width = width self.height = height self.yxBase = array([6, 21]) # pixel offset self.nPixCell = 700 / width self.img = None + if show: + plt.figure(figsize=(10, 10)) def plot(self, *args, **kwargs): plt.plot(*args, **kwargs) @@ -70,6 +72,7 @@ class MPLGL(GraphicsLayer): def beginFrame(self): self.img = None plt.figure(figsize=(10, 10)) + plt.clf() pass def endFrame(self): @@ -115,7 +118,7 @@ class RenderTool(object): gTheta = np.linspace(0, np.pi / 2, 5) gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] - def __init__(self, env, gl="MPL"): + def __init__(self, env, gl="MPL", show=False): self.env = env self.iFrame = 0 self.time1 = time.time() @@ -123,7 +126,7 @@ class RenderTool(object): # self.gl = MPLGL() if gl == "MPL": - self.gl = MPLGL(env.width, env.height) + self.gl = MPLGL(env.width, env.height, show=show) elif gl == "QT": self.gl = QTGL(env.width, env.height) elif gl == "PIL": @@ -219,17 +222,19 @@ class RenderTool(object): if static: color = self.gl.adaptColor(color, lighten=True) + color = color + # print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos) - self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location + self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100) # agent location xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient. - self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6) + self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6) if selected: self._draw_square(xyPos, 1, color) if target is not None: rcTarget = array(target) xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf - self._draw_square(xyTarget, 1 / 3, color) + self._draw_square(xyTarget, 1 / 3, color, layer=1) def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None): """ @@ -397,6 +402,13 @@ class RenderTool(object): visit = visit.prev xyPrev = xy + def drawTrans(self, oFrom, oTo, sColor="gray"): + self.gl.plot( + [oFrom[0], oTo[0]], # x + [oFrom[1], oTo[1]], # y + color=sColor + ) + def drawTrans2( self, xyLine, xyCentre, @@ -489,47 +501,14 @@ class RenderTool(object): for visited_cell in observation_dict[agent]: cell_coord = array(visited_cell[:2]) cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf - self._draw_square(cell_coord_trans, 1 / 3, color) - - def renderEnv( - self, show=False, curves=True, spacing=False, - arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False, iEpisode=None, iStep=None, - iSelectedAgent=None, action_dict=None): - """ - Draw the environment using matplotlib. - Draw into the figure if provided. - - Call pyplot.show() if show==True. - (Use show=False from a Jupyter notebook with %matplotlib inline) - """ - - if not self.gl.is_raster(): - self.renderEnv2(show, curves, spacing, - arrows, agents, sRailColor, - frames, iEpisode, iStep, - iSelectedAgent, action_dict) - return - - # cell_size is a bit pointless with matplotlib - it does not relate to pixels, - # so for now I've changed it to 1 (from 10) - cell_size = 1 - self.gl.beginFrame() + self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100) - # self.gl.clf() - # if oFigure is None: - # oFigure = self.gl.figure() - def drawTrans(oFrom, oTo, sColor="gray"): - self.gl.plot( - [oFrom[0], oTo[0]], # x - [oFrom[1], oTo[1]], # y - color=sColor - ) + def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False): + cell_size = 1 # TODO: remove cell_size env = self.env - # t1 = time.time() - # Draw cells grid grid_color = [0.95, 0.95, 0.95] for r in range(env.height + 1): @@ -613,7 +592,7 @@ class RenderTool(object): rotation, spacing=spacing, bArrow=arrows, sColor=sRailColor) else: - drawTrans(from_xy, to_xy, sRailColor) + self.drawTrans(self, from_xy, to_xy, sRailColor) if False: print( @@ -626,6 +605,54 @@ class RenderTool(object): "rot:", rotation, ) + + + def renderEnv( + self, show=False, curves=True, spacing=False, + arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False, + iEpisode=None, iStep=None, + iSelectedAgent=None, action_dict=None): + """ + Draw the environment using matplotlib. + Draw into the figure if provided. + + Call pyplot.show() if show==True. + (Use show=False from a Jupyter notebook with %matplotlib inline) + """ + + if not self.gl.is_raster(): + self.renderEnv2(show, curves, spacing, + arrows, agents, sRailColor, + frames, iEpisode, iStep, + iSelectedAgent, action_dict) + return + + # cell_size is a bit pointless with matplotlib - it does not relate to pixels, + # so for now I've changed it to 1 (from 10) + cell_size = 1 + + if type(self.gl) in (QTGL, PILGL): + self.gl.beginFrame() + + if type(self.gl) is MPLGL: + #self.gl.clf() + # plt.clf() + self.gl.beginFrame() + pass + + # self.gl.clf() + # if oFigure is None: + # oFigure = self.gl.figure() + + + + env = self.env + + # t1 = time.time() + + + self.renderRail() + # Draw each agent + its orientation + its target if agents: self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent) @@ -657,23 +684,26 @@ class RenderTool(object): # TODO: for MPL, we don't want to call clf (called by endframe) # for QT, we need to call endFrame() # if not show: - self.gl.endFrame() + if type(self.gl) is QTGL: + self.gl.endFrame() + if show: + self.gl.show(block=False) - # t2 = time.time() - # print(t2 - t1, "seconds") + if type(self.gl) is MPLGL: + if show: + self.gl.show(block=False) + #self.gl.endFrame() - if show: - self.gl.show(block=False) - self.gl.pause(0.00001) + self.gl.pause(0.00001) return - def _draw_square(self, center, size, color): + def _draw_square(self, center, size, color, opacity=255, layer=0): x0 = center[0] - size / 2 x1 = center[0] + size / 2 y0 = center[1] - size / 2 y1 = center[1] + size / 2 - self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color) + self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity) def getImage(self): return self.gl.getImage() diff --git a/images/basic-env.npz b/images/basic-env.npz index 356da5d70146b3b8081dd99c0fe5e6bd70646e53..8ffaf023e1116b0c92702212ddb04c71b82f0655 100644 Binary files a/images/basic-env.npz and b/images/basic-env.npz differ diff --git a/tests/test_player.py b/tests/test_player.py index 82f6f67f0fe29af9f7abad433e7c5ddb9df1012f..7b2745f2e372ca80cd2fb5cf9dcaa3db96fb910a 100644 --- a/tests/test_player.py +++ b/tests/test_player.py @@ -1,7 +1,8 @@ -from examples.play_model import main +# from examples.play_model import main +from examples.tkplay import tkmain def test_main(): - main(n_trials=2) + tkmain(n_trials=2) diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index 3259ed387f1a09b7a6a5e73fe7976cc78f02f32f..7d3ddf63226b55217edd28bf25aed0307377433b 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -46,8 +46,8 @@ def test_render_env(save_new_images=False): ) sfTestEnv = "env-data/tests/test1.npy" oEnv.rail.load_transition_map(sfTestEnv) - oRT = rt.RenderTool(oEnv) - oRT.renderEnv() + oRT = rt.RenderTool(oEnv, gl="PIL", show=False) + oRT.renderEnv(show=False) checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)