diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 1bd67c22ad4ba05bd6badfc991e8839a85fe9e30..d64601c7ee32d0f611aa98d1577eb785eea870f3 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -10,6 +10,7 @@ from flatland.utils.render_qt import QTGL, QTSVG from flatland.utils.graphics_pil import PILGL from flatland.utils.graphics_layer import GraphicsLayer + # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! @@ -100,12 +101,12 @@ class RenderTool(object): lColors = list("brgcmyk") # \delta RC for NESW gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) - nPixCell = 1 # misnomer... + nPixCell = 1 # misnomer... nPixHalf = nPixCell / 2 xyHalf = array([nPixHalf, -nPixHalf]) grc2xy = array([[0, -nPixCell], [nPixCell, 0]]) gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * \ - array([[[nPixCell]], [[nPixCell]]]) + array([[[nPixCell]], [[nPixCell]]]) # xyPixHalf = xr.DataArray([nPixHalf, -nPixHalf], # dims="xy", # coords={"xy": ["x", "y"]}) @@ -130,7 +131,7 @@ class RenderTool(object): self.gl = PILGL(env.width, env.height) elif gl == "QTSVG": self.gl = QTSVG(env.width, env.height) - + self.new_rail = True def set_new_rail(self): @@ -153,14 +154,14 @@ class RenderTool(object): def plotAgents(self, targets=True, iSelectedAgent=None): cmap = self.gl.get_cmap('hsv', - lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) + lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) for iAgent, agent in enumerate(self.env.agents_static): if agent is None: continue oColor = cmap(iAgent) self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None, - static=True, selected=iAgent == iSelectedAgent) + static=True, selected=iAgent == iSelectedAgent) for iAgent, agent in enumerate(self.env.agents): if agent is None: @@ -211,8 +212,8 @@ class RenderTool(object): """ rt = self.__class__ - rcDir = rt.gTransRC[iDir] # agent direction in RC - xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy + rcDir = rt.gTransRC[iDir] # agent direction in RC + xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf @@ -220,7 +221,7 @@ class RenderTool(object): color = self.gl.adaptColor(color, lighten=True) # print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos) - self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location + self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient. self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6) if selected: @@ -398,12 +399,12 @@ class RenderTool(object): xyPrev = xy def drawTrans2( - self, - xyLine, xyCentre, - rotation, bDeadEnd=False, - sColor="gray", - bArrow=True, - spacing=0.1): + self, + xyLine, xyCentre, + rotation, bDeadEnd=False, + sColor="gray", + bArrow=True, + spacing=0.1): """ gLine is a numpy 2d array of points, in the plotting space / coords. @@ -420,9 +421,9 @@ class RenderTool(object): if sColor == "auto": if dx > 0 or dy > 0: - sColor = "C1" # N or E + sColor = "C1" # N or E else: - sColor = "C2" # S or W + sColor = "C2" # S or W if bDeadEnd: xyLine2 = array([ @@ -471,6 +472,7 @@ class RenderTool(object): xyMid, xyMid + [-dx + dy, -dx - dy]]) self.gl.plot(*xyArrow.T, color=sColor) + def renderObs(self, agent_handles, observation_list): """ @@ -480,21 +482,21 @@ class RenderTool(object): """ rt = self.__class__ - cmap = self.gl.get_cmap('hsv',lut=max(len(self.env.agents),len(self.env.agents_static)+1)) + cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) for agent in agent_handles: color = cmap(agent) for visited_cell in observation_list[agent]: cell_coord = array(visited_cell[:2]) - cell_coord_trans = np.matmul(cell_coord,rt.grc2xy)+rt.xyHalf - self._draw_square(cell_coord_trans,1 / 3, color) + cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf + self._draw_square(cell_coord_trans, 1 / 3, color) def renderEnv( - self, show=False, curves=True, spacing=False, - arrows=False, agents=True, sRailColor="gray", - frames=False, iEpisode=None, iStep=None, - iSelectedAgent=None, - action_dict=None): + self, show=False, curves=True, spacing=False, + arrows=False, agents=True, obsrender=True, sRailColor="gray", + frames=False, iEpisode=None, iStep=None, + iSelectedAgent=None, + action_dict=None): """ Draw the environment using matplotlib. Draw into the figure if provided. @@ -505,15 +507,16 @@ class RenderTool(object): if not self.gl.is_raster(): self.renderEnv2(show, curves, spacing, - arrows, agents, sRailColor, - frames, iEpisode, iStep, - iSelectedAgent, action_dict) + arrows, agents, sRailColor, + frames, iEpisode, iStep, + iSelectedAgent, action_dict) return # cell_size is a bit pointless with matplotlib - it does not relate to pixels, # so for now I've changed it to 1 (from 10) cell_size = 1 self.gl.beginFrame() + # self.gl.clf() # if oFigure is None: # oFigure = self.gl.figure() @@ -545,9 +548,9 @@ class RenderTool(object): for c in range(env.width): # bounding box of the grid cell - x0 = cell_size * c # left - x1 = cell_size * (c + 1) # right - y0 = cell_size * -r # top + x0 = cell_size * c # left + x1 = cell_size * (c + 1) # right + y0 = cell_size * -r # top y1 = cell_size * -(r + 1) # bottom # centres of cell edges @@ -555,7 +558,7 @@ class RenderTool(object): ((x0 + x1) / 2.0, y0), # N middle top (x1, (y0 + y1) / 2.0), # E middle right ((x0 + x1) / 2.0, y1), # S middle bottom - (x0, (y0 + y1) / 2.0) # W middle left + (x0, (y0 + y1) / 2.0) # W middle left ] # cell centre @@ -628,8 +631,8 @@ class RenderTool(object): # Draw each agent + its orientation + its target if agents: self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent) - - self.renderObs(range(env.get_num_agents()), env.dev_obs_dict) + if obsrender: + self.renderObs(range(env.get_num_agents()), env.dev_obs_dict) # Draw some textual information like fps yText = [-0.3, -0.6, -0.9] if frames: @@ -683,18 +686,18 @@ class RenderTool(object): gP0 = array([[0, 0, 0]]).T nDepth = 2 for i in range(nDepth): - nDepthNodes = nBranchFactor**i + nDepthNodes = nBranchFactor ** i # rScale = nBranchFactor ** (nDepth - i) - rShrinkDepth = 1/(i+1) + rShrinkDepth = 1 / (i + 1) # gX1 = np.linspace(-nDepthNodes / 2, nDepthNodes / 2, nDepthNodes) * rShrinkDepth - - gX1 = np.linspace(-(nDepthNodes-1), (nDepthNodes-1), nDepthNodes) * rShrinkDepth + + gX1 = np.linspace(-(nDepthNodes - 1), (nDepthNodes - 1), nDepthNodes) * rShrinkDepth gY1 = np.ones((nDepthNodes)) * i gZ1 = np.zeros((nDepthNodes)) - + gP1 = array([gX1, gY1, gZ1]) gP01 = np.append(gP0, gP1, axis=1) - + if nDepthNodes > 1: nDepthNodesPrev = nDepthNodes / nBranchFactor giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor) @@ -705,13 +708,13 @@ class RenderTool(object): self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray") gP0 = array([gX1, gY1, gZ1]) - + def renderEnv2( - self, show=False, curves=True, spacing=False, - arrows=False, agents=True, sRailColor="gray", - frames=False, iEpisode=None, iStep=None, - iSelectedAgent=None, - action_dict=dict()): + self, show=False, curves=True, spacing=False, + arrows=False, agents=True, sRailColor="gray", + frames=False, iEpisode=None, iStep=None, + iSelectedAgent=None, + action_dict=dict()): """ Draw the environment using matplotlib. Draw into the figure if provided. @@ -728,12 +731,11 @@ class RenderTool(object): # Draw each cell independently for r in range(env.height): for c in range(env.width): - binTrans = env.rail.grid[r, c] self.gl.setRailAt(r, c, binTrans) cmap = self.gl.get_cmap('hsv', - lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) + lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) for iAgent, agent in enumerate(self.env.agents): if agent is None: @@ -747,14 +749,14 @@ class RenderTool(object): if iAgent in action_dict: iAction = action_dict[iAgent] new_direction, action_isValid = self.env.check_action(agent, iAction) - + if action_isValid: self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction, color=oColor) else: pass # print("invalid action - agent ", iAgent, " bend ", agent.direction, new_direction) # self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction) - + self.gl.show() for i in range(3): self.gl.processEvents()