From da35282adcd7fdb8dcc735f41d50c5cd935c9e30 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Wed, 8 May 2019 08:27:29 +0100 Subject: [PATCH] added Pillow to requirements Added PIL graphics - 5-10 times faster than MPL move adaptColor into GraphicsLayer change scatter size arg to s like MPL move clf, get_cmap up into GraphicsLayer added boolean "train" arg to ddqn step() to make running faster in player / editor add plotTreeObs - not finished. improved pixel scaling in editor (I hope) Made the agent background steps update a progress bar. --- examples/play_model.py | 10 +++- flatland/baselines/dueling_double_dqn.py | 5 +- flatland/utils/editor.py | 25 ++++---- flatland/utils/graphics_layer.py | 20 +++++++ flatland/utils/graphics_pil.py | 66 +++++++++++++++++++++ flatland/utils/render_qt.py | 26 +------- flatland/utils/rendertools.py | 75 +++++++++++++++++++----- requirements_dev.txt | 1 + 8 files changed, 175 insertions(+), 53 deletions(-) create mode 100644 flatland/utils/graphics_pil.py diff --git a/examples/play_model.py b/examples/play_model.py index e6e81c9..a61954e 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -28,7 +28,8 @@ class Player(object): self.dones_list = [] self.action_prob = [0]*4 self.agent = Agent(self.state_size, self.action_size, "FC", 0) - self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) + # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) + self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) self.iFrame = 0 self.tStart = time.time() @@ -48,12 +49,15 @@ class Player(object): def step(self): env = self.env + + # Pass the (stored) observation to the agent network and retrieve the action for a in range(env.number_of_agents): action = self.agent.act(np.array(self.obs[a]), eps=self.eps) self.action_prob[action] += 1 self.action_dict.update({a: action}) - # Environment step + # Environment step - pass the agent actions to the environment, + # retrieve the response - observations, rewards, dones next_obs, all_rewards, done, _ = self.env.step(self.action_dict) for a in range(env.number_of_agents): @@ -62,7 +66,7 @@ class Player(object): # Update replay buffer and train agent for a in range(self.env.number_of_agents): - self.agent.step(self.obs[a], self.action_dict[a], all_rewards[a], next_obs[a], done[a]) + self.agent.step(self.obs[a], self.action_dict[a], all_rewards[a], next_obs[a], done[a], train=False) self.score += all_rewards[a] self.iFrame += 1 diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py index ee75a61..66fe3a3 100644 --- a/flatland/baselines/dueling_double_dqn.py +++ b/flatland/baselines/dueling_double_dqn.py @@ -64,7 +64,7 @@ class Agent: if os.path.exists(filename + ".target"): self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) - def step(self, state, action, reward, next_state, done): + def step(self, state, action, reward, next_state, done, train=True): # Save experience in replay memory self.memory.add(state, action, reward, next_state, done) @@ -74,7 +74,8 @@ class Agent: # If enough samples are available in memory, get random subset and learn if len(self.memory) > BATCH_SIZE: experiences = self.memory.sample() - self.learn(experiences, GAMMA) + if train: + self.learn(experiences, GAMMA) def act(self, state, eps=0.): """Returns actions for given state as per current policy. diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 543b793..f97cb6d 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -27,34 +27,34 @@ import jpy_canvas class EditorMVC(object): - def __init__(self, env=None): - + def __init__(self, env=None, sGL="MPL"): if env is None: env = RailEnv(width=10, height=10, - rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6), + rail_generator=random_rail_generator(), number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2)) env.reset() self.editor = EditorModel(env) - self.editor.view = self.view = View(self.editor) + self.editor.view = self.view = View(self.editor, sGL=sGL) self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view) self.view.init_canvas() self.view.init_widgets() # has to be done after controller class View(object): - def __init__(self, editor): + def __init__(self, editor, sGL="MPL"): self.editor = self.model = editor + self.sGL = sGL def display(self): self.wOutput.clear_output() return self.wMain def init_canvas(self): - self.oRT = rt.RenderTool(self.editor.env) + self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL) plt.figure(figsize=(10, 10)) self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False) img = self.oRT.getImage() @@ -66,8 +66,10 @@ class View(object): self.wImage.register_click(self.controller.on_click) # TODO: These are currently estimated values - self.yxBase = array([6, 21]) # pixel offset - self.nPixCell = 700 / self.model.env.rail.width # 35 + # self.yxBase = array([6, 21]) # pixel offset + # self.nPixCell = 700 / self.model.env.rail.width # 35 + self.yxBase = self.oRT.gl.yxBase + self.nPixCell = self.oRT.gl.nPixCell def init_widgets(self): # Radiobutton for drawmode - TODO: replace with shift/ctrl/alt keys @@ -151,7 +153,7 @@ class View(object): def drag_path_element(self, x, y): # Draw a black square on the in-memory copy of the image if x > 10 and x < self.yxSize[1] and y > 10 and y < self.yxSize[0]: - self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :] = 0 + self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :3] = 0 def xy_to_rc(self, x, y): rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int) @@ -549,17 +551,18 @@ class EditorModel(object): def start_run(self): if self.thread is None: - self.thread = threading.Thread(target=self.bg_updater, args=()) + self.thread = threading.Thread(target=self.bg_updater, args=(self.view.wProg_steps,)) self.thread.start() else: self.log("thread already present") - def bg_updater(self): + def bg_updater(self, wProg_steps): try: for i in range(20): # self.log("step ", i) self.step() time.sleep(0.2) + wProg_steps.value = i+1 # indicate progress on bar finally: self.thread = None diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index aa9257b..6268e84 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -1,4 +1,7 @@ +import matplotlib.pyplot as plt +from numpy import array + class GraphicsLayer(object): def __init__(self): @@ -33,3 +36,20 @@ class GraphicsLayer(object): def getImage(self): pass + + def adaptColor(self, color): + if color == "red" or color == "r": + color = (255, 0, 0) + elif color == "gray": + color = (128, 128, 128) + elif type(color) is list: + color = tuple((array(color) * 255).astype(int)) + elif type(color) is tuple: + gcolor = array(color) + color = tuple((gcolor[:3] * 255).astype(int)) + else: + color = self.tColGrid + return color + + def get_cmap(self, *args, **kwargs): + return plt.get_cmap(*args, **kwargs) diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py new file mode 100644 index 0000000..01cc5f0 --- /dev/null +++ b/flatland/utils/graphics_pil.py @@ -0,0 +1,66 @@ + +from flatland.utils.graphics_layer import GraphicsLayer +from PIL import Image, ImageDraw # , ImageFont +from numpy import array +import numpy as np + + +class PILGL(GraphicsLayer): + def __init__(self, width, height, nPixCell=60): + self.nPixCell = 60 + self.yxBase = (0, 0) + self.linewidth = 4 + # self.tile_size = self.nPixCell + + self.width = width + self.height = height + + # Total grid size at native scale + self.widthPx = self.width * self.nPixCell + self.linewidth + self.heightPx = self.height * self.nPixCell + self.linewidth + self.beginFrame() + + self.tColBg = (255, 255, 255) # white background + # self.tColBg = (220, 120, 40) # background color + self.tColRail = (0, 0, 0) # black rails + self.tColGrid = (230,) * 3 # light grey for grid + + def plot(self, gX, gY, color=None, linewidth=3, **kwargs): + color = self.adaptColor(color) + + # print(gX, gY) + gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell + gPoints = list(gPoints.ravel()) + # print(gPoints, color) + self.draw.line(gPoints, fill=color, width=self.linewidth) + + def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs): + color = self.adaptColor(color) + r = np.sqrt(s) + gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell + for x, y in gPoints: + self.draw.rectangle([(x-r, y-r), (x+r, y+r)], fill=color, outline=color) + + def text(self, *args, **kwargs): + pass + + def prettify(self, *args, **kwargs): + pass + + def prettify2(self, width, height, cell_size): + pass + + def beginFrame(self): + self.img = Image.new("RGBA", (self.widthPx, self.heightPx), (255, 255, 255, 255)) + self.draw = ImageDraw.Draw(self.img) + + def show(self, block=False): + pass + # plt.show(block=block) + + def pause(self, seconds=0.00001): + pass + # plt.pause(seconds) + + def getImage(self): + return array(self.img) diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index 0804c9d..29fdec5 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -1,7 +1,7 @@ from flatland.utils.graphics_qt import QtRenderer from numpy import array from flatland.utils.graphics_layer import GraphicsLayer -from matplotlib import pyplot as plt +# from matplotlib import pyplot as plt import numpy as np @@ -36,20 +36,6 @@ class QTGL(GraphicsLayer): self.qtr.pop() self.qtr.endFrame() - def adaptColor(self, color): - if color == "red" or color == "r": - color = (255, 0, 0) - elif color == "gray": - color = (128, 128, 128) - elif type(color) is list: - color = array(color) * 255 - elif type(color) is tuple: - gcolor = array(color) - color = gcolor[:3] * 255 - else: - color = self.tColGrid - return color - def plot(self, gX, gY, color=None, linewidth=2, **kwargs): color = self.adaptColor(color) @@ -70,11 +56,11 @@ class QTGL(GraphicsLayer): gPoints = np.stack([array(gX), -array(gY)]).T * self.cell_pixels self.qtr.drawPolyline(gPoints) - def scatter(self, gX, gY, color=None, marker="o", size=5, *args, **kwargs): + def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs): color = self.adaptColor(color) self.qtr.setColor(*color) self.qtr.setLineColor(*color) - r = np.sqrt(size) + r = np.sqrt(s) gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels for x, y in gPoints: self.qtr.drawCircle(x, y, r) @@ -94,12 +80,6 @@ class QTGL(GraphicsLayer): def pause(self, seconds=0.00001): pass - def clf(self): - pass - - def get_cmap(self, *args, **kwargs): - return plt.get_cmap(*args, **kwargs) - def beginFrame(self): self.qtr.beginFrame() self.qtr.push() diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index e84ff46..c54b2fe 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -2,19 +2,23 @@ from recordtype import recordtype import numpy as np from numpy import array -import xarray as xr +# import xarray as xr import matplotlib.pyplot as plt import time from collections import deque from flatland.utils.render_qt import QTGL +from flatland.utils.graphics_pil import PILGL from flatland.utils.graphics_layer import GraphicsLayer - # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! class MPLGL(GraphicsLayer): - def __init__(self): + def __init__(self, width, height): + self.width = width + self.height = height + self.yxBase = array([6, 21]) # pixel offset + self.nPixCell = 700 / width pass def plot(self, *args, **kwargs): @@ -62,9 +66,12 @@ class MPLGL(GraphicsLayer): return plt.get_cmap(*args, **kwargs) def beginFrame(self): + # plt.figure(figsize=(10, 10)) pass def endFrame(self): + # plt.clf() + # plt.close() pass def getImage(self): @@ -83,19 +90,19 @@ class RenderTool(object): lColors = list("brgcmyk") # \delta RC for NESW gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) - nPixCell = 1 + nPixCell = 1 # misnomer... nPixHalf = nPixCell / 2 xyHalf = array([nPixHalf, -nPixHalf]) grc2xy = array([[0, -nPixCell], [nPixCell, 0]]) gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * \ array([[[nPixCell]], [[nPixCell]]]) - xyPixHalf = xr.DataArray([nPixHalf, -nPixHalf], - dims="xy", - coords={"xy": ["x", "y"]}) - gCentres = xr.DataArray(gGrid, - dims=["xy", "p1", "p2"], - coords={"xy": ["x", "y"]}) + xyPixHalf - gTheta = np.linspace(0, np.pi / 2, 10) + # xyPixHalf = xr.DataArray([nPixHalf, -nPixHalf], + # dims="xy", + # coords={"xy": ["x", "y"]}) + # gCentres = xr.DataArray(gGrid, + # dims=["xy", "p1", "p2"], + # coords={"xy": ["x", "y"]}) + xyPixHalf + gTheta = np.linspace(0, np.pi / 2, 5) gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] def __init__(self, env, gl="MPL"): @@ -105,7 +112,12 @@ class RenderTool(object): self.lTimes = deque() # self.gl = MPLGL() - self.gl = MPLGL() if gl == "MPL" else QTGL(env.width, env.height) + if gl == "MPL": + self.gl = MPLGL(env.width, env.height) + elif gl == "QT": + self.gl = QTGL(env.width, env.height) + elif gl == "PIL": + self.gl = PILGL(env.width, env.height) def plotTreeOnRail(self, lVisits, color="r"): """ @@ -489,16 +501,18 @@ class RenderTool(object): env = self.env + t1 = time.time() + # Draw cells grid grid_color = [0.95, 0.95, 0.95] for r in range(env.height + 1): self.gl.plot([0, (env.width + 1) * cell_size], [-r * cell_size, -r * cell_size], - color=grid_color) + color=grid_color, linewidth=2) for c in range(env.width + 1): self.gl.plot([c * cell_size, c * cell_size], [0, -(env.height + 1) * cell_size], - color=grid_color) + color=grid_color, linewidth=2) # Draw each cell independently for r in range(env.height): @@ -644,6 +658,9 @@ class RenderTool(object): self.gl.endFrame() + t2 = time.time() + print(t2 - t1, "seconds") + if show: self.gl.show(block=False) self.gl.pause(0.00001) @@ -659,3 +676,33 @@ class RenderTool(object): def getImage(self): return self.gl.getImage() + + def plotTreeObs(self, gObs): + nBranchFactor = 4 + + gP0 = array([[0, 0, 0]]).T + nDepth = 2 + for i in range(nDepth): + nDepthNodes = nBranchFactor**i + # rScale = nBranchFactor ** (nDepth - i) + rShrinkDepth = 1/(i+1) + # gX1 = np.linspace(-nDepthNodes / 2, nDepthNodes / 2, nDepthNodes) * rShrinkDepth + + gX1 = np.linspace(-(nDepthNodes-1), (nDepthNodes-1), nDepthNodes) * rShrinkDepth + gY1 = np.ones((nDepthNodes)) * i + gZ1 = np.zeros((nDepthNodes)) + + gP1 = array([gX1, gY1, gZ1]) + gP01 = np.append(gP0, gP1, axis=1) + + if nDepthNodes > 1: + nDepthNodesPrev = nDepthNodes / nBranchFactor + giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor) + giP1 = np.arange(0, nDepthNodes) + nDepthNodesPrev + giLinePoints = np.stack([giP0, giP1]).ravel("F") + # print(gP01[:,:10]) + print(giLinePoints) + self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray") + + gP0 = array([gX1, gY1, gZ1]) + \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt index 0bc267e..40a6b7f 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -16,4 +16,5 @@ recordtype==1.3 xarray==0.11.3 matplotlib==3.0.2 PyQt5==5.12 +Pillow==5.4.1 -- GitLab