Forked from
Flatland / Flatland
1485 commits behind the upstream repository.
rendertools.py 24.01 KiB
import time
import warnings
from collections import deque
from enum import IntEnum
import numpy as np
from numpy import array
from recordtype import recordtype
from flatland.utils.graphics_pil import PILGL, PILSVG
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
class AgentRenderVariant(IntEnum):
BOX_ONLY = 0
ONE_STEP_BEHIND = 1
AGENT_SHOWS_OPTIONS = 2
ONE_STEP_BEHIND_AND_BOX = 3
AGENT_SHOWS_OPTIONS_AND_BOX = 4
class RenderTool(object):
""" Class to render the RailEnv and agents.
Uses two layers, layer 0 for rails (mostly static), layer 1 for agents etc (dynamic)
The lower / rail layer 0 is only redrawn after set_new_rail() has been called.
Created with a "GraphicsLayer" or gl - now either PIL or PILSVG
"""
visit = recordtype("visit", ["rc", "iDir", "iDepth", "prev"])
color_list = list("brgcmyk")
# \delta RC for NESW
transitions_row_col = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
pix_per_cell = 1 # misnomer...
half_pix_per_cell = pix_per_cell / 2
x_y_half = array([half_pix_per_cell, -half_pix_per_cell])
row_col_to_xy = array([[0, -pix_per_cell], [pix_per_cell, 0]])
grid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[pix_per_cell]], [[pix_per_cell]]])
theta = np.linspace(0, np.pi / 2, 5)
arc = array([np.cos(theta), np.sin(theta)]).T # from [1,0] to [0,1]
def __init__(self, env, gl="PILSVG", jupyter=False,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, screen_width=800, screen_height=600):
self.env = env
self.frame_nr = 0
self.start_time = time.time()
self.times_list = deque()
self.agent_render_variant = agent_render_variant
if gl == "PIL":
self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
elif gl == "PILSVG":
self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
else:
print("[", gl, "] not found, switch to PILSVG")
self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
self.new_rail = True
self.show_debug = show_debug
self.update_background()
def reset(self):
"""
Resets the environment
:return:
"""
self.set_new_rail()
self.frame_nr = 0
self.start_time = time.time()
self.times_list = deque()
return
def update_background(self):
# create background map
targets = {}
for agent_idx, agent in enumerate(self.env.agents_static):
if agent is None:
continue
targets[tuple(agent.target)] = agent_idx
self.gl.build_background_map(targets)
def resize(self):
self.gl.resize(self.env)
def set_new_rail(self):
""" Tell the renderer that the rail has changed.
eg when the rail has been regenerated, or updated in the editor.
"""
self.new_rail = True
def plot_agents(self, targets=True, selected_agent=None):
color_map = self.gl.get_cmap('hsv',
lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
for agent_idx, agent in enumerate(self.env.agents_static):
if agent is None:
continue
color = color_map(agent_idx)
self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None,
static=True, selected=agent_idx == selected_agent)
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
color = color_map(agent_idx)
self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None)
def get_transition_row_col(self, row_col_pos, direction, bgiTrans=False):
"""
Get the available transitions for row_col_pos in direction direction,
as row & col deltas.
If bgiTrans is True, return a grid of indices of available transitions.
eg for a cell row_col_pos = (4,5), in direction direction = 0 (N),
where the available transitions are N and E, returns:
[[-1,0], [0,1]] ie N=up one row, and E=right one col.
and if bgiTrans is True, returns a tuple:
(
[[-1,0], [0,1]], # deltas as before
[0, 1] # available transition indices, ie N, E
)
"""
transitions = self.env.rail.get_transitions(*row_col_pos, direction)
transition_list = np.where(transitions)[0] # RC list of transitions
# HACK: workaround dead-end transitions
if len(transition_list) == 0:
reverse_direciton = (direction + 2) % 4
transitions = tuple(int(tmp_dir == reverse_direciton) for tmp_dir in range(4))
transition_list = np.where(transitions)[0] # RC list of transitions
transition_grid = self.__class__.transitions_row_col[transition_list]
if bgiTrans:
return transition_grid, transition_list
else:
return transition_grid
def plot_single_agent(self, position_row_col, direction, color="r", target=None, static=False, selected=False):
"""
Plot a simple agent.
Assumes a working graphics layer context (cf a MPL figure).
"""
rt = self.__class__
direction_row_col = rt.transitions_row_col[direction] # agent direction in RC
direction_xy = np.matmul(direction_row_col, rt.row_col_to_xy) # agent direction in xy
xyPos = np.matmul(position_row_col - direction_row_col / 2, rt.row_col_to_xy) + rt.x_y_half
if static:
color = self.gl.adapt_color(color, lighten=True)
color = color
self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100) # agent location
xy_dir_line = array([xyPos, xyPos + direction_xy / 2]).T # line for agent orient.
self.gl.plot(*xy_dir_line, color=color, layer=1, lw=5, ms=0, alpha=0.6)
if selected:
self._draw_square(xyPos, 1, color)
if target is not None:
target_row_col = array(target)
target_xy = np.matmul(target_row_col, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(target_xy, 1 / 3, color, layer=1)
def plot_transition(self, position_row_col, transition_row_col, color="r", depth=None):
"""
plot the transitions in transition_row_col at position position_row_col.
transition_row_col is a 2d numpy array containing a list of RC transitions,
eg [[-1,0], [0,1]] means N, E.
"""
rt = self.__class__
position_xy = np.matmul(position_row_col, rt.row_col_to_xy) + rt.x_y_half
transition_xy = position_xy + np.matmul(transition_row_col, rt.row_col_to_xy / 2.4)
self.gl.scatter(*transition_xy.T, color=color, marker="o", s=50, alpha=0.2)
if depth is not None:
for x, y in transition_xy:
self.gl.text(x, y, depth)
def draw_transition(self,
line,
center,
rotation,
dead_end=False,
curves=False,
color="gray",
arrow=True,
spacing=0.1):
"""
gLine is a numpy 2d array of points,
in the plotting space / coords.
eg:
[[0,.5],[1,0.2]] means a line
from x=0, y=0.5
to x=1, y=0.2
"""
if not curves and not dead_end:
# just a straigt line, no curve nor dead_end included in this basic rail element
self.gl.plot(
[line[0][0], line[1][0]], # x
[line[0][1], line[1][1]], # y
color=color
)
else:
# it was not a simple line to draw: the rail has a curve or dead_end included.
rt = self.__class__
straight = rotation in [0, 2]
dx, dy = np.squeeze(np.diff(line, axis=0)) * spacing / 2
if straight:
if color == "auto":
if dx > 0 or dy > 0:
color = "C1" # N or E
else:
color = "C2" # S or W
if dead_end:
line_xy = array([
line[1] + [dy, dx],
center,
line[1] - [dy, dx],
])
self.gl.plot(*line_xy.T, color=color)
else:
line_xy = line + [-dy, dx]
self.gl.plot(*line_xy.T, color=color)
if arrow:
middle_xy = np.sum(line_xy * [[1 / 4], [3 / 4]], axis=0)
arrow_xy = array([
middle_xy + [-dx - dy, +dx - dy],
middle_xy,
middle_xy + [-dx + dy, -dx - dy]])
self.gl.plot(*arrow_xy.T, color=color)
else:
middle_xy = np.mean(line, axis=0)
dxy = middle_xy - center
corner = middle_xy + dxy
if rotation == 1:
arc_factor = 1 - spacing
color_auto = "C1"
else:
arc_factor = 1 + spacing
color_auto = "C2"
dxy2 = (center - corner) * arc_factor # for scaling the arc
if color == "auto":
color = color_auto
self.gl.plot(*(rt.arc * dxy2 + corner).T, color=color)
if arrow:
dx, dy = np.squeeze(np.diff(line, axis=0)) / 20
iArc = int(len(rt.arc) / 2)
middle_xy = corner + rt.arc[iArc] * dxy2
arrow_xy = array([
middle_xy + [-dx - dy, +dx - dy],
middle_xy,
middle_xy + [-dx + dy, -dx - dy]])
self.gl.plot(*arrow_xy.T, color=color)
def render_observation(self, agent_handles, observation_dict):
"""
Render the extent of the observation of each agent. All cells that appear in the agent
observation will be highlighted.
:param agent_handles: List of agent indices to adapt color and get correct observation
:param observation_dict: dictionary containing sets of cells of the agent observation
"""
rt = self.__class__
# Check if the observation builder provides an observation
if len(observation_dict) < 1:
warnings.warn(
"Predictor did not provide any predicted cells to render. \
Observation builder needs to populate: env.dev_obs_dict")
else:
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def render_prediction(self, agent_handles, prediction_dict):
"""
Render the extent of the observation of each agent. All cells that appear in the agent
observation will be highlighted.
:param agent_handles: List of agent indices to adapt color and get correct observation
:param observation_dict: dictionary containing sets of cells of the agent observation
"""
rt = self.__class__
if len(prediction_dict) < 1:
warnings.warn(
"Predictor did not provide any predicted cells to render. \
Predictors builder needs to populate: env.dev_pred_dict")
else:
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in prediction_dict[agent]:
cell_coord = array(visited_cell[:2])
if type(self.gl) is PILSVG:
# TODO : Track highlighting (Adrian)
r = cell_coord[0]
c = cell_coord[1]
transitions = self.env.rail.grid[r, c]
self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color)
else:
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False):
cell_size = 1 # TODO: remove cell_size
env = self.env
# Draw cells grid
grid_color = [0.95, 0.95, 0.95]
for row in range(env.height + 1):
self.gl.plot([0, (env.width + 1) * cell_size],
[-row * cell_size, -row * cell_size],
color=grid_color, linewidth=2)
for col in range(env.width + 1):
self.gl.plot([col * cell_size, col * cell_size],
[0, -(env.height + 1) * cell_size],
color=grid_color, linewidth=2)
# Draw each cell independently
for row in range(env.height):
for col in range(env.width):
# bounding box of the grid cell
x0 = cell_size * col # left
x1 = cell_size * (col + 1) # right
y0 = cell_size * -row # top
y1 = cell_size * -(row + 1) # bottom
# centres of cell edges
coords = [
((x0 + x1) / 2.0, y0), # N middle top
(x1, (y0 + y1) / 2.0), # E middle right
((x0 + x1) / 2.0, y1), # S middle bottom
(x0, (y0 + y1) / 2.0) # W middle left
]
# cell centre
center_xy = array([x0, y1]) + cell_size / 2
# cell transition values
cell = env.rail.get_full_transitions(row, col)
cell_valid = env.rail.cell_neighbours_valid((row, col), check_this_cell=True)
# Special Case 7, with a single bit; terminate at center
nbits = 0
tmp = cell
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
# as above - move the from coord to the centre
# it's a dead env.
is_dead_end = nbits == 1
if not cell_valid:
self.gl.scatter(*center_xy, color="r", s=30)
for orientation in range(4): # ori is where we're heading
from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE
from_xy = coords[from_ori]
moves = env.rail.get_transitions(row, col, orientation)
for to_ori in range(4):
to_xy = coords[to_ori]
rotation = (to_ori - from_ori) % 4
if (moves[to_ori]): # if we have this transition
self.draw_transition(
array([from_xy, to_xy]), center_xy,
rotation, dead_end=is_dead_end, curves=curves and not is_dead_end, spacing=spacing,
color=rail_color)
def render_env(self,
show=False, # whether to call matplotlib show() or equivalent after completion
agents=True, # whether to include agents
show_observations=True, # whether to include observations
show_predictions=False, # whether to include predictions
frames=False, # frame counter to show (intended since invocation)
episode=None, # int episode number to show
step=None, # int step number to show in image
selected_agent=None): # indicate which agent is "selected" in the editor
""" Draw the environment using the GraphicsLayer this RenderTool was created with.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
if type(self.gl) is PILSVG:
self.render_env_svg(show=show,
show_observations=show_observations,
show_predictions=show_predictions,
selected_agent=selected_agent
)
else:
self.render_env_pil(show=show,
agents=agents,
show_observations=show_observations,
show_predictions=show_predictions,
frames=frames,
episode=episode,
step=step,
selected_agent=selected_agent
)
def _draw_square(self, center, size, color, opacity=255, layer=0):
x0 = center[0] - size / 2
x1 = center[0] + size / 2
y0 = center[1] - size / 2
y1 = center[1] + size / 2
self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
def get_image(self):
return self.gl.get_image()
def render_env_pil(self,
show=False, # whether to call matplotlib show() or equivalent after completion
# use false when calling from Jupyter. (and matplotlib no longer supported!)
agents=True, # whether to include agents
show_observations=True, # whether to include observations
show_predictions=False, # whether to include predictions
frames=False, # frame counter to show (intended since invocation)
episode=None, # int episode number to show
step=None, # int step number to show in image
selected_agent=None # indicate which agent is "selected" in the editor
):
if type(self.gl) is PILGL:
self.gl.begin_frame()
env = self.env
self.render_rail()
# Draw each agent + its orientation + its target
if agents:
self.plot_agents(targets=True, selected_agent=selected_agent)
if show_observations:
self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
if show_predictions and len(env.dev_pred_dict) > 0:
self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
# Draw some textual information like fps
text_y = [-0.3, -0.6, -0.9]
if frames:
self.gl.text(0.1, text_y[2], "Frame:{:}".format(self.frame_nr))
self.frame_nr += 1
if episode is not None:
self.gl.text(0.1, text_y[1], "Ep:{}".format(episode))
if step is not None:
self.gl.text(0.1, text_y[0], "Step:{}".format(step))
time_now = time.time()
self.gl.text(2, text_y[2], "elapsed:{:.2f}s".format(time_now - self.start_time))
self.times_list.append(time_now)
if len(self.times_list) > 20:
self.times_list.popleft()
if len(self.times_list) > 1:
rFps = (len(self.times_list) - 1) / (self.times_list[-1] - self.times_list[0])
self.gl.text(2, text_y[1], "fps:{:.2f}".format(rFps))
self.gl.prettify2(env.width, env.height, self.pix_per_cell)
# TODO: for MPL, we don't want to call clf (called by endframe)
# if not show:
if show and type(self.gl) is PILGL:
self.gl.show()
self.gl.pause(0.00001)
return
def render_env_svg(
self, show=False, show_observations=True, show_predictions=False, selected_agent=None
):
"""
Renders the environment with SVG support (nice image)
"""
env = self.env
self.gl.begin_frame()
if self.new_rail:
self.new_rail = False
self.gl.clear_rails()
# store the targets
targets = {}
selected = {}
for agent_idx, agent in enumerate(self.env.agents_static):
if agent is None:
continue
targets[tuple(agent.target)] = agent_idx
selected[tuple(agent.target)] = (agent_idx == selected_agent)
# Draw each cell independently
for r in range(env.height):
for c in range(env.width):
transitions = env.rail.grid[r, c]
if (r, c) in targets:
target = targets[(r, c)]
is_selected = selected[(r, c)]
else:
target = None
is_selected = False
self.gl.set_rail_at(r, c, transitions, target=target, is_selected=is_selected,
rail_grid=env.rail.grid, show_debug=self.show_debug)
self.gl.build_background_map(targets)
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \
self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: # noqa: E125
if agent.old_position is not None:
position = agent.old_position
direction = agent.direction
old_direction = agent.old_direction
else:
position = agent.position
direction = agent.direction
old_direction = agent.direction
# set_agent_at uses the agent index for the color
if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
self.gl.set_agent_at(agent_idx, *position, old_direction, direction,
selected_agent == agent_idx, show_debug=self.show_debug)
else:
position = agent.position
direction = agent.direction
for possible_directions in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
isValid = env.rail.get_transition((*agent.position, agent.direction), possible_directions)
if isValid:
direction = possible_directions
# set_agent_at uses the agent index for the color
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
selected_agent == agent_idx, show_debug=self.show_debug)
# set_agent_at uses the agent index for the color
if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, selected_agent == agent_idx)
if show_observations:
self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
if show_predictions:
self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
if show:
self.gl.show()
for i in range(3):
self.gl.process_events()
self.frame_nr += 1
return
def close_window(self):
self.gl.close_window()