diff --git a/examples/training_navigation.py b/examples/training_navigation.py index b35c9cf772378d7f9be8995f401e5aab74ca7f6f..15b8ddde15b287418ae41229a7bf4edeed92772c 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -1,5 +1,4 @@ import random -import time from collections import deque import numpy as np @@ -193,15 +192,18 @@ for trials in range(1, n_trials + 1): scores.append(np.mean(scores_window)) dones_list.append((np.mean(done_window))) - print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( - env.get_num_agents(), - trials, - np.mean(scores_window), - 100 * np.mean(done_window), - eps, action_prob / np.sum(action_prob)), end=" ") + print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%' + + '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( + env.get_num_agents(), + trials, + np.mean(scores_window), + 100 * np.mean(done_window), + eps, action_prob / np.sum(action_prob)), end=" ") if trials % 100 == 0: - print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( + print( + '\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%' + + '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( env.get_num_agents(), trials, np.mean(scores_window), diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index d17bbef855e1cfcc9c62dea00d043018d8ec237d..9cc06d1d659236fb88af456b291ec933dbf4de66 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -1,14 +1,15 @@ -from recordtype import recordtype +import time +from collections import deque -import numpy as np -from numpy import array # import xarray as xr import matplotlib.pyplot as plt -import time -from collections import deque -from flatland.utils.render_qt import QTGL, QTSVG -from flatland.utils.graphics_pil import PILGL +import numpy as np +from numpy import array +from recordtype import recordtype + from flatland.utils.graphics_layer import GraphicsLayer +from flatland.utils.graphics_pil import PILGL +from flatland.utils.render_qt import QTGL, QTSVG # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! @@ -409,13 +410,12 @@ class RenderTool(object): color=sColor ) - def drawTrans2( - self, - xyLine, xyCentre, - rotation, bDeadEnd=False, - sColor="gray", - bArrow=True, - spacing=0.1): + def drawTrans2(self, + xyLine, xyCentre, + rotation, bDeadEnd=False, + sColor="gray", + bArrow=True, + spacing=0.1): """ gLine is a numpy 2d array of points, in the plotting space / coords. @@ -501,7 +501,7 @@ class RenderTool(object): for visited_cell in observation_dict[agent]: cell_coord = array(visited_cell[:2]) cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf - self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100) + self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100) def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False): @@ -604,11 +604,10 @@ class RenderTool(object): "rot:", rotation, ) - def renderEnv( - self, show=False, curves=True, spacing=False, - arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False, - iEpisode=None, iStep=None, - iSelectedAgent=None, action_dict=None): + def renderEnv(self, show=False, curves=True, spacing=False, + arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False, + iEpisode=None, iStep=None, + iSelectedAgent=None, action_dict=None): """ Draw the environment using matplotlib. Draw into the figure if provided. @@ -683,7 +682,6 @@ class RenderTool(object): self.gl.pause(0.00001) - return def _draw_square(self, center, size, color, opacity=255, layer=0): @@ -725,10 +723,9 @@ class RenderTool(object): gP0 = array([gX1, gY1, gZ1]) - def renderEnv2( - self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True, sRailColor="gray", - frames=False, iEpisode=None, iStep=None, iSelectedAgent=None, - action_dict=dict()): + def renderEnv2(self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True, + sRailColor="gray", frames=False, iEpisode=None, iStep=None, iSelectedAgent=None, + action_dict=dict()): """ Draw the environment using matplotlib. Draw into the figure if provided.