From 9c4beda08efeae698461d00559b9de2ac857761c Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Wed, 24 Apr 2019 11:53:14 +0100 Subject: [PATCH] merged; changed to plotAgent with circle line; added delay and render flag to play --- examples/play_model.py | 39 +++++++++++++++++++-------------- flatland/utils/graphics_qt.py | 2 +- flatland/utils/render_qt.py | 41 ++++++++++++++++++++++++----------- flatland/utils/rendertools.py | 39 ++++++++++++++++++--------------- 4 files changed, 73 insertions(+), 48 deletions(-) diff --git a/examples/play_model.py b/examples/play_model.py index d54decd8..4fe40f9a 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -1,17 +1,16 @@ from flatland.envs.rail_env import RailEnv, random_rail_generator # from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import RenderTool -from flatland.utils.render_qt import QtRailRender from flatland.baselines.dueling_double_dqn import Agent from collections import deque import torch import random import numpy as np import matplotlib.pyplot as plt -import redis +import time -def main(): +def main(render=True, delay=2): random.seed(1) np.random.seed(1) @@ -32,8 +31,9 @@ def main(): height=7, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) - env_renderer = RenderTool(env, gl="QT") - #env_renderer = QtRailRender(env) + + if render: + env_renderer = RenderTool(env, gl="QT") plt.figure(figsize=(5,5)) # fRedis = redis.Redis() @@ -67,6 +67,8 @@ def main(): idx -= 1 return None + iFrame = 0 + tStart = time.time() for trials in range(1, n_trials + 1): # Reset environment @@ -102,7 +104,13 @@ def main(): agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) score += all_rewards[a] - env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step) + if render: + env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step) + if delay > 0: + time.sleep(delay) + + iFrame += 1 + obs = next_obs.copy() if done['__all__']: @@ -116,8 +124,8 @@ def main(): scores.append(np.mean(scores_window)) dones_list.append((np.mean(done_window))) - print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + - '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( + print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + + '\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format( env.number_of_agents, trials, np.mean(scores_window), @@ -125,16 +133,15 @@ def main(): eps, action_prob/np.sum(action_prob)), end=" ") if trials % 100 == 0: - - print( - '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( + tNow = time.time() + rFps = iFrame / (tNow - tStart) + print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + + '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format( env.number_of_agents, trials, - np.mean( - scores_window), - 100 * np.mean( - done_window), - eps, action_prob / np.sum(action_prob))) + np.mean(scores_window), + 100 * np.mean(done_window), + eps, rFps, action_prob / np.sum(action_prob))) torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') action_prob = [1]*4 diff --git a/flatland/utils/graphics_qt.py b/flatland/utils/graphics_qt.py index 6571f4d7..a4abb578 100644 --- a/flatland/utils/graphics_qt.py +++ b/flatland/utils/graphics_qt.py @@ -123,7 +123,7 @@ class QtRenderer(object): def beginFrame(self): self.painter.begin(self.img) - self.painter.setRenderHint(QPainter.Antialiasing, False) + # self.painter.setRenderHint(QPainter.Antialiasing, False) # Clear the background self.painter.setBrush(QColor(0, 0, 0)) diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index 60b8d295..94439b2b 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -2,6 +2,7 @@ from flatland.utils.graphics_qt import QtRenderer from numpy import array from flatland.utils.graphics_layer import GraphicsLayer from matplotlib import pyplot as plt +import numpy as np class QTGL(GraphicsLayer): @@ -35,8 +36,7 @@ class QTGL(GraphicsLayer): self.qtr.pop() self.qtr.endFrame() - def plot(self, gX, gY, color=None, linewidth=2, **kwargs): - + def adaptColor(self, color): if color == "red" or color == "r": color = (255, 0, 0) elif color == "gray": @@ -48,20 +48,35 @@ class QTGL(GraphicsLayer): color = gcolor[:3] * 255 else: color = self.tColGrid + return color + + def plot(self, gX, gY, color=None, linewidth=2, **kwargs): + color = self.adaptColor(color) self.qtr.setLineColor(*color) lastx = lasty = None - for x, y in zip(gX, gY): - if lastx is not None: - # print("line", lastx, lasty, x, y) - self.qtr.drawLine( - lastx*self.cell_pixels, -lasty*self.cell_pixels, - x*self.cell_pixels, -y*self.cell_pixels) - lastx = x - lasty = y - - def scatter(self, *args, **kwargs): - print("scatter not yet implemented in ", self.__class__) + + if False: + for x, y in zip(gX, gY): + if lastx is not None: + # print("line", lastx, lasty, x, y) + self.qtr.drawLine( + lastx*self.cell_pixels, -lasty*self.cell_pixels, + x*self.cell_pixels, -y*self.cell_pixels) + lastx = x + lasty = y + else: + # print(gX, gY) + gPoints = np.stack([array(gX), -array(gY)]).T * self.cell_pixels + self.qtr.drawPolyline(gPoints) + + def scatter(self, gX, gY, color=None, marker="o", size=5, *args, **kwargs): + color = self.adaptColor(color) + self.qtr.setColor(*color) + r = np.sqrt(size) + gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels + for x, y in gPoints: + self.qtr.drawCircle(x, y, r) def text(self, x, y, sText): self.qtr.drawText(x*self.cell_pixels, -y*self.cell_pixels, sText) diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index f9ebf532..278a08f6 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -116,8 +116,8 @@ class RenderTool(object): self.plotAgent(rcPos, iDir, sColor) - gTransRCAg = self.getTransRC(rcPos, iDir) - self.plotTrans(rcPos, gTransRCAg, color=color) + # gTransRCAg = self.getTransRC(rcPos, iDir) + # self.plotTrans(rcPos, gTransRCAg, color=color) if False: # TODO: this was `rcDir' but it was undefined @@ -135,20 +135,17 @@ class RenderTool(object): self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color) def plotAgents(self): - rt = self.__class__ - - # plt.scatter(*rt.gCentres, s=5, color="r") - + cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1) for iAgent in range(self.env.number_of_agents): - sColor = rt.lColors[iAgent] + oColor = cmap(iAgent) rcPos = self.env.agents_position[iAgent] iDir = self.env.agents_direction[iAgent] # agent direction index - self.plotAgent(rcPos, iDir, sColor) + self.plotAgent(rcPos, iDir, oColor) - gTransRCAg = self.getTransRC(rcPos, iDir) - self.plotTrans(rcPos, gTransRCAg) + # gTransRCAg = self.getTransRC(rcPos, iDir) + # self.plotTrans(rcPos, gTransRCAg) def getTransRC(self, rcPos, iDir, bgiTrans=False): """ @@ -189,21 +186,24 @@ class RenderTool(object): def plotAgent(self, rcPos, iDir, sColor="r"): """ Plot a simple agent. - Assumes a working matplotlib context. + Assumes a working graphics layer context (cf a MPL figure). """ rt = self.__class__ - xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf - self.gl.scatter(*xyPos, color=sColor) # agent location rcDir = rt.gTransRC[iDir] # agent direction in RC xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy - xyDirLine = array([xyPos, xyPos+xyDir/2]).T # line for agent orient. + + xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf + self.gl.scatter(*xyPos, color=sColor, size=10) # agent location + + xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient. self.gl.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6) - # just mark the next cell we're heading into - rcNext = rcPos + rcDir - xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf - self.gl.scatter(*xyNext, color=sColor) + if False: + # mark the next cell we're heading into + rcNext = rcPos + rcDir + xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf + self.gl.scatter(*xyNext, color=sColor) def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None): """ @@ -571,6 +571,9 @@ class RenderTool(object): # Draw each agent + its orientation + its target if agents: cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1) + self.plotAgents() + + if False: for i in range(env.number_of_agents): self._draw_square(( env.agents_position[i][1] * -- GitLab