Skip to content
Snippets Groups Projects
Commit 2f6d2729 authored by Erik Nygren's avatar Erik Nygren
Browse files

fixed formatting

parent 53c34f90
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@ from flatland.utils.render_qt import QTGL, QTSVG
from flatland.utils.graphics_pil import PILGL
from flatland.utils.graphics_layer import GraphicsLayer
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
......@@ -100,12 +101,12 @@ class RenderTool(object):
lColors = list("brgcmyk")
# \delta RC for NESW
gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
nPixCell = 1 # misnomer...
nPixCell = 1 # misnomer...
nPixHalf = nPixCell / 2
xyHalf = array([nPixHalf, -nPixHalf])
grc2xy = array([[0, -nPixCell], [nPixCell, 0]])
gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * \
array([[[nPixCell]], [[nPixCell]]])
array([[[nPixCell]], [[nPixCell]]])
# xyPixHalf = xr.DataArray([nPixHalf, -nPixHalf],
# dims="xy",
# coords={"xy": ["x", "y"]})
......@@ -130,7 +131,7 @@ class RenderTool(object):
self.gl = PILGL(env.width, env.height)
elif gl == "QTSVG":
self.gl = QTSVG(env.width, env.height)
self.new_rail = True
def set_new_rail(self):
......@@ -153,14 +154,14 @@ class RenderTool(object):
def plotAgents(self, targets=True, iSelectedAgent=None):
cmap = self.gl.get_cmap('hsv',
lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
for iAgent, agent in enumerate(self.env.agents_static):
if agent is None:
continue
oColor = cmap(iAgent)
self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None,
static=True, selected=iAgent == iSelectedAgent)
static=True, selected=iAgent == iSelectedAgent)
for iAgent, agent in enumerate(self.env.agents):
if agent is None:
......@@ -211,8 +212,8 @@ class RenderTool(object):
"""
rt = self.__class__
rcDir = rt.gTransRC[iDir] # agent direction in RC
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
rcDir = rt.gTransRC[iDir] # agent direction in RC
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
......@@ -220,7 +221,7 @@ class RenderTool(object):
color = self.gl.adaptColor(color, lighten=True)
# print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location
self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location
xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient.
self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
if selected:
......@@ -398,12 +399,12 @@ class RenderTool(object):
xyPrev = xy
def drawTrans2(
self,
xyLine, xyCentre,
rotation, bDeadEnd=False,
sColor="gray",
bArrow=True,
spacing=0.1):
self,
xyLine, xyCentre,
rotation, bDeadEnd=False,
sColor="gray",
bArrow=True,
spacing=0.1):
"""
gLine is a numpy 2d array of points,
in the plotting space / coords.
......@@ -420,9 +421,9 @@ class RenderTool(object):
if sColor == "auto":
if dx > 0 or dy > 0:
sColor = "C1" # N or E
sColor = "C1" # N or E
else:
sColor = "C2" # S or W
sColor = "C2" # S or W
if bDeadEnd:
xyLine2 = array([
......@@ -471,6 +472,7 @@ class RenderTool(object):
xyMid,
xyMid + [-dx + dy, -dx - dy]])
self.gl.plot(*xyArrow.T, color=sColor)
def renderObs(self, agent_handles, observation_list):
"""
......@@ -480,21 +482,21 @@ class RenderTool(object):
"""
rt = self.__class__
cmap = self.gl.get_cmap('hsv',lut=max(len(self.env.agents),len(self.env.agents_static)+1))
cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
for agent in agent_handles:
color = cmap(agent)
for visited_cell in observation_list[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord,rt.grc2xy)+rt.xyHalf
self._draw_square(cell_coord_trans,1 / 3, color)
cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
self._draw_square(cell_coord_trans, 1 / 3, color)
def renderEnv(
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, sRailColor="gray",
frames=False, iEpisode=None, iStep=None,
iSelectedAgent=None,
action_dict=None):
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, obsrender=True, sRailColor="gray",
frames=False, iEpisode=None, iStep=None,
iSelectedAgent=None,
action_dict=None):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
......@@ -505,15 +507,16 @@ class RenderTool(object):
if not self.gl.is_raster():
self.renderEnv2(show, curves, spacing,
arrows, agents, sRailColor,
frames, iEpisode, iStep,
iSelectedAgent, action_dict)
arrows, agents, sRailColor,
frames, iEpisode, iStep,
iSelectedAgent, action_dict)
return
# cell_size is a bit pointless with matplotlib - it does not relate to pixels,
# so for now I've changed it to 1 (from 10)
cell_size = 1
self.gl.beginFrame()
# self.gl.clf()
# if oFigure is None:
# oFigure = self.gl.figure()
......@@ -545,9 +548,9 @@ class RenderTool(object):
for c in range(env.width):
# bounding box of the grid cell
x0 = cell_size * c # left
x1 = cell_size * (c + 1) # right
y0 = cell_size * -r # top
x0 = cell_size * c # left
x1 = cell_size * (c + 1) # right
y0 = cell_size * -r # top
y1 = cell_size * -(r + 1) # bottom
# centres of cell edges
......@@ -555,7 +558,7 @@ class RenderTool(object):
((x0 + x1) / 2.0, y0), # N middle top
(x1, (y0 + y1) / 2.0), # E middle right
((x0 + x1) / 2.0, y1), # S middle bottom
(x0, (y0 + y1) / 2.0) # W middle left
(x0, (y0 + y1) / 2.0) # W middle left
]
# cell centre
......@@ -628,8 +631,8 @@ class RenderTool(object):
# Draw each agent + its orientation + its target
if agents:
self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
if obsrender:
self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
# Draw some textual information like fps
yText = [-0.3, -0.6, -0.9]
if frames:
......@@ -683,18 +686,18 @@ class RenderTool(object):
gP0 = array([[0, 0, 0]]).T
nDepth = 2
for i in range(nDepth):
nDepthNodes = nBranchFactor**i
nDepthNodes = nBranchFactor ** i
# rScale = nBranchFactor ** (nDepth - i)
rShrinkDepth = 1/(i+1)
rShrinkDepth = 1 / (i + 1)
# gX1 = np.linspace(-nDepthNodes / 2, nDepthNodes / 2, nDepthNodes) * rShrinkDepth
gX1 = np.linspace(-(nDepthNodes-1), (nDepthNodes-1), nDepthNodes) * rShrinkDepth
gX1 = np.linspace(-(nDepthNodes - 1), (nDepthNodes - 1), nDepthNodes) * rShrinkDepth
gY1 = np.ones((nDepthNodes)) * i
gZ1 = np.zeros((nDepthNodes))
gP1 = array([gX1, gY1, gZ1])
gP01 = np.append(gP0, gP1, axis=1)
if nDepthNodes > 1:
nDepthNodesPrev = nDepthNodes / nBranchFactor
giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor)
......@@ -705,13 +708,13 @@ class RenderTool(object):
self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray")
gP0 = array([gX1, gY1, gZ1])
def renderEnv2(
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, sRailColor="gray",
frames=False, iEpisode=None, iStep=None,
iSelectedAgent=None,
action_dict=dict()):
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, sRailColor="gray",
frames=False, iEpisode=None, iStep=None,
iSelectedAgent=None,
action_dict=dict()):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
......@@ -728,12 +731,11 @@ class RenderTool(object):
# Draw each cell independently
for r in range(env.height):
for c in range(env.width):
binTrans = env.rail.grid[r, c]
self.gl.setRailAt(r, c, binTrans)
cmap = self.gl.get_cmap('hsv',
lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
for iAgent, agent in enumerate(self.env.agents):
if agent is None:
......@@ -747,14 +749,14 @@ class RenderTool(object):
if iAgent in action_dict:
iAction = action_dict[iAgent]
new_direction, action_isValid = self.env.check_action(agent, iAction)
if action_isValid:
self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction, color=oColor)
else:
pass
# print("invalid action - agent ", iAgent, " bend ", agent.direction, new_direction)
# self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction)
self.gl.show()
for i in range(3):
self.gl.processEvents()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment