Skip to content
Snippets Groups Projects
Commit aa6f5e62 authored by spiglerg's avatar spiglerg
Browse files

fixed pylint errors in rendertools.py

parent fdc4d2af
No related branches found
No related tags found
No related merge requests found
from recordtype import recordtype
import numpy as np
......@@ -9,21 +6,25 @@ import xarray as xr
import matplotlib.pyplot as plt
from flatland.core.transitions import RailEnvTransitions
class RenderTool(object):
class RenderTool(object):
Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])
lColors = list("brgcmyk")
gTransRC = np.array([[-1,0],[0,1],[1,0],[0,-1]]) # \delta RC for NESW
# \delta RC for NESW
gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
nPixCell = 1
nPixHalf = nPixCell / 2
xyHalf = array([nPixHalf, -nPixHalf])
grc2xy = array([[0,-nPixCell],[nPixCell,0]])
gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[nPixCell]],[[nPixCell]]])
xyPixHalf = xr.DataArray([nPixHalf, -nPixHalf], dims="xy", coords={"xy": ["x", "y"]})
gCentres = xr.DataArray(gGrid, dims=["xy", "p1", "p2"], coords={"xy":["x", "y"]}) + xyPixHalf
grc2xy = array([[0, -nPixCell], [nPixCell, 0]])
gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * \
array([[[nPixCell]], [[nPixCell]]])
xyPixHalf = xr.DataArray([nPixHalf, -nPixHalf],
dims="xy",
coords={"xy": ["x", "y"]})
gCentres = xr.DataArray(gGrid,
dims=["xy", "p1", "p2"],
coords={"xy": ["x", "y"]}) + xyPixHalf
RETrans = RailEnvTransitions()
def __init__(self, env):
......@@ -31,95 +32,93 @@ class RenderTool(object):
def plotTreeOnRail(self, rcPos, iDir, nDepth=10):
"""
Derives and plots a tree of transitions starting at position rcPos in direction iDir
Derives and plots a tree of transitions starting at position rcPos
in direction iDir.
Returns a list of Visits which are the nodes / vertices in the tree.
"""
#gGrid = np.meshgrid(np.arange(10), -np.arange(10))
rt=self.__class__
#plt.scatter(*rt.gCentres, s=5, color="r")
# gGrid = np.meshgrid(np.arange(10), -np.arange(10))
rt = self.__class__
# plt.scatter(*rt.gCentres, s=5, color="r")
if False:
for iAgent in range(self.env.number_of_agents):
sColor = rt.lColors[iAgent]
rcPos = self.env.agents_position[iAgent]
iDir = self.env.agents_direction[iAgent] # agent direction index
iDir = self.env.agents_direction[iAgent] # agent dir index
self.plotAgent(rcPos, iDir, sColor)
gTransRCAg = self.getTransRC(rcPos, iDir)
self.plotTrans(rcPos, gTransRCAg)
if False:
rcNext = rcPos + rcDir
oTrans = self.env.rail[rcNext[0]][rcNext[1]] # transition for next cell
tbTrans = RailEnvTransitions.get_transitions_from_orientation(oTrans, iDir)
# TODO: this was `rcDir' but it was undefined
rcNext = rcPos + iDir
# transition for next cell
oTrans = self.env.rail[rcNext[0]][rcNext[1]]
tbTrans = RailEnvTransitions. \
get_transitions_from_orientation(oTrans, iDir)
giTrans = np.where(tbTrans)[0] # RC list of transitions
#print("agent", iAgent, array(list("NESW"))[giTrans], gTransRC[giTrans])
gTransRCAg = gTransRC[giTrans]
gTransRCAg = self.__class__.gTransRC[giTrans]
#rcPos=(6,4)
#iDir=2
# rcPos=(6,4)
# iDir=2
gTransRCAg = self.getTransRC(rcPos, iDir)
self.plotTrans(rcPos, gTransRCAg )
self.plotTrans(rcPos, gTransRCAg)
lVisits = self.getTreeFromRail(rcPos, iDir, nDepth=nDepth)
return lVisits
def plotAgents(self):
rt=self.__class__
#plt.scatter(*rt.gCentres, s=5, color="r")
rt = self.__class__
# plt.scatter(*rt.gCentres, s=5, color="r")
for iAgent in range(self.env.number_of_agents):
sColor = rt.lColors[iAgent]
rcPos = self.env.agents_position[iAgent]
iDir = self.env.agents_direction[iAgent] # agent direction index
iDir = self.env.agents_direction[iAgent] # agent direction index
self.plotAgent(rcPos, iDir, sColor)
gTransRCAg = self.getTransRC(rcPos, iDir)
self.plotTrans(rcPos, gTransRCAg)
def getTransRC(self, rcPos, iDir, bgiTrans=False):
"""
get the available transitions for rcPos in direction iDir,
as row & col deltas.
if bgiTrans is True, return a grid of indices of available transitions.
eg for a cell rcPos = (4,5), in direction iDir = 0 (N),
where the available transitions are N and E, returns:
[[-1,0], [0,1]] ie N=up one row, and E=right one col.
and if bgiTrans is True, returns a tuple:
(
[[-1,0], [0,1]], # deltas as before
[0, 1] # available transition indices, ie N, E
)
"""
Get the available transitions for rcPos in direction iDir,
as row & col deltas.
If bgiTrans is True, return a grid of indices of available transitions.
eg for a cell rcPos = (4,5), in direction iDir = 0 (N),
where the available transitions are N and E, returns:
[[-1,0], [0,1]] ie N=up one row, and E=right one col.
and if bgiTrans is True, returns a tuple:
(
[[-1,0], [0,1]], # deltas as before
[0, 1] # available transition indices, ie N, E
)
"""
rt = self.__class__
gTransRC = np.array([[-1,0],[0,1],[1,0],[0,-1]]) # \delta RC for NESW
# TODO: suggest we provide an accessor in RailEnv
oTrans = self.env.rail[rcPos] # transition for current cell
oTrans = self.env.rail[rcPos] # transition for current cell
tbTrans = rt.RETrans.get_transitions_from_orientation(oTrans, iDir)
giTrans = np.where(tbTrans)[0] # RC list of transitions
# HACK: workaround dead-end transitions
if len(giTrans) == 0:
#print("Dead End", rcPos, iDir, tbTrans, giTrans)
# print("Dead End", rcPos, iDir, tbTrans, giTrans)
iDirReverse = (iDir + 2) % 4
tbTrans = tuple(int(iDir2 == iDirReverse) for iDir2 in range(4) )
tbTrans = tuple(int(iDir2 == iDirReverse) for iDir2 in range(4))
giTrans = np.where(tbTrans)[0] # RC list of transitions
#print("Dead End2", rcPos, iDirReverse, tbTrans, giTrans)
#print("agent", array(list("NESW"))[giTrans], gTransRC[giTrans])
gTransRCAg = gTransRC[giTrans]
# print("Dead End2", rcPos, iDirReverse, tbTrans, giTrans)
# print("agent", array(list("NESW"))[giTrans], self.gTransRC[giTrans])
gTransRCAg = self.__class__.gTransRC[giTrans]
if bgiTrans:
return gTransRCAg, giTrans
......@@ -133,33 +132,27 @@ class RenderTool(object):
"""
rt = self.__class__
xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
plt.scatter(*xyPos, color=sColor) # agent location
rcDir = rt.gTransRC[iDir] # agent direction in RC
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
xyDirLine = array([xyPos, xyPos+xyDir/2]).T # xy line showing agent orientation
plt.scatter(*xyPos, color=sColor) # agent location
rcDir = rt.gTransRC[iDir] # agent direction in RC
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
xyDirLine = array([xyPos, xyPos+xyDir/2]).T # line for agent orient.
plt.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6)
# just mark the next cell we're heading into
rcNext = rcPos + rcDir
xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf
plt.scatter(*xyNext, color = sColor)
plt.scatter(*xyNext, color=sColor)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
"""
"""
rt = self.__class__
xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy/2.4)
#print(gxyTrans)
plt.scatter(*gxyTrans.T, color = color, marker="o", s=50, alpha=0.2)
# print(gxyTrans)
plt.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2)
if depth is not None:
for x,y in gxyTrans:
plt.text(x,y,depth)
for x, y in gxyTrans:
plt.text(x, y, depth)
def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True):
"""
......@@ -167,48 +160,53 @@ class RenderTool(object):
"""
rt = self.__class__
print(rcPos, iDir)
iPos = 0 if bBFS else -1 # BF / DF Search
iPos = 0 if bBFS else -1 # BF / DF Search
iDepth = 0
visited = set()
lVisits = []
#stack = [ (rcPos,iDir,nDepth) ]
stack = [ rt.Visit(rcPos,iDir,iDepth,None) ]
# stack = [ (rcPos,iDir,nDepth) ]
stack = [rt.Visit(rcPos, iDir, iDepth, None)]
while stack:
visit = stack.pop(iPos)
rcd = (visit.rc, visit.iDir)
if visit.iDepth > nDepth:
continue
lVisits.append(visit)
if rcd not in visited:
visited.add(rcd)
#moves = self._get_valid_transitions( node[0], node[1] )
gTransRCAg, giTrans = self.getTransRC(visit.rc, visit.iDir, bgiTrans=True)
#nodePos = node[0]
# moves = self._get_valid_transitions( node[0], node[1] )
gTransRCAg, giTrans = self.getTransRC(visit.rc,
visit.iDir,
bgiTrans=True)
# nodePos = node[0]
# enqueue the next nodes (ie transitions from this node)
for gTransRC2, iTrans in zip(gTransRCAg, giTrans):
#print("Trans:", gTransRC2)
visitNext = rt.Visit(tuple(visit.rc + gTransRC2), iTrans, visit.iDepth+1, visit)
#print("node2: ", node2)
stack.append( visitNext )
# print("Trans:", gTransRC2)
visitNext = rt.Visit(tuple(visit.rc + gTransRC2),
iTrans,
visit.iDepth+1,
visit)
# print("node2: ", node2)
stack.append(visitNext)
# plot the available transitions from this node
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth))
return lVisits
def plotTree(self, lVisits, xyTarg):
'''
Plot a vertical tree of transitions.
Returns the "visit" to the destination (ie where euclidean distance is near zero) or None if absent.
Returns the "visit" to the destination
(ie where euclidean distance is near zero) or None if absent.
'''
dPos = {}
iPos = 0
visitDest = None
......@@ -220,30 +218,33 @@ class RenderTool(object):
xLoc = dPos[visit.rc] = iPos
iPos += 1
rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
sDist = "%.1f" % rDist
rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
# sDist = "%.1f" % rDist
xLoc = rDist + visit.iDir / 4
# point labelled with distance
plt.scatter(xLoc, visit.iDepth, color="k", s=2)
#plt.text(xLoc, visit.iDepth, sDist, color="k", rotation=45)
# plt.text(xLoc, visit.iDepth, sDist, color="k", rotation=45)
plt.text(xLoc, visit.iDepth, visit.rc, color="k", rotation=45)
#if len(dPos)>1:
# if len(dPos)>1:
if visit.prev:
#print(dPos)
#print(tNodeDepth)
# print(dPos)
# print(tNodeDepth)
xLocPrev = dPos[visit.prev.rc]
rDistPrev = np.linalg.norm(array(visit.prev.rc) - array(xyTarg))
sDist = "%.1f" % rDistPrev
rDistPrev = np.linalg.norm(array(visit.prev.rc) -
array(xyTarg))
# sDist = "%.1f" % rDistPrev
xLocPrev = rDistPrev + visit.prev.iDir / 4
# line from prev node
plt.plot([xLocPrev, xLoc], [ visit.iDepth-1, visit.iDepth], color="k", alpha=0.5, lw=1)
plt.plot([xLocPrev, xLoc],
[visit.iDepth-1, visit.iDepth],
color="k", alpha=0.5, lw=1)
if rDist < 0.1:
visitDest = visit
......@@ -252,16 +253,17 @@ class RenderTool(object):
visit = visitDest
xLocPrev = None
while visit is not None:
rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
xLoc = rDist + visit.iDir / 4
if xLocPrev is not None:
plt.plot([xLoc, xLocPrev], [ visit.iDepth, visit.iDepth+1], color="r", alpha=0.5, lw=2)
plt.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth+1],
color="r", alpha=0.5, lw=2)
xLocPrev = xLoc
visit = visit.prev
# prev = prev.prev
# prev = prev.prev
#plt.xticks(range(7)); plt.yticks(range(11))
ax=plt.gca()
# plt.xticks(range(7)); plt.yticks(range(11))
ax = plt.gca()
plt.xticks(range(int(ax.get_xlim()[1])+1))
plt.yticks(range(int(ax.get_ylim()[1])+1))
plt.grid()
......@@ -276,33 +278,34 @@ class RenderTool(object):
visit = visitDest
xyPrev = None
while visit.prev is not None:
#rDist = np.linalg.norm(array(visit.xy) - array(xyTarg))
#xLoc = rDist + visit.iDir / 4
#print (visit.xy)
# rDist = np.linalg.norm(array(visit.xy) - array(xyTarg))
# xLoc = rDist + visit.iDir / 4
# print (visit.xy)
xy = np.matmul(visit.rc, rt.grc2xy) + rt.xyHalf
if xyPrev is not None:
plt.plot([xy[0], xyPrev[0]], [ xy[1], xyPrev[1]], color="r", alpha=0.5, lw=3)
plt.plot([xy[0], xyPrev[0]],
[xy[1], xyPrev[1]],
color="r", alpha=0.5, lw=3)
visit = visit.prev
xyPrev = xy
def renderEnv(self):
"""
Draw the environment using matplotlib. Draw into the figure if provided.
Call pyplot.show() if show==True. (Use show=False from a Jupyter notebook with %matplotlib inline)
Draw the environment using matplotlib.
Draw into the figure if provided.
Call pyplot.show() if show==True.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
cell_size = 1
#if oFigure is None:
# if oFigure is None:
# oFigure = plt.figure()
def drawTrans(oFrom, oTo, sColor="gray"):
plt.plot(
[ oFrom[0], oTo[0] ], # x
[ oFrom[1], oTo[1] ], # y
[oFrom[0], oTo[0]], # x
[oFrom[1], oTo[1]], # y
color=sColor
)
......@@ -310,85 +313,104 @@ class RenderTool(object):
env = self.env
# Draw cells grid
grid_color = [0.95,0.95,0.95]
grid_color = [0.95, 0.95, 0.95]
for r in range(env.height+1):
plt.plot([0, (env.width+1)*cell_size], [-r*cell_size, -r*cell_size], color=grid_color)
plt.plot([0, (env.width+1)*cell_size],
[-r*cell_size, -r*cell_size],
color=grid_color)
for c in range(env.width+1):
plt.plot([c*cell_size, c*cell_size], [0, -(env.height+1)*cell_size], color=grid_color)
plt.plot([c*cell_size, c*cell_size],
[0, -(env.height+1)*cell_size],
color=grid_color)
# Draw each cell independently
for r in range(env.height):
for c in range(env.width):
trans_ = env.rail[r][c]
x0 = c * cell_size
x1 = (c+1) * cell_size
y0 = -r * cell_size
y1 = -(r+1) * cell_size
x0 = cell_size * c
x1 = cell_size * (c+1)
y0 = cell_size * -r
y1 = cell_size * -(r+1)
coords = [
((x0+x1)/2.0, y0), # N middle top
(x1, (y0+y1)/2.0), # E middle right
((x0+x1)/2.0,y1), # S middle bottom
(x0, (y0+y1)/2.0) # W middle left
coords = [
((x0+x1)/2.0, y0), # N middle top
(x1, (y0+y1)/2.0), # E middle right
((x0+x1)/2.0, y1), # S middle bottom
(x0, (y0+y1)/2.0) # W middle left
]
oCell = env.rail[r, c]
oCell = env.rail[r,c]
for orientation in range(4): # orientation is where we're heading
from_ori = (orientation + 2) % 4 # 0123 = NESW -> 2301 = SWNE
for orientation in range(4): # ori is where we're heading
from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE
from_xy = coords[from_ori]
## Special Case 7, with a single bit; terminate at center
## TODO: for the future, this should rather check whether the movement is allowed in a single direction or both, as per the transitions bits. Here we only check for Case 7 as a cheap hack that will hold for the competition, but it won't hold for environments specified by arbitrary transitions.
# Special Case 7, with a single bit; terminate at center
nbits = 0
tmp = trans_
while tmp>0:
nbits += (tmp&1)
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
# as above - move the from coord to the centre - it's a dead env.
if nbits==1:
# as above - move the from coord to the centre
# it's a dead env.
if nbits == 1:
from_xy = ((x0+x1)/2.0, (y0+y1)/2.0)
#renderer.push()
#renderer.translate(c * CELL_PIXELS, r * CELL_PIXELS)
# renderer.push()
# renderer.translate(c * CELL_PIXELS, r * CELL_PIXELS)
if True:
tMoves = RETrans.get_transitions_from_orientation(oCell, orientation)
#to_ori = (orientation + 2) % 4
tMoves = RETrans.get_transitions_from_orientation(
oCell, orientation)
# to_ori = (orientation + 2) % 4
for to_ori in range(4):
to_xy = coords[to_ori]
if False:
print("r,c,ori: ", r, c, orientation, "cell:", "{0:b}".format(oCell),
"moves:", tMoves,
"from:", from_ori, from_xy,
"to: ", to_ori, to_xy)
print("r,c,ori: ", r, c, orientation,
"cell:", "{0:b}".format(oCell),
"moves:", tMoves,
"from:", from_ori, from_xy,
"to: ", to_ori, to_xy)
if (tMoves[to_ori]):
drawTrans(from_xy, to_xy)
# Draw each agent + its orientation + its target
cmap = plt.get_cmap('hsv', lut=env.number_of_agents+1)
for i in range(env.number_of_agents):
self._draw_square( (env.agents_position[i][1]*cell_size+cell_size/2, -env.agents_position[i][0]*cell_size-cell_size/2), cell_size/8, cmap(i) )
self._draw_square((env.agents_position[i][1] *
cell_size+cell_size/2,
-env.agents_position[i][0] *
cell_size-cell_size/2),
cell_size/8, cmap(i))
for i in range(env.number_of_agents):
self._draw_square( (env.agents_target[i][1]*cell_size+cell_size/2, -env.agents_target[i][0]*cell_size-cell_size/2), cell_size/3, [c for c in cmap(i)] )
# orientation is a line connecting the center of the cell to the side of the square of the agent
new_position = env._new_position(env.agents_position[i], env.agents_direction[i])
new_position = ( (new_position[0]+env.agents_position[i][0])/2*cell_size, (new_position[1]+env.agents_position[i][1])/2*cell_size )
plt.plot( [ env.agents_position[i][1]*cell_size+cell_size/2, new_position[1]+cell_size/2], [-env.agents_position[i][0]*cell_size-cell_size/2, -new_position[0]-cell_size/2], color=cmap(i), linewidth=2.0 )
plt.xlim([0,env.width*cell_size])
plt.ylim([-env.height*cell_size,0])
self._draw_square((env.agents_target[i][1] *
cell_size+cell_size/2,
-env.agents_target[i][0] *
cell_size-cell_size/2),
cell_size/3, [c for c in cmap(i)])
# orientation is a line connecting the center of the cell to the
# side of the square of the agent
new_position = env._new_position(env.agents_position[i],
env.agents_direction[i])
new_position = ((new_position[0] +
env.agents_position[i][0])/2*cell_size,
(new_position[1] +
env.agents_position[i][1])/2*cell_size)
plt.plot([env.agents_position[i][1] * cell_size+cell_size/2,
new_position[1]+cell_size/2],
[-env.agents_position[i][0] * cell_size-cell_size/2,
-new_position[0]-cell_size/2], color=cmap(i),
linewidth=2.0)
plt.xlim([0, env.width * cell_size])
plt.ylim([-env.height * cell_size, 0])
plt.xticks(np.linspace(0, env.width * cell_size, env.width+1))
plt.yticks(np.linspace(-env.height * cell_size, 0, env.height+1))
......@@ -398,5 +420,4 @@ class RenderTool(object):
x1 = center[0]+size/2
y0 = center[1]-size/2
y1 = center[1]+size/2
plt.plot( [x0,x1,x1,x0,x0], [y0,y0,y1,y1,y0], color=color )
plt.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment