diff --git a/Makefile b/Makefile index e9c25bbdfed174ea0ebc4570aed4949c53b31c48..69ad1b42fd51ef9ec9420f5473dc8acef5468572 100644 --- a/Makefile +++ b/Makefile @@ -54,13 +54,14 @@ lint: ## check style with flake8 flake8 flatland tests examples test: ## run tests quickly with the default Python + echo "$$DISPLAY" py.test test-all: ## run tests on every Python version with tox tox coverage: ## check code coverage quickly with the default Python - coverage run --source flatland -m pytest + xvfb-run -a coverage run --source flatland -m pytest coverage report -m coverage html $(BROWSER) htmlcov/index.html diff --git a/examples/play_model.py b/examples/play_model.py index 174568177a4a886cfe38e53125d0f73f2dae52de..34c6aadfeefd44771fd335e2957e1fbd0b2f740f 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -1,11 +1,11 @@ +# import torch import random import time +# from flatland.baselines.dueling_double_dqn import Agent from collections import deque import numpy as np -import torch -from flatland.baselines.dueling_double_dqn import Agent from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool @@ -28,10 +28,12 @@ class Player(object): self.scores = [] self.dones_list = [] self.action_prob = [0] * 4 - self.agent = Agent(self.state_size, self.action_size, "FC", 0) + + # Removing refs to a real agent for now. + # self.agent = Agent(self.state_size, self.action_size, "FC", 0) # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) - self.agent.qnetwork_local.load_state_dict(torch.load( - '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) + # self.agent.qnetwork_local.load_state_dict(torch.load( + # '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) self.iFrame = 0 self.tStart = time.time() @@ -49,12 +51,21 @@ class Player(object): self.score = 0 self.env_done = 0 + def reset(self): + self.obs = self.env.reset() + return self.obs + def step(self): env = self.env # Pass the (stored) observation to the agent network and retrieve the action for handle in env.get_agent_handles(): - action = self.agent.act(np.array(self.obs[handle]), eps=self.eps) + # Real Agent + # action = self.agent.act(np.array(self.obs[handle]), eps=self.eps) + # Random actions + action = random.randint(0, 3) + # Numpy version uses single random sequence + # action = np.random.randint(0, 4, size=1) self.action_prob[action] += 1 self.action_dict.update({handle: action}) @@ -67,11 +78,12 @@ class Player(object): next_obs[handle] = np.clip(np.array(next_obs[handle]) / norm, -1, 1) # Update replay buffer and train agent - for handle in self.env.get_agent_handles(): - self.agent.step(self.obs[handle], self.action_dict[handle], - all_rewards[handle], next_obs[handle], done[handle], - train=False) - self.score += all_rewards[handle] + if False: + for handle in self.env.get_agent_handles(): + self.agent.step(self.obs[handle], self.action_dict[handle], + all_rewards[handle], next_obs[handle], done[handle], + train=False) + self.score += all_rewards[handle] self.iFrame += 1 @@ -94,7 +106,50 @@ def max_lt(seq, val): return None -def main(render=True, delay=0.0): +def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"): + random.seed(1) + np.random.seed(1) + + # Example generate a random rail + env = RailEnv(width=15, height=15, + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), + number_of_agents=5) + + if render: + # env_renderer = RenderTool(env, gl="QTSVG") + env_renderer = RenderTool(env, gl=sGL) + + oPlayer = Player(env) + + for trials in range(1, n_trials + 1): + + # Reset environment + oPlayer.reset() + env_renderer.set_new_rail() + + # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) + + # score = 0 + # env_done = 0 + + # Run episode + for step in range(n_steps): + oPlayer.step() + if render: + env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, + action_dict=oPlayer.action_dict) + # time.sleep(10) + if delay > 0: + time.sleep(delay) + + +def main_old(render=True, delay=0.0): + ''' DEPRECATED main which drives agent directly + Please use the new main() which creates a Player object which is also used by the Editor. + Please fix any bugs in main() and Player rather than here. + Will delete this one shortly. + ''' + random.seed(1) np.random.seed(1) @@ -107,8 +162,6 @@ def main(render=True, delay=0.0): env_renderer = RenderTool(env, gl="QTSVG") # env_renderer = RenderTool(env, gl="QT") - state_size = 105 - action_size = 4 n_trials = 9999 eps = 1. eps_end = 0.005 @@ -119,8 +172,11 @@ def main(render=True, delay=0.0): scores = [] dones_list = [] action_prob = [0] * 4 - agent = Agent(state_size, action_size, "FC", 0) + # Real Agent + # state_size = 105 + # action_size = 4 + # agent = Agent(state_size, action_size, "FC", 0) # agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) def max_lt(seq, val): @@ -161,7 +217,7 @@ def main(render=True, delay=0.0): # print(step) # Action for a in range(env.get_num_agents()): - action = agent.act(np.array(obs[a]), eps=eps) + action = random.randint(0, 3) # agent.act(np.array(obs[a]), eps=eps) action_prob[action] += 1 action_dict.update({a: action}) @@ -174,13 +230,16 @@ def main(render=True, delay=0.0): # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) + for a in range(env.get_num_agents()): norm = max(1, max_lt(next_obs[a], np.inf)) next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) + # Update replay buffer and train agent - for a in range(env.get_num_agents()): - agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) - score += all_rewards[a] + # only needed for "real" agent + # for a in range(env.get_num_agents()): + # agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) + # score += all_rewards[a] obs = next_obs.copy() if done['__all__']: @@ -212,8 +271,8 @@ def main(render=True, delay=0.0): np.mean(scores_window), 100 * np.mean(done_window), eps, rFps, action_prob / np.sum(action_prob))) - torch.save(agent.qnetwork_local.state_dict(), - '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') + # torch.save(agent.qnetwork_local.state_dict(), + # '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') action_prob = [1] * 4 diff --git a/examples/tkplay.py b/examples/tkplay.py new file mode 100644 index 0000000000000000000000000000000000000000..95842e3b430000169093d27c3c9de02ebe037de9 --- /dev/null +++ b/examples/tkplay.py @@ -0,0 +1,60 @@ +import time +import tkinter as tk + +from PIL import ImageTk, Image + +from examples.play_model import Player +from flatland.envs.generators import complex_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool + + +def tkmain(n_trials=2): + # This creates the main window of an application + window = tk.Tk() + window.title("Join") + window.configure(background='grey') + + # Example generate a random rail + env = RailEnv(width=15, height=15, + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), + number_of_agents=5) + + env_renderer = RenderTool(env, gl="PIL") + + oPlayer = Player(env) + n_trials = 1 + n_steps = 20 + delay = 0 + for trials in range(1, n_trials + 1): + + # Reset environment8 + oPlayer.reset() + env_renderer.set_new_rail() + + first = True + + for step in range(n_steps): + oPlayer.step() + env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, + action_dict=oPlayer.action_dict) + img = env_renderer.getImage() + img = Image.fromarray(img) + tkimg = ImageTk.PhotoImage(img) + + if first: + panel = tk.Label(window, image=tkimg) + panel.pack(side="bottom", fill="both", expand="yes") + else: + # update the image in situ + panel.configure(image=tkimg) + panel.image = tkimg + + window.update() + if delay > 0: + time.sleep(delay) + first = False + + +if __name__ == "__main__": + tkmain() diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index 4cfcc64bffb82f91a0f36822188db297bc1ed37e..40bf319da737c8bb49e6c228c98b70604734ddc4 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -51,7 +51,7 @@ class GraphicsLayer(object): elif type(color) is tuple: if type(color[0]) is not int: gcolor = array(color) - color = tuple((gcolor[:3] * 255).astype(int)) + color = tuple((gcolor[:4] * 255).astype(int)) else: color = self.tColGrid diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 41516fd94737556a4b8abbc7ccfce0fd503a3d6e..b66c8dc55f38c321d038306f933de66493a6e6b3 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -18,28 +18,32 @@ class PILGL(GraphicsLayer): # Total grid size at native scale self.widthPx = self.width * self.nPixCell + self.linewidth self.heightPx = self.height * self.nPixCell + self.linewidth - self.beginFrame() + self.layers = [] + self.draws = [] self.tColBg = (255, 255, 255) # white background # self.tColBg = (220, 120, 40) # background color self.tColRail = (0, 0, 0) # black rails self.tColGrid = (230,) * 3 # light grey for grid - def plot(self, gX, gY, color=None, linewidth=3, **kwargs): - color = self.adaptColor(color) + self.beginFrame() - # print(gX, gY) + def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): + color = self.adaptColor(color) + if len(color) == 3: + color += (opacity,) + elif len(color) == 4: + color = color[:3] + (opacity,) gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell gPoints = list(gPoints.ravel()) - # print(gPoints, color) - self.draw.line(gPoints, fill=color, width=self.linewidth) + self.draws[layer].line(gPoints, fill=color, width=self.linewidth) - def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs): + def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs): color = self.adaptColor(color) r = np.sqrt(s) gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell for x, y in gPoints: - self.draw.rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) + self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) def text(self, *args, **kwargs): pass @@ -51,8 +55,8 @@ class PILGL(GraphicsLayer): pass def beginFrame(self): - self.img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, 255)) - self.draw = ImageDraw.Draw(self.img) + self.create_layer(0) + self.create_layer(1) def show(self, block=False): pass @@ -62,5 +66,35 @@ class PILGL(GraphicsLayer): pass # plt.pause(seconds) + def alpha_composite_layers(self): + img = self.layers[0] + for img2 in self.layers[1:]: + img = Image.alpha_composite(img, img2) + return img + def getImage(self): - return array(self.img) + """ return a blended / alpha composited image composed of all the layers, + with layer 0 at the "back". + """ + img = self.alpha_composite_layers() + return array(img) + + def create_image(self, opacity=255): + img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, opacity)) + return img + + def create_layer(self, iLayer=0): + if len(self.layers) <= iLayer: + for i in range(len(self.layers), iLayer+1): + if i == 0: + opacity = 255 # "bottom" layer is opaque (for rails) + else: + opacity = 0 # subsequent layers are transparent + img = self.create_image(opacity) + self.layers.append(img) + self.draws.append(ImageDraw.Draw(img)) + else: + opacity = 0 if iLayer > 0 else 255 + self.layers[iLayer] = img = self.create_image(opacity) + self.draws[iLayer] = ImageDraw.Draw(img) + diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 34f3e9fa6857e86f4d99d211784d983a2e2a1e75..26ec39d1f426691e7d4fe5a8b4b6aec3bcc7b1fd 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -15,12 +15,14 @@ from flatland.utils.graphics_layer import GraphicsLayer class MPLGL(GraphicsLayer): - def __init__(self, width, height): + def __init__(self, width, height, show=False): self.width = width self.height = height self.yxBase = array([6, 21]) # pixel offset self.nPixCell = 700 / width self.img = None + if show: + plt.figure(figsize=(10, 10)) def plot(self, *args, **kwargs): plt.plot(*args, **kwargs) @@ -70,6 +72,7 @@ class MPLGL(GraphicsLayer): def beginFrame(self): self.img = None plt.figure(figsize=(10, 10)) + plt.clf() pass def endFrame(self): @@ -115,7 +118,7 @@ class RenderTool(object): gTheta = np.linspace(0, np.pi / 2, 5) gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] - def __init__(self, env, gl="MPL"): + def __init__(self, env, gl="MPL", show=False): self.env = env self.iFrame = 0 self.time1 = time.time() @@ -123,7 +126,7 @@ class RenderTool(object): # self.gl = MPLGL() if gl == "MPL": - self.gl = MPLGL(env.width, env.height) + self.gl = MPLGL(env.width, env.height, show=show) elif gl == "QT": self.gl = QTGL(env.width, env.height) elif gl == "PIL": @@ -219,17 +222,19 @@ class RenderTool(object): if static: color = self.gl.adaptColor(color, lighten=True) + color = color + # print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos) - self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location + self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100) # agent location xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient. - self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6) + self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6) if selected: self._draw_square(xyPos, 1, color) if target is not None: rcTarget = array(target) xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf - self._draw_square(xyTarget, 1 / 3, color) + self._draw_square(xyTarget, 1 / 3, color, layer=1) def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None): """ @@ -397,6 +402,13 @@ class RenderTool(object): visit = visit.prev xyPrev = xy + def drawTrans(self, oFrom, oTo, sColor="gray"): + self.gl.plot( + [oFrom[0], oTo[0]], # x + [oFrom[1], oTo[1]], # y + color=sColor + ) + def drawTrans2( self, xyLine, xyCentre, @@ -474,8 +486,8 @@ class RenderTool(object): def renderObs(self, agent_handles, observation_dict): """ - Render the extent of the observation of each agent. All cells that appear in the agent obsrevation will be - highlighted. + Render the extent of the observation of each agent. All cells that appear in the agent + observation will be highlighted. :param agent_handles: List of agent indices to adapt color and get correct observation :param observation_dict: dictionary containing sets of cells of the agent observation @@ -489,47 +501,13 @@ class RenderTool(object): for visited_cell in observation_dict[agent]: cell_coord = array(visited_cell[:2]) cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf - self._draw_square(cell_coord_trans, 1 / 3, color) + self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100) - def renderEnv( - self, show=False, curves=True, spacing=False, - arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False, iEpisode=None, iStep=None, - iSelectedAgent=None, action_dict=None): - """ - Draw the environment using matplotlib. - Draw into the figure if provided. - - Call pyplot.show() if show==True. - (Use show=False from a Jupyter notebook with %matplotlib inline) - """ - - if not self.gl.is_raster(): - self.renderEnv2(show, curves, spacing, - arrows, agents, sRailColor, - frames, iEpisode, iStep, - iSelectedAgent, action_dict) - return - - # cell_size is a bit pointless with matplotlib - it does not relate to pixels, - # so for now I've changed it to 1 (from 10) - cell_size = 1 - self.gl.beginFrame() - - # self.gl.clf() - # if oFigure is None: - # oFigure = self.gl.figure() - - def drawTrans(oFrom, oTo, sColor="gray"): - self.gl.plot( - [oFrom[0], oTo[0]], # x - [oFrom[1], oTo[1]], # y - color=sColor - ) + def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False): + cell_size = 1 # TODO: remove cell_size env = self.env - # t1 = time.time() - # Draw cells grid grid_color = [0.95, 0.95, 0.95] for r in range(env.height + 1): @@ -613,7 +591,7 @@ class RenderTool(object): rotation, spacing=spacing, bArrow=arrows, sColor=sRailColor) else: - drawTrans(from_xy, to_xy, sRailColor) + self.drawTrans(self, from_xy, to_xy, sRailColor) if False: print( @@ -626,6 +604,42 @@ class RenderTool(object): "rot:", rotation, ) + def renderEnv( + self, show=False, curves=True, spacing=False, + arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False, + iEpisode=None, iStep=None, + iSelectedAgent=None, action_dict=None): + """ + Draw the environment using matplotlib. + Draw into the figure if provided. + + Call pyplot.show() if show==True. + (Use show=False from a Jupyter notebook with %matplotlib inline) + """ + + if not self.gl.is_raster(): + self.renderEnv2(show, curves, spacing, + arrows, agents, sRailColor, + frames, iEpisode, iStep, + iSelectedAgent, action_dict) + return + + if type(self.gl) in (QTGL, PILGL): + self.gl.beginFrame() + + if type(self.gl) is MPLGL: + # self.gl.clf() + self.gl.beginFrame() + pass + + # self.gl.clf() + # if oFigure is None: + # oFigure = self.gl.figure() + + env = self.env + + self.renderRail() + # Draw each agent + its orientation + its target if agents: self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent) @@ -657,23 +671,26 @@ class RenderTool(object): # TODO: for MPL, we don't want to call clf (called by endframe) # for QT, we need to call endFrame() # if not show: - self.gl.endFrame() + if type(self.gl) is QTGL: + self.gl.endFrame() + if show: + self.gl.show(block=False) - # t2 = time.time() - # print(t2 - t1, "seconds") + if type(self.gl) is MPLGL: + if show: + self.gl.show(block=False) + # self.gl.endFrame() - if show: - self.gl.show(block=False) - self.gl.pause(0.00001) + self.gl.pause(0.00001) return - def _draw_square(self, center, size, color): + def _draw_square(self, center, size, color, opacity=255, layer=0): x0 = center[0] - size / 2 x1 = center[0] + size / 2 y0 = center[1] - size / 2 y1 = center[1] + size / 2 - self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color) + self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity) def getImage(self): return self.gl.getImage() diff --git a/images/basic-env.npz b/images/basic-env.npz index 356da5d70146b3b8081dd99c0fe5e6bd70646e53..8ffaf023e1116b0c92702212ddb04c71b82f0655 100644 Binary files a/images/basic-env.npz and b/images/basic-env.npz differ diff --git a/tests/test_player.py b/tests/test_player.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2745f2e372ca80cd2fb5cf9dcaa3db96fb910a --- /dev/null +++ b/tests/test_player.py @@ -0,0 +1,8 @@ + +# from examples.play_model import main +from examples.tkplay import tkmain + + +def test_main(): + tkmain(n_trials=2) + diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index c7841df54022d0c6ea24e209f6442342514153bc..8204a305328df746a772d034f3c763c848cceb93 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -46,8 +46,8 @@ def test_render_env(save_new_images=False): ) sfTestEnv = "env-data/tests/test1.npy" oEnv.rail.load_transition_map(sfTestEnv) - oRT = rt.RenderTool(oEnv) - oRT.renderEnv() + oRT = rt.RenderTool(oEnv, gl="PIL", show=False) + oRT.renderEnv(show=False) checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) diff --git a/tox.ini b/tox.ini index 71edb7b5fbe2652b00bf48fc63a35d523c791a7a..6dd011aadeb2e7ba802ff692278aa763fb665f10 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36, py37, flake8, docs, coverage +envlist = py36, py37, flake8, docs, coverage, xvfb-run, sh [travis] python = @@ -8,7 +8,7 @@ python = [flake8] max-line-length = 120 -ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505 +ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505 [testenv:flake8] basepython = python @@ -23,12 +23,15 @@ commands = make docs [testenv:coverage] basepython = python whitelist_externals = make -commands = +commands = pip install -U pip pip install -r requirements_dev.txt make coverage [testenv] +whitelist_externals = xvfb-run + sh + pip setenv = PYTHONPATH = {toxinidir} deps = @@ -39,6 +42,7 @@ deps = commands = pip install -U pip pip install -r requirements_dev.txt - py.test --basetemp={envtmpdir} + sh -c 'echo DISPLAY: $DISPLAY' + xvfb-run -a py.test --basetemp={envtmpdir}