diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 65235bf410f41770b3fbcb93265cc7b02ed92fa2..597f407225b51df657668591aff392ea89096a67 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -49,6 +49,9 @@ class PILGL(GraphicsLayer): """ convert a hex RGB string like 0091ea to 3-tuple of ints """ return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2]) + def getAgentColor(self, iAgent): + return self.ltAgentColors[iAgent % self.nAgentColors] + def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): color = self.adaptColor(color) if len(color) == 3: diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 37d81519a104856ac3081a2f6d5ae9ebd03fd44a..15b5774c0e0198706ef5a959c3b4543deda04a4b 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -498,10 +498,11 @@ class RenderTool(object): """ rt = self.__class__ - cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) + # cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) for agent in agent_handles: - color = cmap(agent) + # color = cmap(agent) + color = self.gl.getAgentColor(agent) for visited_cell in observation_dict[agent]: cell_coord = array(visited_cell[:2]) cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf @@ -786,6 +787,9 @@ class RenderTool(object): # cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) self.gl.setAgentAt(iAgent, *position, old_direction, direction) # ,color=cmap(iAgent)) + if show_observations: + self.renderObs(range(env.get_num_agents()), env.dev_obs_dict) + if show: self.gl.show() for i in range(3):