Skip to content
Snippets Groups Projects
Commit a415c73c authored by hagrid67's avatar hagrid67
Browse files

added transparency / opacity with PIL.

moved test_player.py to using PIL and tkinter instead of QT
parent 32d2635f
No related branches found
No related tags found
No related merge requests found
......@@ -100,7 +100,7 @@ def max_lt(seq, val):
return None
def main(render=True, delay=0.0, n_trials=3, n_steps=50):
def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"):
random.seed(1)
np.random.seed(1)
......@@ -111,7 +111,7 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50):
if render:
# env_renderer = RenderTool(env, gl="QTSVG")
env_renderer = RenderTool(env, gl="QT")
env_renderer = RenderTool(env, gl=sGL)
oPlayer = Player(env)
......
import tkinter as tk
from PIL import ImageTk, Image
from examples.play_model import Player
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import complex_rail_generator
from flatland.utils.rendertools import RenderTool
import time
def tkmain(n_trials=2):
# This creates the main window of an application
window = tk.Tk()
window.title("Join")
window.configure(background='grey')
# Example generate a random rail
env = RailEnv(width=15, height=15,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
number_of_agents=5)
env_renderer = RenderTool(env, gl="PIL")
oPlayer = Player(env)
n_trials = 1
n_steps = 20
delay = 0
for trials in range(1, n_trials + 1):
# Reset environment8
oPlayer.reset()
env_renderer.set_new_rail()
first = True
for step in range(n_steps):
oPlayer.step()
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step,
action_dict=oPlayer.action_dict)
img = env_renderer.getImage()
img = Image.fromarray(img)
tkimg = ImageTk.PhotoImage(img)
if first:
panel = tk.Label(window, image = tkimg)
panel.pack(side = "bottom", fill = "both", expand = "yes")
else:
# update the image in situ
panel.configure(image=tkimg)
panel.image = tkimg
window.update()
if delay > 0:
time.sleep(delay)
first = False
if __name__ == "__main__":
tkmain()
\ No newline at end of file
......@@ -51,7 +51,7 @@ class GraphicsLayer(object):
elif type(color) is tuple:
if type(color[0]) is not int:
gcolor = array(color)
color = tuple((gcolor[:3] * 255).astype(int))
color = tuple((gcolor[:4] * 255).astype(int))
else:
color = self.tColGrid
......
......@@ -18,28 +18,32 @@ class PILGL(GraphicsLayer):
# Total grid size at native scale
self.widthPx = self.width * self.nPixCell + self.linewidth
self.heightPx = self.height * self.nPixCell + self.linewidth
self.beginFrame()
self.layers = []
self.draws = []
self.tColBg = (255, 255, 255) # white background
# self.tColBg = (220, 120, 40) # background color
self.tColRail = (0, 0, 0) # black rails
self.tColGrid = (230,) * 3 # light grey for grid
def plot(self, gX, gY, color=None, linewidth=3, **kwargs):
color = self.adaptColor(color)
self.beginFrame()
# print(gX, gY)
def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
color = self.adaptColor(color)
if len(color) == 3:
color += (opacity,)
elif len(color) == 4:
color = color[:3] + (opacity,)
gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell
gPoints = list(gPoints.ravel())
# print(gPoints, color)
self.draw.line(gPoints, fill=color, width=self.linewidth)
self.draws[layer].line(gPoints, fill=color, width=self.linewidth)
def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs):
def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs):
color = self.adaptColor(color)
r = np.sqrt(s)
gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
for x, y in gPoints:
self.draw.rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
def text(self, *args, **kwargs):
pass
......@@ -51,8 +55,8 @@ class PILGL(GraphicsLayer):
pass
def beginFrame(self):
self.img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, 255))
self.draw = ImageDraw.Draw(self.img)
self.create_layer(0)
self.create_layer(1)
def show(self, block=False):
pass
......@@ -62,5 +66,35 @@ class PILGL(GraphicsLayer):
pass
# plt.pause(seconds)
def alpha_composite_layers(self):
img = self.layers[0]
for img2 in self.layers[1:]:
img = Image.alpha_composite(img, img2)
return img
def getImage(self):
return array(self.img)
""" return a blended / alpha composited image composed of all the layers,
with layer 0 at the "back".
"""
img = self.alpha_composite_layers()
return array(img)
def create_image(self, opacity=255):
img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity))
return img
def create_layer(self, iLayer=0):
if len(self.layers) <= iLayer:
for i in range(len(self.layers), iLayer+1):
if i==0:
opacity = 255 # "bottom" layer is opaque (for rails)
else:
opacity = 0 # subsequent layers are transparent
img = self.create_image(opacity)
self.layers.append(img)
self.draws.append(ImageDraw.Draw(img))
else:
opacity = 0 if iLayer > 0 else 255
self.layers[iLayer] = img = self.create_image(opacity)
self.draws[iLayer] = ImageDraw.Draw(img)
......@@ -15,12 +15,14 @@ from flatland.utils.graphics_layer import GraphicsLayer
class MPLGL(GraphicsLayer):
def __init__(self, width, height):
def __init__(self, width, height, show=False):
self.width = width
self.height = height
self.yxBase = array([6, 21]) # pixel offset
self.nPixCell = 700 / width
self.img = None
if show:
plt.figure(figsize=(10, 10))
def plot(self, *args, **kwargs):
plt.plot(*args, **kwargs)
......@@ -70,6 +72,7 @@ class MPLGL(GraphicsLayer):
def beginFrame(self):
self.img = None
plt.figure(figsize=(10, 10))
plt.clf()
pass
def endFrame(self):
......@@ -115,7 +118,7 @@ class RenderTool(object):
gTheta = np.linspace(0, np.pi / 2, 5)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def __init__(self, env, gl="MPL"):
def __init__(self, env, gl="MPL", show=False):
self.env = env
self.iFrame = 0
self.time1 = time.time()
......@@ -123,7 +126,7 @@ class RenderTool(object):
# self.gl = MPLGL()
if gl == "MPL":
self.gl = MPLGL(env.width, env.height)
self.gl = MPLGL(env.width, env.height, show=show)
elif gl == "QT":
self.gl = QTGL(env.width, env.height)
elif gl == "PIL":
......@@ -219,17 +222,19 @@ class RenderTool(object):
if static:
color = self.gl.adaptColor(color, lighten=True)
color = color
# print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location
self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100) # agent location
xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient.
self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6)
if selected:
self._draw_square(xyPos, 1, color)
if target is not None:
rcTarget = array(target)
xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
self._draw_square(xyTarget, 1 / 3, color)
self._draw_square(xyTarget, 1 / 3, color, layer=1)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
"""
......@@ -397,6 +402,13 @@ class RenderTool(object):
visit = visit.prev
xyPrev = xy
def drawTrans(self, oFrom, oTo, sColor="gray"):
self.gl.plot(
[oFrom[0], oTo[0]], # x
[oFrom[1], oTo[1]], # y
color=sColor
)
def drawTrans2(
self,
xyLine, xyCentre,
......@@ -489,47 +501,14 @@ class RenderTool(object):
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
self._draw_square(cell_coord_trans, 1 / 3, color)
def renderEnv(
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False, iEpisode=None, iStep=None,
iSelectedAgent=None, action_dict=None):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
Call pyplot.show() if show==True.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
if not self.gl.is_raster():
self.renderEnv2(show, curves, spacing,
arrows, agents, sRailColor,
frames, iEpisode, iStep,
iSelectedAgent, action_dict)
return
# cell_size is a bit pointless with matplotlib - it does not relate to pixels,
# so for now I've changed it to 1 (from 10)
cell_size = 1
self.gl.beginFrame()
self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100)
# self.gl.clf()
# if oFigure is None:
# oFigure = self.gl.figure()
def drawTrans(oFrom, oTo, sColor="gray"):
self.gl.plot(
[oFrom[0], oTo[0]], # x
[oFrom[1], oTo[1]], # y
color=sColor
)
def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
cell_size = 1 # TODO: remove cell_size
env = self.env
# t1 = time.time()
# Draw cells grid
grid_color = [0.95, 0.95, 0.95]
for r in range(env.height + 1):
......@@ -613,7 +592,7 @@ class RenderTool(object):
rotation, spacing=spacing, bArrow=arrows,
sColor=sRailColor)
else:
drawTrans(from_xy, to_xy, sRailColor)
self.drawTrans(self, from_xy, to_xy, sRailColor)
if False:
print(
......@@ -626,6 +605,54 @@ class RenderTool(object):
"rot:", rotation,
)
def renderEnv(
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False,
iEpisode=None, iStep=None,
iSelectedAgent=None, action_dict=None):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
Call pyplot.show() if show==True.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
if not self.gl.is_raster():
self.renderEnv2(show, curves, spacing,
arrows, agents, sRailColor,
frames, iEpisode, iStep,
iSelectedAgent, action_dict)
return
# cell_size is a bit pointless with matplotlib - it does not relate to pixels,
# so for now I've changed it to 1 (from 10)
cell_size = 1
if type(self.gl) in (QTGL, PILGL):
self.gl.beginFrame()
if type(self.gl) is MPLGL:
#self.gl.clf()
# plt.clf()
self.gl.beginFrame()
pass
# self.gl.clf()
# if oFigure is None:
# oFigure = self.gl.figure()
env = self.env
# t1 = time.time()
self.renderRail()
# Draw each agent + its orientation + its target
if agents:
self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
......@@ -657,23 +684,26 @@ class RenderTool(object):
# TODO: for MPL, we don't want to call clf (called by endframe)
# for QT, we need to call endFrame()
# if not show:
self.gl.endFrame()
if type(self.gl) is QTGL:
self.gl.endFrame()
if show:
self.gl.show(block=False)
# t2 = time.time()
# print(t2 - t1, "seconds")
if type(self.gl) is MPLGL:
if show:
self.gl.show(block=False)
#self.gl.endFrame()
if show:
self.gl.show(block=False)
self.gl.pause(0.00001)
self.gl.pause(0.00001)
return
def _draw_square(self, center, size, color):
def _draw_square(self, center, size, color, opacity=255, layer=0):
x0 = center[0] - size / 2
x1 = center[0] + size / 2
y0 = center[1] - size / 2
y1 = center[1] + size / 2
self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color)
self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
def getImage(self):
return self.gl.getImage()
......
No preview for this file type
from examples.play_model import main
# from examples.play_model import main
from examples.tkplay import tkmain
def test_main():
main(n_trials=2)
tkmain(n_trials=2)
......@@ -46,8 +46,8 @@ def test_render_env(save_new_images=False):
)
sfTestEnv = "env-data/tests/test1.npy"
oEnv.rail.load_transition_map(sfTestEnv)
oRT = rt.RenderTool(oEnv)
oRT.renderEnv()
oRT = rt.RenderTool(oEnv, gl="PIL", show=False)
oRT.renderEnv(show=False)
checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment