graphics_pil.py 16.5 KB
Newer Older
hagrid67's avatar
hagrid67 committed
1
import io
2
import os
3
import site
u214892's avatar
u214892 committed
4
5
6
7
8
9
10
11
12
import time
import tkinter as tk

import numpy as np
from PIL import Image, ImageDraw, ImageTk  # , ImageFont
from numpy import array

from flatland.utils.graphics_layer import GraphicsLayer

13
14

def enable_windows_cairo_support():
u214892's avatar
u214892 committed
15
    if os.name == 'nt':
16
17
18
19
20
21
22
        import site
        import ctypes.util
        default_os_path = os.environ['PATH']
        os.environ['PATH'] = ''
        for s in site.getsitepackages():
            os.environ['PATH'] = os.environ['PATH'] + ';' + s + '\\cairo'
        os.environ['PATH'] = os.environ['PATH'] + ';' + default_os_path
23
24
        if ctypes.util.find_library('cairo') is None:
            print("Error: cairo not installed")
25

Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
26

u214892's avatar
u214892 committed
27
28
29
enable_windows_cairo_support()
from cairosvg import svg2png  # noqa: E402
from screeninfo import get_monitors  # noqa: E402
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
30

u214892's avatar
u214892 committed
31
from flatland.core.transitions import RailEnvTransitions  # noqa: E402
hagrid67's avatar
hagrid67 committed
32
33
34


class PILGL(GraphicsLayer):
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
35
    def __init__(self, width, height, jupyter=False):
hagrid67's avatar
hagrid67 committed
36
37
        self.yxBase = (0, 0)
        self.linewidth = 4
hagrid67's avatar
hagrid67 committed
38
        self.nAgentColors = 1  # overridden in loadAgent
hagrid67's avatar
hagrid67 committed
39
40
41
42

        self.width = width
        self.height = height

u214892's avatar
u214892 committed
43
        if jupyter is False:
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
44
45
46
            self.screen_width = 99999
            self.screen_height = 99999
            for m in get_monitors():
u214892's avatar
u214892 committed
47
48
                self.screen_height = min(self.screen_height, m.height)
                self.screen_width = min(self.screen_width, m.width)
49

u214892's avatar
u214892 committed
50
51
52
            w = (self.screen_width - self.width - 10) / (self.width + 1 + self.linewidth)
            h = (self.screen_height - self.height - 10) / (self.height + 1 + self.linewidth)
            self.nPixCell = int(max(1, np.ceil(min(w, h))))
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
53
54
        else:
            self.nPixCell = 40
55

hagrid67's avatar
hagrid67 committed
56
57
58
        # Total grid size at native scale
        self.widthPx = self.width * self.nPixCell + self.linewidth
        self.heightPx = self.height * self.nPixCell + self.linewidth
59
60
61
62

        self.xPx = int((self.screen_width - self.widthPx) / 2.0)
        self.yPx = int((self.screen_height - self.heightPx) / 2.0)

63
64
        self.layers = []
        self.draws = []
hagrid67's avatar
hagrid67 committed
65

u214892's avatar
u214892 committed
66
67
68
        self.tColBg = (255, 255, 255)  # white background
        self.tColRail = (0, 0, 0)  # black rails
        self.tColGrid = (230,) * 3  # light grey for grid
hagrid67's avatar
hagrid67 committed
69

hagrid67's avatar
hagrid67 committed
70
        sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
u214892's avatar
u214892 committed
71
72
                  "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"

hagrid67's avatar
hagrid67 committed
73
74
75
        self.ltAgentColors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")]
        self.nAgentColors = len(self.ltAgentColors)

76
77
        self.window_open = False
        self.firstFrame = True
hagrid67's avatar
hagrid67 committed
78
        self.create_layers()
hagrid67's avatar
hagrid67 committed
79

hagrid67's avatar
hagrid67 committed
80
81
82
83
    def rgb_s2i(self, sRGB):
        """ convert a hex RGB string like 0091ea to 3-tuple of ints """
        return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2])

84
85
86
    def getAgentColor(self, iAgent):
        return self.ltAgentColors[iAgent % self.nAgentColors]

87
88
89
90
91
92
    def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
        color = self.adaptColor(color)
        if len(color) == 3:
            color += (opacity,)
        elif len(color) == 4:
            color = color[:3] + (opacity,)
hagrid67's avatar
hagrid67 committed
93
94
        gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell
        gPoints = list(gPoints.ravel())
95
        self.draws[layer].line(gPoints, fill=color, width=self.linewidth)
hagrid67's avatar
hagrid67 committed
96

97
    def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs):
hagrid67's avatar
hagrid67 committed
98
99
100
101
        color = self.adaptColor(color)
        r = np.sqrt(s)
        gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
        for x, y in gPoints:
102
            self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
hagrid67's avatar
hagrid67 committed
103

hagrid67's avatar
hagrid67 committed
104
    def drawImageXY(self, pil_img, xyPixLeftTop, layer=0):
u214892's avatar
u214892 committed
105
        if (pil_img.mode == "RGBA"):
hagrid67's avatar
hagrid67 committed
106
107
108
            pil_mask = pil_img
        else:
            pil_mask = None
u214892's avatar
u214892 committed
109

hagrid67's avatar
hagrid67 committed
110
111
112
113
114
115
        self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask)

    def drawImageRC(self, pil_img, rcTopLeft, layer=0):
        xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
        self.drawImageXY(pil_img, xyPixLeftTop, layer=layer)

116
117
118
119
120
121
122
    def open_window(self):
        assert self.window_open is False, "Window is already open!"
        self.window = tk.Tk()
        self.window.title("Flatland")
        self.window.configure(background='grey')
        self.window_open = True

hagrid67's avatar
hagrid67 committed
123
124
125
126
127
    def close_window(self):
        self.panel.destroy()
        self.window.quit()
        self.window.destroy()

hagrid67's avatar
hagrid67 committed
128
129
130
131
132
133
134
135
136
137
    def text(self, *args, **kwargs):
        pass

    def prettify(self, *args, **kwargs):
        pass

    def prettify2(self, width, height, cell_size):
        pass

    def beginFrame(self):
hagrid67's avatar
hagrid67 committed
138
139
        # Create a new agent layer
        self.create_layer(iLayer=1, clear=True)
hagrid67's avatar
hagrid67 committed
140
141

    def show(self, block=False):
142
        img = self.alpha_composite_layers()
143

144
145
        if not self.window_open:
            self.open_window()
u214892's avatar
u214892 committed
146

147
        tkimg = ImageTk.PhotoImage(img)
u214892's avatar
u214892 committed
148

149
        if self.firstFrame:
hagrid67's avatar
hagrid67 committed
150
            # Do TK actions for a new panel (not sure what they really do)
151
152
153
154
155
156
157
158
159
            self.panel = tk.Label(self.window, image=tkimg)
            self.panel.pack(side="bottom", fill="both", expand="yes")
        else:
            # update the image in situ
            self.panel.configure(image=tkimg)
            self.panel.image = tkimg

        self.window.update()
        self.firstFrame = False
hagrid67's avatar
hagrid67 committed
160
161
162
163

    def pause(self, seconds=0.00001):
        pass

164
165
166
167
168
169
    def alpha_composite_layers(self):
        img = self.layers[0]
        for img2 in self.layers[1:]:
            img = Image.alpha_composite(img, img2)
        return img

hagrid67's avatar
hagrid67 committed
170
    def getImage(self):
171
172
173
174
175
176
        """ return a blended / alpha composited image composed of all the layers,
            with layer 0 at the "back".
        """
        img = self.alpha_composite_layers()
        return array(img)

u214892's avatar
u214892 committed
177
    def saveImage(self, filename):
178
179
180
        img = self.alpha_composite_layers()
        img.save(filename)

181
182
183
184
    def create_image(self, opacity=255):
        img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity))
        return img

185
    def clear_layer(self, iLayer=0, opacity=None):
186
187
188
189
190
191
        if opacity is None:
            opacity = 0 if iLayer > 0 else 255
        self.layers[iLayer] = img = self.create_image(opacity)
        # We also need to maintain a Draw object for each layer
        self.draws[iLayer] = ImageDraw.Draw(img)

hagrid67's avatar
hagrid67 committed
192
193
    def create_layer(self, iLayer=0, clear=True):
        # If we don't have the layers already, create them
194
        if len(self.layers) <= iLayer:
u214892's avatar
u214892 committed
195
            for i in range(len(self.layers), iLayer + 1):
196
                if i == 0:
197
198
                    opacity = 255  # "bottom" layer is opaque (for rails)
                else:
u214892's avatar
u214892 committed
199
                    opacity = 0  # subsequent layers are transparent
200
201
202
203
                img = self.create_image(opacity)
                self.layers.append(img)
                self.draws.append(ImageDraw.Draw(img))
        else:
hagrid67's avatar
hagrid67 committed
204
205
            # We do already have this iLayer.  Clear it if requested.
            if clear:
206
                self.clear_layer(iLayer)
hagrid67's avatar
hagrid67 committed
207

208
209
210
211
212
    def create_layers(self, clear=True):
        self.create_layer(0, clear=clear)  # rail / background (scene)
        self.create_layer(1, clear=clear)  # agents
        self.create_layer(2, clear=clear)  # drawing layer for selected agent
        self.create_layer(3, clear=clear)  # drawing layer for selected agent's target
hagrid67's avatar
hagrid67 committed
213
214
215


class PILSVG(PILGL):
u214892's avatar
u214892 committed
216
    def __init__(self, width, height, jupyter=False):
hagrid67's avatar
hagrid67 committed
217
        oSuper = super()
u214892's avatar
u214892 committed
218
        oSuper.__init__(width, height, jupyter)
hagrid67's avatar
hagrid67 committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

        self.lwAgents = []
        self.agents_prev = []

        self.loadRailSVGs()
        self.loadAgentSVGs()

    def is_raster(self):
        return False

    def processEvents(self):
        time.sleep(0.001)

    def clear_rails(self):
        self.create_layers()
        self.clear_agents()

    def clear_agents(self):
        for wAgent in self.lwAgents:
            self.layout.removeWidget(wAgent)
        self.lwAgents = []
        self.agents_prev = []

    def pilFromSvgFile(self, sfPath):
243
244
245
        try:
            with open(sfPath, "r") as fIn:
                bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell)
u214892's avatar
u214892 committed
246
247
        except:  # noqa: E722
            newList = ''
248
249
            for directory in site.getsitepackages():
                x = [word for word in os.listdir(directory) if word.startswith('flatland')]
u214892's avatar
u214892 committed
250
251
252
                if len(x) > 0:
                    newList = directory + '/' + x[0]
            with open(newList + '/' + sfPath, "r") as fIn:
253
                bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell)
hagrid67's avatar
hagrid67 committed
254
255
256
        with io.BytesIO(bytesPNG) as fIn:
            pil_img = Image.open(fIn)
            pil_img.load()
u214892's avatar
u214892 committed
257

hagrid67's avatar
hagrid67 committed
258
259
260
261
262
263
264
265
266
267
268
        return pil_img

    def pilFromSvgBytes(self, bytesSVG):
        bytesPNG = svg2png(bytesSVG, output_height=self.nPixCell, output_width=self.nPixCell)
        with io.BytesIO(bytesPNG) as fIn:
            pil_img = Image.open(fIn)
            return pil_img

    def loadRailSVGs(self):
        """ Load the rail SVG images, apply rotations, and store as PIL images.
        """
hagrid67's avatar
hagrid67 committed
269
        dRailFiles = {
hagrid67's avatar
hagrid67 committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
            "": "Background_#91D1DD.svg",
            "WE": "Gleis_Deadend.svg",
            "WW EE NN SS": "Gleis_Diamond_Crossing.svg",
            "WW EE": "Gleis_horizontal.svg",
            "EN SW": "Gleis_Kurve_oben_links.svg",
            "WN SE": "Gleis_Kurve_oben_rechts.svg",
            "ES NW": "Gleis_Kurve_unten_links.svg",
            "NE WS": "Gleis_Kurve_unten_rechts.svg",
            "NN SS": "Gleis_vertikal.svg",
            "NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg",
            "EE WW EN SW": "Weiche_horizontal_oben_links.svg",
            "EE WW SE WN": "Weiche_horizontal_oben_rechts.svg",
            "EE WW ES NW": "Weiche_horizontal_unten_links.svg",
            "EE WW NE WS": "Weiche_horizontal_unten_rechts.svg",
            "NN SS EE WW NW ES": "Weiche_Single_Slip.svg",
            "NE NW ES WS": "Weiche_Symetrical.svg",
            "NN SS EN SW": "Weiche_vertikal_oben_links.svg",
            "NN SS SE WN": "Weiche_vertikal_oben_rechts.svg",
            "NN SS NW ES": "Weiche_vertikal_unten_links.svg",
            "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg"}

hagrid67's avatar
hagrid67 committed
291
292
293
294
295
296
297
298
        dTargetFiles = {
            "EW": "Bahnhof_#d50000_Deadend_links.svg",
            "NS": "Bahnhof_#d50000_Deadend_oben.svg",
            "WE": "Bahnhof_#d50000_Deadend_rechts.svg",
            "SN": "Bahnhof_#d50000_Deadend_unten.svg",
            "EE WW": "Bahnhof_#d50000_Gleis_horizontal.svg",
            "NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"}

hagrid67's avatar
hagrid67 committed
299
        # Dict of rail cell images indexed by binary transitions
hagrid67's avatar
hagrid67 committed
300
301
        self.dPilRail = self.loadSVGs(dRailFiles, rotate=True)

hagrid67's avatar
hagrid67 committed
302
303
304
305
306
307
        # Load the target files (which have rails and transitions of their own)
        # They are indexed by (binTrans, iAgent), ie a tuple of the binary transition and the agent index
        dPilRail2 = self.loadSVGs(dTargetFiles, rotate=False, agent_colors=self.ltAgentColors)
        # Merge them with the regular rails.
        # https://stackoverflow.com/questions/38987/how-to-merge-two-dictionaries-in-a-single-expression
        self.dPilRail = {**self.dPilRail, **dPilRail2}
u214892's avatar
u214892 committed
308

hagrid67's avatar
hagrid67 committed
309
    def loadSVGs(self, dDirFile, rotate=False, agent_colors=False):
hagrid67's avatar
hagrid67 committed
310
        dPil = {}
hagrid67's avatar
hagrid67 committed
311
312
313
314
315
316
317

        transitions = RailEnvTransitions()

        lDirs = list("NESW")

        # svgBG = SVG("./svg/Background_#91D1DD.svg")

hagrid67's avatar
hagrid67 committed
318
        for sTrans, sFile in dDirFile.items():
hagrid67's avatar
hagrid67 committed
319
320
            sPathSvg = "./svg/" + sFile

hagrid67's avatar
hagrid67 committed
321
            # Translate the ascii transition description in the format  "NE WS" to the 
hagrid67's avatar
hagrid67 committed
322
323
324
325
326
327
328
329
330
331
332
333
            # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
            lTrans16 = ["0"] * 16
            for sTran in sTrans.split(" "):
                if len(sTran) == 2:
                    iDirIn = lDirs.index(sTran[0])
                    iDirOut = lDirs.index(sTran[1])
                    iTrans = 4 * iDirIn + iDirOut
                    lTrans16[iTrans] = "1"
            sTrans16 = "".join(lTrans16)
            binTrans = int(sTrans16, 2)

            pilRail = self.pilFromSvgFile(sPathSvg)
u214892's avatar
u214892 committed
334

hagrid67's avatar
hagrid67 committed
335
            if rotate:
hagrid67's avatar
hagrid67 committed
336
337
                # For rotations, we also store the base image
                dPil[binTrans] = pilRail
hagrid67's avatar
hagrid67 committed
338
339
340
341
342
343
344
                # Rotate both the transition binary and the image and save in the dict
                for nRot in [90, 180, 270]:
                    binTrans2 = transitions.rotate_transition(binTrans, nRot)

                    # PIL rotates anticlockwise for positive theta
                    pilRail2 = pilRail.rotate(-nRot)
                    dPil[binTrans2] = pilRail2
u214892's avatar
u214892 committed
345

hagrid67's avatar
hagrid67 committed
346
347
348
349
350
351
352
            if agent_colors:
                # For recoloring, we don't store the base image.
                a3BaseColor = self.rgb_s2i("d50000")
                lPils = self.recolorImage(pilRail, a3BaseColor, self.ltAgentColors)
                for iColor, pilRail2 in enumerate(lPils):
                    dPil[(binTrans, iColor)] = lPils[iColor]

hagrid67's avatar
hagrid67 committed
353
354
        return dPil

355
    def setRailAt(self, row, col, binTrans, iTarget=None, isSelected=False):
hagrid67's avatar
hagrid67 committed
356
        if iTarget is None:
hagrid67's avatar
hagrid67 committed
357
358
359
360
361
            if binTrans in self.dPilRail:
                pilTrack = self.dPilRail[binTrans]
                self.drawImageRC(pilTrack, (row, col))
            else:
                print("Illegal rail:", row, col, format(binTrans, "#018b")[2:])
hagrid67's avatar
hagrid67 committed
362
        else:
hagrid67's avatar
hagrid67 committed
363
364
            if (binTrans, iTarget) in self.dPilRail:
                pilTrack = self.dPilRail[(binTrans, iTarget)]
hagrid67's avatar
hagrid67 committed
365
366
367
                self.drawImageRC(pilTrack, (row, col))
            else:
                print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:])
hagrid67's avatar
hagrid67 committed
368

369
            if isSelected:
370
                svgBG = self.pilFromSvgFile("./svg/Selected_Target.svg")
u214892's avatar
u214892 committed
371
372
                self.clear_layer(3, 0)
                self.drawImageRC(svgBG, (row, col), layer=3)
373

hagrid67's avatar
hagrid67 committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    def recolorImage(self, pil, a3BaseColor, ltColors):
        rgbaImg = array(pil)
        lPils = []

        for iColor, tnColor in enumerate(ltColors):
            # find the pixels which match the base paint color
            xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor == 0, axis=2)
            rgbaImg2 = np.copy(rgbaImg)

            # Repaint the base color with the new color
            rgbaImg2[xy_color_mask, 0:3] = tnColor
            pil2 = Image.fromarray(rgbaImg2)
            lPils.append(pil2)
        return lPils
hagrid67's avatar
hagrid67 committed
388
389
390
391
392
393
394
395

    def loadAgentSVGs(self):

        # Seed initial train/zug files indexed by tuple(iDirIn, iDirOut):
        dDirsFile = {
            (0, 0): "svg/Zug_Gleis_#0091ea.svg",
            (1, 2): "svg/Zug_1_Weiche_#0091ea.svg",
            (0, 3): "svg/Zug_2_Weiche_#0091ea.svg"
u214892's avatar
u214892 committed
396
        }
hagrid67's avatar
hagrid67 committed
397

hagrid67's avatar
hagrid67 committed
398
        # "paint" color of the train images we load - this is the color we will change.
u214892's avatar
u214892 committed
399
        # a3BaseColor = self.rgb_s2i("0091ea") \#  noqa: E800
hagrid67's avatar
hagrid67 committed
400
        # temporary workaround for trains / agents renamed with different colour:
401
        a3BaseColor = self.rgb_s2i("d50000")
hagrid67's avatar
hagrid67 committed
402
403
404
405
406

        self.dPilZug = {}

        for tDirs, sPathSvg in dDirsFile.items():
            iDirIn, iDirOut = tDirs
u214892's avatar
u214892 committed
407

hagrid67's avatar
hagrid67 committed
408
409
410
411
412
413
414
415
416
417
418
            pilZug = self.pilFromSvgFile(sPathSvg)

            # Rotate both the directions and the image and save in the dict
            for iDirRot in range(4):
                nDegRot = iDirRot * 90
                iDirIn2 = (iDirIn + iDirRot) % 4
                iDirOut2 = (iDirOut + iDirRot) % 4

                # PIL rotates anticlockwise for positive theta
                pilZug2 = pilZug.rotate(-nDegRot)

hagrid67's avatar
hagrid67 committed
419
420
421
422
                # Save colored versions of each rotation / variant
                lPils = self.recolorImage(pilZug2, a3BaseColor, self.ltAgentColors)
                for iColor, pilZug3 in enumerate(lPils):
                    self.dPilZug[(iDirIn2, iDirOut2, iColor)] = lPils[iColor]
hagrid67's avatar
hagrid67 committed
423

424
    def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, isSelected):
hagrid67's avatar
hagrid67 committed
425
426
427
428
429
430
431
432
        delta_dir = (iDirOut - iDirIn) % 4
        iColor = iAgent % self.nAgentColors
        # when flipping direction at a dead end, use the "iDirOut" direction.
        if delta_dir == 2:
            iDirIn = iDirOut
        pilZug = self.dPilZug[(iDirIn % 4, iDirOut % 4, iColor)]
        self.drawImageRC(pilZug, (row, col), layer=1)

433
434
        if isSelected:
            svgBG = self.pilFromSvgFile("./svg/Selected_Agent.svg")
435
436
            self.clear_layer(2, 0)
            self.drawImageRC(svgBG, (row, col), layer=2)
437

hagrid67's avatar
hagrid67 committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458

def main2():
    gl = PILSVG(10, 10)
    for i in range(10):
        gl.beginFrame()
        gl.plot([3 + i, 4], [-4 - i, -5], color="r")
        gl.endFrame()
        time.sleep(1)


def main():
    gl = PILSVG(width=10, height=10)

    for i in range(1000):
        gl.processEvents()
        time.sleep(0.1)
    time.sleep(1)


if __name__ == "__main__":
    main()