From 3b1ffed4e6af0e70c34fe0bc124548dd16e09bf7 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Wed, 17 Apr 2019 20:22:48 +0100 Subject: [PATCH] merged a bit of refactoring in rendertools.py with TransitionMap changes --- flatland/utils/rendertools.py | 172 ++++++++++++++++++---------------- 1 file changed, 93 insertions(+), 79 deletions(-) diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 771f228f..d0a78917 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -24,6 +24,8 @@ class RenderTool(object): gCentres = xr.DataArray(gGrid, dims=["xy", "p1", "p2"], coords={"xy": ["x", "y"]}) + xyPixHalf + gTheta = np.linspace(0, np.pi/2, 10) + gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] def __init__(self, env): self.env = env @@ -54,14 +56,13 @@ class RenderTool(object): # TODO: this was `rcDir' but it was undefined rcNext = rcPos + iDir # transition for next cell - tbTrans = self.env.rail. \ - get_transitions((rcNext[0], rcNext[1], iDir)) + tbTrans = self.env.rail.get_transitions((*rcNext, iDir)) giTrans = np.where(tbTrans)[0] # RC list of transitions - gTransRCAg = self.__class__.gTransRC[giTrans] + gTransRCAg = rt.gTransRC[giTrans] for visit in lVisits: # transition for next cell - tbTrans = self.env.rail.get_transitions((visit.rc[0], visit.rc[1], visit.iDir)) + tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir)) giTrans = np.where(tbTrans)[0] # RC list of transitions gTransRCAg = rt.gTransRC[giTrans] self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color) @@ -99,8 +100,7 @@ class RenderTool(object): ) """ - # TODO: suggest we provide an accessor in RailEnv - tbTrans = self.env.rail.get_transitions((rcPos[0], rcPos[1], iDir)) + tbTrans = self.env.rail.get_transitions((*rcPos, iDir)) giTrans = np.where(tbTrans)[0] # RC list of transitions # HACK: workaround dead-end transitions @@ -310,7 +310,86 @@ class RenderTool(object): visit = visit.prev xyPrev = xy - def renderEnv(self, show=False, curves=True, spacing=False, arrows=False, agents=True): + def drawTrans2( + self, + xyLine, xyCentre, + rotation, bDeadEnd=False, + sColor="gray", + bArrow=True, + spacing=0.1): + """ + gLine is a numpy 2d array of points, + in the plotting space / coords. + eg: + [[0,.5],[1,0.2]] means a line + from x=0, y=0.5 + to x=1, y=0.2 + """ + rt = self.__class__ + bStraight = rotation in [0, 2] + dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2 + + if bStraight: + + if sColor == "auto": + if dx > 0 or dy > 0: + sColor = "C1" # N or E + else: + sColor = "C2" # S or W + + if bDeadEnd: + xyLine2 = array([ + xyLine[1] + [dy, dx], + xyCentre, + xyLine[1] - [dy, dx], + ]) + plt.plot(*xyLine2.T, color=sColor) + else: + xyLine2 = xyLine + [-dy, dx] + plt.plot(*xyLine2.T, color=sColor) + + if bArrow: + xyMid = np.sum(xyLine2 * [[1/4], [3/4]], axis=0) + + xyArrow = array([ + xyMid + [-dx-dy, +dx-dy], + xyMid, + xyMid + [-dx+dy, -dx-dy] + ]) + plt.plot(*xyArrow.T, color=sColor) + + else: + + xyMid = np.mean(xyLine, axis=0) + dxy = xyMid - xyCentre + xyCorner = xyMid + dxy + if rotation == 1: + rArcFactor = 1 - spacing + sColorAuto = "C1" + else: + rArcFactor = 1 + spacing + sColorAuto = "C2" + dxy2 = (xyCentre - xyCorner) * rArcFactor # for scaling the arc + + if sColor == "auto": + sColor = sColorAuto + + plt.plot(*(rt.gArc * dxy2 + xyCorner).T, color=sColor) + + if bArrow: + dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20 + iArc = int(len(rt.gArc) / 2) + xyMid = xyCorner + rt.gArc[iArc] * dxy2 + xyArrow = array([ + xyMid + [-dx-dy, +dx-dy], + xyMid, + xyMid + [-dx+dy, -dx-dy] + ]) + plt.plot(*xyArrow.T, color=sColor) + + def renderEnv( + self, show=False, curves=True, spacing=False, + arrows=False, agents=True, sRailColor="gray"): """ Draw the environment using matplotlib. Draw into the figure if provided. @@ -326,9 +405,6 @@ class RenderTool(object): # if oFigure is None: # oFigure = plt.figure() - gTheta = np.linspace(0, np.pi/2, 10) - gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] - def drawTrans(oFrom, oTo, sColor="gray"): plt.plot( [oFrom[0], oTo[0]], # x @@ -336,70 +412,6 @@ class RenderTool(object): color=sColor ) - def drawTrans2( - xyLine, xyCentre, - rotation, bDeadEnd=False, - sColor="gray", - bArrow=True, - spacing=0.1): - """ - gLine is a numpy 2d array of points, - in the plotting space / coords. - eg: - [[0,.5],[1,0.2]] means a line - from x=0, y=0.5 - to x=1, y=0.2 - """ - - bStraight = rotation in [0, 2] - if bStraight: - - dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2 - if bDeadEnd: - xyLine2 = array([ - xyLine[1] + [dy, dx], - xyCentre, - xyLine[1] - [dy, dx], - ]) - plt.plot(*xyLine2.T, color=sColor) - else: - xyLine2 = xyLine + [dy, dx] - plt.plot(*xyLine2.T, color=sColor) - - if bArrow: - xyMid = np.sum(xyLine2 * [[1/4], [3/4]], axis=0) - - xyArrow = array([ - xyMid + [-dx-dy, +dx-dy], - xyMid, - xyMid + [-dx+dy, -dx-dy] - ]) - plt.plot(*xyArrow.T, color=sColor) - - else: - - xyMid = np.mean(xyLine, axis=0) - dxy = xyMid - xyCentre - xyCorner = xyMid + dxy - if rotation == 1: - rArcFactor = 1 - spacing - else: - rArcFactor = 1 + spacing - dxy2 = (xyCentre - xyCorner) * rArcFactor # for scaling the arc - - plt.plot(*(gArc * dxy2 + xyCorner).T, color=sColor) - - if bArrow: - dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20 - iArc = int(len(gArc) / 2) - xyMid = xyCorner + gArc[iArc] * dxy2 - xyArrow = array([ - xyMid + [-dx-dy, +dx-dy], - xyMid, - xyMid + [-dx+dy, -dx-dy] - ]) - plt.plot(*xyArrow.T, color=sColor) - env = self.env # Draw cells grid @@ -466,18 +478,20 @@ class RenderTool(object): if (tMoves[to_ori]): # if we have this transition if bDeadEnd: - drawTrans2( + self.drawTrans2( array([from_xy, to_xy]), xyCentre, - rotation, bDeadEnd=True, spacing=spacing) + rotation, bDeadEnd=True, spacing=spacing, + sColor=sRailColor) else: if curves: - drawTrans2( + self.drawTrans2( array([from_xy, to_xy]), xyCentre, - rotation, spacing=spacing, bArrow=arrows) + rotation, spacing=spacing, bArrow=arrows, + sColor=sRailColor) else: - drawTrans(from_xy, to_xy) + drawTrans(from_xy, to_xy, sRailColor) if False: print( -- GitLab