Skip to content
Snippets Groups Projects
Commit 14c757ee authored by hagrid67's avatar hagrid67
Browse files

implemented curved transitions

parent d0665e51
No related branches found
No related tags found
No related merge requests found
...@@ -145,6 +145,13 @@ class RenderTool(object): ...@@ -145,6 +145,13 @@ class RenderTool(object):
plt.scatter(*xyNext, color=sColor) plt.scatter(*xyNext, color=sColor)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None): def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
"""
plot the transitions in gTransRCAg at position rcPos.
gTransRCAg is a 2d numpy array containing a list of RC transitions,
eg [[-1,0], [0,1]] means N, E.
"""
rt = self.__class__ rt = self.__class__
xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy/2.4) gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy/2.4)
...@@ -299,7 +306,7 @@ class RenderTool(object): ...@@ -299,7 +306,7 @@ class RenderTool(object):
visit = visit.prev visit = visit.prev
xyPrev = xy xyPrev = xy
def renderEnv(self, show=False): def renderEnv(self, show=False, curves=True):
""" """
Draw the environment using matplotlib. Draw the environment using matplotlib.
Draw into the figure if provided. Draw into the figure if provided.
...@@ -307,11 +314,17 @@ class RenderTool(object): ...@@ -307,11 +314,17 @@ class RenderTool(object):
Call pyplot.show() if show==True. Call pyplot.show() if show==True.
(Use show=False from a Jupyter notebook with %matplotlib inline) (Use show=False from a Jupyter notebook with %matplotlib inline)
""" """
# cell_size is a bit pointless with matplotlib - it does not relate to pixels,
# so for now I've changed it to 1 (from 10)
cell_size = 1 cell_size = 1
# if oFigure is None: # if oFigure is None:
# oFigure = plt.figure() # oFigure = plt.figure()
gTheta = np.linspace(0, np.pi/2, 10)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def drawTrans(oFrom, oTo, sColor="gray"): def drawTrans(oFrom, oTo, sColor="gray"):
plt.plot( plt.plot(
[oFrom[0], oTo[0]], # x [oFrom[0], oTo[0]], # x
...@@ -319,6 +332,26 @@ class RenderTool(object): ...@@ -319,6 +332,26 @@ class RenderTool(object):
color=sColor color=sColor
) )
def drawTrans2(xyLine, xyCentre, rotation, sColor="gray"):
"""
gLine is a numpy 2d array of points,
in the plotting space / coords.
eg:
[[0,.5],[1,0.2]] means a line
from x=0, y=0.5
to x=1, y=0.2
"""
xyMid = np.mean(xyLine, axis=0)
dxy = xyMid - xyCentre
xyCorner = xyMid + dxy
dxy2 = xyCentre - xyCorner
bStraight = rotation in [0, 2]
if bStraight:
plt.plot(*xyLine.T, color=sColor)
else:
plt.plot(*(gArc * dxy2 + xyCorner).T, color=sColor)
RETrans = RailEnvTransitions() RETrans = RailEnvTransitions()
env = self.env env = self.env
...@@ -338,10 +371,10 @@ class RenderTool(object): ...@@ -338,10 +371,10 @@ class RenderTool(object):
for c in range(env.width): for c in range(env.width):
trans_ = env.rail[r][c] trans_ = env.rail[r][c]
x0 = cell_size * c x0 = cell_size * c # left
x1 = cell_size * (c+1) x1 = cell_size * (c+1) # right
y0 = cell_size * -r y0 = cell_size * -r # top
y1 = cell_size * -(r+1) y1 = cell_size * -(r+1) # bottom
coords = [ coords = [
((x0+x1)/2.0, y0), # N middle top ((x0+x1)/2.0, y0), # N middle top
...@@ -350,6 +383,8 @@ class RenderTool(object): ...@@ -350,6 +383,8 @@ class RenderTool(object):
(x0, (y0+y1)/2.0) # W middle left (x0, (y0+y1)/2.0) # W middle left
] ]
xyCentre = array([x0, y1]) + cell_size / 2
oCell = env.rail[r, c] oCell = env.rail[r, c]
for orientation in range(4): # ori is where we're heading for orientation in range(4): # ori is where we're heading
...@@ -378,15 +413,23 @@ class RenderTool(object): ...@@ -378,15 +413,23 @@ class RenderTool(object):
# to_ori = (orientation + 2) % 4 # to_ori = (orientation + 2) % 4
for to_ori in range(4): for to_ori in range(4):
to_xy = coords[to_ori] to_xy = coords[to_ori]
rotation = (to_ori - from_ori) % 4
if False:
print("r,c,ori: ", r, c, orientation,
"cell:", "{0:b}".format(oCell),
"moves:", tMoves,
"from:", from_ori, from_xy,
"to: ", to_ori, to_xy)
if (tMoves[to_ori]): if (tMoves[to_ori]):
drawTrans(from_xy, to_xy) if curves:
drawTrans2(array([from_xy, to_xy]), xyCentre, rotation)
else:
drawTrans(from_xy, to_xy)
if False:
print(
"r,c,ori: ", r, c, orientation,
"cell:", "{0:b}".format(oCell),
"moves:", tMoves,
"from:", from_ori, from_xy,
"to: ", to_ori, to_xy,
"cen:", *xyCentre,
"rot:", rotation,
)
# Draw each agent + its orientation + its target # Draw each agent + its orientation + its target
cmap = plt.get_cmap('hsv', lut=env.number_of_agents+1) cmap = plt.get_cmap('hsv', lut=env.number_of_agents+1)
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment