From 216d1ae6fcd9a3d06951b30a66ecbaa46a1ee586 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Tue, 14 May 2019 22:47:53 +0100
Subject: [PATCH] added basic colors (from MPL colormap hsv) to SVG

---
 flatland/utils/render_qt.py   |  6 ++--
 flatland/utils/rendertools.py |  7 +++-
 flatland/utils/svg.py         | 62 +++++++++++++++++++++++++----------
 3 files changed, 54 insertions(+), 21 deletions(-)

diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index f70afae..af40c05 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -178,7 +178,7 @@ class QTSVG(GraphicsLayer):
         else:
             print("Illegal rail:", row, col, format(binTrans, "#018b")[2:])
 
-    def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut):
+    def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, color=None):
         if iAgent < len(self.lwAgents):
             wAgent = self.lwAgents[iAgent]
             agentPrev = self.agents_prev[iAgent]
@@ -198,7 +198,7 @@ class QTSVG(GraphicsLayer):
                     # print("new dir:", iAgent, row, col, agentPrev.direction, iDirIn)
                     agentPrev.direction = iDirOut
                     agentPrev.old_direction = iDirIn
-                    sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut).to_string()
+                    sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut, color=color).to_string()
                     bySVG = bytearray(sSVG, encoding='utf-8')
                     wAgent.renderer().load(bySVG)
                     return
@@ -209,7 +209,7 @@ class QTSVG(GraphicsLayer):
             self.agents_prev.append(None)
 
         # Create a new widget for the agent
-        sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut).to_string()
+        sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut, color=color).to_string()
         bySVG = bytearray(sSVG, encoding='utf-8')
         svgWidget = QtSvg.QSvgWidget()
         svgWidget.renderer().load(bySVG)
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 18bc3c8..4921def 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -714,10 +714,15 @@ class RenderTool(object):
                     binTrans = env.rail.grid[r, c]
                     self.gl.setRailAt(r, c, binTrans)
 
+        cmap = self.gl.get_cmap('hsv',
+            lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
+
         for iAgent, agent in enumerate(self.env.agents):
             if agent is None:
                 continue
 
+            oColor = self.gl.adaptColor(cmap(iAgent))
+
             new_direction = agent.direction
             action_isValid = False
 
@@ -726,7 +731,7 @@ class RenderTool(object):
                 new_direction, action_isValid = self.env.check_action(agent, iAction)
             
             if action_isValid:
-                self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction)
+                self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction, color=oColor)
             else:
                 pass
                 # print("invalid action - agent ", iAgent, " bend ", agent.direction, new_direction)
diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py
index bd6ab4e..b219560 100644
--- a/flatland/utils/svg.py
+++ b/flatland/utils/svg.py
@@ -7,8 +7,13 @@ from flatland.core.transitions import RailEnvTransitions
 
 
 class SVG(object):
-    def __init__(self, sfName):
-        self.svg = svgutils.transform.fromfile(sfName)
+    def __init__(self, sfName=None, svgETree=None):
+
+        if sfName is not None:
+            self.svg = svgutils.transform.fromfile(sfName)
+        elif svgETree is not None:
+            self.svg = svgETree
+
         self.init2()
 
     def init2(self):
@@ -18,8 +23,9 @@ class SVG(object):
         self.dStyles = dict(ltMatch)
 
     def copy(self):
-        self2 = copy.deepcopy(self)
-        self2.init2()
+        new_svg = copy.deepcopy(self.svg)
+
+        self2 = SVG(svgETree=new_svg)
         return self2
 
     def merge(self, svg2):
@@ -45,7 +51,17 @@ class SVG(object):
         
             sStyle2 = str(iStyle+offset)
 
-            sNewStyle = "\t.st"+sStyle2+"{"+self.dStyles[sStyle]+"}\n"
+            sNewStyle = "\t.st" + sStyle2 + "{" + self.dStyles[sStyle] + "}\n"
+            sNewStyles += sNewStyle
+        
+        self.eStyle.text = sNewStyles
+
+    def set_style_color(self, style_name, color):
+        sNewStyles = "\n"
+        for sKey, sValue in self.dStyles.items():
+            if sKey == style_name:
+                sValue = "fill:#" + "".join([format(col, "#04x")[2:] for col in color]) + ";"
+            sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n"
             sNewStyles += sNewStyle
         
         self.eStyle.text = sNewStyles
@@ -63,7 +79,7 @@ class Zug(object):
         self.svg_curve1 = SVG("svg/Zug_1_Weiche_#0091ea.svg")
         self.svg_curve2 = SVG("svg/Zug_2_Weiche_#0091ea.svg")
 
-    def getSvg(self, iAgent, iDirIn, iDirOut):
+    def getSvg(self, iAgent, iDirIn, iDirOut, color=None):
         delta_dir = (iDirOut - iDirIn) % 4
         # if delta_dir != 0:
         #    print("Bend:", iAgent, iDirIn, iDirOut)
@@ -71,17 +87,19 @@ class Zug(object):
         if delta_dir in (0, 2):
             svg = self.svg_straight.copy()
             svg.set_rotate(iDirIn * 90)
-            return svg
         
         if delta_dir == 1:  # bend to right, eg N->E, E->S
             svg = self.svg_curve1.copy()
             svg.set_rotate((iDirIn - 1) * 90)
-            return svg
 
         elif delta_dir == 3:  # bend to left, eg N->W
             svg = self.svg_curve2.copy()
             svg.set_rotate(iDirIn * 90)
-            return svg
+
+        if color is not None:
+            svg.set_style_color("2", color)
+
+        return svg
 
 
 class Track(object):
@@ -145,17 +163,27 @@ class Track(object):
 
 
 def main():
-    svg1 = SVG("./svg/Gleis_vertikal.svg")
-    svg2 = SVG("./svg/Zug_1_Weiche_#0091ea.svg")
-    svg3 = svg2.merge(svg1)
-    svg3.set_rotate(90)
+    # svg1 = SVG("./svg/Gleis_vertikal.svg")
+    # svg2 = SVG("./svg/Zug_1_Weiche_#0091ea.svg")
+    
+    # svg3 = svg2.merge(svg1)
+    # svg3.set_rotate(90)
+
+    # s = svg3.to_string()
+    # print(s)
+
+    # svg4 = svg2.copy()
+    # svg4.set_style_color("2", (255, 0, 0))
+    # print(svg4.to_string())
+    # print(svg2.to_string())
 
-    s = svg3.to_string()
-    print(s)
+    # track = Track()
+    # print(len(track.dSvg))
 
-    track = Track()
+    zug = Zug()
 
-    print(len(track.dSvg))
+    svg = zug.getSvg(0, 0, 0, color=(255, 0, 0))
+    print(svg.to_string()[:800])
 
 
 if __name__ == "__main__":
-- 
GitLab