Skip to content
Snippets Groups Projects
Commit 9c4beda0 authored by hagrid67's avatar hagrid67
Browse files

merged; changed to plotAgent with circle line; added delay and render flag to play

parent 401888d3
No related branches found
No related tags found
No related merge requests found
Pipeline #357 passed
from flatland.envs.rail_env import RailEnv, random_rail_generator
# from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import RenderTool
from flatland.utils.render_qt import QtRailRender
from flatland.baselines.dueling_double_dqn import Agent
from collections import deque
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import redis
import time
def main():
def main(render=True, delay=2):
random.seed(1)
np.random.seed(1)
......@@ -32,8 +31,9 @@ def main():
height=7,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
env_renderer = RenderTool(env, gl="QT")
#env_renderer = QtRailRender(env)
if render:
env_renderer = RenderTool(env, gl="QT")
plt.figure(figsize=(5,5))
# fRedis = redis.Redis()
......@@ -67,6 +67,8 @@ def main():
idx -= 1
return None
iFrame = 0
tStart = time.time()
for trials in range(1, n_trials + 1):
# Reset environment
......@@ -102,7 +104,13 @@ def main():
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
score += all_rewards[a]
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
if render:
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
if delay > 0:
time.sleep(delay)
iFrame += 1
obs = next_obs.copy()
if done['__all__']:
......@@ -116,8 +124,8 @@ def main():
scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window)))
print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format(
env.number_of_agents,
trials,
np.mean(scores_window),
......@@ -125,16 +133,15 @@ def main():
eps, action_prob/np.sum(action_prob)),
end=" ")
if trials % 100 == 0:
print(
'\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
tNow = time.time()
rFps = iFrame / (tNow - tStart)
print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format(
env.number_of_agents,
trials,
np.mean(
scores_window),
100 * np.mean(
done_window),
eps, action_prob / np.sum(action_prob)))
np.mean(scores_window),
100 * np.mean(done_window),
eps, rFps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(),
'../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
action_prob = [1]*4
......
......@@ -123,7 +123,7 @@ class QtRenderer(object):
def beginFrame(self):
self.painter.begin(self.img)
self.painter.setRenderHint(QPainter.Antialiasing, False)
# self.painter.setRenderHint(QPainter.Antialiasing, False)
# Clear the background
self.painter.setBrush(QColor(0, 0, 0))
......
......@@ -2,6 +2,7 @@ from flatland.utils.graphics_qt import QtRenderer
from numpy import array
from flatland.utils.graphics_layer import GraphicsLayer
from matplotlib import pyplot as plt
import numpy as np
class QTGL(GraphicsLayer):
......@@ -35,8 +36,7 @@ class QTGL(GraphicsLayer):
self.qtr.pop()
self.qtr.endFrame()
def plot(self, gX, gY, color=None, linewidth=2, **kwargs):
def adaptColor(self, color):
if color == "red" or color == "r":
color = (255, 0, 0)
elif color == "gray":
......@@ -48,20 +48,35 @@ class QTGL(GraphicsLayer):
color = gcolor[:3] * 255
else:
color = self.tColGrid
return color
def plot(self, gX, gY, color=None, linewidth=2, **kwargs):
color = self.adaptColor(color)
self.qtr.setLineColor(*color)
lastx = lasty = None
for x, y in zip(gX, gY):
if lastx is not None:
# print("line", lastx, lasty, x, y)
self.qtr.drawLine(
lastx*self.cell_pixels, -lasty*self.cell_pixels,
x*self.cell_pixels, -y*self.cell_pixels)
lastx = x
lasty = y
def scatter(self, *args, **kwargs):
print("scatter not yet implemented in ", self.__class__)
if False:
for x, y in zip(gX, gY):
if lastx is not None:
# print("line", lastx, lasty, x, y)
self.qtr.drawLine(
lastx*self.cell_pixels, -lasty*self.cell_pixels,
x*self.cell_pixels, -y*self.cell_pixels)
lastx = x
lasty = y
else:
# print(gX, gY)
gPoints = np.stack([array(gX), -array(gY)]).T * self.cell_pixels
self.qtr.drawPolyline(gPoints)
def scatter(self, gX, gY, color=None, marker="o", size=5, *args, **kwargs):
color = self.adaptColor(color)
self.qtr.setColor(*color)
r = np.sqrt(size)
gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels
for x, y in gPoints:
self.qtr.drawCircle(x, y, r)
def text(self, x, y, sText):
self.qtr.drawText(x*self.cell_pixels, -y*self.cell_pixels, sText)
......
......@@ -116,8 +116,8 @@ class RenderTool(object):
self.plotAgent(rcPos, iDir, sColor)
gTransRCAg = self.getTransRC(rcPos, iDir)
self.plotTrans(rcPos, gTransRCAg, color=color)
# gTransRCAg = self.getTransRC(rcPos, iDir)
# self.plotTrans(rcPos, gTransRCAg, color=color)
if False:
# TODO: this was `rcDir' but it was undefined
......@@ -135,20 +135,17 @@ class RenderTool(object):
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
def plotAgents(self):
rt = self.__class__
# plt.scatter(*rt.gCentres, s=5, color="r")
cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1)
for iAgent in range(self.env.number_of_agents):
sColor = rt.lColors[iAgent]
oColor = cmap(iAgent)
rcPos = self.env.agents_position[iAgent]
iDir = self.env.agents_direction[iAgent] # agent direction index
self.plotAgent(rcPos, iDir, sColor)
self.plotAgent(rcPos, iDir, oColor)
gTransRCAg = self.getTransRC(rcPos, iDir)
self.plotTrans(rcPos, gTransRCAg)
# gTransRCAg = self.getTransRC(rcPos, iDir)
# self.plotTrans(rcPos, gTransRCAg)
def getTransRC(self, rcPos, iDir, bgiTrans=False):
"""
......@@ -189,21 +186,24 @@ class RenderTool(object):
def plotAgent(self, rcPos, iDir, sColor="r"):
"""
Plot a simple agent.
Assumes a working matplotlib context.
Assumes a working graphics layer context (cf a MPL figure).
"""
rt = self.__class__
xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
self.gl.scatter(*xyPos, color=sColor) # agent location
rcDir = rt.gTransRC[iDir] # agent direction in RC
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
xyDirLine = array([xyPos, xyPos+xyDir/2]).T # line for agent orient.
xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
self.gl.scatter(*xyPos, color=sColor, size=10) # agent location
xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient.
self.gl.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6)
# just mark the next cell we're heading into
rcNext = rcPos + rcDir
xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf
self.gl.scatter(*xyNext, color=sColor)
if False:
# mark the next cell we're heading into
rcNext = rcPos + rcDir
xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf
self.gl.scatter(*xyNext, color=sColor)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
"""
......@@ -571,6 +571,9 @@ class RenderTool(object):
# Draw each agent + its orientation + its target
if agents:
cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1)
self.plotAgents()
if False:
for i in range(env.number_of_agents):
self._draw_square((
env.agents_position[i][1] *
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment