diff --git a/examples/play_model.py b/examples/play_model.py index 4fe40f9a77d0d55aef652a6e9d9240c6c8153f6a..6a67397ea4ba8d8906ce62b1a8d21327c247a3e0 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -10,7 +10,7 @@ import matplotlib.pyplot as plt import time -def main(render=True, delay=2): +def main(render=True, delay=0.0): random.seed(1) np.random.seed(1) @@ -27,10 +27,10 @@ def main(render=True, delay=2): 0.0] # Case 7 - dead end # Example generate a random rail - env = RailEnv(width=7, - height=7, + env = RailEnv(width=15, + height=15, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=1) + number_of_agents=5) if render: env_renderer = RenderTool(env, gl="QT") @@ -52,7 +52,7 @@ def main(render=True, delay=2): dones_list = [] action_prob = [0]*4 agent = Agent(state_size, action_size, "FC", 0) - agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) + # agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) def max_lt(seq, val): """ @@ -108,7 +108,7 @@ def main(render=True, delay=2): env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step) if delay > 0: time.sleep(delay) - + iFrame += 1 diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index 94439b2b442c99e34387767e3f36246ac3b354cf..e2334d3e1a8a66c31139b848b1b9bbb2f0e0a590 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -73,6 +73,7 @@ class QTGL(GraphicsLayer): def scatter(self, gX, gY, color=None, marker="o", size=5, *args, **kwargs): color = self.adaptColor(color) self.qtr.setColor(*color) + self.qtr.setLineColor(*color) r = np.sqrt(size) gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels for x, y in gPoints: diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 278a08f687dd78987f2a19f8321524aaf726e6ac..985f2b6dff207aad75412387e2f604fb3a80a1b2 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -134,7 +134,7 @@ class RenderTool(object): gTransRCAg = rt.gTransRC[giTrans] self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color) - def plotAgents(self): + def plotAgents(self, targets=True): cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1) for iAgent in range(self.env.number_of_agents): oColor = cmap(iAgent) @@ -142,7 +142,11 @@ class RenderTool(object): rcPos = self.env.agents_position[iAgent] iDir = self.env.agents_direction[iAgent] # agent direction index - self.plotAgent(rcPos, iDir, oColor) + if targets: + target = self.env.agents_target[iAgent] + else: + target = None + self.plotAgent(rcPos, iDir, oColor, target=target) # gTransRCAg = self.getTransRC(rcPos, iDir) # self.plotTrans(rcPos, gTransRCAg) @@ -183,7 +187,7 @@ class RenderTool(object): else: return gTransRCAg - def plotAgent(self, rcPos, iDir, sColor="r"): + def plotAgent(self, rcPos, iDir, color="r", target=None): """ Plot a simple agent. Assumes a working graphics layer context (cf a MPL figure). @@ -194,16 +198,21 @@ class RenderTool(object): xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf - self.gl.scatter(*xyPos, color=sColor, size=10) # agent location + self.gl.scatter(*xyPos, color=color, size=40) # agent location xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient. - self.gl.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6) + self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6) + + if target is not None: + rcTarget = array(target) + xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf + self._draw_square(xyTarget, 1/3, color) if False: # mark the next cell we're heading into rcNext = rcPos + rcDir xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf - self.gl.scatter(*xyNext, color=sColor) + self.gl.scatter(*xyNext, color=color) def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None): """ @@ -571,7 +580,7 @@ class RenderTool(object): # Draw each agent + its orientation + its target if agents: cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1) - self.plotAgents() + self.plotAgents(targets=True) if False: for i in range(env.number_of_agents):