Skip to content
Snippets Groups Projects
Commit c361c45b authored by hagrid67's avatar hagrid67
Browse files

plot targets. play with larger grid 15x15, 5 agents. delay=0.

parent 0a43b421
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
import time
def main(render=True, delay=2):
def main(render=True, delay=0.0):
random.seed(1)
np.random.seed(1)
......@@ -27,10 +27,10 @@ def main(render=True, delay=2):
0.0] # Case 7 - dead end
# Example generate a random rail
env = RailEnv(width=7,
height=7,
env = RailEnv(width=15,
height=15,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
number_of_agents=5)
if render:
env_renderer = RenderTool(env, gl="QT")
......@@ -52,7 +52,7 @@ def main(render=True, delay=2):
dones_list = []
action_prob = [0]*4
agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
def max_lt(seq, val):
"""
......@@ -108,7 +108,7 @@ def main(render=True, delay=2):
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
if delay > 0:
time.sleep(delay)
iFrame += 1
......
......@@ -73,6 +73,7 @@ class QTGL(GraphicsLayer):
def scatter(self, gX, gY, color=None, marker="o", size=5, *args, **kwargs):
color = self.adaptColor(color)
self.qtr.setColor(*color)
self.qtr.setLineColor(*color)
r = np.sqrt(size)
gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels
for x, y in gPoints:
......
......@@ -134,7 +134,7 @@ class RenderTool(object):
gTransRCAg = rt.gTransRC[giTrans]
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
def plotAgents(self):
def plotAgents(self, targets=True):
cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1)
for iAgent in range(self.env.number_of_agents):
oColor = cmap(iAgent)
......@@ -142,7 +142,11 @@ class RenderTool(object):
rcPos = self.env.agents_position[iAgent]
iDir = self.env.agents_direction[iAgent] # agent direction index
self.plotAgent(rcPos, iDir, oColor)
if targets:
target = self.env.agents_target[iAgent]
else:
target = None
self.plotAgent(rcPos, iDir, oColor, target=target)
# gTransRCAg = self.getTransRC(rcPos, iDir)
# self.plotTrans(rcPos, gTransRCAg)
......@@ -183,7 +187,7 @@ class RenderTool(object):
else:
return gTransRCAg
def plotAgent(self, rcPos, iDir, sColor="r"):
def plotAgent(self, rcPos, iDir, color="r", target=None):
"""
Plot a simple agent.
Assumes a working graphics layer context (cf a MPL figure).
......@@ -194,16 +198,21 @@ class RenderTool(object):
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
self.gl.scatter(*xyPos, color=sColor, size=10) # agent location
self.gl.scatter(*xyPos, color=color, size=40) # agent location
xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient.
self.gl.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6)
self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
if target is not None:
rcTarget = array(target)
xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
self._draw_square(xyTarget, 1/3, color)
if False:
# mark the next cell we're heading into
rcNext = rcPos + rcDir
xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf
self.gl.scatter(*xyNext, color=sColor)
self.gl.scatter(*xyNext, color=color)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
"""
......@@ -571,7 +580,7 @@ class RenderTool(object):
# Draw each agent + its orientation + its target
if agents:
cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1)
self.plotAgents()
self.plotAgents(targets=True)
if False:
for i in range(env.number_of_agents):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment