From d8b561336e2527e3144e4bb32699465485702012 Mon Sep 17 00:00:00 2001
From: "S.P. Mohanty" <spmohanty91@gmail.com>
Date: Thu, 11 Jun 2020 14:55:15 +0200
Subject: [PATCH] Fix recoloring issue of trains and stations

---
 flatland/utils/graphics_pil.py | 75 +++++++++++++++++++---------------
 1 file changed, 43 insertions(+), 32 deletions(-)

diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 2d0713b1..4b51da2e 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -65,11 +65,9 @@ class PILGL(GraphicsLayer):
 
         sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
                   "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
-
         self.agent_colors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")]
         self.n_agent_colors = len(self.agent_colors)
 
-        # self.window_open = False
         self.firstFrame = True
         self.old_background_image = (None, None, None)
         self.create_layers()
@@ -136,11 +134,24 @@ class PILGL(GraphicsLayer):
             self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
 
     def draw_image_xy(self, pil_img, xyPixLeftTop, layer=RAIL_LAYER, ):
+
+        # Resize all PIL images just before drawing them
+        # to ensure that resizing doesnt affect the 
+        # recolorizing strategies in place
+        # 
+        # That said : All the code in this file needs 
+        # some serious refactoring -_- to ensure the 
+        # code style and structure is consitent.
+        #                               - Mohanty
+        pil_img = pil_img.resize(
+            (self.nPixCell, self.nPixCell)
+        )
+
         if (pil_img.mode == "RGBA"):
             pil_mask = pil_img
         else:
             pil_mask = None
-
+        
         self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask)
 
     def draw_image_row_col(self, pil_img, rcTopLeft, layer=RAIL_LAYER, ):
@@ -241,6 +252,11 @@ class PILGL(GraphicsLayer):
 
 
 class PILSVG(PILGL):
+    """
+    Note : This class should now ideally be called as PILPNG,
+    but for backward compatibility, and to not introduce any breaking changes at this point
+    we are sticking to the legacy name of PILSVG (when in practice we are not using SVG anymore)
+    """
     def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
         oSuper = super()
         oSuper.__init__(width, height, jupyter, screen_width, screen_height)
@@ -266,14 +282,11 @@ class PILSVG(PILGL):
         self.lwAgents = []
         self.agents_prev = []
 
-    def pil_from_svg_file(self, package, resource):
+    def pil_from_png_file(self, package, resource):
         bytestring = resource_bytes(package, resource)
         with io.BytesIO(bytestring) as fIn:
             pil_img = Image.open(fIn)
-            pil_img = pil_img.resize(
-                    (self.nPixCell, self.nPixCell),
-                    Image.ANTIALIAS
-                )
+            pil_img.load()
         return pil_img
 
     def load_buildings(self):
@@ -299,15 +312,12 @@ class PILSVG(PILGL):
             "Buildings-Fabrik_I.png"
         ]
 
-        imgBg = self.pil_from_svg_file('flatland.png', "Background_city.png")
+        imgBg = self.pil_from_png_file('flatland.png', "Background_city.png")
         imgBg = imgBg.convert("RGBA")
-        #print("imgBg mode:", imgBg.mode)
 
         self.lBuildings = []
         for sFile in lBuildingFiles:
-            #print("Loading:", sFile)
-            img = self.pil_from_svg_file('flatland.png', sFile)
-            #print("img mode:", img.mode)
+            img = self.pil_from_png_file('flatland.png', sFile)
             img = Image.alpha_composite(imgBg, img)
             self.lBuildings.append(img)
 
@@ -336,31 +346,31 @@ class PILSVG(PILGL):
             "Scenery_Water.png"
         ]
 
-        img_back_ground = self.pil_from_svg_file('flatland.png', "Background_Light_green.png").convert("RGBA")
+        img_back_ground = self.pil_from_png_file('flatland.png', "Background_Light_green.png").convert("RGBA")
 
-        self.scenery_background_white = self.pil_from_svg_file('flatland.png', "Background_white.png").convert("RGBA")
+        self.scenery_background_white = self.pil_from_png_file('flatland.png', "Background_white.png").convert("RGBA")
 
         self.scenery = []
         for file in scenery_files:
-            img = self.pil_from_svg_file('flatland.png', file)
+            img = self.pil_from_png_file('flatland.png', file)
             img = Image.alpha_composite(img_back_ground, img)
             self.scenery.append(img)
 
         self.scenery_d2 = []
         for file in scenery_files_d2:
-            img = self.pil_from_svg_file('flatland.png', file)
+            img = self.pil_from_png_file('flatland.png', file)
             img = Image.alpha_composite(img_back_ground, img)
             self.scenery_d2.append(img)
 
         self.scenery_d3 = []
         for file in scenery_files_d3:
-            img = self.pil_from_svg_file('flatland.png', file)
+            img = self.pil_from_png_file('flatland.png', file)
             img = Image.alpha_composite(img_back_ground, img)
             self.scenery_d3.append(img)
 
         self.scenery_water = []
         for file in scenery_files_water:
-            img = self.pil_from_svg_file('flatland.png', file)
+            img = self.pil_from_png_file('flatland.png', file)
             img = Image.alpha_composite(img_back_ground, img)
             self.scenery_water.append(img)
 
@@ -401,22 +411,22 @@ class PILSVG(PILGL):
             "NN SS": "Bahnhof_#d50000_Gleis_vertikal.png"}
 
         # Dict of rail cell images indexed by binary transitions
-        pil_rail_files_org = self.load_svgs(rail_files, rotate=True)
-        pil_rail_files = self.load_svgs(rail_files, rotate=True, background_image="Background_rail.png",
+        pil_rail_files_org = self.load_pngs(rail_files, rotate=True)
+        pil_rail_files = self.load_pngs(rail_files, rotate=True, background_image="Background_rail.png",
                                         whitefilter="Background_white_filter.png")
 
         # Load the target files (which have rails and transitions of their own)
         # They are indexed by (binTrans, iAgent), ie a tuple of the binary transition and the agent index
-        pil_target_files_org = self.load_svgs(target_files, rotate=False, agent_colors=self.agent_colors)
-        pil_target_files = self.load_svgs(target_files, rotate=False, agent_colors=self.agent_colors,
+        pil_target_files_org = self.load_pngs(target_files, rotate=False, agent_colors=self.agent_colors)
+        pil_target_files = self.load_pngs(target_files, rotate=False, agent_colors=self.agent_colors,
                                           background_image="Background_rail.png",
                                           whitefilter="Background_white_filter.png")
 
         # Load station and recolorize them
-        station = self.pil_from_svg_file('flatland.png', "Bahnhof_#d50000_target.png")
+        station = self.pil_from_png_file('flatland.png', "Bahnhof_#d50000_target.png")
         self.station_colors = self.recolor_image(station, [0, 0, 0], self.agent_colors, False)
 
-        cell_occupied = self.pil_from_svg_file('flatland.png', "Cell_occupied.png")
+        cell_occupied = self.pil_from_png_file('flatland.png', "Cell_occupied.png")
         self.cell_occupied = self.recolor_image(cell_occupied, [0, 0, 0], self.agent_colors, False)
 
         # Merge them with the regular rails.
@@ -424,7 +434,7 @@ class PILSVG(PILGL):
         self.pil_rail = {**pil_rail_files, **pil_target_files}
         self.pil_rail_org = {**pil_rail_files_org, **pil_target_files_org}
 
-    def load_svgs(self, file_directory, rotate=False, agent_colors=False, background_image=None, whitefilter=None):
+    def load_pngs(self, file_directory, rotate=False, agent_colors=False, background_image=None, whitefilter=None):
         pil = {}
 
         transitions = RailEnvTransitions()
@@ -445,14 +455,14 @@ class PILSVG(PILGL):
             transition_16_bit_string = "".join(transition_16_bit)
             binary_trans = int(transition_16_bit_string, 2)
 
-            pil_rail = self.pil_from_svg_file('flatland.png', file).convert("RGBA")
+            pil_rail = self.pil_from_png_file('flatland.png', file).convert("RGBA")
 
             if background_image is not None:
-                img_bg = self.pil_from_svg_file('flatland.png', background_image).convert("RGBA")
+                img_bg = self.pil_from_png_file('flatland.png', background_image).convert("RGBA")
                 pil_rail = Image.alpha_composite(img_bg, pil_rail)
 
             if whitefilter is not None:
-                img_bg = self.pil_from_svg_file('flatland.png', whitefilter).convert("RGBA")
+                img_bg = self.pil_from_png_file('flatland.png', whitefilter).convert("RGBA")
                 pil_rail = Image.alpha_composite(pil_rail, img_bg)
 
             if rotate:
@@ -544,7 +554,7 @@ class PILSVG(PILGL):
 
         if target is not None:
             if is_selected:
-                svgBG = self.pil_from_svg_file('flatland.png', "Selected_Target.png")
+                svgBG = self.pil_from_png_file('flatland.png', "Selected_Target.png")
                 self.clear_layer(PILGL.SELECTED_TARGET_LAYER, 0)
                 self.draw_image_row_col(svgBG, (row, col), layer=PILGL.SELECTED_TARGET_LAYER)
 
@@ -557,6 +567,7 @@ class PILSVG(PILGL):
                 xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor != 0, axis=2)
             else:
                 xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor == 0, axis=2)
+            
             rgbaImg2 = np.copy(rgbaImg)
 
             # Repaint the base color with the new color
@@ -584,7 +595,7 @@ class PILSVG(PILGL):
         for directions, path_svg in file_directory.items():
             in_direction, out_direction = directions
 
-            pil_zug = self.pil_from_svg_file('flatland.png', path_svg)
+            pil_zug = self.pil_from_png_file('flatland.png', path_svg)
 
             # Rotate both the directions and the image and save in the dict
             for rot_direction in range(4):
@@ -614,7 +625,7 @@ class PILSVG(PILGL):
                 self.draw_image_row_col(self.scenery_background_white, (row, col), layer=PILGL.RAIL_LAYER)
 
         if is_selected:
-            bg_svg = self.pil_from_svg_file('flatland.png', "Selected_Agent.png")
+            bg_svg = self.pil_from_png_file('flatland.png', "Selected_Agent.png")
             self.clear_layer(PILGL.SELECTED_AGENT_LAYER, 0)
             self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER)
         if show_debug:
-- 
GitLab