Skip to content
Snippets Groups Projects
Commit c361c45b authored by hagrid67's avatar hagrid67
Browse files

plot targets. play with larger grid 15x15, 5 agents. delay=0.

parent 0a43b421
No related branches found
No related tags found
No related merge requests found
...@@ -10,7 +10,7 @@ import matplotlib.pyplot as plt ...@@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
import time import time
def main(render=True, delay=2): def main(render=True, delay=0.0):
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
...@@ -27,10 +27,10 @@ def main(render=True, delay=2): ...@@ -27,10 +27,10 @@ def main(render=True, delay=2):
0.0] # Case 7 - dead end 0.0] # Case 7 - dead end
# Example generate a random rail # Example generate a random rail
env = RailEnv(width=7, env = RailEnv(width=15,
height=7, height=15,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1) number_of_agents=5)
if render: if render:
env_renderer = RenderTool(env, gl="QT") env_renderer = RenderTool(env, gl="QT")
...@@ -52,7 +52,7 @@ def main(render=True, delay=2): ...@@ -52,7 +52,7 @@ def main(render=True, delay=2):
dones_list = [] dones_list = []
action_prob = [0]*4 action_prob = [0]*4
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) # agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
def max_lt(seq, val): def max_lt(seq, val):
""" """
...@@ -108,7 +108,7 @@ def main(render=True, delay=2): ...@@ -108,7 +108,7 @@ def main(render=True, delay=2):
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step) env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
if delay > 0: if delay > 0:
time.sleep(delay) time.sleep(delay)
iFrame += 1 iFrame += 1
......
...@@ -73,6 +73,7 @@ class QTGL(GraphicsLayer): ...@@ -73,6 +73,7 @@ class QTGL(GraphicsLayer):
def scatter(self, gX, gY, color=None, marker="o", size=5, *args, **kwargs): def scatter(self, gX, gY, color=None, marker="o", size=5, *args, **kwargs):
color = self.adaptColor(color) color = self.adaptColor(color)
self.qtr.setColor(*color) self.qtr.setColor(*color)
self.qtr.setLineColor(*color)
r = np.sqrt(size) r = np.sqrt(size)
gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels
for x, y in gPoints: for x, y in gPoints:
......
...@@ -134,7 +134,7 @@ class RenderTool(object): ...@@ -134,7 +134,7 @@ class RenderTool(object):
gTransRCAg = rt.gTransRC[giTrans] gTransRCAg = rt.gTransRC[giTrans]
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color) self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
def plotAgents(self): def plotAgents(self, targets=True):
cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1) cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1)
for iAgent in range(self.env.number_of_agents): for iAgent in range(self.env.number_of_agents):
oColor = cmap(iAgent) oColor = cmap(iAgent)
...@@ -142,7 +142,11 @@ class RenderTool(object): ...@@ -142,7 +142,11 @@ class RenderTool(object):
rcPos = self.env.agents_position[iAgent] rcPos = self.env.agents_position[iAgent]
iDir = self.env.agents_direction[iAgent] # agent direction index iDir = self.env.agents_direction[iAgent] # agent direction index
self.plotAgent(rcPos, iDir, oColor) if targets:
target = self.env.agents_target[iAgent]
else:
target = None
self.plotAgent(rcPos, iDir, oColor, target=target)
# gTransRCAg = self.getTransRC(rcPos, iDir) # gTransRCAg = self.getTransRC(rcPos, iDir)
# self.plotTrans(rcPos, gTransRCAg) # self.plotTrans(rcPos, gTransRCAg)
...@@ -183,7 +187,7 @@ class RenderTool(object): ...@@ -183,7 +187,7 @@ class RenderTool(object):
else: else:
return gTransRCAg return gTransRCAg
def plotAgent(self, rcPos, iDir, sColor="r"): def plotAgent(self, rcPos, iDir, color="r", target=None):
""" """
Plot a simple agent. Plot a simple agent.
Assumes a working graphics layer context (cf a MPL figure). Assumes a working graphics layer context (cf a MPL figure).
...@@ -194,16 +198,21 @@ class RenderTool(object): ...@@ -194,16 +198,21 @@ class RenderTool(object):
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
self.gl.scatter(*xyPos, color=sColor, size=10) # agent location self.gl.scatter(*xyPos, color=color, size=40) # agent location
xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient. xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient.
self.gl.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6) self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
if target is not None:
rcTarget = array(target)
xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
self._draw_square(xyTarget, 1/3, color)
if False: if False:
# mark the next cell we're heading into # mark the next cell we're heading into
rcNext = rcPos + rcDir rcNext = rcPos + rcDir
xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf
self.gl.scatter(*xyNext, color=sColor) self.gl.scatter(*xyNext, color=color)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None): def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
""" """
...@@ -571,7 +580,7 @@ class RenderTool(object): ...@@ -571,7 +580,7 @@ class RenderTool(object):
# Draw each agent + its orientation + its target # Draw each agent + its orientation + its target
if agents: if agents:
cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1) cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1)
self.plotAgents() self.plotAgents(targets=True)
if False: if False:
for i in range(env.number_of_agents): for i in range(env.number_of_agents):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment