diff --git a/examples/play_model.py b/examples/play_model.py index 7d7ed1104e7689e4ff49d993476a6f1fc6d4b6e8..08dadeb1d3d1e64b012c00e99594bcefd2905d1c 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -107,7 +107,7 @@ def max_lt(seq, val): return None -def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PIL"): +def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PILSVG"): random.seed(1) np.random.seed(1) @@ -277,4 +277,4 @@ def main_old(render=True, delay=0.0): if __name__ == "__main__": - main(render=True, delay=0) + main(render=True, delay=0.5) diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 949628fcd093c404ffef8d313dbeea4379d86708..25ee27e56420f55308ca9a7047737219337833af 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -4,6 +4,12 @@ from PIL import Image, ImageDraw, ImageTk # , ImageFont import tkinter as tk from numpy import array import numpy as np +# from flatland.utils.svg import Track, Zug +import time +import io +from cairosvg import svg2png +from flatland.core.transitions import RailEnvTransitions +# from copy import copy class PILGL(GraphicsLayer): @@ -11,6 +17,7 @@ class PILGL(GraphicsLayer): self.nPixCell = 60 self.yxBase = (0, 0) self.linewidth = 4 + self.nAgentColors = 1 # overridden in loadAgent # self.tile_size = self.nPixCell self.width = width @@ -30,6 +37,7 @@ class PILGL(GraphicsLayer): self.window_open = False # self.bShow = show self.firstFrame = True + self.create_layers() self.beginFrame() def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): @@ -49,6 +57,20 @@ class PILGL(GraphicsLayer): for x, y in gPoints: self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) + def drawImageXY(self, pil_img, xyPixLeftTop, layer=0): + # self.layers[layer].alpha_composite(pil_img, offset=xyPixLeftTop) + if (pil_img.mode == "RGBA"): + pil_mask = pil_img + else: + pil_mask = None + # print(pil_img, pil_img.mode, xyPixLeftTop, layer) + + self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask) + + def drawImageRC(self, pil_img, rcTopLeft, layer=0): + xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]]) + self.drawImageXY(pil_img, xyPixLeftTop, layer=layer) + def open_window(self): assert self.window_open is False, "Window is already open!" self.window = tk.Tk() @@ -66,8 +88,8 @@ class PILGL(GraphicsLayer): pass def beginFrame(self): - self.create_layer(0) - self.create_layer(1) + # Create a new agent layer + self.create_layer(iLayer=1, clear=True) def show(self, block=False): img = self.alpha_composite_layers() @@ -78,6 +100,7 @@ class PILGL(GraphicsLayer): tkimg = ImageTk.PhotoImage(img) if self.firstFrame: + # Do TK actions for a new panel (not sure what they really do) self.panel = tk.Label(self.window, image=tkimg) self.panel.pack(side="bottom", fill="both", expand="yes") else: @@ -109,7 +132,8 @@ class PILGL(GraphicsLayer): img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity)) return img - def create_layer(self, iLayer=0): + def create_layer(self, iLayer=0, clear=True): + # If we don't have the layers already, create them if len(self.layers) <= iLayer: for i in range(len(self.layers), iLayer+1): if i == 0: @@ -120,7 +144,216 @@ class PILGL(GraphicsLayer): self.layers.append(img) self.draws.append(ImageDraw.Draw(img)) else: - opacity = 0 if iLayer > 0 else 255 - self.layers[iLayer] = img = self.create_image(opacity) - self.draws[iLayer] = ImageDraw.Draw(img) + # We do already have this iLayer. Clear it if requested. + if clear: + opacity = 0 if iLayer > 0 else 255 + self.layers[iLayer] = img = self.create_image(opacity) + # We also need to maintain a Draw object for each layer + self.draws[iLayer] = ImageDraw.Draw(img) + + def create_layers(self, clear=True): + self.create_layer(0, clear=clear) + self.create_layer(1, clear=clear) + + +class PILSVG(PILGL): + def __init__(self, width, height): + print(self, type(self)) + oSuper = super() + print(oSuper, type(oSuper)) + oSuper.__init__(width, height) + + # self.track = self.track = Track() + # self.lwTrack = [] + # self.zug = Zug() + + self.lwAgents = [] + self.agents_prev = [] + + self.loadRailSVGs() + self.loadAgentSVGs() + + def is_raster(self): + return False + + def processEvents(self): + # self.app.processEvents() + time.sleep(0.001) + + def clear_rails(self): + print("Clear rails") + self.create_layers() + self.clear_agents() + + def clear_agents(self): + # print("Clear Agents: ", len(self.lwAgents)) + for wAgent in self.lwAgents: + self.layout.removeWidget(wAgent) + self.lwAgents = [] + self.agents_prev = [] + + def pilFromSvgFile(self, sfPath): + with open(sfPath, "r") as fIn: + bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell) + + with io.BytesIO(bytesPNG) as fIn: + pil_img = Image.open(fIn) + pil_img.load() + # print(pil_img.mode) + + return pil_img + + def pilFromSvgBytes(self, bytesSVG): + bytesPNG = svg2png(bytesSVG, output_height=self.nPixCell, output_width=self.nPixCell) + with io.BytesIO(bytesPNG) as fIn: + pil_img = Image.open(fIn) + return pil_img + + def loadRailSVGs(self): + """ Load the rail SVG images, apply rotations, and store as PIL images. + """ + dFiles = { + "": "Background_#91D1DD.svg", + "WE": "Gleis_Deadend.svg", + "WW EE NN SS": "Gleis_Diamond_Crossing.svg", + "WW EE": "Gleis_horizontal.svg", + "EN SW": "Gleis_Kurve_oben_links.svg", + "WN SE": "Gleis_Kurve_oben_rechts.svg", + "ES NW": "Gleis_Kurve_unten_links.svg", + "NE WS": "Gleis_Kurve_unten_rechts.svg", + "NN SS": "Gleis_vertikal.svg", + "NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg", + "EE WW EN SW": "Weiche_horizontal_oben_links.svg", + "EE WW SE WN": "Weiche_horizontal_oben_rechts.svg", + "EE WW ES NW": "Weiche_horizontal_unten_links.svg", + "EE WW NE WS": "Weiche_horizontal_unten_rechts.svg", + "NN SS EE WW NW ES": "Weiche_Single_Slip.svg", + "NE NW ES WS": "Weiche_Symetrical.svg", + "NN SS EN SW": "Weiche_vertikal_oben_links.svg", + "NN SS SE WN": "Weiche_vertikal_oben_rechts.svg", + "NN SS NW ES": "Weiche_vertikal_unten_links.svg", + "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg"} + + self.dPil = {} + + transitions = RailEnvTransitions() + + lDirs = list("NESW") + + # svgBG = SVG("./svg/Background_#91D1DD.svg") + + for sTrans, sFile in dFiles.items(): + sPathSvg = "./svg/" + sFile + + # Translate the ascii transition descption in the format "NE WS" to the + # binary list of transitions as per RailEnv - NESW (in) x NESW (out) + lTrans16 = ["0"] * 16 + for sTran in sTrans.split(" "): + if len(sTran) == 2: + iDirIn = lDirs.index(sTran[0]) + iDirOut = lDirs.index(sTran[1]) + iTrans = 4 * iDirIn + iDirOut + lTrans16[iTrans] = "1" + sTrans16 = "".join(lTrans16) + binTrans = int(sTrans16, 2) + print(sTrans, sTrans16, sFile) + + # Merge the transition svg image with the background colour. + # This is a shortcut / hack and will need re-working. + # if binTrans > 0: + # svg = svg.merge(svgBG) + + pilRail = self.pilFromSvgFile(sPathSvg) + self.dPil[binTrans] = pilRail + + # Rotate both the transition binary and the image and save in the dict + for nRot in [90, 180, 270]: + binTrans2 = transitions.rotate_transition(binTrans, nRot) + + # PIL rotates anticlockwise for positive theta + pilRail2 = pilRail.rotate(-nRot) + self.dPil[binTrans2] = pilRail2 + + def setRailAt(self, row, col, binTrans): + if binTrans in self.dPil: + pilTrack = self.dPil[binTrans] + self.drawImageRC(pilTrack, (row, col)) + else: + print("Illegal rail:", row, col, format(binTrans, "#018b")[2:]) + + def rgb_s2i(self, sRGB): + """ convert a hex RGB string like 0091ea to 3-tuple of ints """ + return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2]) + + def loadAgentSVGs(self): + + # Seed initial train/zug files indexed by tuple(iDirIn, iDirOut): + dDirsFile = { + (0, 0): "svg/Zug_Gleis_#0091ea.svg", + (1, 2): "svg/Zug_1_Weiche_#0091ea.svg", + (0, 3): "svg/Zug_2_Weiche_#0091ea.svg" + } + + sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \ + "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64" + lColors = sColors.split("#") + self.nAgentColors = len(lColors) + + # "paint" color of the train images we load + a_base_color = self.rgb_s2i("0091ea") + + self.dPilZug = {} + + for tDirs, sPathSvg in dDirsFile.items(): + iDirIn, iDirOut = tDirs + + pilZug = self.pilFromSvgFile(sPathSvg) + + # Rotate both the directions and the image and save in the dict + for iDirRot in range(4): + nDegRot = iDirRot * 90 + iDirIn2 = (iDirIn + iDirRot) % 4 + iDirOut2 = (iDirOut + iDirRot) % 4 + + # PIL rotates anticlockwise for positive theta + pilZug2 = pilZug.rotate(-nDegRot) + rgbaZug2 = array(pilZug2) + + for iColor, sColor in enumerate(lColors): + tnNewColor = self.rgb_s2i(sColor) + xy_color_mask = np.all(rgbaZug2[:, :, 0:3] - a_base_color == 0, axis=2) + rgbaZug3 = np.copy(rgbaZug2) + rgbaZug3[xy_color_mask, 0:3] = tnNewColor + self.dPilZug[(iDirIn2, iDirOut2, iColor)] = Image.fromarray(rgbaZug3) + + def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, color=None): + delta_dir = (iDirOut - iDirIn) % 4 + iColor = iAgent % self.nAgentColors + # when flipping direction at a dead end, use the "iDirOut" direction. + if delta_dir == 2: + iDirIn = iDirOut + pilZug = self.dPilZug[(iDirIn % 4, iDirOut % 4, iColor)] + self.drawImageRC(pilZug, (row, col), layer=1) + + +def main2(): + gl = PILSVG(10, 10) + for i in range(10): + gl.beginFrame() + gl.plot([3 + i, 4], [-4 - i, -5], color="r") + gl.endFrame() + time.sleep(1) + + +def main(): + gl = PILSVG(width=10, height=10) + + for i in range(1000): + gl.processEvents() + time.sleep(0.1) + time.sleep(1) + + +if __name__ == "__main__": + main() diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index 73b8ca77a33042bf181097d4b1a0a1afcb48b56e..a8ccc780165fe4bcbf5908969f9284906b2571fe 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -16,10 +16,11 @@ def transform_string_svg(sSVG): bySVG = bytearray(sSVG, encoding='utf-8') return bySVG + def create_QtSvgWidget_from_svg_string(sSVG): svgWidget = QtSvg.QSvgWidget() ret = svgWidget.renderer().load(transform_string_svg(sSVG)) - if ret == False: + if ret is False: print("create_QtSvgWidget_from_svg_string : failed to parse:", sSVG) return svgWidget @@ -132,26 +133,7 @@ class QTSVG(GraphicsLayer): self.lwAgents = [] self.agents_prev = [] - svgWidget = None - - iArt = 0 - iCol = 0 - iRow = 0 - nCols = 10 - - if False: - for binTrans in self.track.dSvg.keys(): - sSVG = self.track.dSvg[binTrans].to_string() - self.layout.addWidget(create_QtSvgWidget_from_svg_string(sSVG), iRow, iCol) - - iArt += 1 - iRow = int(iArt / nCols) - iCol = iArt % nCols - - svgWidget2 = QtSvg.QSvgWidget() - svgWidget2.renderer().load(bySVG) - - self.layout.addWidget(svgWidget2, 0, 0) + # svgWidget = None def is_raster(self): return False diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 1aa1748f6b61129509584c980a73086b87300d4e..9b99d1ecffb7bf0974bc70a63ed89ff687e7434e 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -3,14 +3,12 @@ from collections import deque # import xarray as xr import matplotlib.pyplot as plt -import numpy as np -from numpy import array -from recordtype import recordtype - -from flatland.utils.graphics_layer import GraphicsLayer -from flatland.utils.graphics_pil import PILGL from flatland.utils.render_qt import QTGL, QTSVG - +from flatland.utils.graphics_pil import PILGL, PILSVG +from flatland.utils.graphics_layer import GraphicsLayer +import recordtype +from numpy import array +import numpy as np # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! @@ -133,6 +131,8 @@ class RenderTool(object): self.gl = QTGL(env.width, env.height) elif gl == "PIL": self.gl = PILGL(env.width, env.height) + elif gl == "PILSVG": + self.gl = PILSVG(env.width, env.height) elif gl == "QTSVG": self.gl = QTSVG(env.width, env.height) @@ -618,11 +618,11 @@ class RenderTool(object): """ if not self.gl.is_raster(): - self.renderEnv2(show, curves, spacing, - arrows, agents, renderobs,sRailColor, - frames, iEpisode, iStep, - iSelectedAgent, action_dict) - + self.renderEnv2(show=show, curves=curves, spacing=spacing, + arrows=arrows, agents=agents, show_observations=show_observations, + sRailColor=sRailColor, + frames=frames, iEpisode=iEpisode, iStep=iStep, + iSelectedAgent=iSelectedAgent, action_dict=action_dict) return if type(self.gl) in (QTGL, PILGL): @@ -728,9 +728,11 @@ class RenderTool(object): gP0 = array([gX1, gY1, gZ1]) - def renderEnv2(self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True, - sRailColor="gray", frames=False, iEpisode=None, iStep=None, iSelectedAgent=None, - action_dict=dict()): + def renderEnv2( + self, show=False, curves=True, spacing=False, arrows=False, agents=True, + show_observations=True, sRailColor="gray", + frames=False, iEpisode=None, iStep=None, iSelectedAgent=None, + action_dict=dict()): """ Draw the environment using matplotlib. Draw into the figure if provided. @@ -741,6 +743,8 @@ class RenderTool(object): env = self.env + self.gl.beginFrame() + if self.new_rail: self.new_rail = False self.gl.clear_rails() @@ -766,7 +770,6 @@ class RenderTool(object): iAction = action_dict[iAgent] new_direction, action_isValid = self.env.check_action(agent, iAction) - # ** TODO *** # why should we only update if the action is valid ? if True: @@ -779,7 +782,8 @@ class RenderTool(object): else: self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction, color=oColor) - self.gl.show() + if show: + self.gl.show() for i in range(3): self.gl.processEvents() diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py index fb8b987cae83f4b8a3696e0e9496659045709c17..89730af2c41a442e65d8c5976f24b1ea84468330 100644 --- a/flatland/utils/svg.py +++ b/flatland/utils/svg.py @@ -104,6 +104,13 @@ class Zug(object): class Track(object): + """ Class to load and hold SVG track images. + Creates a mapping between + - cell entry and exit directions (ie transitions), and + - specific images provided by the SBB graphic artist. + The directions and images are also rotated by 90, 180 & 270 degrees. + (There is some redundancy in this process, given the images provided) + """ def __init__(self): dFiles = { "": "Background_#9CCB89.svg", @@ -138,6 +145,8 @@ class Track(object): for sTrans, sFile in dFiles.items(): svg = SVG("./svg/" + sFile) + # Translate the ascii transition descption in the format "NE WS" to the + # binary list of transitions as per RailEnv - NESW (in) x NESW (out) lTrans16 = ["0"] * 16 for sTran in sTrans.split(" "): if len(sTran) == 2: @@ -149,11 +158,14 @@ class Track(object): binTrans = int(sTrans16, 2) print(sTrans, sTrans16, sFile) + # Merge the transition svg image with the background colour. + # This is a shortcut / hack and will need re-working. if binTrans > 0: svg = svg.merge(svgBG) self.dSvg[binTrans] = svg + # Rotate both the transition binary and the image and save in the dict for nRot in [90, 180, 270]: binTrans2 = transitions.rotate_transition(binTrans, nRot) svg2 = svg.copy() diff --git a/images/basic-env.npz b/images/basic-env.npz index 8ffaf023e1116b0c92702212ddb04c71b82f0655..e645113154b9e953575a12dcbadf4f9f3195b4ad 100644 Binary files a/images/basic-env.npz and b/images/basic-env.npz differ diff --git a/tests/test_player.py b/tests/test_player.py index a0e580b92ee04f41ec1bab7c4e99da1339767c96..668afb96a502ce5c25fa1e6d6fe9383a1facbf87 100644 --- a/tests/test_player.py +++ b/tests/test_player.py @@ -4,5 +4,9 @@ from examples.play_model import main def test_main(): - main(render=True, n_steps=20, n_trials=2, sGL="PIL") + main(render=True, n_steps=20, n_trials=2, sGL="PILSVG") + # main(render=True, n_steps=20, n_trials=2, sGL="PIL") + +if __name__ == "__main__": + test_main() diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index 8204a305328df746a772d034f3c763c848cceb93..1bfce0323322638447ed11191e7d5d9bbea565b5 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -12,6 +12,7 @@ import numpy as np import flatland.utils.rendertools as rt from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, random_rail_generator +from flatland.envs.generators import empty_rail_generator def checkFrozenImage(oRT, sFileImage, resave=False): @@ -39,14 +40,15 @@ def test_render_env(save_new_images=False): # random.seed(100) np.random.seed(100) oEnv = RailEnv(width=10, height=10, - rail_generator=random_rail_generator(), + # rail_generator=random_rail_generator(), + rail_generator=empty_rail_generator(), number_of_agents=0, # obs_builder_object=GlobalObsForRailEnv()) obs_builder_object=TreeObsForRailEnv(max_depth=2) ) sfTestEnv = "env-data/tests/test1.npy" oEnv.rail.load_transition_map(sfTestEnv) - oRT = rt.RenderTool(oEnv, gl="PIL", show=False) + oRT = rt.RenderTool(oEnv, gl="PILSVG", show=False) oRT.renderEnv(show=False) checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) @@ -82,6 +84,7 @@ def main(): test_render_env(save_new_images=True) else: print("Run 'python test_rendertools.py save' to regenerate images") + test_render_env() if __name__ == "__main__":