From ecf84f78f50974828f8b0890f0945b9bc71666d7 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Tue, 28 May 2019 23:11:08 +0100 Subject: [PATCH] added colors to stations / targets. --- examples/demo.py | 3 ++ examples/tkplay.py | 2 +- flatland/utils/graphics_layer.py | 7 ++- flatland/utils/graphics_pil.py | 79 ++++++++++++++++++++++---------- flatland/utils/render_qt.py | 2 +- flatland/utils/rendertools.py | 10 ++-- flatland/utils/svg.py | 2 +- 7 files changed, 73 insertions(+), 32 deletions(-) diff --git a/examples/demo.py b/examples/demo.py index 86e986f..6e9f053 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -137,6 +137,8 @@ class Demo: # Reset environment _ = self.env.reset(False, False) + time.sleep(0.0001) # to satisfy lint... + for step in range(max_nbr_of_steps): # time.sleep(.1) @@ -171,6 +173,7 @@ class Demo: if done['__all__']: break + self.renderer.close_window() if True: diff --git a/examples/tkplay.py b/examples/tkplay.py index a46dcbc..c17ea51 100644 --- a/examples/tkplay.py +++ b/examples/tkplay.py @@ -26,7 +26,7 @@ def tkmain(n_trials=2, n_steps=50, sGL="PIL"): env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, action_dict=oPlayer.action_dict) - env_renderer.gl.close_window() + env_renderer.close_window() if __name__ == "__main__": diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index a1de818..1b2ff7e 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -66,7 +66,12 @@ class GraphicsLayer(object): def get_cmap(self, *args, **kwargs): return plt.get_cmap(*args, **kwargs) - def setRailAt(self, row, col, binTrans): + def setRailAt(self, row, col, binTrans, target=None): + """ Set the rail at cell (row, col) to have transitions binTrans. + The target argument can contain the index of the agent to indicate + that agent's target is at that cell, so that a station can be + rendered in the static rail layer. + """ pass def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut): diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 63a4be2..65235bf 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -33,12 +33,22 @@ class PILGL(GraphicsLayer): self.tColRail = (0, 0, 0) # black rails self.tColGrid = (230,) * 3 # light grey for grid + sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \ + "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64" + + self.ltAgentColors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")] + self.nAgentColors = len(self.ltAgentColors) + self.window_open = False # self.bShow = show self.firstFrame = True self.create_layers() # self.beginFrame() + def rgb_s2i(self, sRGB): + """ convert a hex RGB string like 0091ea to 3-tuple of ints """ + return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2]) + def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): color = self.adaptColor(color) if len(color) == 3: @@ -246,10 +256,17 @@ class PILSVG(PILGL): "EE WW": "Bahnhof_#d50000_Gleis_horizontal.svg", "NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"} + # Dict of rail cell images indexed by binary transitions self.dPilRail = self.loadSVGs(dRailFiles, rotate=True) - self.dPilTarget = self.loadSVGs(dTargetFiles, rotate=False) - def loadSVGs(self, dDirFile, rotate=False): + # Load the target files (which have rails and transitions of their own) + # They are indexed by (binTrans, iAgent), ie a tuple of the binary transition and the agent index + dPilRail2 = self.loadSVGs(dTargetFiles, rotate=False, agent_colors=self.ltAgentColors) + # Merge them with the regular rails. + # https://stackoverflow.com/questions/38987/how-to-merge-two-dictionaries-in-a-single-expression + self.dPilRail = {**self.dPilRail, **dPilRail2} + + def loadSVGs(self, dDirFile, rotate=False, agent_colors=False): dPil = {} transitions = RailEnvTransitions() @@ -280,9 +297,10 @@ class PILSVG(PILGL): # svg = svg.merge(svgBG) pilRail = self.pilFromSvgFile(sPathSvg) - dPil[binTrans] = pilRail - + if rotate: + # For rotations, we also store the base image + dPil[binTrans] = pilRail # Rotate both the transition binary and the image and save in the dict for nRot in [90, 180, 270]: binTrans2 = transitions.rotate_transition(binTrans, nRot) @@ -290,25 +308,44 @@ class PILSVG(PILGL): # PIL rotates anticlockwise for positive theta pilRail2 = pilRail.rotate(-nRot) dPil[binTrans2] = pilRail2 + + if agent_colors: + # For recoloring, we don't store the base image. + a3BaseColor = self.rgb_s2i("d50000") + lPils = self.recolorImage(pilRail, a3BaseColor, self.ltAgentColors) + for iColor, pilRail2 in enumerate(lPils): + dPil[(binTrans, iColor)] = lPils[iColor] + return dPil - def setRailAt(self, row, col, binTrans, target=None): - if target is None: + def setRailAt(self, row, col, binTrans, iTarget=None): + if iTarget is None: if binTrans in self.dPilRail: pilTrack = self.dPilRail[binTrans] self.drawImageRC(pilTrack, (row, col)) else: print("Illegal rail:", row, col, format(binTrans, "#018b")[2:]) else: - if binTrans in self.dPilTarget: - pilTrack = self.dPilTarget[binTrans] + if (binTrans, iTarget) in self.dPilRail: + pilTrack = self.dPilRail[(binTrans, iTarget)] self.drawImageRC(pilTrack, (row, col)) else: print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:]) - def rgb_s2i(self, sRGB): - """ convert a hex RGB string like 0091ea to 3-tuple of ints """ - return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2]) + def recolorImage(self, pil, a3BaseColor, ltColors): + rgbaImg = array(pil) + lPils = [] + + for iColor, tnColor in enumerate(ltColors): + # find the pixels which match the base paint color + xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor == 0, axis=2) + rgbaImg2 = np.copy(rgbaImg) + + # Repaint the base color with the new color + rgbaImg2[xy_color_mask, 0:3] = tnColor + pil2 = Image.fromarray(rgbaImg2) + lPils.append(pil2) + return lPils def loadAgentSVGs(self): @@ -319,13 +356,8 @@ class PILSVG(PILGL): (0, 3): "svg/Zug_2_Weiche_#0091ea.svg" } - sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \ - "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64" - lColors = sColors.split("#") - self.nAgentColors = len(lColors) - # "paint" color of the train images we load - a_base_color = self.rgb_s2i("0091ea") + a3BaseColor = self.rgb_s2i("0091ea") self.dPilZug = {} @@ -342,16 +374,13 @@ class PILSVG(PILGL): # PIL rotates anticlockwise for positive theta pilZug2 = pilZug.rotate(-nDegRot) - rgbaZug2 = array(pilZug2) - for iColor, sColor in enumerate(lColors): - tnNewColor = self.rgb_s2i(sColor) - xy_color_mask = np.all(rgbaZug2[:, :, 0:3] - a_base_color == 0, axis=2) - rgbaZug3 = np.copy(rgbaZug2) - rgbaZug3[xy_color_mask, 0:3] = tnNewColor - self.dPilZug[(iDirIn2, iDirOut2, iColor)] = Image.fromarray(rgbaZug3) + # Save colored versions of each rotation / variant + lPils = self.recolorImage(pilZug2, a3BaseColor, self.ltAgentColors) + for iColor, pilZug3 in enumerate(lPils): + self.dPilZug[(iDirIn2, iDirOut2, iColor)] = lPils[iColor] - def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, color=None): + def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut): delta_dir = (iDirOut - iDirIn) % 4 iColor = iAgent % self.nAgentColors # when flipping direction at a dead end, use the "iDirOut" direction. diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index 233f07b..ff8ebd1 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -157,7 +157,7 @@ class QTSVG(GraphicsLayer): self.lwAgents = [] self.agents_prev = [] - def setRailAt(self, row, col, binTrans): + def setRailAt(self, row, col, binTrans, target=None): if binTrans in self.track.dSvg: sSVG = self.track.dSvg[binTrans].to_string() svgWidget = create_QtSvgWidget_from_svg_string(sSVG) diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index f2200b9..37d8151 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -767,7 +767,7 @@ class RenderTool(object): target = dTargets[(r, c)] else: target = None - self.gl.setRailAt(r, c, binTrans) + self.gl.setRailAt(r, c, binTrans, iTarget=target) for iAgent, agent in enumerate(self.env.agents): if agent is None: @@ -782,8 +782,9 @@ class RenderTool(object): direction = agent.direction old_direction = agent.direction - cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) - self.gl.setAgentAt(iAgent, *position, old_direction, direction,color=cmap(iAgent)) + # setAgentAt uses the agent index for the color + # cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) + self.gl.setAgentAt(iAgent, *position, old_direction, direction) # ,color=cmap(iAgent)) if show: self.gl.show() @@ -792,3 +793,6 @@ class RenderTool(object): self.iFrame += 1 return + + def close_window(self): + self.gl.close_window() diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py index 0a6b895..e7b2ceb 100644 --- a/flatland/utils/svg.py +++ b/flatland/utils/svg.py @@ -60,7 +60,7 @@ class SVG(object): sNewStyles = "\n" for sKey, sValue in self.dStyles.items(): if sKey == style_name: - sValue = "fill:#" + "".join([ ('{:#04x}'.format(int(255.0*col))[2:4]) for col in color[0:3]]) + ";" + sValue = "fill:#" + "".join([('{:#04x}'.format(int(255.0*col))[2:4]) for col in color[0:3]]) + ";" sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n" sNewStyles += sNewStyle -- GitLab