Skip to content
Snippets Groups Projects
Commit 68fde32f authored by hagrid67's avatar hagrid67
Browse files

SVG with PIL

parent dcc27491
No related branches found
No related tags found
No related merge requests found
...@@ -107,7 +107,7 @@ def max_lt(seq, val): ...@@ -107,7 +107,7 @@ def max_lt(seq, val):
return None return None
def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PIL"): def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PILSVG"):
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
...@@ -277,4 +277,4 @@ def main_old(render=True, delay=0.0): ...@@ -277,4 +277,4 @@ def main_old(render=True, delay=0.0):
if __name__ == "__main__": if __name__ == "__main__":
main(render=True, delay=0) main(render=True, delay=0.5)
...@@ -4,6 +4,12 @@ from PIL import Image, ImageDraw, ImageTk # , ImageFont ...@@ -4,6 +4,12 @@ from PIL import Image, ImageDraw, ImageTk # , ImageFont
import tkinter as tk import tkinter as tk
from numpy import array from numpy import array
import numpy as np import numpy as np
# from flatland.utils.svg import Track, Zug
import time
import io
from cairosvg import svg2png
from flatland.core.transitions import RailEnvTransitions
# from copy import copy
class PILGL(GraphicsLayer): class PILGL(GraphicsLayer):
...@@ -11,6 +17,7 @@ class PILGL(GraphicsLayer): ...@@ -11,6 +17,7 @@ class PILGL(GraphicsLayer):
self.nPixCell = 60 self.nPixCell = 60
self.yxBase = (0, 0) self.yxBase = (0, 0)
self.linewidth = 4 self.linewidth = 4
self.nAgentColors = 1 # overridden in loadAgent
# self.tile_size = self.nPixCell # self.tile_size = self.nPixCell
self.width = width self.width = width
...@@ -30,6 +37,7 @@ class PILGL(GraphicsLayer): ...@@ -30,6 +37,7 @@ class PILGL(GraphicsLayer):
self.window_open = False self.window_open = False
# self.bShow = show # self.bShow = show
self.firstFrame = True self.firstFrame = True
self.create_layers()
self.beginFrame() self.beginFrame()
def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
...@@ -49,6 +57,20 @@ class PILGL(GraphicsLayer): ...@@ -49,6 +57,20 @@ class PILGL(GraphicsLayer):
for x, y in gPoints: for x, y in gPoints:
self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
def drawImageXY(self, pil_img, xyPixLeftTop, layer=0):
# self.layers[layer].alpha_composite(pil_img, offset=xyPixLeftTop)
if (pil_img.mode == "RGBA"):
pil_mask = pil_img
else:
pil_mask = None
# print(pil_img, pil_img.mode, xyPixLeftTop, layer)
self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask)
def drawImageRC(self, pil_img, rcTopLeft, layer=0):
xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
self.drawImageXY(pil_img, xyPixLeftTop, layer=layer)
def open_window(self): def open_window(self):
assert self.window_open is False, "Window is already open!" assert self.window_open is False, "Window is already open!"
self.window = tk.Tk() self.window = tk.Tk()
...@@ -66,8 +88,8 @@ class PILGL(GraphicsLayer): ...@@ -66,8 +88,8 @@ class PILGL(GraphicsLayer):
pass pass
def beginFrame(self): def beginFrame(self):
self.create_layer(0) # Create a new agent layer
self.create_layer(1) self.create_layer(iLayer=1, clear=True)
def show(self, block=False): def show(self, block=False):
img = self.alpha_composite_layers() img = self.alpha_composite_layers()
...@@ -78,6 +100,7 @@ class PILGL(GraphicsLayer): ...@@ -78,6 +100,7 @@ class PILGL(GraphicsLayer):
tkimg = ImageTk.PhotoImage(img) tkimg = ImageTk.PhotoImage(img)
if self.firstFrame: if self.firstFrame:
# Do TK actions for a new panel (not sure what they really do)
self.panel = tk.Label(self.window, image=tkimg) self.panel = tk.Label(self.window, image=tkimg)
self.panel.pack(side="bottom", fill="both", expand="yes") self.panel.pack(side="bottom", fill="both", expand="yes")
else: else:
...@@ -109,7 +132,8 @@ class PILGL(GraphicsLayer): ...@@ -109,7 +132,8 @@ class PILGL(GraphicsLayer):
img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity)) img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity))
return img return img
def create_layer(self, iLayer=0): def create_layer(self, iLayer=0, clear=True):
# If we don't have the layers already, create them
if len(self.layers) <= iLayer: if len(self.layers) <= iLayer:
for i in range(len(self.layers), iLayer+1): for i in range(len(self.layers), iLayer+1):
if i == 0: if i == 0:
...@@ -120,7 +144,216 @@ class PILGL(GraphicsLayer): ...@@ -120,7 +144,216 @@ class PILGL(GraphicsLayer):
self.layers.append(img) self.layers.append(img)
self.draws.append(ImageDraw.Draw(img)) self.draws.append(ImageDraw.Draw(img))
else: else:
opacity = 0 if iLayer > 0 else 255 # We do already have this iLayer. Clear it if requested.
self.layers[iLayer] = img = self.create_image(opacity) if clear:
self.draws[iLayer] = ImageDraw.Draw(img) opacity = 0 if iLayer > 0 else 255
self.layers[iLayer] = img = self.create_image(opacity)
# We also need to maintain a Draw object for each layer
self.draws[iLayer] = ImageDraw.Draw(img)
def create_layers(self, clear=True):
self.create_layer(0, clear=clear)
self.create_layer(1, clear=clear)
class PILSVG(PILGL):
def __init__(self, width, height):
print(self, type(self))
oSuper = super()
print(oSuper, type(oSuper))
oSuper.__init__(width, height)
# self.track = self.track = Track()
# self.lwTrack = []
# self.zug = Zug()
self.lwAgents = []
self.agents_prev = []
self.loadRailSVGs()
self.loadAgentSVGs()
def is_raster(self):
return False
def processEvents(self):
# self.app.processEvents()
time.sleep(0.001)
def clear_rails(self):
print("Clear rails")
self.create_layers()
self.clear_agents()
def clear_agents(self):
# print("Clear Agents: ", len(self.lwAgents))
for wAgent in self.lwAgents:
self.layout.removeWidget(wAgent)
self.lwAgents = []
self.agents_prev = []
def pilFromSvgFile(self, sfPath):
with open(sfPath, "r") as fIn:
bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell)
with io.BytesIO(bytesPNG) as fIn:
pil_img = Image.open(fIn)
pil_img.load()
# print(pil_img.mode)
return pil_img
def pilFromSvgBytes(self, bytesSVG):
bytesPNG = svg2png(bytesSVG, output_height=self.nPixCell, output_width=self.nPixCell)
with io.BytesIO(bytesPNG) as fIn:
pil_img = Image.open(fIn)
return pil_img
def loadRailSVGs(self):
""" Load the rail SVG images, apply rotations, and store as PIL images.
"""
dFiles = {
"": "Background_#91D1DD.svg",
"WE": "Gleis_Deadend.svg",
"WW EE NN SS": "Gleis_Diamond_Crossing.svg",
"WW EE": "Gleis_horizontal.svg",
"EN SW": "Gleis_Kurve_oben_links.svg",
"WN SE": "Gleis_Kurve_oben_rechts.svg",
"ES NW": "Gleis_Kurve_unten_links.svg",
"NE WS": "Gleis_Kurve_unten_rechts.svg",
"NN SS": "Gleis_vertikal.svg",
"NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg",
"EE WW EN SW": "Weiche_horizontal_oben_links.svg",
"EE WW SE WN": "Weiche_horizontal_oben_rechts.svg",
"EE WW ES NW": "Weiche_horizontal_unten_links.svg",
"EE WW NE WS": "Weiche_horizontal_unten_rechts.svg",
"NN SS EE WW NW ES": "Weiche_Single_Slip.svg",
"NE NW ES WS": "Weiche_Symetrical.svg",
"NN SS EN SW": "Weiche_vertikal_oben_links.svg",
"NN SS SE WN": "Weiche_vertikal_oben_rechts.svg",
"NN SS NW ES": "Weiche_vertikal_unten_links.svg",
"NN SS NE WS": "Weiche_vertikal_unten_rechts.svg"}
self.dPil = {}
transitions = RailEnvTransitions()
lDirs = list("NESW")
# svgBG = SVG("./svg/Background_#91D1DD.svg")
for sTrans, sFile in dFiles.items():
sPathSvg = "./svg/" + sFile
# Translate the ascii transition descption in the format "NE WS" to the
# binary list of transitions as per RailEnv - NESW (in) x NESW (out)
lTrans16 = ["0"] * 16
for sTran in sTrans.split(" "):
if len(sTran) == 2:
iDirIn = lDirs.index(sTran[0])
iDirOut = lDirs.index(sTran[1])
iTrans = 4 * iDirIn + iDirOut
lTrans16[iTrans] = "1"
sTrans16 = "".join(lTrans16)
binTrans = int(sTrans16, 2)
print(sTrans, sTrans16, sFile)
# Merge the transition svg image with the background colour.
# This is a shortcut / hack and will need re-working.
# if binTrans > 0:
# svg = svg.merge(svgBG)
pilRail = self.pilFromSvgFile(sPathSvg)
self.dPil[binTrans] = pilRail
# Rotate both the transition binary and the image and save in the dict
for nRot in [90, 180, 270]:
binTrans2 = transitions.rotate_transition(binTrans, nRot)
# PIL rotates anticlockwise for positive theta
pilRail2 = pilRail.rotate(-nRot)
self.dPil[binTrans2] = pilRail2
def setRailAt(self, row, col, binTrans):
if binTrans in self.dPil:
pilTrack = self.dPil[binTrans]
self.drawImageRC(pilTrack, (row, col))
else:
print("Illegal rail:", row, col, format(binTrans, "#018b")[2:])
def rgb_s2i(self, sRGB):
""" convert a hex RGB string like 0091ea to 3-tuple of ints """
return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2])
def loadAgentSVGs(self):
# Seed initial train/zug files indexed by tuple(iDirIn, iDirOut):
dDirsFile = {
(0, 0): "svg/Zug_Gleis_#0091ea.svg",
(1, 2): "svg/Zug_1_Weiche_#0091ea.svg",
(0, 3): "svg/Zug_2_Weiche_#0091ea.svg"
}
sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
"#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
lColors = sColors.split("#")
self.nAgentColors = len(lColors)
# "paint" color of the train images we load
a_base_color = self.rgb_s2i("0091ea")
self.dPilZug = {}
for tDirs, sPathSvg in dDirsFile.items():
iDirIn, iDirOut = tDirs
pilZug = self.pilFromSvgFile(sPathSvg)
# Rotate both the directions and the image and save in the dict
for iDirRot in range(4):
nDegRot = iDirRot * 90
iDirIn2 = (iDirIn + iDirRot) % 4
iDirOut2 = (iDirOut + iDirRot) % 4
# PIL rotates anticlockwise for positive theta
pilZug2 = pilZug.rotate(-nDegRot)
rgbaZug2 = array(pilZug2)
for iColor, sColor in enumerate(lColors):
tnNewColor = self.rgb_s2i(sColor)
xy_color_mask = np.all(rgbaZug2[:, :, 0:3] - a_base_color == 0, axis=2)
rgbaZug3 = np.copy(rgbaZug2)
rgbaZug3[xy_color_mask, 0:3] = tnNewColor
self.dPilZug[(iDirIn2, iDirOut2, iColor)] = Image.fromarray(rgbaZug3)
def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, color=None):
delta_dir = (iDirOut - iDirIn) % 4
iColor = iAgent % self.nAgentColors
# when flipping direction at a dead end, use the "iDirOut" direction.
if delta_dir == 2:
iDirIn = iDirOut
pilZug = self.dPilZug[(iDirIn % 4, iDirOut % 4, iColor)]
self.drawImageRC(pilZug, (row, col), layer=1)
def main2():
gl = PILSVG(10, 10)
for i in range(10):
gl.beginFrame()
gl.plot([3 + i, 4], [-4 - i, -5], color="r")
gl.endFrame()
time.sleep(1)
def main():
gl = PILSVG(width=10, height=10)
for i in range(1000):
gl.processEvents()
time.sleep(0.1)
time.sleep(1)
if __name__ == "__main__":
main()
...@@ -16,10 +16,11 @@ def transform_string_svg(sSVG): ...@@ -16,10 +16,11 @@ def transform_string_svg(sSVG):
bySVG = bytearray(sSVG, encoding='utf-8') bySVG = bytearray(sSVG, encoding='utf-8')
return bySVG return bySVG
def create_QtSvgWidget_from_svg_string(sSVG): def create_QtSvgWidget_from_svg_string(sSVG):
svgWidget = QtSvg.QSvgWidget() svgWidget = QtSvg.QSvgWidget()
ret = svgWidget.renderer().load(transform_string_svg(sSVG)) ret = svgWidget.renderer().load(transform_string_svg(sSVG))
if ret == False: if ret is False:
print("create_QtSvgWidget_from_svg_string : failed to parse:", sSVG) print("create_QtSvgWidget_from_svg_string : failed to parse:", sSVG)
return svgWidget return svgWidget
...@@ -132,26 +133,7 @@ class QTSVG(GraphicsLayer): ...@@ -132,26 +133,7 @@ class QTSVG(GraphicsLayer):
self.lwAgents = [] self.lwAgents = []
self.agents_prev = [] self.agents_prev = []
svgWidget = None # svgWidget = None
iArt = 0
iCol = 0
iRow = 0
nCols = 10
if False:
for binTrans in self.track.dSvg.keys():
sSVG = self.track.dSvg[binTrans].to_string()
self.layout.addWidget(create_QtSvgWidget_from_svg_string(sSVG), iRow, iCol)
iArt += 1
iRow = int(iArt / nCols)
iCol = iArt % nCols
svgWidget2 = QtSvg.QSvgWidget()
svgWidget2.renderer().load(bySVG)
self.layout.addWidget(svgWidget2, 0, 0)
def is_raster(self): def is_raster(self):
return False return False
......
...@@ -3,14 +3,12 @@ from collections import deque ...@@ -3,14 +3,12 @@ from collections import deque
# import xarray as xr # import xarray as xr
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
from numpy import array
from recordtype import recordtype
from flatland.utils.graphics_layer import GraphicsLayer
from flatland.utils.graphics_pil import PILGL
from flatland.utils.render_qt import QTGL, QTSVG from flatland.utils.render_qt import QTGL, QTSVG
from flatland.utils.graphics_pil import PILGL, PILSVG
from flatland.utils.graphics_layer import GraphicsLayer
import recordtype
from numpy import array
import numpy as np
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
...@@ -133,6 +131,8 @@ class RenderTool(object): ...@@ -133,6 +131,8 @@ class RenderTool(object):
self.gl = QTGL(env.width, env.height) self.gl = QTGL(env.width, env.height)
elif gl == "PIL": elif gl == "PIL":
self.gl = PILGL(env.width, env.height) self.gl = PILGL(env.width, env.height)
elif gl == "PILSVG":
self.gl = PILSVG(env.width, env.height)
elif gl == "QTSVG": elif gl == "QTSVG":
self.gl = QTSVG(env.width, env.height) self.gl = QTSVG(env.width, env.height)
...@@ -618,11 +618,11 @@ class RenderTool(object): ...@@ -618,11 +618,11 @@ class RenderTool(object):
""" """
if not self.gl.is_raster(): if not self.gl.is_raster():
self.renderEnv2(show, curves, spacing, self.renderEnv2(show=show, curves=curves, spacing=spacing,
arrows, agents, renderobs,sRailColor, arrows=arrows, agents=agents, show_observations=show_observations,
frames, iEpisode, iStep, sRailColor=sRailColor,
iSelectedAgent, action_dict) frames=frames, iEpisode=iEpisode, iStep=iStep,
iSelectedAgent=iSelectedAgent, action_dict=action_dict)
return return
if type(self.gl) in (QTGL, PILGL): if type(self.gl) in (QTGL, PILGL):
...@@ -728,9 +728,11 @@ class RenderTool(object): ...@@ -728,9 +728,11 @@ class RenderTool(object):
gP0 = array([gX1, gY1, gZ1]) gP0 = array([gX1, gY1, gZ1])
def renderEnv2(self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True, def renderEnv2(
sRailColor="gray", frames=False, iEpisode=None, iStep=None, iSelectedAgent=None, self, show=False, curves=True, spacing=False, arrows=False, agents=True,
action_dict=dict()): show_observations=True, sRailColor="gray",
frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
action_dict=dict()):
""" """
Draw the environment using matplotlib. Draw the environment using matplotlib.
Draw into the figure if provided. Draw into the figure if provided.
...@@ -741,6 +743,8 @@ class RenderTool(object): ...@@ -741,6 +743,8 @@ class RenderTool(object):
env = self.env env = self.env
self.gl.beginFrame()
if self.new_rail: if self.new_rail:
self.new_rail = False self.new_rail = False
self.gl.clear_rails() self.gl.clear_rails()
...@@ -766,7 +770,6 @@ class RenderTool(object): ...@@ -766,7 +770,6 @@ class RenderTool(object):
iAction = action_dict[iAgent] iAction = action_dict[iAgent]
new_direction, action_isValid = self.env.check_action(agent, iAction) new_direction, action_isValid = self.env.check_action(agent, iAction)
# ** TODO *** # ** TODO ***
# why should we only update if the action is valid ? # why should we only update if the action is valid ?
if True: if True:
...@@ -779,7 +782,8 @@ class RenderTool(object): ...@@ -779,7 +782,8 @@ class RenderTool(object):
else: else:
self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction, color=oColor) self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction, color=oColor)
self.gl.show() if show:
self.gl.show()
for i in range(3): for i in range(3):
self.gl.processEvents() self.gl.processEvents()
......
...@@ -104,6 +104,13 @@ class Zug(object): ...@@ -104,6 +104,13 @@ class Zug(object):
class Track(object): class Track(object):
""" Class to load and hold SVG track images.
Creates a mapping between
- cell entry and exit directions (ie transitions), and
- specific images provided by the SBB graphic artist.
The directions and images are also rotated by 90, 180 & 270 degrees.
(There is some redundancy in this process, given the images provided)
"""
def __init__(self): def __init__(self):
dFiles = { dFiles = {
"": "Background_#9CCB89.svg", "": "Background_#9CCB89.svg",
...@@ -138,6 +145,8 @@ class Track(object): ...@@ -138,6 +145,8 @@ class Track(object):
for sTrans, sFile in dFiles.items(): for sTrans, sFile in dFiles.items():
svg = SVG("./svg/" + sFile) svg = SVG("./svg/" + sFile)
# Translate the ascii transition descption in the format "NE WS" to the
# binary list of transitions as per RailEnv - NESW (in) x NESW (out)
lTrans16 = ["0"] * 16 lTrans16 = ["0"] * 16
for sTran in sTrans.split(" "): for sTran in sTrans.split(" "):
if len(sTran) == 2: if len(sTran) == 2:
...@@ -149,11 +158,14 @@ class Track(object): ...@@ -149,11 +158,14 @@ class Track(object):
binTrans = int(sTrans16, 2) binTrans = int(sTrans16, 2)
print(sTrans, sTrans16, sFile) print(sTrans, sTrans16, sFile)
# Merge the transition svg image with the background colour.
# This is a shortcut / hack and will need re-working.
if binTrans > 0: if binTrans > 0:
svg = svg.merge(svgBG) svg = svg.merge(svgBG)
self.dSvg[binTrans] = svg self.dSvg[binTrans] = svg
# Rotate both the transition binary and the image and save in the dict
for nRot in [90, 180, 270]: for nRot in [90, 180, 270]:
binTrans2 = transitions.rotate_transition(binTrans, nRot) binTrans2 = transitions.rotate_transition(binTrans, nRot)
svg2 = svg.copy() svg2 = svg.copy()
......
No preview for this file type
...@@ -4,5 +4,9 @@ from examples.play_model import main ...@@ -4,5 +4,9 @@ from examples.play_model import main
def test_main(): def test_main():
main(render=True, n_steps=20, n_trials=2, sGL="PIL") main(render=True, n_steps=20, n_trials=2, sGL="PILSVG")
# main(render=True, n_steps=20, n_trials=2, sGL="PIL")
if __name__ == "__main__":
test_main()
...@@ -12,6 +12,7 @@ import numpy as np ...@@ -12,6 +12,7 @@ import numpy as np
import flatland.utils.rendertools as rt import flatland.utils.rendertools as rt
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv, random_rail_generator from flatland.envs.rail_env import RailEnv, random_rail_generator
from flatland.envs.generators import empty_rail_generator
def checkFrozenImage(oRT, sFileImage, resave=False): def checkFrozenImage(oRT, sFileImage, resave=False):
...@@ -39,14 +40,15 @@ def test_render_env(save_new_images=False): ...@@ -39,14 +40,15 @@ def test_render_env(save_new_images=False):
# random.seed(100) # random.seed(100)
np.random.seed(100) np.random.seed(100)
oEnv = RailEnv(width=10, height=10, oEnv = RailEnv(width=10, height=10,
rail_generator=random_rail_generator(), # rail_generator=random_rail_generator(),
rail_generator=empty_rail_generator(),
number_of_agents=0, number_of_agents=0,
# obs_builder_object=GlobalObsForRailEnv()) # obs_builder_object=GlobalObsForRailEnv())
obs_builder_object=TreeObsForRailEnv(max_depth=2) obs_builder_object=TreeObsForRailEnv(max_depth=2)
) )
sfTestEnv = "env-data/tests/test1.npy" sfTestEnv = "env-data/tests/test1.npy"
oEnv.rail.load_transition_map(sfTestEnv) oEnv.rail.load_transition_map(sfTestEnv)
oRT = rt.RenderTool(oEnv, gl="PIL", show=False) oRT = rt.RenderTool(oEnv, gl="PILSVG", show=False)
oRT.renderEnv(show=False) oRT.renderEnv(show=False)
checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
...@@ -82,6 +84,7 @@ def main(): ...@@ -82,6 +84,7 @@ def main():
test_render_env(save_new_images=True) test_render_env(save_new_images=True)
else: else:
print("Run 'python test_rendertools.py save' to regenerate images") print("Run 'python test_rendertools.py save' to regenerate images")
test_render_env()
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment