Skip to content
Snippets Groups Projects
Commit 3b1ffed4 authored by hagrid67's avatar hagrid67
Browse files

merged a bit of refactoring in rendertools.py with TransitionMap changes

parent 994ef32a
No related branches found
No related tags found
No related merge requests found
......@@ -24,6 +24,8 @@ class RenderTool(object):
gCentres = xr.DataArray(gGrid,
dims=["xy", "p1", "p2"],
coords={"xy": ["x", "y"]}) + xyPixHalf
gTheta = np.linspace(0, np.pi/2, 10)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def __init__(self, env):
self.env = env
......@@ -54,14 +56,13 @@ class RenderTool(object):
# TODO: this was `rcDir' but it was undefined
rcNext = rcPos + iDir
# transition for next cell
tbTrans = self.env.rail. \
get_transitions((rcNext[0], rcNext[1], iDir))
tbTrans = self.env.rail.get_transitions((*rcNext, iDir))
giTrans = np.where(tbTrans)[0] # RC list of transitions
gTransRCAg = self.__class__.gTransRC[giTrans]
gTransRCAg = rt.gTransRC[giTrans]
for visit in lVisits:
# transition for next cell
tbTrans = self.env.rail.get_transitions((visit.rc[0], visit.rc[1], visit.iDir))
tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir))
giTrans = np.where(tbTrans)[0] # RC list of transitions
gTransRCAg = rt.gTransRC[giTrans]
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
......@@ -99,8 +100,7 @@ class RenderTool(object):
)
"""
# TODO: suggest we provide an accessor in RailEnv
tbTrans = self.env.rail.get_transitions((rcPos[0], rcPos[1], iDir))
tbTrans = self.env.rail.get_transitions((*rcPos, iDir))
giTrans = np.where(tbTrans)[0] # RC list of transitions
# HACK: workaround dead-end transitions
......@@ -310,7 +310,86 @@ class RenderTool(object):
visit = visit.prev
xyPrev = xy
def renderEnv(self, show=False, curves=True, spacing=False, arrows=False, agents=True):
def drawTrans2(
self,
xyLine, xyCentre,
rotation, bDeadEnd=False,
sColor="gray",
bArrow=True,
spacing=0.1):
"""
gLine is a numpy 2d array of points,
in the plotting space / coords.
eg:
[[0,.5],[1,0.2]] means a line
from x=0, y=0.5
to x=1, y=0.2
"""
rt = self.__class__
bStraight = rotation in [0, 2]
dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2
if bStraight:
if sColor == "auto":
if dx > 0 or dy > 0:
sColor = "C1" # N or E
else:
sColor = "C2" # S or W
if bDeadEnd:
xyLine2 = array([
xyLine[1] + [dy, dx],
xyCentre,
xyLine[1] - [dy, dx],
])
plt.plot(*xyLine2.T, color=sColor)
else:
xyLine2 = xyLine + [-dy, dx]
plt.plot(*xyLine2.T, color=sColor)
if bArrow:
xyMid = np.sum(xyLine2 * [[1/4], [3/4]], axis=0)
xyArrow = array([
xyMid + [-dx-dy, +dx-dy],
xyMid,
xyMid + [-dx+dy, -dx-dy]
])
plt.plot(*xyArrow.T, color=sColor)
else:
xyMid = np.mean(xyLine, axis=0)
dxy = xyMid - xyCentre
xyCorner = xyMid + dxy
if rotation == 1:
rArcFactor = 1 - spacing
sColorAuto = "C1"
else:
rArcFactor = 1 + spacing
sColorAuto = "C2"
dxy2 = (xyCentre - xyCorner) * rArcFactor # for scaling the arc
if sColor == "auto":
sColor = sColorAuto
plt.plot(*(rt.gArc * dxy2 + xyCorner).T, color=sColor)
if bArrow:
dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20
iArc = int(len(rt.gArc) / 2)
xyMid = xyCorner + rt.gArc[iArc] * dxy2
xyArrow = array([
xyMid + [-dx-dy, +dx-dy],
xyMid,
xyMid + [-dx+dy, -dx-dy]
])
plt.plot(*xyArrow.T, color=sColor)
def renderEnv(
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, sRailColor="gray"):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
......@@ -326,9 +405,6 @@ class RenderTool(object):
# if oFigure is None:
# oFigure = plt.figure()
gTheta = np.linspace(0, np.pi/2, 10)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def drawTrans(oFrom, oTo, sColor="gray"):
plt.plot(
[oFrom[0], oTo[0]], # x
......@@ -336,70 +412,6 @@ class RenderTool(object):
color=sColor
)
def drawTrans2(
xyLine, xyCentre,
rotation, bDeadEnd=False,
sColor="gray",
bArrow=True,
spacing=0.1):
"""
gLine is a numpy 2d array of points,
in the plotting space / coords.
eg:
[[0,.5],[1,0.2]] means a line
from x=0, y=0.5
to x=1, y=0.2
"""
bStraight = rotation in [0, 2]
if bStraight:
dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2
if bDeadEnd:
xyLine2 = array([
xyLine[1] + [dy, dx],
xyCentre,
xyLine[1] - [dy, dx],
])
plt.plot(*xyLine2.T, color=sColor)
else:
xyLine2 = xyLine + [dy, dx]
plt.plot(*xyLine2.T, color=sColor)
if bArrow:
xyMid = np.sum(xyLine2 * [[1/4], [3/4]], axis=0)
xyArrow = array([
xyMid + [-dx-dy, +dx-dy],
xyMid,
xyMid + [-dx+dy, -dx-dy]
])
plt.plot(*xyArrow.T, color=sColor)
else:
xyMid = np.mean(xyLine, axis=0)
dxy = xyMid - xyCentre
xyCorner = xyMid + dxy
if rotation == 1:
rArcFactor = 1 - spacing
else:
rArcFactor = 1 + spacing
dxy2 = (xyCentre - xyCorner) * rArcFactor # for scaling the arc
plt.plot(*(gArc * dxy2 + xyCorner).T, color=sColor)
if bArrow:
dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20
iArc = int(len(gArc) / 2)
xyMid = xyCorner + gArc[iArc] * dxy2
xyArrow = array([
xyMid + [-dx-dy, +dx-dy],
xyMid,
xyMid + [-dx+dy, -dx-dy]
])
plt.plot(*xyArrow.T, color=sColor)
env = self.env
# Draw cells grid
......@@ -466,18 +478,20 @@ class RenderTool(object):
if (tMoves[to_ori]): # if we have this transition
if bDeadEnd:
drawTrans2(
self.drawTrans2(
array([from_xy, to_xy]), xyCentre,
rotation, bDeadEnd=True, spacing=spacing)
rotation, bDeadEnd=True, spacing=spacing,
sColor=sRailColor)
else:
if curves:
drawTrans2(
self.drawTrans2(
array([from_xy, to_xy]), xyCentre,
rotation, spacing=spacing, bArrow=arrows)
rotation, spacing=spacing, bArrow=arrows,
sColor=sRailColor)
else:
drawTrans(from_xy, to_xy)
drawTrans(from_xy, to_xy, sRailColor)
if False:
print(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment