Skip to content
Snippets Groups Projects
Commit 14c757ee authored by hagrid67's avatar hagrid67
Browse files

implemented curved transitions

parent d0665e51
No related branches found
No related tags found
No related merge requests found
......@@ -145,6 +145,13 @@ class RenderTool(object):
plt.scatter(*xyNext, color=sColor)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
"""
plot the transitions in gTransRCAg at position rcPos.
gTransRCAg is a 2d numpy array containing a list of RC transitions,
eg [[-1,0], [0,1]] means N, E.
"""
rt = self.__class__
xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy/2.4)
......@@ -299,7 +306,7 @@ class RenderTool(object):
visit = visit.prev
xyPrev = xy
def renderEnv(self, show=False):
def renderEnv(self, show=False, curves=True):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
......@@ -307,11 +314,17 @@ class RenderTool(object):
Call pyplot.show() if show==True.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
# cell_size is a bit pointless with matplotlib - it does not relate to pixels,
# so for now I've changed it to 1 (from 10)
cell_size = 1
# if oFigure is None:
# oFigure = plt.figure()
gTheta = np.linspace(0, np.pi/2, 10)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def drawTrans(oFrom, oTo, sColor="gray"):
plt.plot(
[oFrom[0], oTo[0]], # x
......@@ -319,6 +332,26 @@ class RenderTool(object):
color=sColor
)
def drawTrans2(xyLine, xyCentre, rotation, sColor="gray"):
"""
gLine is a numpy 2d array of points,
in the plotting space / coords.
eg:
[[0,.5],[1,0.2]] means a line
from x=0, y=0.5
to x=1, y=0.2
"""
xyMid = np.mean(xyLine, axis=0)
dxy = xyMid - xyCentre
xyCorner = xyMid + dxy
dxy2 = xyCentre - xyCorner
bStraight = rotation in [0, 2]
if bStraight:
plt.plot(*xyLine.T, color=sColor)
else:
plt.plot(*(gArc * dxy2 + xyCorner).T, color=sColor)
RETrans = RailEnvTransitions()
env = self.env
......@@ -338,10 +371,10 @@ class RenderTool(object):
for c in range(env.width):
trans_ = env.rail[r][c]
x0 = cell_size * c
x1 = cell_size * (c+1)
y0 = cell_size * -r
y1 = cell_size * -(r+1)
x0 = cell_size * c # left
x1 = cell_size * (c+1) # right
y0 = cell_size * -r # top
y1 = cell_size * -(r+1) # bottom
coords = [
((x0+x1)/2.0, y0), # N middle top
......@@ -350,6 +383,8 @@ class RenderTool(object):
(x0, (y0+y1)/2.0) # W middle left
]
xyCentre = array([x0, y1]) + cell_size / 2
oCell = env.rail[r, c]
for orientation in range(4): # ori is where we're heading
......@@ -378,15 +413,23 @@ class RenderTool(object):
# to_ori = (orientation + 2) % 4
for to_ori in range(4):
to_xy = coords[to_ori]
rotation = (to_ori - from_ori) % 4
if False:
print("r,c,ori: ", r, c, orientation,
"cell:", "{0:b}".format(oCell),
"moves:", tMoves,
"from:", from_ori, from_xy,
"to: ", to_ori, to_xy)
if (tMoves[to_ori]):
drawTrans(from_xy, to_xy)
if curves:
drawTrans2(array([from_xy, to_xy]), xyCentre, rotation)
else:
drawTrans(from_xy, to_xy)
if False:
print(
"r,c,ori: ", r, c, orientation,
"cell:", "{0:b}".format(oCell),
"moves:", tMoves,
"from:", from_ori, from_xy,
"to: ", to_ori, to_xy,
"cen:", *xyCentre,
"rot:", rotation,
)
# Draw each agent + its orientation + its target
cmap = plt.get_cmap('hsv', lut=env.number_of_agents+1)
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment