Skip to content
Snippets Groups Projects
Forked from Flatland / Flatland
2316 commits behind the upstream repository.
graphics_pil.py 13.38 KiB
from flatland.utils.graphics_layer import GraphicsLayer
from PIL import Image, ImageDraw, ImageTk   # , ImageFont
import tkinter as tk
from numpy import array
import numpy as np
# from flatland.utils.svg import Track, Zug
import time
import io
from cairosvg import svg2png
from flatland.core.transitions import RailEnvTransitions
# from copy import copy


class PILGL(GraphicsLayer):
    def __init__(self, width, height, nPixCell=60):
        self.nPixCell = 60
        self.yxBase = (0, 0)
        self.linewidth = 4
        self.nAgentColors = 1  # overridden in loadAgent
        # self.tile_size = self.nPixCell

        self.width = width
        self.height = height

        # Total grid size at native scale
        self.widthPx = self.width * self.nPixCell + self.linewidth
        self.heightPx = self.height * self.nPixCell + self.linewidth
        self.layers = []
        self.draws = []

        self.tColBg = (255, 255, 255)     # white background
        # self.tColBg = (220, 120, 40)    # background color
        self.tColRail = (0, 0, 0)         # black rails
        self.tColGrid = (230,) * 3        # light grey for grid

        self.window_open = False
        # self.bShow = show
        self.firstFrame = True
        self.create_layers()
        # self.beginFrame()

    def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
        color = self.adaptColor(color)
        if len(color) == 3:
            color += (opacity,)
        elif len(color) == 4:
            color = color[:3] + (opacity,)
        gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell
        gPoints = list(gPoints.ravel())
        self.draws[layer].line(gPoints, fill=color, width=self.linewidth)

    def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs):
        color = self.adaptColor(color)
        r = np.sqrt(s)
        gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
        for x, y in gPoints:
            self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)

    def drawImageXY(self, pil_img, xyPixLeftTop, layer=0):
        # self.layers[layer].alpha_composite(pil_img, offset=xyPixLeftTop)
        if (pil_img.mode == "RGBA"): 
            pil_mask = pil_img
        else:
            pil_mask = None
            # print(pil_img, pil_img.mode, xyPixLeftTop, layer)
        
        self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask)

    def drawImageRC(self, pil_img, rcTopLeft, layer=0):
        xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
        self.drawImageXY(pil_img, xyPixLeftTop, layer=layer)

    def open_window(self):
        assert self.window_open is False, "Window is already open!"
        self.window = tk.Tk()
        self.window.title("Flatland")
        self.window.configure(background='grey')
        self.window_open = True

    def close_window(self):
        self.panel.destroy()
        self.window.quit()
        self.window.destroy()

    def text(self, *args, **kwargs):
        pass

    def prettify(self, *args, **kwargs):
        pass

    def prettify2(self, width, height, cell_size):
        pass

    def beginFrame(self):
        # Create a new agent layer
        self.create_layer(iLayer=1, clear=True)

    def show(self, block=False):
        img = self.alpha_composite_layers()
        
        if not self.window_open:
            self.open_window()
        
        tkimg = ImageTk.PhotoImage(img)
        
        if self.firstFrame:
            # Do TK actions for a new panel (not sure what they really do)
            self.panel = tk.Label(self.window, image=tkimg)
            self.panel.pack(side="bottom", fill="both", expand="yes")
        else:
            # update the image in situ
            self.panel.configure(image=tkimg)
            self.panel.image = tkimg

        self.window.update()
        self.firstFrame = False

    def pause(self, seconds=0.00001):
        pass
        # plt.pause(seconds)

    def alpha_composite_layers(self):
        img = self.layers[0]
        for img2 in self.layers[1:]:
            img = Image.alpha_composite(img, img2)
        return img

    def getImage(self):
        """ return a blended / alpha composited image composed of all the layers,
            with layer 0 at the "back".
        """
        img = self.alpha_composite_layers()
        return array(img)

    def create_image(self, opacity=255):
        img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity))
        return img

    def create_layer(self, iLayer=0, clear=True):
        # If we don't have the layers already, create them
        if len(self.layers) <= iLayer:
            for i in range(len(self.layers), iLayer+1):
                if i == 0:
                    opacity = 255  # "bottom" layer is opaque (for rails)
                else:
                    opacity = 0   # subsequent layers are transparent
                img = self.create_image(opacity)
                self.layers.append(img)
                self.draws.append(ImageDraw.Draw(img))
        else:
            # We do already have this iLayer.  Clear it if requested.
            if clear:
                opacity = 0 if iLayer > 0 else 255
                self.layers[iLayer] = img = self.create_image(opacity)
                # We also need to maintain a Draw object for each layer
                self.draws[iLayer] = ImageDraw.Draw(img)

    def create_layers(self, clear=True):        
        self.create_layer(0, clear=clear)
        self.create_layer(1, clear=clear)


class PILSVG(PILGL):
    def __init__(self, width, height):
        print(self, type(self))
        oSuper = super()
        print(oSuper, type(oSuper))
        oSuper.__init__(width, height)

        # self.track = self.track = Track()
        # self.lwTrack = []
        # self.zug = Zug()

        self.lwAgents = []
        self.agents_prev = []

        self.loadRailSVGs()
        self.loadAgentSVGs()

    def is_raster(self):
        return False

    def processEvents(self):
        # self.app.processEvents()
        time.sleep(0.001)

    def clear_rails(self):
        print("Clear rails")
        self.create_layers()
        self.clear_agents()

    def clear_agents(self):
        # print("Clear Agents: ", len(self.lwAgents))
        for wAgent in self.lwAgents:
            self.layout.removeWidget(wAgent)
        self.lwAgents = []
        self.agents_prev = []

    def pilFromSvgFile(self, sfPath):
        with open(sfPath, "r") as fIn:
            bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell)
        
        with io.BytesIO(bytesPNG) as fIn:
            pil_img = Image.open(fIn)
            pil_img.load()
            # print(pil_img.mode)
        
        return pil_img

    def pilFromSvgBytes(self, bytesSVG):
        bytesPNG = svg2png(bytesSVG, output_height=self.nPixCell, output_width=self.nPixCell)
        with io.BytesIO(bytesPNG) as fIn:
            pil_img = Image.open(fIn)
            return pil_img

    def loadRailSVGs(self):
        """ Load the rail SVG images, apply rotations, and store as PIL images.
        """
        dRailFiles = {
            "": "Background_#91D1DD.svg",
            "WE": "Gleis_Deadend.svg",
            "WW EE NN SS": "Gleis_Diamond_Crossing.svg",
            "WW EE": "Gleis_horizontal.svg",
            "EN SW": "Gleis_Kurve_oben_links.svg",
            "WN SE": "Gleis_Kurve_oben_rechts.svg",
            "ES NW": "Gleis_Kurve_unten_links.svg",
            "NE WS": "Gleis_Kurve_unten_rechts.svg",
            "NN SS": "Gleis_vertikal.svg",
            "NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg",
            "EE WW EN SW": "Weiche_horizontal_oben_links.svg",
            "EE WW SE WN": "Weiche_horizontal_oben_rechts.svg",
            "EE WW ES NW": "Weiche_horizontal_unten_links.svg",
            "EE WW NE WS": "Weiche_horizontal_unten_rechts.svg",
            "NN SS EE WW NW ES": "Weiche_Single_Slip.svg",
            "NE NW ES WS": "Weiche_Symetrical.svg",
            "NN SS EN SW": "Weiche_vertikal_oben_links.svg",
            "NN SS SE WN": "Weiche_vertikal_oben_rechts.svg",
            "NN SS NW ES": "Weiche_vertikal_unten_links.svg",
            "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg"}

        dTargetFiles = {
            "EW": "Bahnhof_#d50000_Deadend_links.svg",
            "NS": "Bahnhof_#d50000_Deadend_oben.svg",
            "WE": "Bahnhof_#d50000_Deadend_rechts.svg",
            "SN": "Bahnhof_#d50000_Deadend_unten.svg",
            "EE WW": "Bahnhof_#d50000_Gleis_horizontal.svg",
            "NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"}

        self.dPilRail = self.loadSVGs(dRailFiles, rotate=True)
        self.dPilTarget = self.loadSVGs(dTargetFiles, rotate=False)

    def loadSVGs(self, dDirFile, rotate=False):
        dPil = {}

        transitions = RailEnvTransitions()

        lDirs = list("NESW")

        # svgBG = SVG("./svg/Background_#91D1DD.svg")

        for sTrans, sFile in dDirFile.items():
            sPathSvg = "./svg/" + sFile

            # Translate the ascii transition description in the format  "NE WS" to the 
            # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
            lTrans16 = ["0"] * 16
            for sTran in sTrans.split(" "):
                if len(sTran) == 2:
                    iDirIn = lDirs.index(sTran[0])
                    iDirOut = lDirs.index(sTran[1])
                    iTrans = 4 * iDirIn + iDirOut
                    lTrans16[iTrans] = "1"
            sTrans16 = "".join(lTrans16)
            binTrans = int(sTrans16, 2)
            print(sTrans, sTrans16, sFile)

            # Merge the transition svg image with the background colour.
            # This is a shortcut / hack and will need re-working.
            # if binTrans > 0:
            #    svg = svg.merge(svgBG)

            pilRail = self.pilFromSvgFile(sPathSvg)
            dPil[binTrans] = pilRail

            if rotate:
                # Rotate both the transition binary and the image and save in the dict
                for nRot in [90, 180, 270]:
                    binTrans2 = transitions.rotate_transition(binTrans, nRot)

                    # PIL rotates anticlockwise for positive theta
                    pilRail2 = pilRail.rotate(-nRot)
                    dPil[binTrans2] = pilRail2
        return dPil

    def setRailAt(self, row, col, binTrans, target=None):
        if target is None:
            if binTrans in self.dPilRail:
                pilTrack = self.dPilRail[binTrans]
                self.drawImageRC(pilTrack, (row, col))
            else:
                print("Illegal rail:", row, col, format(binTrans, "#018b")[2:])
        else:
            if binTrans in self.dPilTarget:
                pilTrack = self.dPilTarget[binTrans]
                self.drawImageRC(pilTrack, (row, col))
            else:
                print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:])

    def rgb_s2i(self, sRGB):
        """ convert a hex RGB string like 0091ea to 3-tuple of ints """
        return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2])

    def loadAgentSVGs(self):

        # Seed initial train/zug files indexed by tuple(iDirIn, iDirOut):
        dDirsFile = {
            (0, 0): "svg/Zug_Gleis_#0091ea.svg",
            (1, 2): "svg/Zug_1_Weiche_#0091ea.svg",
            (0, 3): "svg/Zug_2_Weiche_#0091ea.svg"
            }

        sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
            "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
        lColors = sColors.split("#")
        self.nAgentColors = len(lColors)

        # "paint" color of the train images we load
        a_base_color = self.rgb_s2i("0091ea")

        self.dPilZug = {}

        for tDirs, sPathSvg in dDirsFile.items():
            iDirIn, iDirOut = tDirs
            
            pilZug = self.pilFromSvgFile(sPathSvg)

            # Rotate both the directions and the image and save in the dict
            for iDirRot in range(4):
                nDegRot = iDirRot * 90
                iDirIn2 = (iDirIn + iDirRot) % 4
                iDirOut2 = (iDirOut + iDirRot) % 4

                # PIL rotates anticlockwise for positive theta
                pilZug2 = pilZug.rotate(-nDegRot)
                rgbaZug2 = array(pilZug2)

                for iColor, sColor in enumerate(lColors):
                    tnNewColor = self.rgb_s2i(sColor)
                    xy_color_mask = np.all(rgbaZug2[:, :, 0:3] - a_base_color == 0, axis=2)
                    rgbaZug3 = np.copy(rgbaZug2)
                    rgbaZug3[xy_color_mask, 0:3] = tnNewColor
                    self.dPilZug[(iDirIn2, iDirOut2, iColor)] = Image.fromarray(rgbaZug3)

    def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, color=None):
        delta_dir = (iDirOut - iDirIn) % 4
        iColor = iAgent % self.nAgentColors
        # when flipping direction at a dead end, use the "iDirOut" direction.
        if delta_dir == 2:
            iDirIn = iDirOut
        pilZug = self.dPilZug[(iDirIn % 4, iDirOut % 4, iColor)]
        self.drawImageRC(pilZug, (row, col), layer=1)


def main2():
    gl = PILSVG(10, 10)
    for i in range(10):
        gl.beginFrame()
        gl.plot([3 + i, 4], [-4 - i, -5], color="r")
        gl.endFrame()
        time.sleep(1)


def main():
    gl = PILSVG(width=10, height=10)

    for i in range(1000):
        gl.processEvents()
        time.sleep(0.1)
    time.sleep(1)


if __name__ == "__main__":
    main()