Skip to content
Snippets Groups Projects
Commit ecf84f78 authored by hagrid67's avatar hagrid67
Browse files

added colors to stations / targets.

parent 2ceb6044
No related branches found
No related tags found
No related merge requests found
...@@ -137,6 +137,8 @@ class Demo: ...@@ -137,6 +137,8 @@ class Demo:
# Reset environment # Reset environment
_ = self.env.reset(False, False) _ = self.env.reset(False, False)
time.sleep(0.0001) # to satisfy lint...
for step in range(max_nbr_of_steps): for step in range(max_nbr_of_steps):
# time.sleep(.1) # time.sleep(.1)
...@@ -171,6 +173,7 @@ class Demo: ...@@ -171,6 +173,7 @@ class Demo:
if done['__all__']: if done['__all__']:
break break
self.renderer.close_window()
if True: if True:
......
...@@ -26,7 +26,7 @@ def tkmain(n_trials=2, n_steps=50, sGL="PIL"): ...@@ -26,7 +26,7 @@ def tkmain(n_trials=2, n_steps=50, sGL="PIL"):
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step,
action_dict=oPlayer.action_dict) action_dict=oPlayer.action_dict)
env_renderer.gl.close_window() env_renderer.close_window()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -66,7 +66,12 @@ class GraphicsLayer(object): ...@@ -66,7 +66,12 @@ class GraphicsLayer(object):
def get_cmap(self, *args, **kwargs): def get_cmap(self, *args, **kwargs):
return plt.get_cmap(*args, **kwargs) return plt.get_cmap(*args, **kwargs)
def setRailAt(self, row, col, binTrans): def setRailAt(self, row, col, binTrans, target=None):
""" Set the rail at cell (row, col) to have transitions binTrans.
The target argument can contain the index of the agent to indicate
that agent's target is at that cell, so that a station can be
rendered in the static rail layer.
"""
pass pass
def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut): def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut):
......
...@@ -33,12 +33,22 @@ class PILGL(GraphicsLayer): ...@@ -33,12 +33,22 @@ class PILGL(GraphicsLayer):
self.tColRail = (0, 0, 0) # black rails self.tColRail = (0, 0, 0) # black rails
self.tColGrid = (230,) * 3 # light grey for grid self.tColGrid = (230,) * 3 # light grey for grid
sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
"#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
self.ltAgentColors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")]
self.nAgentColors = len(self.ltAgentColors)
self.window_open = False self.window_open = False
# self.bShow = show # self.bShow = show
self.firstFrame = True self.firstFrame = True
self.create_layers() self.create_layers()
# self.beginFrame() # self.beginFrame()
def rgb_s2i(self, sRGB):
""" convert a hex RGB string like 0091ea to 3-tuple of ints """
return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2])
def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
color = self.adaptColor(color) color = self.adaptColor(color)
if len(color) == 3: if len(color) == 3:
...@@ -246,10 +256,17 @@ class PILSVG(PILGL): ...@@ -246,10 +256,17 @@ class PILSVG(PILGL):
"EE WW": "Bahnhof_#d50000_Gleis_horizontal.svg", "EE WW": "Bahnhof_#d50000_Gleis_horizontal.svg",
"NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"} "NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"}
# Dict of rail cell images indexed by binary transitions
self.dPilRail = self.loadSVGs(dRailFiles, rotate=True) self.dPilRail = self.loadSVGs(dRailFiles, rotate=True)
self.dPilTarget = self.loadSVGs(dTargetFiles, rotate=False)
def loadSVGs(self, dDirFile, rotate=False): # Load the target files (which have rails and transitions of their own)
# They are indexed by (binTrans, iAgent), ie a tuple of the binary transition and the agent index
dPilRail2 = self.loadSVGs(dTargetFiles, rotate=False, agent_colors=self.ltAgentColors)
# Merge them with the regular rails.
# https://stackoverflow.com/questions/38987/how-to-merge-two-dictionaries-in-a-single-expression
self.dPilRail = {**self.dPilRail, **dPilRail2}
def loadSVGs(self, dDirFile, rotate=False, agent_colors=False):
dPil = {} dPil = {}
transitions = RailEnvTransitions() transitions = RailEnvTransitions()
...@@ -280,9 +297,10 @@ class PILSVG(PILGL): ...@@ -280,9 +297,10 @@ class PILSVG(PILGL):
# svg = svg.merge(svgBG) # svg = svg.merge(svgBG)
pilRail = self.pilFromSvgFile(sPathSvg) pilRail = self.pilFromSvgFile(sPathSvg)
dPil[binTrans] = pilRail
if rotate: if rotate:
# For rotations, we also store the base image
dPil[binTrans] = pilRail
# Rotate both the transition binary and the image and save in the dict # Rotate both the transition binary and the image and save in the dict
for nRot in [90, 180, 270]: for nRot in [90, 180, 270]:
binTrans2 = transitions.rotate_transition(binTrans, nRot) binTrans2 = transitions.rotate_transition(binTrans, nRot)
...@@ -290,25 +308,44 @@ class PILSVG(PILGL): ...@@ -290,25 +308,44 @@ class PILSVG(PILGL):
# PIL rotates anticlockwise for positive theta # PIL rotates anticlockwise for positive theta
pilRail2 = pilRail.rotate(-nRot) pilRail2 = pilRail.rotate(-nRot)
dPil[binTrans2] = pilRail2 dPil[binTrans2] = pilRail2
if agent_colors:
# For recoloring, we don't store the base image.
a3BaseColor = self.rgb_s2i("d50000")
lPils = self.recolorImage(pilRail, a3BaseColor, self.ltAgentColors)
for iColor, pilRail2 in enumerate(lPils):
dPil[(binTrans, iColor)] = lPils[iColor]
return dPil return dPil
def setRailAt(self, row, col, binTrans, target=None): def setRailAt(self, row, col, binTrans, iTarget=None):
if target is None: if iTarget is None:
if binTrans in self.dPilRail: if binTrans in self.dPilRail:
pilTrack = self.dPilRail[binTrans] pilTrack = self.dPilRail[binTrans]
self.drawImageRC(pilTrack, (row, col)) self.drawImageRC(pilTrack, (row, col))
else: else:
print("Illegal rail:", row, col, format(binTrans, "#018b")[2:]) print("Illegal rail:", row, col, format(binTrans, "#018b")[2:])
else: else:
if binTrans in self.dPilTarget: if (binTrans, iTarget) in self.dPilRail:
pilTrack = self.dPilTarget[binTrans] pilTrack = self.dPilRail[(binTrans, iTarget)]
self.drawImageRC(pilTrack, (row, col)) self.drawImageRC(pilTrack, (row, col))
else: else:
print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:]) print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:])
def rgb_s2i(self, sRGB): def recolorImage(self, pil, a3BaseColor, ltColors):
""" convert a hex RGB string like 0091ea to 3-tuple of ints """ rgbaImg = array(pil)
return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2]) lPils = []
for iColor, tnColor in enumerate(ltColors):
# find the pixels which match the base paint color
xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor == 0, axis=2)
rgbaImg2 = np.copy(rgbaImg)
# Repaint the base color with the new color
rgbaImg2[xy_color_mask, 0:3] = tnColor
pil2 = Image.fromarray(rgbaImg2)
lPils.append(pil2)
return lPils
def loadAgentSVGs(self): def loadAgentSVGs(self):
...@@ -319,13 +356,8 @@ class PILSVG(PILGL): ...@@ -319,13 +356,8 @@ class PILSVG(PILGL):
(0, 3): "svg/Zug_2_Weiche_#0091ea.svg" (0, 3): "svg/Zug_2_Weiche_#0091ea.svg"
} }
sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
"#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
lColors = sColors.split("#")
self.nAgentColors = len(lColors)
# "paint" color of the train images we load # "paint" color of the train images we load
a_base_color = self.rgb_s2i("0091ea") a3BaseColor = self.rgb_s2i("0091ea")
self.dPilZug = {} self.dPilZug = {}
...@@ -342,16 +374,13 @@ class PILSVG(PILGL): ...@@ -342,16 +374,13 @@ class PILSVG(PILGL):
# PIL rotates anticlockwise for positive theta # PIL rotates anticlockwise for positive theta
pilZug2 = pilZug.rotate(-nDegRot) pilZug2 = pilZug.rotate(-nDegRot)
rgbaZug2 = array(pilZug2)
for iColor, sColor in enumerate(lColors): # Save colored versions of each rotation / variant
tnNewColor = self.rgb_s2i(sColor) lPils = self.recolorImage(pilZug2, a3BaseColor, self.ltAgentColors)
xy_color_mask = np.all(rgbaZug2[:, :, 0:3] - a_base_color == 0, axis=2) for iColor, pilZug3 in enumerate(lPils):
rgbaZug3 = np.copy(rgbaZug2) self.dPilZug[(iDirIn2, iDirOut2, iColor)] = lPils[iColor]
rgbaZug3[xy_color_mask, 0:3] = tnNewColor
self.dPilZug[(iDirIn2, iDirOut2, iColor)] = Image.fromarray(rgbaZug3)
def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, color=None): def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut):
delta_dir = (iDirOut - iDirIn) % 4 delta_dir = (iDirOut - iDirIn) % 4
iColor = iAgent % self.nAgentColors iColor = iAgent % self.nAgentColors
# when flipping direction at a dead end, use the "iDirOut" direction. # when flipping direction at a dead end, use the "iDirOut" direction.
......
...@@ -157,7 +157,7 @@ class QTSVG(GraphicsLayer): ...@@ -157,7 +157,7 @@ class QTSVG(GraphicsLayer):
self.lwAgents = [] self.lwAgents = []
self.agents_prev = [] self.agents_prev = []
def setRailAt(self, row, col, binTrans): def setRailAt(self, row, col, binTrans, target=None):
if binTrans in self.track.dSvg: if binTrans in self.track.dSvg:
sSVG = self.track.dSvg[binTrans].to_string() sSVG = self.track.dSvg[binTrans].to_string()
svgWidget = create_QtSvgWidget_from_svg_string(sSVG) svgWidget = create_QtSvgWidget_from_svg_string(sSVG)
......
...@@ -767,7 +767,7 @@ class RenderTool(object): ...@@ -767,7 +767,7 @@ class RenderTool(object):
target = dTargets[(r, c)] target = dTargets[(r, c)]
else: else:
target = None target = None
self.gl.setRailAt(r, c, binTrans) self.gl.setRailAt(r, c, binTrans, iTarget=target)
for iAgent, agent in enumerate(self.env.agents): for iAgent, agent in enumerate(self.env.agents):
if agent is None: if agent is None:
...@@ -782,8 +782,9 @@ class RenderTool(object): ...@@ -782,8 +782,9 @@ class RenderTool(object):
direction = agent.direction direction = agent.direction
old_direction = agent.direction old_direction = agent.direction
cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) # setAgentAt uses the agent index for the color
self.gl.setAgentAt(iAgent, *position, old_direction, direction,color=cmap(iAgent)) # cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
self.gl.setAgentAt(iAgent, *position, old_direction, direction) # ,color=cmap(iAgent))
if show: if show:
self.gl.show() self.gl.show()
...@@ -792,3 +793,6 @@ class RenderTool(object): ...@@ -792,3 +793,6 @@ class RenderTool(object):
self.iFrame += 1 self.iFrame += 1
return return
def close_window(self):
self.gl.close_window()
...@@ -60,7 +60,7 @@ class SVG(object): ...@@ -60,7 +60,7 @@ class SVG(object):
sNewStyles = "\n" sNewStyles = "\n"
for sKey, sValue in self.dStyles.items(): for sKey, sValue in self.dStyles.items():
if sKey == style_name: if sKey == style_name:
sValue = "fill:#" + "".join([ ('{:#04x}'.format(int(255.0*col))[2:4]) for col in color[0:3]]) + ";" sValue = "fill:#" + "".join([('{:#04x}'.format(int(255.0*col))[2:4]) for col in color[0:3]]) + ";"
sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n" sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n"
sNewStyles += sNewStyle sNewStyles += sNewStyle
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment