Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 1505 additions and 494 deletions
import time
import warnings
from collections import deque
from enum import IntEnum
import numpy as np
from numpy import array
from recordtype import recordtype
from flatland.envs.step_utils.states import TrainState
from flatland.utils.graphics_pil import PILGL, PILSVG
from flatland.utils.graphics_pgl import PGLGL
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
class AgentRenderVariant(IntEnum):
BOX_ONLY = 0
ONE_STEP_BEHIND = 1
AGENT_SHOWS_OPTIONS = 2
ONE_STEP_BEHIND_AND_BOX = 3
AGENT_SHOWS_OPTIONS_AND_BOX = 4
class RenderTool(object):
""" RenderTool is a facade to a renderer.
(This was introduced for the Browser / JS renderer which has now been removed.)
"""
def __init__(self, env, gl="PGL", jupyter=False,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600,
host="localhost", port=None):
self.env = env
self.frame_nr = 0
self.start_time = time.time()
self.times_list = deque()
self.agent_render_variant = agent_render_variant
if gl in ["PIL", "PILSVG", "PGL"]:
self.renderer = RenderLocal(env, gl, jupyter,
agent_render_variant,
show_debug, clear_debug_text, screen_width, screen_height)
self.gl = self.renderer.gl
else:
print("[", gl, "] not found, switch to PGL")
def render_env(self,
show=False, # whether to call matplotlib show() or equivalent after completion
show_agents=True, # whether to include agents
show_inactive_agents=False, # whether to show agents before they start
show_observations=True, # whether to include observations
show_predictions=False, # whether to include predictions
show_rowcols=False, # label the rows and columns
frames=False, # frame counter to show (intended since invocation)
episode=None, # int episode number to show
step=None, # int step number to show in image
selected_agent=None, # indicate which agent is "selected" in the editor):
return_image=False): # indicate if image is returned for use in monitor:
return self.renderer.render_env(show, show_agents, show_inactive_agents, show_observations,
show_predictions, show_rowcols, frames, episode, step, selected_agent, return_image)
def close_window(self):
self.renderer.close_window()
def reset(self):
self.renderer.reset()
def set_new_rail(self):
self.renderer.set_new_rail()
self.renderer.env = self.env # bit of a hack - copy our env to the delegate
def update_background(self):
self.renderer.update_background()
def get_endpoint_URL(self):
""" Returns a string URL for the root of the HTTP server
TODO: Need to update this work work on a remote server! May be tricky...
"""
#return "http://localhost:{}".format(self.renderer.get_port())
if hasattr(self.renderer, "get_endpoint_url"):
return self.renderer.get_endpoint_url()
else:
print("Attempt to get_endpoint_url from RenderTool - only supported with BROWSER")
return None
def get_image(self):
"""
"""
if hasattr(self.renderer, "gl"):
return self.renderer.gl.get_image()
else:
print("Attempt to retrieve image from RenderTool - not supported with BROWSER")
return None
class RenderBase(object):
def __init__(self, env):
pass
def render_env(self):
pass
def close_window(self):
pass
def reset(self):
pass
def set_new_rail(self):
""" Signal to the renderer that the env has changed and will need re-rendering.
"""
pass
def update_background(self):
""" A lesser version of set_new_rail?
TODO: can update_background be pruned for simplicity?
"""
pass
class RenderLocal(RenderBase):
""" Class to render the RailEnv and agents.
Uses two layers, layer 0 for rails (mostly static), layer 1 for agents etc (dynamic)
The lower / rail layer 0 is only redrawn after set_new_rail() has been called.
Created with a "GraphicsLayer" or gl - now either PIL or PILSVG
"""
Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])
visit = recordtype("visit", ["rc", "iDir", "iDepth", "prev"])
lColors = list("brgcmyk")
color_list = list("brgcmyk")
# \delta RC for NESW
gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
nPixCell = 1 # misnomer...
nPixHalf = nPixCell / 2
xyHalf = array([nPixHalf, -nPixHalf])
grc2xy = array([[0, -nPixCell], [nPixCell, 0]])
gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[nPixCell]], [[nPixCell]]])
gTheta = np.linspace(0, np.pi / 2, 5)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def __init__(self, env, gl="PILSVG", jupyter=False):
transitions_row_col = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
pix_per_cell = 1 # misnomer...
half_pix_per_cell = pix_per_cell / 2
x_y_half = array([half_pix_per_cell, -half_pix_per_cell])
row_col_to_xy = array([[0, -pix_per_cell], [pix_per_cell, 0]])
grid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[pix_per_cell]], [[pix_per_cell]]])
theta = np.linspace(0, np.pi / 2, 5)
arc = array([np.cos(theta), np.sin(theta)]).T # from [1,0] to [0,1]
def __init__(self, env, gl="PILSVG", jupyter=False,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600):
self.env = env
self.iFrame = 0
self.time1 = time.time()
self.lTimes = deque()
self.frame_nr = 0
self.start_time = time.time()
self.times_list = deque()
self.agent_render_variant = agent_render_variant
self.gl_str = gl
if gl == "PIL":
self.gl = PILGL(env.width, env.height, jupyter)
self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
elif gl == "PILSVG":
self.gl = PILSVG(env.width, env.height, jupyter)
self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
else:
print("[", gl, "] not found, switch to PILSVG")
self.gl = PILSVG(env.width, env.height, jupyter)
if gl != "PGL":
print("[", gl, "] not found, switch to PGL, PILSVG")
print("Using PGL")
self.gl = PGLGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
self.new_rail = True
self.show_debug = show_debug
self.clear_debug_text = clear_debug_text
self.update_background()
def reset(self):
"""
Resets the environment
:return:
"""
self.set_new_rail()
self.frame_nr = 0
self.start_time = time.time()
self.times_list = deque()
return
def update_background(self):
# create background map
dTargets = {}
for iAgent, agent in enumerate(self.env.agents_static):
targets = {}
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
dTargets[tuple(agent.target)] = iAgent
self.gl.build_background_map(dTargets)
#print(f"updatebg: {agent_idx} {agent.target}")
targets[tuple(agent.target)] = agent_idx
self.gl.build_background_map(targets)
def resize(self):
self.gl.resize(self.env)
......@@ -65,47 +199,30 @@ class RenderTool(object):
"""
self.new_rail = True
def plotTreeOnRail(self, lVisits, color="r"):
"""
DEFUNCT
Derives and plots a tree of transitions starting at position rcPos
in direction iDir.
Returns a list of Visits which are the nodes / vertices in the tree.
"""
rt = self.__class__
for visit in lVisits:
# transition for next cell
tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir))
giTrans = np.where(tbTrans)[0] # RC list of transitions
gTransRCAg = rt.gTransRC[giTrans]
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
def plotAgents(self, targets=True, iSelectedAgent=None):
cmap = self.gl.get_cmap('hsv',
lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
def plot_agents(self, targets=True, selected_agent=None):
color_map = self.gl.get_cmap('hsv', lut=(len(self.env.agents) + 1))
for iAgent, agent in enumerate(self.env.agents_static):
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
oColor = cmap(iAgent)
self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None,
static=True, selected=iAgent == iSelectedAgent)
color = color_map(agent_idx)
self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None,
static=True, selected=agent_idx == selected_agent)
for iAgent, agent in enumerate(self.env.agents):
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
oColor = cmap(iAgent)
self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None)
color = color_map(agent_idx)
self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None)
def getTransRC(self, rcPos, iDir, bgiTrans=False):
def get_transition_row_col(self, row_col_pos, direction, bgiTrans=False):
"""
Get the available transitions for rcPos in direction iDir,
Get the available transitions for row_col_pos in direction direction,
as row & col deltas.
If bgiTrans is True, return a grid of indices of available transitions.
eg for a cell rcPos = (4,5), in direction iDir = 0 (N),
eg for a cell row_col_pos = (4,5), in direction direction = 0 (N),
where the available transitions are N and E, returns:
[[-1,0], [0,1]] ie N=up one row, and E=right one col.
and if bgiTrans is True, returns a tuple:
......@@ -115,217 +232,78 @@ class RenderTool(object):
)
"""
tbTrans = self.env.rail.get_transitions((*rcPos, iDir))
giTrans = np.where(tbTrans)[0] # RC list of transitions
transitions = self.env.rail.get_transitions(*row_col_pos, direction)
transition_list = np.where(transitions)[0] # RC list of transitions
# HACK: workaround dead-end transitions
if len(giTrans) == 0:
iDirReverse = (iDir + 2) % 4
tbTrans = tuple(int(iDir2 == iDirReverse) for iDir2 in range(4))
giTrans = np.where(tbTrans)[0] # RC list of transitions
if len(transition_list) == 0:
reverse_direciton = (direction + 2) % 4
transitions = tuple(int(tmp_dir == reverse_direciton) for tmp_dir in range(4))
transition_list = np.where(transitions)[0] # RC list of transitions
gTransRCAg = self.__class__.gTransRC[giTrans]
transition_grid = self.__class__.transitions_row_col[transition_list]
if bgiTrans:
return gTransRCAg, giTrans
return transition_grid, transition_list
else:
return gTransRCAg
return transition_grid
def plotAgent(self, rcPos, iDir, color="r", target=None, static=False, selected=False):
def plot_single_agent(self, position_row_col, direction, color="r", target=None, static=False, selected=False):
"""
Plot a simple agent.
Assumes a working graphics layer context (cf a MPL figure).
"""
if position_row_col is None:
return
rt = self.__class__
rcDir = rt.gTransRC[iDir] # agent direction in RC
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
direction_row_col = rt.transitions_row_col[direction] # agent direction in RC
direction_xy = np.matmul(direction_row_col, rt.row_col_to_xy) # agent direction in xy
xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
xyPos = np.matmul(position_row_col - direction_row_col / 2, rt.row_col_to_xy) + rt.x_y_half
if static:
color = self.gl.adaptColor(color, lighten=True)
color = self.gl.adapt_color(color, lighten=True)
color = color
self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100) # agent location
xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient.
self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6)
xy_dir_line = array([xyPos, xyPos + direction_xy / 2]).T # line for agent orient.
self.gl.plot(*xy_dir_line, color=color, layer=1, lw=5, ms=0, alpha=0.6)
if selected:
self._draw_square(xyPos, 1, color)
if target is not None:
rcTarget = array(target)
xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
self._draw_square(xyTarget, 1 / 3, color, layer=1)
target_row_col = array(target)
target_xy = np.matmul(target_row_col, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(target_xy, 1 / 3, color, layer=1)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
def plot_transition(self, position_row_col, transition_row_col, color="r", depth=None):
"""
plot the transitions in gTransRCAg at position rcPos.
gTransRCAg is a 2d numpy array containing a list of RC transitions,
plot the transitions in transition_row_col at position position_row_col.
transition_row_col is a 2d numpy array containing a list of RC transitions,
eg [[-1,0], [0,1]] means N, E.
"""
rt = self.__class__
xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy / 2.4)
self.gl.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2)
position_xy = np.matmul(position_row_col, rt.row_col_to_xy) + rt.x_y_half
transition_xy = position_xy + np.matmul(transition_row_col, rt.row_col_to_xy / 2.4)
self.gl.scatter(*transition_xy.T, color=color, marker="o", s=50, alpha=0.2)
if depth is not None:
for x, y in gxyTrans:
for x, y in transition_xy:
self.gl.text(x, y, depth)
def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True, bPlot=False):
"""
DEFUNCT
Generate a tree from the env starting at rcPos, iDir.
"""
rt = self.__class__
print(rcPos, iDir)
iPos = 0 if bBFS else -1 # BF / DF Search
iDepth = 0
visited = set()
lVisits = []
stack = [rt.Visit(rcPos, iDir, iDepth, None)]
while stack:
visit = stack.pop(iPos)
rcd = (visit.rc, visit.iDir)
if visit.iDepth > nDepth:
continue
lVisits.append(visit)
if rcd not in visited:
visited.add(rcd)
gTransRCAg, giTrans = self.getTransRC(visit.rc,
visit.iDir,
bgiTrans=True)
# enqueue the next nodes (ie transitions from this node)
for gTransRC2, iTrans in zip(gTransRCAg, giTrans):
visitNext = rt.Visit(tuple(visit.rc + gTransRC2),
iTrans,
visit.iDepth + 1,
visit)
stack.append(visitNext)
# plot the available transitions from this node
if bPlot:
self.plotTrans(
visit.rc, gTransRCAg,
depth=str(visit.iDepth))
return lVisits
def plotTree(self, lVisits, xyTarg):
'''
Plot a vertical tree of transitions.
Returns the "visit" to the destination
(ie where euclidean distance is near zero) or None if absent.
'''
dPos = {}
iPos = 0
visitDest = None
for iVisit, visit in enumerate(lVisits):
if visit.rc in dPos:
xLoc = dPos[visit.rc]
else:
xLoc = dPos[visit.rc] = iPos
iPos += 1
rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
xLoc = rDist + visit.iDir / 4
# point labelled with distance
self.gl.scatter(xLoc, visit.iDepth, color="k", s=2)
self.gl.text(xLoc, visit.iDepth, visit.rc, color="k", rotation=45)
# if len(dPos)>1:
if visit.prev:
xLocPrev = dPos[visit.prev.rc]
rDistPrev = np.linalg.norm(array(visit.prev.rc) -
array(xyTarg))
xLocPrev = rDistPrev + visit.prev.iDir / 4
# line from prev node
self.gl.plot([xLocPrev, xLoc],
[visit.iDepth - 1, visit.iDepth],
color="k", alpha=0.5, lw=1)
if rDist < 0.1:
visitDest = visit
# Walk backwards from destination to origin, plotting in red
if visitDest is not None:
visit = visitDest
xLocPrev = None
while visit is not None:
rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
xLoc = rDist + visit.iDir / 4
if xLocPrev is not None:
self.gl.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth + 1],
color="r", alpha=0.5, lw=2)
xLocPrev = xLoc
visit = visit.prev
self.gl.prettify()
return visitDest
def plotPath(self, visitDest):
"""
Given a "final" visit visitDest, plotPath recurses back through the path
using the visit.prev field (previous) to get back to the start of the path.
The path of transitions is plotted with arrows at 3/4 along the line.
The transition is plotted slightly to one side of the rail, so that
transitions in opposite directions are separate.
Currently, no attempt is made to make the transition arrows coincide
at corners, and they are straight only.
"""
rt = self.__class__
# Walk backwards from destination to origin
if visitDest is not None:
visit = visitDest
xyPrev = None
while visit is not None:
xy = np.matmul(visit.rc, rt.grc2xy) + rt.xyHalf
if xyPrev is not None:
dx, dy = (xyPrev - xy) / 20
xyLine = array([xy, xyPrev]) + array([dy, dx])
self.gl.plot(*xyLine.T, color="r", alpha=0.5, lw=1)
xyMid = np.sum(xyLine * [[1 / 4], [3 / 4]], axis=0)
xyArrow = array([
xyMid + [-dx - dy, +dx - dy],
xyMid,
xyMid + [-dx + dy, -dx - dy]])
self.gl.plot(*xyArrow.T, color="r")
visit = visit.prev
xyPrev = xy
def drawTrans(self, oFrom, oTo, sColor="gray"):
self.gl.plot(
[oFrom[0], oTo[0]], # x
[oFrom[1], oTo[1]], # y
color=sColor
)
def drawTrans2(self,
xyLine, xyCentre,
rotation, bDeadEnd=False,
sColor="gray",
bArrow=True,
spacing=0.1):
def draw_transition(self,
line,
center,
rotation,
dead_end=False,
curves=False,
color="gray",
arrow=True,
spacing=0.1):
"""
gLine is a numpy 2d array of points,
in the plotting space / coords.
......@@ -334,67 +312,77 @@ class RenderTool(object):
from x=0, y=0.5
to x=1, y=0.2
"""
rt = self.__class__
bStraight = rotation in [0, 2]
dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2
if bStraight:
if not curves and not dead_end:
# just a straigt line, no curve nor dead_end included in this basic rail element
self.gl.plot(
[line[0][0], line[1][0]], # x
[line[0][1], line[1][1]], # y
color=color
)
else:
# it was not a simple line to draw: the rail has a curve or dead_end included.
rt = self.__class__
straight = rotation in [0, 2]
dx, dy = np.squeeze(np.diff(line, axis=0)) * spacing / 2
if sColor == "auto":
if dx > 0 or dy > 0:
sColor = "C1" # N or E
else:
sColor = "C2" # S or W
if bDeadEnd:
xyLine2 = array([
xyLine[1] + [dy, dx],
xyCentre,
xyLine[1] - [dy, dx],
])
self.gl.plot(*xyLine2.T, color=sColor)
else:
xyLine2 = xyLine + [-dy, dx]
self.gl.plot(*xyLine2.T, color=sColor)
if straight:
if bArrow:
xyMid = np.sum(xyLine2 * [[1 / 4], [3 / 4]], axis=0)
if color == "auto":
if dx > 0 or dy > 0:
color = "C1" # N or E
else:
color = "C2" # S or W
if dead_end:
line_xy = array([
line[1] + [dy, dx],
center,
line[1] - [dy, dx],
])
self.gl.plot(*line_xy.T, color=color)
else:
line_xy = line + [-dy, dx]
self.gl.plot(*line_xy.T, color=color)
xyArrow = array([
xyMid + [-dx - dy, +dx - dy],
xyMid,
xyMid + [-dx + dy, -dx - dy]])
self.gl.plot(*xyArrow.T, color=sColor)
if arrow:
middle_xy = np.sum(line_xy * [[1 / 4], [3 / 4]], axis=0)
else:
arrow_xy = array([
middle_xy + [-dx - dy, +dx - dy],
middle_xy,
middle_xy + [-dx + dy, -dx - dy]])
self.gl.plot(*arrow_xy.T, color=color)
xyMid = np.mean(xyLine, axis=0)
dxy = xyMid - xyCentre
xyCorner = xyMid + dxy
if rotation == 1:
rArcFactor = 1 - spacing
sColorAuto = "C1"
else:
rArcFactor = 1 + spacing
sColorAuto = "C2"
dxy2 = (xyCentre - xyCorner) * rArcFactor # for scaling the arc
if sColor == "auto":
sColor = sColorAuto
self.gl.plot(*(rt.gArc * dxy2 + xyCorner).T, color=sColor)
if bArrow:
dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20
iArc = int(len(rt.gArc) / 2)
xyMid = xyCorner + rt.gArc[iArc] * dxy2
xyArrow = array([
xyMid + [-dx - dy, +dx - dy],
xyMid,
xyMid + [-dx + dy, -dx - dy]])
self.gl.plot(*xyArrow.T, color=sColor)
def renderObs(self, agent_handles, observation_dict):
middle_xy = np.mean(line, axis=0)
dxy = middle_xy - center
corner = middle_xy + dxy
if rotation == 1:
arc_factor = 1 - spacing
color_auto = "C1"
else:
arc_factor = 1 + spacing
color_auto = "C2"
dxy2 = (center - corner) * arc_factor # for scaling the arc
if color == "auto":
color = color_auto
self.gl.plot(*(rt.arc * dxy2 + corner).T, color=color)
if arrow:
dx, dy = np.squeeze(np.diff(line, axis=0)) / 20
iArc = int(len(rt.arc) / 2)
middle_xy = corner + rt.arc[iArc] * dxy2
arrow_xy = array([
middle_xy + [-dx - dy, +dx - dy],
middle_xy,
middle_xy + [-dx + dy, -dx - dy]])
self.gl.plot(*arrow_xy.T, color=color)
def render_observation(self, agent_handles, observation_dict):
"""
Render the extent of the observation of each agent. All cells that appear in the agent
observation will be highlighted.
......@@ -404,38 +392,72 @@ class RenderTool(object):
"""
rt = self.__class__
for agent in agent_handles:
color = self.gl.getAgentColor(agent)
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
# Check if the observation builder provides an observation
if len(observation_dict) < 1:
warnings.warn(
"Predictor did not provide any predicted cells to render. \
Observation builder needs to populate: env.dev_obs_dict")
else:
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def render_prediction(self, agent_handles, prediction_dict):
"""
Render the extent of the observation of each agent. All cells that appear in the agent
observation will be highlighted.
:param agent_handles: List of agent indices to adapt color and get correct observation
:param observation_dict: dictionary containing sets of cells of the agent observation
"""
rt = self.__class__
if len(prediction_dict) < 1:
warnings.warn(
"Predictor did not provide any predicted cells to render. \
Predictors builder needs to populate: env.dev_pred_dict")
else:
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in prediction_dict[agent]:
cell_coord = array(visited_cell[:2])
if type(self.gl) is PILSVG:
# TODO : Track highlighting (Adrian)
r = cell_coord[0]
c = cell_coord[1]
transitions = self.env.rail.grid[r, c]
self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color)
else:
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False):
cell_size = 1 # TODO: remove cell_size
env = self.env
# Draw cells grid
grid_color = [0.95, 0.95, 0.95]
for r in range(env.height + 1):
for row in range(env.height + 1):
self.gl.plot([0, (env.width + 1) * cell_size],
[-r * cell_size, -r * cell_size],
[-row * cell_size, -row * cell_size],
color=grid_color, linewidth=2)
for c in range(env.width + 1):
self.gl.plot([c * cell_size, c * cell_size],
for col in range(env.width + 1):
self.gl.plot([col * cell_size, col * cell_size],
[0, -(env.height + 1) * cell_size],
color=grid_color, linewidth=2)
# Draw each cell independently
for r in range(env.height):
for c in range(env.width):
for row in range(env.height):
for col in range(env.width):
# bounding box of the grid cell
x0 = cell_size * c # left
x1 = cell_size * (c + 1) # right
y0 = cell_size * -r # top
y1 = cell_size * -(r + 1) # bottom
x0 = cell_size * col # left
x1 = cell_size * (col + 1) # right
y0 = cell_size * -row # top
y1 = cell_size * -(row + 1) # bottom
# centres of cell edges
coords = [
......@@ -446,16 +468,16 @@ class RenderTool(object):
]
# cell centre
xyCentre = array([x0, y1]) + cell_size / 2
center_xy = array([x0, y1]) + cell_size / 2
# cell transition values
oCell = env.rail.get_transitions((r, c))
cell = env.rail.get_full_transitions(row, col)
bCellValid = env.rail.cell_neighbours_valid((r, c), check_this_cell=True)
cell_valid = env.rail.cell_neighbours_valid((row, col), check_this_cell=True)
# Special Case 7, with a single bit; terminate at center
nbits = 0
tmp = oCell
tmp = cell
while tmp > 0:
nbits += (tmp & 1)
......@@ -463,110 +485,128 @@ class RenderTool(object):
# as above - move the from coord to the centre
# it's a dead env.
bDeadEnd = nbits == 1
is_dead_end = nbits == 1
if not bCellValid:
self.gl.scatter(*xyCentre, color="r", s=30)
if not cell_valid:
self.gl.scatter(*center_xy, color="r", s=30)
for orientation in range(4): # ori is where we're heading
from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE
from_xy = coords[from_ori]
tMoves = env.rail.get_transitions((r, c, orientation))
moves = env.rail.get_transitions(row, col, orientation)
for to_ori in range(4):
to_xy = coords[to_ori]
rotation = (to_ori - from_ori) % 4
if (tMoves[to_ori]): # if we have this transition
if bDeadEnd:
self.drawTrans2(
array([from_xy, to_xy]), xyCentre,
rotation, bDeadEnd=True, spacing=spacing,
sColor=sRailColor)
else:
if curves:
self.drawTrans2(
array([from_xy, to_xy]), xyCentre,
rotation, spacing=spacing, bArrow=arrows,
sColor=sRailColor)
else:
self.drawTrans(self, from_xy, to_xy, sRailColor)
if False:
print(
"r,c,ori: ", r, c, orientation,
"cell:", "{0:b}".format(oCell),
"moves:", tMoves,
"from:", from_ori, from_xy,
"to: ", to_ori, to_xy,
"cen:", *xyCentre,
"rot:", rotation,
)
def renderEnv(self,
show=False, # whether to call matplotlib show() or equivalent after completion
# use false when calling from Jupyter. (and matplotlib no longer supported!)
curves=True, # draw turns as curves instead of straight diagonal lines
spacing=False, # defunct - size of spacing between rails
arrows=False, # defunct - draw arrows on rail lines
agents=True, # whether to include agents
show_observations=True, # whether to include observations
sRailColor="gray", # color to use in drawing rails (not used with SVG)
frames=False, # frame counter to show (intended since invocation)
iEpisode=None, # int episode number to show
iStep=None, # int step number to show in image
iSelectedAgent=None, # indicate which agent is "selected" in the editor
action_dict=None): # defunct - was used to indicate agent intention to turn
if (moves[to_ori]): # if we have this transition
self.draw_transition(
array([from_xy, to_xy]), center_xy,
rotation, dead_end=is_dead_end, curves=curves and not is_dead_end, spacing=spacing,
color=rail_color)
def render_env(self,
show=False, # whether to call matplotlib show() or equivalent after completion
show_agents=True, # whether to include agents
show_inactive_agents=False,
show_observations=True, # whether to include observations
show_predictions=False, # whether to include predictions
show_rowcols=False, # label the rows and columns
frames=False, # frame counter to show (intended since invocation)
episode=None, # int episode number to show
step=None, # int step number to show in image
selected_agent=None, # indicate which agent is "selected" in the editor
return_image=False): # indicate if image is returned for use in monitor:
""" Draw the environment using the GraphicsLayer this RenderTool was created with.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
if not self.gl.is_raster():
self.renderEnv2(show=show, curves=curves, spacing=spacing,
arrows=arrows, agents=agents, show_observations=show_observations,
sRailColor=sRailColor,
frames=frames, iEpisode=iEpisode, iStep=iStep,
iSelectedAgent=iSelectedAgent, action_dict=action_dict)
return
# if type(self.gl) is PILSVG:
if self.gl_str in ["PILSVG", "PGL"]:
return self.render_env_svg(show=show,
show_observations=show_observations,
show_predictions=show_predictions,
selected_agent=selected_agent,
show_agents=show_agents,
show_inactive_agents=show_inactive_agents,
show_rowcols=show_rowcols,
return_image=return_image
)
else:
return self.render_env_pil(show=show,
show_agents=show_agents,
show_inactive_agents=show_inactive_agents,
show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols,
frames=frames,
episode=episode,
step=step,
selected_agent=selected_agent,
return_image=return_image
)
def _draw_square(self, center, size, color, opacity=255, layer=0):
x0 = center[0] - size / 2
x1 = center[0] + size / 2
y0 = center[1] - size / 2
y1 = center[1] + size / 2
self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
def get_image(self):
return self.gl.get_image()
def render_env_pil(self,
show=False, # whether to call matplotlib show() or equivalent after completion
# use false when calling from Jupyter. (and matplotlib no longer supported!)
show_agents=True, # whether to include agents
show_inactive_agents=False,
show_observations=True, # whether to include observations
show_predictions=False, # whether to include predictions
show_rowcols=False, # label the rows and columns
frames=False, # frame counter to show (intended since invocation)
episode=None, # int episode number to show
step=None, # int step number to show in image
selected_agent=None, # indicate which agent is "selected" in the editor
return_image=False # indicate if image is returned for use in monitor:
):
if type(self.gl) is PILGL:
self.gl.beginFrame()
self.gl.begin_frame()
env = self.env
self.renderRail()
self.render_rail()
# Draw each agent + its orientation + its target
if agents:
self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
if show_agents:
self.plot_agents(targets=True, selected_agent=selected_agent)
if show_observations:
self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
if show_predictions and len(env.dev_pred_dict) > 0:
self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
# Draw some textual information like fps
yText = [-0.3, -0.6, -0.9]
text_y = [-0.3, -0.6, -0.9]
if frames:
self.gl.text(0.1, yText[2], "Frame:{:}".format(self.iFrame))
self.iFrame += 1
self.gl.text(0.1, text_y[2], "Frame:{:}".format(self.frame_nr))
self.frame_nr += 1
if iEpisode is not None:
self.gl.text(0.1, yText[1], "Ep:{}".format(iEpisode))
if episode is not None:
self.gl.text(0.1, text_y[1], "Ep:{}".format(episode))
if iStep is not None:
self.gl.text(0.1, yText[0], "Step:{}".format(iStep))
if step is not None:
self.gl.text(0.1, text_y[0], "Step:{}".format(step))
tNow = time.time()
self.gl.text(2, yText[2], "elapsed:{:.2f}s".format(tNow - self.time1))
self.lTimes.append(tNow)
if len(self.lTimes) > 20:
self.lTimes.popleft()
if len(self.lTimes) > 1:
rFps = (len(self.lTimes) - 1) / (self.lTimes[-1] - self.lTimes[0])
self.gl.text(2, yText[1], "fps:{:.2f}".format(rFps))
time_now = time.time()
self.gl.text(2, text_y[2], "elapsed:{:.2f}s".format(time_now - self.start_time))
self.times_list.append(time_now)
if len(self.times_list) > 20:
self.times_list.popleft()
if len(self.times_list) > 1:
rFps = (len(self.times_list) - 1) / (self.times_list[-1] - self.times_list[0])
self.gl.text(2, text_y[1], "fps:{:.2f}".format(rFps))
self.gl.prettify2(env.width, env.height, self.nPixCell)
self.gl.prettify2(env.width, env.height, self.pix_per_cell)
# TODO: for MPL, we don't want to call clf (called by endframe)
# if not show:
......@@ -576,115 +616,155 @@ class RenderTool(object):
self.gl.pause(0.00001)
if return_image:
return self.get_image()
return
def _draw_square(self, center, size, color, opacity=255, layer=0):
x0 = center[0] - size / 2
x1 = center[0] + size / 2
y0 = center[1] - size / 2
y1 = center[1] + size / 2
self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
def getImage(self):
return self.gl.getImage()
def plotTreeObs(self, gObs):
nBranchFactor = 4
gP0 = array([[0, 0, 0]]).T
nDepth = 2
for i in range(nDepth):
nDepthNodes = nBranchFactor ** i
rShrinkDepth = 1 / (i + 1)
gX1 = np.linspace(-(nDepthNodes - 1), (nDepthNodes - 1), nDepthNodes) * rShrinkDepth
gY1 = np.ones((nDepthNodes)) * i
gZ1 = np.zeros((nDepthNodes))
gP1 = array([gX1, gY1, gZ1])
gP01 = np.append(gP0, gP1, axis=1)
if nDepthNodes > 1:
nDepthNodesPrev = nDepthNodes / nBranchFactor
giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor)
giP1 = np.arange(0, nDepthNodes) + nDepthNodesPrev
giLinePoints = np.stack([giP0, giP1]).ravel("F")
self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray")
gP0 = array([gX1, gY1, gZ1])
def renderEnv2(
self, show=False, curves=True, spacing=False, arrows=False, agents=True,
show_observations=True, sRailColor="gray",
frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
action_dict=dict()
def render_env_svg(
self, show=False, show_observations=True, show_predictions=False, selected_agent=None,
show_agents=True, show_inactive_agents=False, show_rowcols=False, return_image=False
):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
Call pyplot.show() if show==True.
(Use show=False from a Jupyter notebook with %matplotlib inline)
Renders the environment with SVG support (nice image)
"""
env = self.env
self.gl.beginFrame()
self.gl.begin_frame()
if self.new_rail:
self.new_rail = False
self.gl.clear_rails()
# store the targets
dTargets = {}
dSelected = {}
for iAgent, agent in enumerate(self.env.agents_static):
targets = {}
selected = {}
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
dTargets[tuple(agent.target)] = iAgent
dSelected[tuple(agent.target)] = (iAgent == iSelectedAgent)
targets[tuple(agent.target)] = agent_idx
selected[tuple(agent.target)] = (agent_idx == selected_agent)
# Draw each cell independently
for r in range(env.height):
for c in range(env.width):
binTrans = env.rail.grid[r, c]
if (r, c) in dTargets:
target = dTargets[(r, c)]
isSelected = dSelected[(r, c)]
transitions = env.rail.grid[r, c]
if (r, c) in targets:
target = targets[(r, c)]
is_selected = selected[(r, c)]
else:
target = None
isSelected = False
is_selected = False
self.gl.setRailAt(r, c, binTrans, iTarget=target, isSelected=isSelected)
self.gl.set_rail_at(r, c, transitions, target=target, is_selected=is_selected,
rail_grid=env.rail.grid, num_agents=env.get_num_agents(),
show_debug=self.show_debug)
self.gl.build_background_map(dTargets)
self.gl.build_background_map(targets)
for iAgent, agent in enumerate(self.env.agents):
if show_rowcols:
# label rows, cols
for iRow in range(env.height):
self.gl.text_rowcol((iRow, 0), str(iRow), layer=self.gl.RAIL_LAYER)
for iCol in range(env.width):
self.gl.text_rowcol((0, iCol), str(iCol), layer=self.gl.RAIL_LAYER)
if agent is None:
continue
if agent.old_position is not None:
position = agent.old_position
direction = agent.direction
old_direction = agent.old_direction
else:
position = agent.position
direction = agent.direction
old_direction = agent.direction
if show_agents:
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
# Show an agent even if it hasn't already started
if agent.position is None:
if show_inactive_agents:
# print("agent ", agent_idx, agent.position, agent.old_position, agent.initial_position)
self.gl.set_agent_at(agent_idx, *(agent.initial_position),
agent.initial_direction, agent.initial_direction,
is_selected=(selected_agent == agent_idx),
rail_grid=env.rail.grid,
show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
malfunction=False)
continue
# setAgentAt uses the agent index for the color
self.gl.setAgentAt(iAgent, *position, old_direction, direction, iSelectedAgent == iAgent)
is_malfunction = agent.malfunction_handler.malfunction_down_counter > 0
if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \
self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: # noqa: E125
# Most common case - the agent has been running for >1 steps
if agent.old_position is not None:
position = agent.old_position
direction = agent.direction
old_direction = agent.old_direction
# the agent's first step - it doesn't have an old position yet
elif agent.position is not None:
position = agent.position
direction = agent.direction
old_direction = agent.direction
# When the editor has just added an agent
elif agent.initial_position is not None:
position = agent.initial_position
direction = agent.initial_direction
old_direction = agent.initial_direction
# set_agent_at uses the agent index for the color
if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
self.gl.set_agent_at(agent_idx, *position, old_direction, direction,
selected_agent == agent_idx, rail_grid=env.rail.grid,
show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
malfunction=is_malfunction)
else:
position = agent.position
direction = agent.direction
for possible_direction in range(4):
# Is a transition along movement `desired_movement_from_new_cell` to the current cell possible?
isValid = env.rail.get_transition((*agent.position, agent.direction), possible_direction)
if isValid:
direction = possible_direction
# set_agent_at uses the agent index for the color
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
selected_agent == agent_idx, rail_grid=env.rail.grid,
show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
malfunction=is_malfunction)
# set_agent_at uses the agent index for the color
if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
if show_inactive_agents:
show_this_agent = True
else:
show_this_agent = agent.state.is_on_map_state()
if show_this_agent:
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
selected_agent == agent_idx,
rail_grid=env.rail.grid, malfunction=is_malfunction)
if show_observations:
self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
if show_predictions:
self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
if show:
self.gl.show()
for i in range(3):
self.gl.processEvents()
self.gl.process_events()
self.iFrame += 1
self.frame_nr += 1
if return_image:
return self.get_image()
return
def close_window(self):
......
from typing import Tuple, Dict
import numpy as np
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# Note that that cells have invalid RailEnvTransitions!
# |
# |
# |
# _ _ _ _\ _ _ _ _ _ _
# /
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_left = cells[2]
simple_switch_north_right = cells[10]
simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[simple_switch_east_west_north] +
[horizontal_straight] * 2 + [simple_switch_east_west_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# Note that that cells have invalid RailEnvTransitions!
# |
# |
# |
# _ _ _ _\ _ _ _ _ _
# /
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_left = cells[2]
simple_switch_north_right = cells[10]
simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[simple_switch_east_west_north] +
[dead_end_from_west] + [dead_end_from_east] + [simple_switch_east_west_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ _\ _ _ _ _ _ _
# \
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_right = cells[10]
simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[simple_switch_east_west_north] +
[horizontal_straight] * 2 + [simple_switch_west_east_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# Note that that cells have invalid RailEnvTransitions!
# |
# |
# |
# _ _ _ _ _ _ _ _ _ _
# /
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_left = cells[2]
# simple_switch_north_right = cells[10]
# simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] +
[[empty] * 3 + [dead_end_from_north] + [empty] * 6] +
[[dead_end_from_east] + [horizontal_straight] * 5 + [simple_switch_east_west_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# 0 1 2 3 4 5 6 7 8 9 10
# 0 /-------------\
# 1 | |
# 2 | |
# 3 _ _ _ /_ _ _ |
# 4 \ ___ /
# 5 |/
# 6 |
# 7 |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
right_turn_from_south = cells[8]
right_turn_from_west = transitions.rotate_transition(right_turn_from_south, 90)
right_turn_from_north = transitions.rotate_transition(right_turn_from_south, 180)
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
simple_switch_north_left = cells[2]
simple_switch_north_right = cells[10]
simple_switch_left_east = transitions.rotate_transition(simple_switch_north_left, 90)
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west]] +
[[empty] * 3 + [vertical_straight] + [empty] * 5 + [vertical_straight]] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 + [simple_switch_left_east] + [horizontal_straight] * 2 + [
right_turn_from_west] + [empty] * 2 + [vertical_straight]] +
[[empty] * 6 + [simple_switch_north_right] + [horizontal_straight] * 2 + [right_turn_from_north]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array, Dict[str, str]]:
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_oval_rail() -> Tuple[GridTransitionMap, np.array]:
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
right_turn_from_south = cells[8]
right_turn_from_west = transitions.rotate_transition(right_turn_from_south, 90)
right_turn_from_north = transitions.rotate_transition(right_turn_from_south, 180)
right_turn_from_east = transitions.rotate_transition(right_turn_from_south, 270)
rail_map = np.array(
[[empty] * 9] +
[[empty] + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west] + [empty]] +
[[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]]+
[[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]] +
[[empty] + [right_turn_from_east] + [horizontal_straight] * 5 + [right_turn_from_north] + [empty]] +
[[empty] * 9], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(1, 4), (4, 4)]
train_stations = [
[((1, 4), 0)],
[((4, 4), 0)],
]
city_orientations = [1, 3]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
......@@ -3,7 +3,7 @@ import re
import svgutils
from flatland.core.transitions import RailEnvTransitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions
class SVG(object):
......@@ -14,9 +14,6 @@ class SVG(object):
elif svgETree is not None:
self.svg = svgETree
self.init2()
def init2(self):
expr = "//*[local-name() = $name]"
self.eStyle = self.svg.root.xpath(expr, name="style")[0]
ltMatch = re.findall(r".st([a-zA-Z0-9]+)[{]([^}]*)}", self.eStyle.text)
......@@ -25,8 +22,7 @@ class SVG(object):
def copy(self):
new_svg = copy.deepcopy(self.svg)
self2 = SVG(svgETree=new_svg)
return self2
return SVG(svgETree=new_svg)
def merge(self, svg2):
svg3 = svg2.copy()
......
Flatland 2.0 Introduction
=========================
## What's new?
In this version of **Flat**land, we are moving closer to realistic and more complex railway problems.
Earlier versions of **Flat**land introduced you to the concept of restricted transitions, but they were still too simplistic to give us feasible solutions for daily operations.
Thus the following changes are coming in the next version to be closer to real railway network challenges:
- **New Level Generator** provide less connections between different nodes in the network and thus agent densities on rails are much higher.
- **Stochastic Events** cause agents to stop and get stuck for different numbers of time steps.
- **Different Speed Classes** allow agents to move at different speeds and thus enhance complexity in the search for optimal solutions.
We explain these changes in more detail and how you can play with their parametrization in Tutorials 3--5:
* [Tutorials](https://gitlab.aicrowd.com/flatland/flatland/tree/master/docs/tutorials)
We appreciate *your feedback* on the performance and the difficulty on these levels to help us shape the best possible **Flat**land 2.0 environment.
## Example code
To see all the changes in action you can just run the
* [examples/flatland_example_2_0.py](https://gitlab.aicrowd.com/flatland/flatland/blob/master/examples/flatland_2_0_example.py)
example.
@echo on
set FLATLAND_BASEDIR=%~dp0\..
cd %FLATLAND_BASEDIR%
call conda install -y -c conda-forge tox-conda || goto :error
call conda install -y tox || goto :error
call tox -v --recreate || goto :error
call tox -v -e start_jupyter --recreate || goto :error
goto :EOF
:error
echo Failed with error #%errorlevel%.
pause
#!/bin/bash
set -e # stop on error
set -x # echo commands
FLATLAND_BASEDIR=$(dirname "$BASH_SOURCE")/..
cd ${FLATLAND_BASEDIR}
conda install -y -c conda-forge tox-conda
conda install -y tox
tox -v
tox -v -e start_jupyter &
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
#!/usr/bin/env python
import glob
import os
import shutil
import subprocess
import webbrowser
from urllib.request import pathname2url
......@@ -18,16 +19,43 @@ def remove_exists(filename):
# clean docs config and html files, and rebuild everything
remove_exists('docs/flatland.rst')
# wildcards do not work under Windows
for image_file in glob.glob(r'./docs/flatland*.rst'):
remove_exists(image_file)
remove_exists('docs/modules.rst')
subprocess.call(['sphinx-apidoc', '-o', 'docs/', 'flatland'])
for md_file in glob.glob(r'./*.md') + glob.glob(r'./docs/specifications/*.md') + glob.glob(r'./docs/tutorials/*.md') + glob.glob(r'./docs/interface/*.md'):
from m2r import parse_from_file
rst_content = parse_from_file(md_file)
rst_file = md_file.replace(".md", ".rst")
remove_exists(rst_file)
with open(rst_file, 'w') as out:
print("m2r {}->{}".format(md_file, rst_file))
out.write(rst_content)
out.flush()
img_dest = 'docs/images/'
if not os.path.exists(img_dest):
os.makedirs(img_dest)
for image_file in glob.glob(r'./images/*.png'):
shutil.copy(image_file, img_dest)
subprocess.call(['sphinx-apidoc', '--force', '-a', '-e', '-o', 'docs/', 'flatland', '-H', 'API Reference', '--tocfile',
'05_apidoc'])
os.environ["SPHINXPROJ"] = "flatland"
os.environ["SPHINXPROJ"] = "Flatland"
os.chdir('docs')
subprocess.call(['python', '-msphinx', '-M', 'clean', '.', '_build'])
# TODO fix sphinx warnings instead of suppressing them...
subprocess.call(['python', '-msphinx', '-M', 'html', '.', '_build', '-Q'])
subprocess.call(['python', '-mpydeps', '../flatland', '-o', '_build/html/flatland.svg', '--no-config', '--noshow'])
img_dest = '_build/html/img'
if not os.path.exists(img_dest):
os.makedirs(img_dest)
for image_file in glob.glob(r'./specifications/img/*'):
shutil.copy(image_file, img_dest)
subprocess.call(['python', '-msphinx', '-M', 'html', '.', '_build'])
# we do not currrently use pydeps, commented out https://gitlab.aicrowd.com/flatland/flatland/issues/149
# subprocess.call(['python', '-mpydeps', '../flatland', '-o', '_build/html/flatland.svg', '--no-config', '--noshow'])
browser('_build/html/index.html')
File added
No preview for this file type
File added
%% Cell type:markdown id: tags:
# Simple Animation Demo
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
import numpy as np
import time
from IPython import display
from ipycanvas import canvas
from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions as rea
from flatland.envs.persistence import RailEnvPersister
```
%% Cell type:code id: tags:
``` python
env, env_dict = RailEnvPersister.load_new("complex_scene_2.pkl", load_from_package="env_data.railway")
_ = env.reset()
env._max_episode_steps = 100
```
%% Output
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
This env file has no max_episode_steps (deprecated) - setting to 100
%% Cell type:code id: tags:
``` python
oRT = RenderTool(env, gl="PILSVG", jupyter=False, show_debug=True)
image_arr = oRT.get_image()
oCanvas = canvas.Canvas()
oCanvas.put_image_data(image_arr[:,:,0:3])
display.display(oCanvas)
done={"__all__":False}
while not done["__all__"]:
actions = {}
for agent_handle, agents in enumerate(env.agents):
actions.update({agent_handle:rea.MOVE_FORWARD})
obs, rew, done, info = env.step(actions)
oRT.render_env(show_observations=False,show_predictions=False)
gIm = oRT.get_image()
oCanvas.put_image_data(gIm[:,:,0:3])
time.sleep(0.1)
```
%% Output
# list of notebooks to include in run-all-notebooks.py test
simple_example_manual_control.ipynb
simple_rendering_demo.ipynb
flatland_animate.ipynb
render_episode.ipynb
scene_editor.ipynb
test_saved_envs.ipynb
test_service.ipynb
%% Cell type:markdown id: tags:
# Render Episode
Render a stored episode. Env file needs to have "episode" and "action" keys.
- creates a moving gif file of the episode
- displays the episode in a widget with a slider for the time steps.
%% Cell type:markdown id: tags:
# Setup
%% Cell type:code id: tags:
``` python
#!apt -qq install graphviz libgraphviz-dev pkg-config
#!pip install -qq git+https://gitlab.aicrowd.com/flatland/flatland.git
```
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
from IPython import display
```
%% Cell type:code id: tags:
``` python
import os
import pandas as pd
import PIL
import imageio
```
%% Cell type:code id: tags:
``` python
from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import malfunction_from_file, no_malfunction_generator
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.step_utils.states import TrainState
from flatland.envs.persistence import RailEnvPersister
```
%% Cell type:code id: tags:
``` python
def render_env(env_renderer):
ag0= env_renderer.env.agents[0]
#print("render_env ag0: ",ag0.position, ag0.direction)
aImage = env_renderer.render_env(show_rowcols=True, return_image=True)
pil_image = PIL.Image.fromarray(aImage)
return pil_image
```
%% Cell type:markdown id: tags:
# Experiments
This has been mostly changed to load envs using `importlib_resources`. It's getting them from the package "envdata.tests`
%% Cell type:code id: tags:
``` python
env, env_dict = RailEnvPersister.load_new("complex_scene_2.pkl", load_from_package="env_data.railway")
_ = env.reset()
env._max_episode_steps = 100
```
%% Output
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
This env file has no max_episode_steps (deprecated) - setting to 100
%% Cell type:code id: tags:
``` python
# the seed has to match that used to record the episode, in order for the malfunctions to match.
oRT = RenderTool(env, show_debug=True)
aImg = oRT.render_env(show_rowcols=True, return_image=True, show_inactive_agents=True)
print(env._max_episode_steps)
```
%% Cell type:code id: tags:
``` python
loAgs = env_dict["agents"]
lCols = "initial_direction,direction,initial_position,position".split(",")
pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols]
for oAg in loAgs], columns=lCols)
```
%% Cell type:code id: tags:
``` python
pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols]
for oAg in env.agents], columns=lCols)
```
%% Cell type:code id: tags:
``` python
pd.DataFrame([ vars(oAg) for oAg in env.agents])
```
%% Cell type:code id: tags:
``` python
# from persistence.py
def get_agent_state(env):
list_agents_state = []
for iAg, oAg in enumerate(env.agents):
# the int cast is to avoid numpy types which may cause problems with msgpack
# in env v2, agents may have position None, before starting
if oAg.position is None:
pos = (0, 0)
else:
pos = (int(oAg.position[0]), int(oAg.position[1]))
# print("pos:", pos, type(pos[0]))
list_agents_state.append(
[*pos, int(oAg.direction), oAg.malfunction_handler])
return list_agents_state
```
%% Cell type:code id: tags:
``` python
pd.DataFrame([ vars(oAg) for oAg in env.agents])
```
%% Cell type:code id: tags:
``` python
expert_actions = []
action = {}
```
%% Cell type:code id: tags:
``` python
env_renderer = RenderTool(env, gl="PGL", show_debug=True)
n_agents = env.get_num_agents()
x_dim, y_dim = env.width, env.height
max_steps = env._max_episode_steps
action_dict = {}
frames = []
# log everything in original state
statuses = []
for a in range(n_agents):
statuses.append(env.agents[a].state)
pilImg = render_env(env_renderer)
frames.append({
'image': pilImg,
'statuses': statuses
})
step = 0
all_done = False
failed_action_check = False
print("Processing episode steps:")
while not all_done:
print(step, end=", ")
for agent_handle, agent in enumerate(env.agents):
action_dict.update({agent_handle: RailEnvActions.MOVE_FORWARD})
next_obs, all_rewards, done, info = env.step(action_dict)
statuses = []
for a in range(n_agents):
statuses.append(env.agents[a].state)
#clear_output(wait=True)
pilImg = render_env(env_renderer)
frames.append({
'image': pilImg,
'statuses': statuses
})
#print("Replaying {}/{}".format(step, max_steps))
if done['__all__']:
all_done = True
max_steps = step + 1
print("done")
step += 1
```
%% Cell type:code id: tags:
``` python
assert failed_action_check == False, "Realised states did not match stored states."
```
%% Cell type:code id: tags:
``` python
from ipywidgets import interact, interactive, fixed, interact_manual, Play
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from IPython.display import HTML
display.display(HTML('<link rel="stylesheet" href="//stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"/>'))
def plot_func(frame_idx):
frame = frames[int(frame_idx)]
display.display(frame['image'])
#print(frame['statuses'])
slider = widgets.FloatSlider(value=0, min=0, max=max_steps, step=1)
interact(plot_func, frame_idx = slider)
play = Play(
max=max_steps,
value=0,
step=1,
interval=250
)
widgets.link((play, 'value'), (slider, 'value'))
widgets.VBox([play])
```
%% Cell type:code id: tags:
``` python
```
import shlex
import sys
from subprocess import Popen, PIPE
import importlib_resources
import pkg_resources
from importlib_resources import path
import importlib_resources as ir
from ipython_genutils.py3compat import string_types, bytes_to_str
# taken from https://github.com/jupyter/nbconvert/blob/master/nbconvert/tests/base.py
def run_python(parameters, ignore_return_code=False, stdin=None):
"""
Run python as a shell command, listening for both Errors and
non-zero return codes. Returns the tuple (stdout, stderr) of
output produced during the nbconvert run.
Parameters
----------
parameters : str, list(str)
List of parameters to pass to IPython.
ignore_return_code : optional bool (default False)
Throw an OSError if the return code
"""
cmd = [sys.executable]
if sys.platform == 'win32':
if isinstance(parameters, string_types):
cmd = ' '.join(cmd) + ' ' + parameters
else:
cmd = ' '.join(cmd + parameters)
else:
if isinstance(parameters, string_types):
parameters = shlex.split(parameters)
cmd += parameters
p = Popen(cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE)
stdout, stderr = p.communicate(input=stdin)
if not (p.returncode == 0 or ignore_return_code):
raise OSError(bytes_to_str(stderr))
return stdout.decode('utf8', 'replace'), stderr.decode('utf8', 'replace')
def main():
# If the file notebooks-list exists, use it as a definitive list of notebooks to run
# This in effect ignores any local notebooks you might be working on, so you can run tox
# without them causing the notebooks task / testenv to fail.
if importlib_resources.is_resource("notebooks", "notebook-list"):
print("Using the notebooks-list file to designate which notebooks to run")
lsNB = [
sLine for sLine in ir.read_text("notebooks", "notebook-list").split("\n")
if len(sLine) > 3 and not sLine.startswith("#")
]
else:
lsNB = [
entry for entry in importlib_resources.contents('notebooks') if
not pkg_resources.resource_isdir('notebooks', entry)
and entry.endswith(".ipynb")
]
print("Running notebooks:", " ".join(lsNB))
for entry in lsNB:
print("*****************************************************************")
print("Converting and running {}".format(entry))
print("*****************************************************************")
with path('notebooks', entry) as file_in:
out, err = run_python(" -m jupyter nbconvert --ExecutePreprocessor.timeout=120 " +
"--execute --to notebook --inplace " + str(file_in))
sys.stderr.write(err)
sys.stderr.flush()
sys.stdout.write(out)
sys.stdout.flush()
if __name__ == "__main__":
main()
\ No newline at end of file
%% Cell type:markdown id: tags:
# Railway Scene Editor
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
import numpy as np
from numpy import array
import ipywidgets
import IPython
from IPython.core.display import display, HTML
```
%% Cell type:code id: tags:
``` python
display(HTML("<style>.container { width:95% !important; }</style>"))
```
%% Output
%% Cell type:code id: tags:
``` python
from flatland.utils.editor import EditorMVC
```
%% Cell type:code id: tags:
``` python
mvc = EditorMVC(sGL="PILSVG" )
```
%% Cell type:markdown id: tags:
## Instructions
- Drag to draw (improved dead-ends)
- ctrl-click to add agent or select agent
- if agent is selected:
- ctrl-click to move agent position
- use rotate agent to rotate 90°
- ctrl-shift-click to set target for selected agent
- target can be moved by repeating
- to Resize the env (cannot preserve work):
- select "Regen" tab, set regen size slider, click regenerate.
- alt-click remove all rails from cell
Demo Scene: complex_scene.pkl
%% Cell type:code id: tags:
``` python
mvc.view.display()
```
%% Output
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
%% Cell type:markdown id: tags:
### Simple Example 3 - Manual Control
By default this runs a few "move forward" actions for two agents, in a separate window.
If you uncomment the "input" line below, it opens a text box in the Jupyter notebook, allowing basic manual control.
eg Enter `"0 2 s<enter>"` to tell agent 0 to move forward, and step the environment.
You should be able to see the red agent step forward, and get a reward from the env, looking like this:
`Rewards: {0: -1.0, 1: -1.0} [done= {0: False, 1: False, '__all__': False} ]`
Note that this example is set up to use the straightforward "PIL" renderer - without the special SBB artwork!
The agent observations are displayed as squares of varying sizes, with a paler version of the agent colour. The targets are half-size squares in the full agent colour.
You can switch to the "PILSVG" renderer which is prettier but currently renders the agents one step behind, because it needs to know which way the agent is turning. This can be confusing if you are debugging step-by-step.
The image below is what the separate window should look like.
%% Cell type:markdown id: tags:
![simple_example_3.png](simple_example_3.png)
%% Cell type:code id: tags:
``` python
import random
import numpy as np
import time
```
%% Cell type:code id: tags:
``` python
from IPython import display
from ipycanvas import canvas
```
%% Cell type:code id: tags:
``` python
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
```
%% Cell type:code id: tags:
``` python
random.seed(1)
np.random.seed(1)
```
%% Cell type:code id: tags:
``` python
nAgents = 3
n_cities = 2
max_rails_between_cities = 2
max_rails_in_city = 4
seed = 0
env = RailEnv(
width=20,
height=30,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
)
init_observation = env.reset()
```
%% Cell type:code id: tags:
``` python
# Print the observation vector for agent 0
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
env.obs_builder.util_print_obs_subtree(tree=obs[i])
env_renderer = RenderTool(env, gl="PIL")
env_renderer.render_env(show=True, frames=True)
print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
(turnleft+move, move to front, turnright+move)")
```
%% Cell type:code id: tags:
``` python
image_arr = env_renderer.get_image()
oCanvas = canvas.Canvas()
oCanvas.put_image_data(image_arr[:,:,0:3])
display.display(oCanvas)
for step in range(10):
# This is an example command, setting agent 0's action to 2 (move forward), and agent 1's action to 2,
# then stepping the environment.
cmd = "0 2 1 2 s"
# uncomment this input statement if you want to try interactive manual commands
# cmd = input(">> ")
cmds = cmd.split(" ")
action_dict = {}
i = 0
while i < len(cmds):
if cmds[i] == 'q':
import sys
sys.exit()
elif cmds[i] == 's':
obs, all_rewards, done, _ = env.step(action_dict)
action_dict = {}
print("Rewards: ", all_rewards, " [done=", done, "]")
else:
agent_id = int(cmds[i])
action = int(cmds[i + 1])
action_dict[agent_id] = action
i = i + 1
i += 1
env_renderer.render_env(show=True, frames=True)
gIm = env_renderer.get_image()
oCanvas.put_image_data(gIm[:,:,0:3])
time.sleep(0.1)
```