diff --git a/examples/demo.py b/examples/demo.py index 84c9cd2f8ef89ae7a16d32887a6fa3421a50cb01..2ee1d5c0e747b8de11b010ca69d24db465e8f4f1 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -2,6 +2,7 @@ import os import random from collections import deque +import time import numpy as np import torch @@ -165,6 +166,8 @@ class Demo: for step in range(max_nbr_of_steps): self.renderer.renderEnv(show=True) + time.sleep(.2) + # print(step) # Action for a in range(self.env.get_num_agents()): diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 838be72cd66011b79eeaf54985d96026c6b8b6b8..497db9cab5e6716aa93448b399ece0804f6b8753 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -1,5 +1,4 @@ import random -import time from collections import deque import numpy as np @@ -193,15 +192,18 @@ for trials in range(1, n_trials + 1): scores.append(np.mean(scores_window)) dones_list.append((np.mean(done_window))) - print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( - env.get_num_agents(), - trials, - np.mean(scores_window), - 100 * np.mean(done_window), - eps, action_prob / np.sum(action_prob)), end=" ") + print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%' + + '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( + env.get_num_agents(), + trials, + np.mean(scores_window), + 100 * np.mean(done_window), + eps, action_prob / np.sum(action_prob)), end=" ") if trials % 100 == 0: - print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( + print( + '\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%' + + '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( env.get_num_agents(), trials, np.mean(scores_window), diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index d17bbef855e1cfcc9c62dea00d043018d8ec237d..9cc06d1d659236fb88af456b291ec933dbf4de66 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -1,14 +1,15 @@ -from recordtype import recordtype +import time +from collections import deque -import numpy as np -from numpy import array # import xarray as xr import matplotlib.pyplot as plt -import time -from collections import deque -from flatland.utils.render_qt import QTGL, QTSVG -from flatland.utils.graphics_pil import PILGL +import numpy as np +from numpy import array +from recordtype import recordtype + from flatland.utils.graphics_layer import GraphicsLayer +from flatland.utils.graphics_pil import PILGL +from flatland.utils.render_qt import QTGL, QTSVG # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! @@ -409,13 +410,12 @@ class RenderTool(object): color=sColor ) - def drawTrans2( - self, - xyLine, xyCentre, - rotation, bDeadEnd=False, - sColor="gray", - bArrow=True, - spacing=0.1): + def drawTrans2(self, + xyLine, xyCentre, + rotation, bDeadEnd=False, + sColor="gray", + bArrow=True, + spacing=0.1): """ gLine is a numpy 2d array of points, in the plotting space / coords. @@ -501,7 +501,7 @@ class RenderTool(object): for visited_cell in observation_dict[agent]: cell_coord = array(visited_cell[:2]) cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf - self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100) + self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100) def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False): @@ -604,11 +604,10 @@ class RenderTool(object): "rot:", rotation, ) - def renderEnv( - self, show=False, curves=True, spacing=False, - arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False, - iEpisode=None, iStep=None, - iSelectedAgent=None, action_dict=None): + def renderEnv(self, show=False, curves=True, spacing=False, + arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False, + iEpisode=None, iStep=None, + iSelectedAgent=None, action_dict=None): """ Draw the environment using matplotlib. Draw into the figure if provided. @@ -683,7 +682,6 @@ class RenderTool(object): self.gl.pause(0.00001) - return def _draw_square(self, center, size, color, opacity=255, layer=0): @@ -725,10 +723,9 @@ class RenderTool(object): gP0 = array([gX1, gY1, gZ1]) - def renderEnv2( - self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True, sRailColor="gray", - frames=False, iEpisode=None, iStep=None, iSelectedAgent=None, - action_dict=dict()): + def renderEnv2(self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True, + sRailColor="gray", frames=False, iEpisode=None, iStep=None, iSelectedAgent=None, + action_dict=dict()): """ Draw the environment using matplotlib. Draw into the figure if provided. diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py index 32d5631839f43964cf5ff7d94520aacf67db0488..e7c295f7b9d6ecda1b0788e3a6f9af3e02a7051a 100644 --- a/flatland/utils/svg.py +++ b/flatland/utils/svg.py @@ -105,7 +105,7 @@ class Zug(object): class Track(object): def __init__(self): dFiles = { - "": "Background_#91D1DD.svg", + "": "Background_#9CCB89.svg", "WE": "Gleis_Deadend.svg", "WW EE NN SS": "Gleis_Diamond_Crossing.svg", "WW EE": "Gleis_horizontal.svg", @@ -132,7 +132,7 @@ class Track(object): lDirs = list("NESW") - svgBG = SVG("./svg/Background_#91D1DD.svg") + svgBG = SVG("./svg/Background_#9CCB89.svg") for sTrans, sFile in dFiles.items(): svg = SVG("./svg/" + sFile)