Skip to content
Snippets Groups Projects
Commit 216d1ae6 authored by hagrid67's avatar hagrid67
Browse files

added basic colors (from MPL colormap hsv) to SVG

parent 5bdf72e0
No related branches found
No related tags found
No related merge requests found
...@@ -178,7 +178,7 @@ class QTSVG(GraphicsLayer): ...@@ -178,7 +178,7 @@ class QTSVG(GraphicsLayer):
else: else:
print("Illegal rail:", row, col, format(binTrans, "#018b")[2:]) print("Illegal rail:", row, col, format(binTrans, "#018b")[2:])
def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut): def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, color=None):
if iAgent < len(self.lwAgents): if iAgent < len(self.lwAgents):
wAgent = self.lwAgents[iAgent] wAgent = self.lwAgents[iAgent]
agentPrev = self.agents_prev[iAgent] agentPrev = self.agents_prev[iAgent]
...@@ -198,7 +198,7 @@ class QTSVG(GraphicsLayer): ...@@ -198,7 +198,7 @@ class QTSVG(GraphicsLayer):
# print("new dir:", iAgent, row, col, agentPrev.direction, iDirIn) # print("new dir:", iAgent, row, col, agentPrev.direction, iDirIn)
agentPrev.direction = iDirOut agentPrev.direction = iDirOut
agentPrev.old_direction = iDirIn agentPrev.old_direction = iDirIn
sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut).to_string() sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut, color=color).to_string()
bySVG = bytearray(sSVG, encoding='utf-8') bySVG = bytearray(sSVG, encoding='utf-8')
wAgent.renderer().load(bySVG) wAgent.renderer().load(bySVG)
return return
...@@ -209,7 +209,7 @@ class QTSVG(GraphicsLayer): ...@@ -209,7 +209,7 @@ class QTSVG(GraphicsLayer):
self.agents_prev.append(None) self.agents_prev.append(None)
# Create a new widget for the agent # Create a new widget for the agent
sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut).to_string() sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut, color=color).to_string()
bySVG = bytearray(sSVG, encoding='utf-8') bySVG = bytearray(sSVG, encoding='utf-8')
svgWidget = QtSvg.QSvgWidget() svgWidget = QtSvg.QSvgWidget()
svgWidget.renderer().load(bySVG) svgWidget.renderer().load(bySVG)
......
...@@ -714,10 +714,15 @@ class RenderTool(object): ...@@ -714,10 +714,15 @@ class RenderTool(object):
binTrans = env.rail.grid[r, c] binTrans = env.rail.grid[r, c]
self.gl.setRailAt(r, c, binTrans) self.gl.setRailAt(r, c, binTrans)
cmap = self.gl.get_cmap('hsv',
lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
for iAgent, agent in enumerate(self.env.agents): for iAgent, agent in enumerate(self.env.agents):
if agent is None: if agent is None:
continue continue
oColor = self.gl.adaptColor(cmap(iAgent))
new_direction = agent.direction new_direction = agent.direction
action_isValid = False action_isValid = False
...@@ -726,7 +731,7 @@ class RenderTool(object): ...@@ -726,7 +731,7 @@ class RenderTool(object):
new_direction, action_isValid = self.env.check_action(agent, iAction) new_direction, action_isValid = self.env.check_action(agent, iAction)
if action_isValid: if action_isValid:
self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction) self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction, color=oColor)
else: else:
pass pass
# print("invalid action - agent ", iAgent, " bend ", agent.direction, new_direction) # print("invalid action - agent ", iAgent, " bend ", agent.direction, new_direction)
......
...@@ -7,8 +7,13 @@ from flatland.core.transitions import RailEnvTransitions ...@@ -7,8 +7,13 @@ from flatland.core.transitions import RailEnvTransitions
class SVG(object): class SVG(object):
def __init__(self, sfName): def __init__(self, sfName=None, svgETree=None):
self.svg = svgutils.transform.fromfile(sfName)
if sfName is not None:
self.svg = svgutils.transform.fromfile(sfName)
elif svgETree is not None:
self.svg = svgETree
self.init2() self.init2()
def init2(self): def init2(self):
...@@ -18,8 +23,9 @@ class SVG(object): ...@@ -18,8 +23,9 @@ class SVG(object):
self.dStyles = dict(ltMatch) self.dStyles = dict(ltMatch)
def copy(self): def copy(self):
self2 = copy.deepcopy(self) new_svg = copy.deepcopy(self.svg)
self2.init2()
self2 = SVG(svgETree=new_svg)
return self2 return self2
def merge(self, svg2): def merge(self, svg2):
...@@ -45,7 +51,17 @@ class SVG(object): ...@@ -45,7 +51,17 @@ class SVG(object):
sStyle2 = str(iStyle+offset) sStyle2 = str(iStyle+offset)
sNewStyle = "\t.st"+sStyle2+"{"+self.dStyles[sStyle]+"}\n" sNewStyle = "\t.st" + sStyle2 + "{" + self.dStyles[sStyle] + "}\n"
sNewStyles += sNewStyle
self.eStyle.text = sNewStyles
def set_style_color(self, style_name, color):
sNewStyles = "\n"
for sKey, sValue in self.dStyles.items():
if sKey == style_name:
sValue = "fill:#" + "".join([format(col, "#04x")[2:] for col in color]) + ";"
sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n"
sNewStyles += sNewStyle sNewStyles += sNewStyle
self.eStyle.text = sNewStyles self.eStyle.text = sNewStyles
...@@ -63,7 +79,7 @@ class Zug(object): ...@@ -63,7 +79,7 @@ class Zug(object):
self.svg_curve1 = SVG("svg/Zug_1_Weiche_#0091ea.svg") self.svg_curve1 = SVG("svg/Zug_1_Weiche_#0091ea.svg")
self.svg_curve2 = SVG("svg/Zug_2_Weiche_#0091ea.svg") self.svg_curve2 = SVG("svg/Zug_2_Weiche_#0091ea.svg")
def getSvg(self, iAgent, iDirIn, iDirOut): def getSvg(self, iAgent, iDirIn, iDirOut, color=None):
delta_dir = (iDirOut - iDirIn) % 4 delta_dir = (iDirOut - iDirIn) % 4
# if delta_dir != 0: # if delta_dir != 0:
# print("Bend:", iAgent, iDirIn, iDirOut) # print("Bend:", iAgent, iDirIn, iDirOut)
...@@ -71,17 +87,19 @@ class Zug(object): ...@@ -71,17 +87,19 @@ class Zug(object):
if delta_dir in (0, 2): if delta_dir in (0, 2):
svg = self.svg_straight.copy() svg = self.svg_straight.copy()
svg.set_rotate(iDirIn * 90) svg.set_rotate(iDirIn * 90)
return svg
if delta_dir == 1: # bend to right, eg N->E, E->S if delta_dir == 1: # bend to right, eg N->E, E->S
svg = self.svg_curve1.copy() svg = self.svg_curve1.copy()
svg.set_rotate((iDirIn - 1) * 90) svg.set_rotate((iDirIn - 1) * 90)
return svg
elif delta_dir == 3: # bend to left, eg N->W elif delta_dir == 3: # bend to left, eg N->W
svg = self.svg_curve2.copy() svg = self.svg_curve2.copy()
svg.set_rotate(iDirIn * 90) svg.set_rotate(iDirIn * 90)
return svg
if color is not None:
svg.set_style_color("2", color)
return svg
class Track(object): class Track(object):
...@@ -145,17 +163,27 @@ class Track(object): ...@@ -145,17 +163,27 @@ class Track(object):
def main(): def main():
svg1 = SVG("./svg/Gleis_vertikal.svg") # svg1 = SVG("./svg/Gleis_vertikal.svg")
svg2 = SVG("./svg/Zug_1_Weiche_#0091ea.svg") # svg2 = SVG("./svg/Zug_1_Weiche_#0091ea.svg")
svg3 = svg2.merge(svg1)
svg3.set_rotate(90) # svg3 = svg2.merge(svg1)
# svg3.set_rotate(90)
# s = svg3.to_string()
# print(s)
# svg4 = svg2.copy()
# svg4.set_style_color("2", (255, 0, 0))
# print(svg4.to_string())
# print(svg2.to_string())
s = svg3.to_string() # track = Track()
print(s) # print(len(track.dSvg))
track = Track() zug = Zug()
print(len(track.dSvg)) svg = zug.getSvg(0, 0, 0, color=(255, 0, 0))
print(svg.to_string()[:800])
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment