Forked from
Flatland / Flatland
1974 commits behind the upstream repository.
-
Egli Adrian (IT-SCI-API-PFI) authoredEgli Adrian (IT-SCI-API-PFI) authored
rendertools.py 25.20 KiB
import time
from collections import deque
import numpy as np
from numpy import array
from recordtype import recordtype
from flatland.utils.graphics_pil import PILGL, PILSVG
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
class RenderTool(object):
""" Class to render the RailEnv and agents.
Uses two layers, layer 0 for rails (mostly static), layer 1 for agents etc (dynamic)
The lower / rail layer 0 is only redrawn after set_new_rail() has been called.
Created with a "GraphicsLayer" or gl - now either PIL or PILSVG
"""
Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])
lColors = list("brgcmyk")
# \delta RC for NESW
gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
nPixCell = 1 # misnomer...
nPixHalf = nPixCell / 2
xyHalf = array([nPixHalf, -nPixHalf])
grc2xy = array([[0, -nPixCell], [nPixCell, 0]])
gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[nPixCell]], [[nPixCell]]])
gTheta = np.linspace(0, np.pi / 2, 5)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def __init__(self, env, gl="PILSVG", jupyter=False):
self.env = env
self.iFrame = 0
self.time1 = time.time()
self.lTimes = deque()
if gl == "PIL":
self.gl = PILGL(env.width, env.height, jupyter)
elif gl == "PILSVG":
self.gl = PILSVG(env.width, env.height, jupyter)
else:
print("[", gl, "] not found, switch to PILSVG")
self.gl = PILSVG(env.width, env.height, jupyter)
self.new_rail = True
self.update_background()
def update_background(self):
# create background map
dTargets = {}
for iAgent, agent in enumerate(self.env.agents_static):
if agent is None:
continue
dTargets[tuple(agent.target)] = iAgent
self.gl.build_background_map(dTargets)
def resize(self):
self.gl.resize(self.env)
def set_new_rail(self):
""" Tell the renderer that the rail has changed.
eg when the rail has been regenerated, or updated in the editor.
"""
self.new_rail = True
def plotTreeOnRail(self, lVisits, color="r"):
"""
DEFUNCT
Derives and plots a tree of transitions starting at position rcPos
in direction iDir.
Returns a list of Visits which are the nodes / vertices in the tree.
"""
rt = self.__class__
for visit in lVisits:
# transition for next cell
tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir))
giTrans = np.where(tbTrans)[0] # RC list of transitions
gTransRCAg = rt.gTransRC[giTrans]
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
def plotAgents(self, targets=True, iSelectedAgent=None):
cmap = self.gl.get_cmap('hsv',
lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
for iAgent, agent in enumerate(self.env.agents_static):
if agent is None:
continue
oColor = cmap(iAgent)
self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None,
static=True, selected=iAgent == iSelectedAgent)
for iAgent, agent in enumerate(self.env.agents):
if agent is None:
continue
oColor = cmap(iAgent)
self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None)
def getTransRC(self, rcPos, iDir, bgiTrans=False):
"""
Get the available transitions for rcPos in direction iDir,
as row & col deltas.
If bgiTrans is True, return a grid of indices of available transitions.
eg for a cell rcPos = (4,5), in direction iDir = 0 (N),
where the available transitions are N and E, returns:
[[-1,0], [0,1]] ie N=up one row, and E=right one col.
and if bgiTrans is True, returns a tuple:
(
[[-1,0], [0,1]], # deltas as before
[0, 1] # available transition indices, ie N, E
)
"""
tbTrans = self.env.rail.get_transitions((*rcPos, iDir))
giTrans = np.where(tbTrans)[0] # RC list of transitions
# HACK: workaround dead-end transitions
if len(giTrans) == 0:
iDirReverse = (iDir + 2) % 4
tbTrans = tuple(int(iDir2 == iDirReverse) for iDir2 in range(4))
giTrans = np.where(tbTrans)[0] # RC list of transitions
gTransRCAg = self.__class__.gTransRC[giTrans]
if bgiTrans:
return gTransRCAg, giTrans
else:
return gTransRCAg
def plotAgent(self, rcPos, iDir, color="r", target=None, static=False, selected=False):
"""
Plot a simple agent.
Assumes a working graphics layer context (cf a MPL figure).
"""
rt = self.__class__
rcDir = rt.gTransRC[iDir] # agent direction in RC
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
if static:
color = self.gl.adaptColor(color, lighten=True)
color = color
self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100) # agent location
xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient.
self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6)
if selected:
self._draw_square(xyPos, 1, color)
if target is not None:
rcTarget = array(target)
xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
self._draw_square(xyTarget, 1 / 3, color, layer=1)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
"""
plot the transitions in gTransRCAg at position rcPos.
gTransRCAg is a 2d numpy array containing a list of RC transitions,
eg [[-1,0], [0,1]] means N, E.
"""
rt = self.__class__
xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy / 2.4)
self.gl.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2)
if depth is not None:
for x, y in gxyTrans:
self.gl.text(x, y, depth)
def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True, bPlot=False):
"""
DEFUNCT
Generate a tree from the env starting at rcPos, iDir.
"""
rt = self.__class__
print(rcPos, iDir)
iPos = 0 if bBFS else -1 # BF / DF Search
iDepth = 0
visited = set()
lVisits = []
stack = [rt.Visit(rcPos, iDir, iDepth, None)]
while stack:
visit = stack.pop(iPos)
rcd = (visit.rc, visit.iDir)
if visit.iDepth > nDepth:
continue
lVisits.append(visit)
if rcd not in visited:
visited.add(rcd)
gTransRCAg, giTrans = self.getTransRC(visit.rc,
visit.iDir,
bgiTrans=True)
# enqueue the next nodes (ie transitions from this node)
for gTransRC2, iTrans in zip(gTransRCAg, giTrans):
visitNext = rt.Visit(tuple(visit.rc + gTransRC2),
iTrans,
visit.iDepth + 1,
visit)
stack.append(visitNext)
# plot the available transitions from this node
if bPlot:
self.plotTrans(
visit.rc, gTransRCAg,
depth=str(visit.iDepth))
return lVisits
def plotTree(self, lVisits, xyTarg):
'''
Plot a vertical tree of transitions.
Returns the "visit" to the destination
(ie where euclidean distance is near zero) or None if absent.
'''
dPos = {}
iPos = 0
visitDest = None
for iVisit, visit in enumerate(lVisits):
if visit.rc in dPos:
xLoc = dPos[visit.rc]
else:
xLoc = dPos[visit.rc] = iPos
iPos += 1
rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
xLoc = rDist + visit.iDir / 4
# point labelled with distance
self.gl.scatter(xLoc, visit.iDepth, color="k", s=2)
self.gl.text(xLoc, visit.iDepth, visit.rc, color="k", rotation=45)
# if len(dPos)>1:
if visit.prev:
xLocPrev = dPos[visit.prev.rc]
rDistPrev = np.linalg.norm(array(visit.prev.rc) -
array(xyTarg))
xLocPrev = rDistPrev + visit.prev.iDir / 4
# line from prev node
self.gl.plot([xLocPrev, xLoc],
[visit.iDepth - 1, visit.iDepth],
color="k", alpha=0.5, lw=1)
if rDist < 0.1:
visitDest = visit
# Walk backwards from destination to origin, plotting in red
if visitDest is not None:
visit = visitDest
xLocPrev = None
while visit is not None:
rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
xLoc = rDist + visit.iDir / 4
if xLocPrev is not None:
self.gl.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth + 1],
color="r", alpha=0.5, lw=2)
xLocPrev = xLoc
visit = visit.prev
self.gl.prettify()
return visitDest
def plotPath(self, visitDest):
"""
Given a "final" visit visitDest, plotPath recurses back through the path
using the visit.prev field (previous) to get back to the start of the path.
The path of transitions is plotted with arrows at 3/4 along the line.
The transition is plotted slightly to one side of the rail, so that
transitions in opposite directions are separate.
Currently, no attempt is made to make the transition arrows coincide
at corners, and they are straight only.
"""
rt = self.__class__
# Walk backwards from destination to origin
if visitDest is not None:
visit = visitDest
xyPrev = None
while visit is not None:
xy = np.matmul(visit.rc, rt.grc2xy) + rt.xyHalf
if xyPrev is not None:
dx, dy = (xyPrev - xy) / 20
xyLine = array([xy, xyPrev]) + array([dy, dx])
self.gl.plot(*xyLine.T, color="r", alpha=0.5, lw=1)
xyMid = np.sum(xyLine * [[1 / 4], [3 / 4]], axis=0)
xyArrow = array([
xyMid + [-dx - dy, +dx - dy],
xyMid,
xyMid + [-dx + dy, -dx - dy]])
self.gl.plot(*xyArrow.T, color="r")
visit = visit.prev
xyPrev = xy
def drawTrans(self, oFrom, oTo, sColor="gray"):
self.gl.plot(
[oFrom[0], oTo[0]], # x
[oFrom[1], oTo[1]], # y
color=sColor
)
def drawTrans2(self,
xyLine, xyCentre,
rotation, bDeadEnd=False,
sColor="gray",
bArrow=True,
spacing=0.1):
"""
gLine is a numpy 2d array of points,
in the plotting space / coords.
eg:
[[0,.5],[1,0.2]] means a line
from x=0, y=0.5
to x=1, y=0.2
"""
rt = self.__class__
bStraight = rotation in [0, 2]
dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2
if bStraight:
if sColor == "auto":
if dx > 0 or dy > 0:
sColor = "C1" # N or E
else:
sColor = "C2" # S or W
if bDeadEnd:
xyLine2 = array([
xyLine[1] + [dy, dx],
xyCentre,
xyLine[1] - [dy, dx],
])
self.gl.plot(*xyLine2.T, color=sColor)
else:
xyLine2 = xyLine + [-dy, dx]
self.gl.plot(*xyLine2.T, color=sColor)
if bArrow:
xyMid = np.sum(xyLine2 * [[1 / 4], [3 / 4]], axis=0)
xyArrow = array([
xyMid + [-dx - dy, +dx - dy],
xyMid,
xyMid + [-dx + dy, -dx - dy]])
self.gl.plot(*xyArrow.T, color=sColor)
else:
xyMid = np.mean(xyLine, axis=0)
dxy = xyMid - xyCentre
xyCorner = xyMid + dxy
if rotation == 1:
rArcFactor = 1 - spacing
sColorAuto = "C1"
else:
rArcFactor = 1 + spacing
sColorAuto = "C2"
dxy2 = (xyCentre - xyCorner) * rArcFactor # for scaling the arc
if sColor == "auto":
sColor = sColorAuto
self.gl.plot(*(rt.gArc * dxy2 + xyCorner).T, color=sColor)
if bArrow:
dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20
iArc = int(len(rt.gArc) / 2)
xyMid = xyCorner + rt.gArc[iArc] * dxy2
xyArrow = array([
xyMid + [-dx - dy, +dx - dy],
xyMid,
xyMid + [-dx + dy, -dx - dy]])
self.gl.plot(*xyArrow.T, color=sColor)
def renderObs(self, agent_handles, observation_dict):
"""
Render the extent of the observation of each agent. All cells that appear in the agent
observation will be highlighted.
:param agent_handles: List of agent indices to adapt color and get correct observation
:param observation_dict: dictionary containing sets of cells of the agent observation
"""
rt = self.__class__
for agent in agent_handles:
color = self.gl.getAgentColor(agent)
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
cell_size = 1 # TODO: remove cell_size
env = self.env
# Draw cells grid
grid_color = [0.95, 0.95, 0.95]
for r in range(env.height + 1):
self.gl.plot([0, (env.width + 1) * cell_size],
[-r * cell_size, -r * cell_size],
color=grid_color, linewidth=2)
for c in range(env.width + 1):
self.gl.plot([c * cell_size, c * cell_size],
[0, -(env.height + 1) * cell_size],
color=grid_color, linewidth=2)
# Draw each cell independently
for r in range(env.height):
for c in range(env.width):
# bounding box of the grid cell
x0 = cell_size * c # left
x1 = cell_size * (c + 1) # right
y0 = cell_size * -r # top
y1 = cell_size * -(r + 1) # bottom
# centres of cell edges
coords = [
((x0 + x1) / 2.0, y0), # N middle top
(x1, (y0 + y1) / 2.0), # E middle right
((x0 + x1) / 2.0, y1), # S middle bottom
(x0, (y0 + y1) / 2.0) # W middle left
]
# cell centre
xyCentre = array([x0, y1]) + cell_size / 2
# cell transition values
oCell = env.rail.get_transitions((r, c))
bCellValid = env.rail.cell_neighbours_valid((r, c), check_this_cell=True)
# Special Case 7, with a single bit; terminate at center
nbits = 0
tmp = oCell
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
# as above - move the from coord to the centre
# it's a dead env.
bDeadEnd = nbits == 1
if not bCellValid:
self.gl.scatter(*xyCentre, color="r", s=30)
for orientation in range(4): # ori is where we're heading
from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE
from_xy = coords[from_ori]
tMoves = env.rail.get_transitions((r, c, orientation))
for to_ori in range(4):
to_xy = coords[to_ori]
rotation = (to_ori - from_ori) % 4
if (tMoves[to_ori]): # if we have this transition
if bDeadEnd:
self.drawTrans2(
array([from_xy, to_xy]), xyCentre,
rotation, bDeadEnd=True, spacing=spacing,
sColor=sRailColor)
else:
if curves:
self.drawTrans2(
array([from_xy, to_xy]), xyCentre,
rotation, spacing=spacing, bArrow=arrows,
sColor=sRailColor)
else:
self.drawTrans(self, from_xy, to_xy, sRailColor)
if False:
print(
"r,c,ori: ", r, c, orientation,
"cell:", "{0:b}".format(oCell),
"moves:", tMoves,
"from:", from_ori, from_xy,
"to: ", to_ori, to_xy,
"cen:", *xyCentre,
"rot:", rotation,
)
def renderEnv(self,
show=False, # whether to call matplotlib show() or equivalent after completion
# use false when calling from Jupyter. (and matplotlib no longer supported!)
curves=True, # draw turns as curves instead of straight diagonal lines
spacing=False, # defunct - size of spacing between rails
arrows=False, # defunct - draw arrows on rail lines
agents=True, # whether to include agents
show_observations=True, # whether to include observations
sRailColor="gray", # color to use in drawing rails (not used with SVG)
frames=False, # frame counter to show (intended since invocation)
iEpisode=None, # int episode number to show
iStep=None, # int step number to show in image
iSelectedAgent=None, # indicate which agent is "selected" in the editor
action_dict=None): # defunct - was used to indicate agent intention to turn
""" Draw the environment using the GraphicsLayer this RenderTool was created with.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
if not self.gl.is_raster():
self.renderEnv2(show=show, curves=curves, spacing=spacing,
arrows=arrows, agents=agents, show_observations=show_observations,
sRailColor=sRailColor,
frames=frames, iEpisode=iEpisode, iStep=iStep,
iSelectedAgent=iSelectedAgent, action_dict=action_dict)
return
if type(self.gl) is PILGL:
self.gl.beginFrame()
env = self.env
self.renderRail()
# Draw each agent + its orientation + its target
if agents:
self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
if show_observations:
self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
# Draw some textual information like fps
yText = [-0.3, -0.6, -0.9]
if frames:
self.gl.text(0.1, yText[2], "Frame:{:}".format(self.iFrame))
self.iFrame += 1
if iEpisode is not None:
self.gl.text(0.1, yText[1], "Ep:{}".format(iEpisode))
if iStep is not None:
self.gl.text(0.1, yText[0], "Step:{}".format(iStep))
tNow = time.time()
self.gl.text(2, yText[2], "elapsed:{:.2f}s".format(tNow - self.time1))
self.lTimes.append(tNow)
if len(self.lTimes) > 20:
self.lTimes.popleft()
if len(self.lTimes) > 1:
rFps = (len(self.lTimes) - 1) / (self.lTimes[-1] - self.lTimes[0])
self.gl.text(2, yText[1], "fps:{:.2f}".format(rFps))
self.gl.prettify2(env.width, env.height, self.nPixCell)
# TODO: for MPL, we don't want to call clf (called by endframe)
# if not show:
if show and type(self.gl) is PILGL:
self.gl.show()
self.gl.pause(0.00001)
return
def _draw_square(self, center, size, color, opacity=255, layer=0):
x0 = center[0] - size / 2
x1 = center[0] + size / 2
y0 = center[1] - size / 2
y1 = center[1] + size / 2
self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
def getImage(self):
return self.gl.getImage()
def plotTreeObs(self, gObs):
nBranchFactor = 4
gP0 = array([[0, 0, 0]]).T
nDepth = 2
for i in range(nDepth):
nDepthNodes = nBranchFactor ** i
rShrinkDepth = 1 / (i + 1)
gX1 = np.linspace(-(nDepthNodes - 1), (nDepthNodes - 1), nDepthNodes) * rShrinkDepth
gY1 = np.ones((nDepthNodes)) * i
gZ1 = np.zeros((nDepthNodes))
gP1 = array([gX1, gY1, gZ1])
gP01 = np.append(gP0, gP1, axis=1)
if nDepthNodes > 1:
nDepthNodesPrev = nDepthNodes / nBranchFactor
giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor)
giP1 = np.arange(0, nDepthNodes) + nDepthNodesPrev
giLinePoints = np.stack([giP0, giP1]).ravel("F")
self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray")
gP0 = array([gX1, gY1, gZ1])
def renderEnv2(
self, show=False, curves=True, spacing=False, arrows=False, agents=True,
show_observations=True, sRailColor="gray",
frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
action_dict=dict()
):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
Call pyplot.show() if show==True.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
env = self.env
self.gl.beginFrame()
if self.new_rail:
self.new_rail = False
self.gl.clear_rails()
# store the targets
dTargets = {}
dSelected = {}
for iAgent, agent in enumerate(self.env.agents_static):
if agent is None:
continue
dTargets[tuple(agent.target)] = iAgent
dSelected[tuple(agent.target)] = (iAgent == iSelectedAgent)
# Draw each cell independently
for r in range(env.height):
for c in range(env.width):
binTrans = env.rail.grid[r, c]
if (r, c) in dTargets:
target = dTargets[(r, c)]
isSelected = dSelected[(r, c)]
else:
target = None
isSelected = False
self.gl.setRailAt(r, c, binTrans, iTarget=target, isSelected=isSelected, rail_grid=env.rail.grid)
self.gl.build_background_map(dTargets)
for iAgent, agent in enumerate(self.env.agents):
if agent is None:
continue
if agent.old_position is not None:
position = agent.old_position
direction = agent.direction
old_direction = agent.old_direction
else:
position = agent.position
direction = agent.direction
old_direction = agent.direction
# setAgentAt uses the agent index for the color
self.gl.setCellOccupied(iAgent, *(agent.position))
self.gl.setAgentAt(iAgent, *position, old_direction, direction, iSelectedAgent == iAgent)
if show_observations:
self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
if show:
self.gl.show()
for i in range(3):
self.gl.processEvents()
self.iFrame += 1
return
def close_window(self):
self.gl.close_window()