Skip to content
Snippets Groups Projects
Commit 1a7d49ff authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '29-jw-test-play-render' into 'master'

Resolve "test for agent player and rendering to a window"

Closes #29

See merge request flatland/flatland!14
parents eae0c40f 281b7b4f
No related branches found
No related tags found
No related merge requests found
......@@ -54,13 +54,14 @@ lint: ## check style with flake8
flake8 flatland tests examples
test: ## run tests quickly with the default Python
echo "$$DISPLAY"
py.test
test-all: ## run tests on every Python version with tox
tox
coverage: ## check code coverage quickly with the default Python
coverage run --source flatland -m pytest
xvfb-run -a coverage run --source flatland -m pytest
coverage report -m
coverage html
$(BROWSER) htmlcov/index.html
......
# import torch
import random
import time
# from flatland.baselines.dueling_double_dqn import Agent
from collections import deque
import numpy as np
import torch
from flatland.baselines.dueling_double_dqn import Agent
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
......@@ -28,10 +28,12 @@ class Player(object):
self.scores = []
self.dones_list = []
self.action_prob = [0] * 4
self.agent = Agent(self.state_size, self.action_size, "FC", 0)
# Removing refs to a real agent for now.
# self.agent = Agent(self.state_size, self.action_size, "FC", 0)
# self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
self.agent.qnetwork_local.load_state_dict(torch.load(
'../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
# self.agent.qnetwork_local.load_state_dict(torch.load(
# '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
self.iFrame = 0
self.tStart = time.time()
......@@ -49,12 +51,21 @@ class Player(object):
self.score = 0
self.env_done = 0
def reset(self):
self.obs = self.env.reset()
return self.obs
def step(self):
env = self.env
# Pass the (stored) observation to the agent network and retrieve the action
for handle in env.get_agent_handles():
action = self.agent.act(np.array(self.obs[handle]), eps=self.eps)
# Real Agent
# action = self.agent.act(np.array(self.obs[handle]), eps=self.eps)
# Random actions
action = random.randint(0, 3)
# Numpy version uses single random sequence
# action = np.random.randint(0, 4, size=1)
self.action_prob[action] += 1
self.action_dict.update({handle: action})
......@@ -67,11 +78,12 @@ class Player(object):
next_obs[handle] = np.clip(np.array(next_obs[handle]) / norm, -1, 1)
# Update replay buffer and train agent
for handle in self.env.get_agent_handles():
self.agent.step(self.obs[handle], self.action_dict[handle],
all_rewards[handle], next_obs[handle], done[handle],
train=False)
self.score += all_rewards[handle]
if False:
for handle in self.env.get_agent_handles():
self.agent.step(self.obs[handle], self.action_dict[handle],
all_rewards[handle], next_obs[handle], done[handle],
train=False)
self.score += all_rewards[handle]
self.iFrame += 1
......@@ -94,7 +106,50 @@ def max_lt(seq, val):
return None
def main(render=True, delay=0.0):
def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"):
random.seed(1)
np.random.seed(1)
# Example generate a random rail
env = RailEnv(width=15, height=15,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
number_of_agents=5)
if render:
# env_renderer = RenderTool(env, gl="QTSVG")
env_renderer = RenderTool(env, gl=sGL)
oPlayer = Player(env)
for trials in range(1, n_trials + 1):
# Reset environment
oPlayer.reset()
env_renderer.set_new_rail()
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
# score = 0
# env_done = 0
# Run episode
for step in range(n_steps):
oPlayer.step()
if render:
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step,
action_dict=oPlayer.action_dict)
# time.sleep(10)
if delay > 0:
time.sleep(delay)
def main_old(render=True, delay=0.0):
''' DEPRECATED main which drives agent directly
Please use the new main() which creates a Player object which is also used by the Editor.
Please fix any bugs in main() and Player rather than here.
Will delete this one shortly.
'''
random.seed(1)
np.random.seed(1)
......@@ -107,8 +162,6 @@ def main(render=True, delay=0.0):
env_renderer = RenderTool(env, gl="QTSVG")
# env_renderer = RenderTool(env, gl="QT")
state_size = 105
action_size = 4
n_trials = 9999
eps = 1.
eps_end = 0.005
......@@ -119,8 +172,11 @@ def main(render=True, delay=0.0):
scores = []
dones_list = []
action_prob = [0] * 4
agent = Agent(state_size, action_size, "FC", 0)
# Real Agent
# state_size = 105
# action_size = 4
# agent = Agent(state_size, action_size, "FC", 0)
# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
def max_lt(seq, val):
......@@ -161,7 +217,7 @@ def main(render=True, delay=0.0):
# print(step)
# Action
for a in range(env.get_num_agents()):
action = agent.act(np.array(obs[a]), eps=eps)
action = random.randint(0, 3) # agent.act(np.array(obs[a]), eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
......@@ -174,13 +230,16 @@ def main(render=True, delay=0.0):
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
norm = max(1, max_lt(next_obs[a], np.inf))
next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
score += all_rewards[a]
# only needed for "real" agent
# for a in range(env.get_num_agents()):
# agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
# score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
......@@ -212,8 +271,8 @@ def main(render=True, delay=0.0):
np.mean(scores_window),
100 * np.mean(done_window),
eps, rFps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(),
'../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
# torch.save(agent.qnetwork_local.state_dict(),
# '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
action_prob = [1] * 4
......
import time
import tkinter as tk
from PIL import ImageTk, Image
from examples.play_model import Player
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
def tkmain(n_trials=2):
# This creates the main window of an application
window = tk.Tk()
window.title("Join")
window.configure(background='grey')
# Example generate a random rail
env = RailEnv(width=15, height=15,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
number_of_agents=5)
env_renderer = RenderTool(env, gl="PIL")
oPlayer = Player(env)
n_trials = 1
n_steps = 20
delay = 0
for trials in range(1, n_trials + 1):
# Reset environment8
oPlayer.reset()
env_renderer.set_new_rail()
first = True
for step in range(n_steps):
oPlayer.step()
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step,
action_dict=oPlayer.action_dict)
img = env_renderer.getImage()
img = Image.fromarray(img)
tkimg = ImageTk.PhotoImage(img)
if first:
panel = tk.Label(window, image=tkimg)
panel.pack(side="bottom", fill="both", expand="yes")
else:
# update the image in situ
panel.configure(image=tkimg)
panel.image = tkimg
window.update()
if delay > 0:
time.sleep(delay)
first = False
if __name__ == "__main__":
tkmain()
......@@ -51,7 +51,7 @@ class GraphicsLayer(object):
elif type(color) is tuple:
if type(color[0]) is not int:
gcolor = array(color)
color = tuple((gcolor[:3] * 255).astype(int))
color = tuple((gcolor[:4] * 255).astype(int))
else:
color = self.tColGrid
......
......@@ -18,28 +18,32 @@ class PILGL(GraphicsLayer):
# Total grid size at native scale
self.widthPx = self.width * self.nPixCell + self.linewidth
self.heightPx = self.height * self.nPixCell + self.linewidth
self.beginFrame()
self.layers = []
self.draws = []
self.tColBg = (255, 255, 255) # white background
# self.tColBg = (220, 120, 40) # background color
self.tColRail = (0, 0, 0) # black rails
self.tColGrid = (230,) * 3 # light grey for grid
def plot(self, gX, gY, color=None, linewidth=3, **kwargs):
color = self.adaptColor(color)
self.beginFrame()
# print(gX, gY)
def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs):
color = self.adaptColor(color)
if len(color) == 3:
color += (opacity,)
elif len(color) == 4:
color = color[:3] + (opacity,)
gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell
gPoints = list(gPoints.ravel())
# print(gPoints, color)
self.draw.line(gPoints, fill=color, width=self.linewidth)
self.draws[layer].line(gPoints, fill=color, width=self.linewidth)
def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs):
def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs):
color = self.adaptColor(color)
r = np.sqrt(s)
gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
for x, y in gPoints:
self.draw.rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
def text(self, *args, **kwargs):
pass
......@@ -51,8 +55,8 @@ class PILGL(GraphicsLayer):
pass
def beginFrame(self):
self.img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, 255))
self.draw = ImageDraw.Draw(self.img)
self.create_layer(0)
self.create_layer(1)
def show(self, block=False):
pass
......@@ -62,5 +66,35 @@ class PILGL(GraphicsLayer):
pass
# plt.pause(seconds)
def alpha_composite_layers(self):
img = self.layers[0]
for img2 in self.layers[1:]:
img = Image.alpha_composite(img, img2)
return img
def getImage(self):
return array(self.img)
""" return a blended / alpha composited image composed of all the layers,
with layer 0 at the "back".
"""
img = self.alpha_composite_layers()
return array(img)
def create_image(self, opacity=255):
img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity))
return img
def create_layer(self, iLayer=0):
if len(self.layers) <= iLayer:
for i in range(len(self.layers), iLayer+1):
if i == 0:
opacity = 255 # "bottom" layer is opaque (for rails)
else:
opacity = 0 # subsequent layers are transparent
img = self.create_image(opacity)
self.layers.append(img)
self.draws.append(ImageDraw.Draw(img))
else:
opacity = 0 if iLayer > 0 else 255
self.layers[iLayer] = img = self.create_image(opacity)
self.draws[iLayer] = ImageDraw.Draw(img)
......@@ -15,12 +15,14 @@ from flatland.utils.graphics_layer import GraphicsLayer
class MPLGL(GraphicsLayer):
def __init__(self, width, height):
def __init__(self, width, height, show=False):
self.width = width
self.height = height
self.yxBase = array([6, 21]) # pixel offset
self.nPixCell = 700 / width
self.img = None
if show:
plt.figure(figsize=(10, 10))
def plot(self, *args, **kwargs):
plt.plot(*args, **kwargs)
......@@ -70,6 +72,7 @@ class MPLGL(GraphicsLayer):
def beginFrame(self):
self.img = None
plt.figure(figsize=(10, 10))
plt.clf()
pass
def endFrame(self):
......@@ -115,7 +118,7 @@ class RenderTool(object):
gTheta = np.linspace(0, np.pi / 2, 5)
gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1]
def __init__(self, env, gl="MPL"):
def __init__(self, env, gl="MPL", show=False):
self.env = env
self.iFrame = 0
self.time1 = time.time()
......@@ -123,7 +126,7 @@ class RenderTool(object):
# self.gl = MPLGL()
if gl == "MPL":
self.gl = MPLGL(env.width, env.height)
self.gl = MPLGL(env.width, env.height, show=show)
elif gl == "QT":
self.gl = QTGL(env.width, env.height)
elif gl == "PIL":
......@@ -219,17 +222,19 @@ class RenderTool(object):
if static:
color = self.gl.adaptColor(color, lighten=True)
color = color
# print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location
self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100) # agent location
xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient.
self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6)
if selected:
self._draw_square(xyPos, 1, color)
if target is not None:
rcTarget = array(target)
xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
self._draw_square(xyTarget, 1 / 3, color)
self._draw_square(xyTarget, 1 / 3, color, layer=1)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
"""
......@@ -397,6 +402,13 @@ class RenderTool(object):
visit = visit.prev
xyPrev = xy
def drawTrans(self, oFrom, oTo, sColor="gray"):
self.gl.plot(
[oFrom[0], oTo[0]], # x
[oFrom[1], oTo[1]], # y
color=sColor
)
def drawTrans2(
self,
xyLine, xyCentre,
......@@ -474,8 +486,8 @@ class RenderTool(object):
def renderObs(self, agent_handles, observation_dict):
"""
Render the extent of the observation of each agent. All cells that appear in the agent obsrevation will be
highlighted.
Render the extent of the observation of each agent. All cells that appear in the agent
observation will be highlighted.
:param agent_handles: List of agent indices to adapt color and get correct observation
:param observation_dict: dictionary containing sets of cells of the agent observation
......@@ -489,47 +501,13 @@ class RenderTool(object):
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
self._draw_square(cell_coord_trans, 1 / 3, color)
self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100)
def renderEnv(
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False, iEpisode=None, iStep=None,
iSelectedAgent=None, action_dict=None):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
Call pyplot.show() if show==True.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
if not self.gl.is_raster():
self.renderEnv2(show, curves, spacing,
arrows, agents, sRailColor,
frames, iEpisode, iStep,
iSelectedAgent, action_dict)
return
# cell_size is a bit pointless with matplotlib - it does not relate to pixels,
# so for now I've changed it to 1 (from 10)
cell_size = 1
self.gl.beginFrame()
# self.gl.clf()
# if oFigure is None:
# oFigure = self.gl.figure()
def drawTrans(oFrom, oTo, sColor="gray"):
self.gl.plot(
[oFrom[0], oTo[0]], # x
[oFrom[1], oTo[1]], # y
color=sColor
)
def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
cell_size = 1 # TODO: remove cell_size
env = self.env
# t1 = time.time()
# Draw cells grid
grid_color = [0.95, 0.95, 0.95]
for r in range(env.height + 1):
......@@ -613,7 +591,7 @@ class RenderTool(object):
rotation, spacing=spacing, bArrow=arrows,
sColor=sRailColor)
else:
drawTrans(from_xy, to_xy, sRailColor)
self.drawTrans(self, from_xy, to_xy, sRailColor)
if False:
print(
......@@ -626,6 +604,42 @@ class RenderTool(object):
"rot:", rotation,
)
def renderEnv(
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False,
iEpisode=None, iStep=None,
iSelectedAgent=None, action_dict=None):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
Call pyplot.show() if show==True.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
if not self.gl.is_raster():
self.renderEnv2(show, curves, spacing,
arrows, agents, sRailColor,
frames, iEpisode, iStep,
iSelectedAgent, action_dict)
return
if type(self.gl) in (QTGL, PILGL):
self.gl.beginFrame()
if type(self.gl) is MPLGL:
# self.gl.clf()
self.gl.beginFrame()
pass
# self.gl.clf()
# if oFigure is None:
# oFigure = self.gl.figure()
env = self.env
self.renderRail()
# Draw each agent + its orientation + its target
if agents:
self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
......@@ -657,23 +671,26 @@ class RenderTool(object):
# TODO: for MPL, we don't want to call clf (called by endframe)
# for QT, we need to call endFrame()
# if not show:
self.gl.endFrame()
if type(self.gl) is QTGL:
self.gl.endFrame()
if show:
self.gl.show(block=False)
# t2 = time.time()
# print(t2 - t1, "seconds")
if type(self.gl) is MPLGL:
if show:
self.gl.show(block=False)
# self.gl.endFrame()
if show:
self.gl.show(block=False)
self.gl.pause(0.00001)
self.gl.pause(0.00001)
return
def _draw_square(self, center, size, color):
def _draw_square(self, center, size, color, opacity=255, layer=0):
x0 = center[0] - size / 2
x1 = center[0] + size / 2
y0 = center[1] - size / 2
y1 = center[1] + size / 2
self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color)
self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
def getImage(self):
return self.gl.getImage()
......
No preview for this file type
# from examples.play_model import main
from examples.tkplay import tkmain
def test_main():
tkmain(n_trials=2)
......@@ -46,8 +46,8 @@ def test_render_env(save_new_images=False):
)
sfTestEnv = "env-data/tests/test1.npy"
oEnv.rail.load_transition_map(sfTestEnv)
oRT = rt.RenderTool(oEnv)
oRT.renderEnv()
oRT = rt.RenderTool(oEnv, gl="PIL", show=False)
oRT.renderEnv(show=False)
checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
......
[tox]
envlist = py36, py37, flake8, docs, coverage
envlist = py36, py37, flake8, docs, coverage, xvfb-run, sh
[travis]
python =
......@@ -8,7 +8,7 @@ python =
[flake8]
max-line-length = 120
ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505
ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505
[testenv:flake8]
basepython = python
......@@ -23,12 +23,15 @@ commands = make docs
[testenv:coverage]
basepython = python
whitelist_externals = make
commands =
commands =
pip install -U pip
pip install -r requirements_dev.txt
make coverage
[testenv]
whitelist_externals = xvfb-run
sh
pip
setenv =
PYTHONPATH = {toxinidir}
deps =
......@@ -39,6 +42,7 @@ deps =
commands =
pip install -U pip
pip install -r requirements_dev.txt
py.test --basetemp={envtmpdir}
sh -c 'echo DISPLAY: $DISPLAY'
xvfb-run -a py.test --basetemp={envtmpdir}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment