Skip to content
Snippets Groups Projects
Commit 78658902 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files
parents 9131c24f 76bbf514
No related branches found
No related tags found
No related merge requests found
import random import random
import time
from collections import deque from collections import deque
import numpy as np import numpy as np
...@@ -193,15 +192,18 @@ for trials in range(1, n_trials + 1): ...@@ -193,15 +192,18 @@ for trials in range(1, n_trials + 1):
scores.append(np.mean(scores_window)) scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window))) dones_list.append((np.mean(done_window)))
print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%' +
env.get_num_agents(), '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
trials, env.get_num_agents(),
np.mean(scores_window), trials,
100 * np.mean(done_window), np.mean(scores_window),
eps, action_prob / np.sum(action_prob)), end=" ") 100 * np.mean(done_window),
eps, action_prob / np.sum(action_prob)), end=" ")
if trials % 100 == 0: if trials % 100 == 0:
print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( print(
'\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.get_num_agents(), env.get_num_agents(),
trials, trials,
np.mean(scores_window), np.mean(scores_window),
......
from recordtype import recordtype import time
from collections import deque
import numpy as np
from numpy import array
# import xarray as xr # import xarray as xr
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import time import numpy as np
from collections import deque from numpy import array
from flatland.utils.render_qt import QTGL, QTSVG from recordtype import recordtype
from flatland.utils.graphics_pil import PILGL
from flatland.utils.graphics_layer import GraphicsLayer from flatland.utils.graphics_layer import GraphicsLayer
from flatland.utils.graphics_pil import PILGL
from flatland.utils.render_qt import QTGL, QTSVG
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
...@@ -409,13 +410,12 @@ class RenderTool(object): ...@@ -409,13 +410,12 @@ class RenderTool(object):
color=sColor color=sColor
) )
def drawTrans2( def drawTrans2(self,
self, xyLine, xyCentre,
xyLine, xyCentre, rotation, bDeadEnd=False,
rotation, bDeadEnd=False, sColor="gray",
sColor="gray", bArrow=True,
bArrow=True, spacing=0.1):
spacing=0.1):
""" """
gLine is a numpy 2d array of points, gLine is a numpy 2d array of points,
in the plotting space / coords. in the plotting space / coords.
...@@ -501,7 +501,7 @@ class RenderTool(object): ...@@ -501,7 +501,7 @@ class RenderTool(object):
for visited_cell in observation_dict[agent]: for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2]) cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100) self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False): def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
...@@ -604,11 +604,10 @@ class RenderTool(object): ...@@ -604,11 +604,10 @@ class RenderTool(object):
"rot:", rotation, "rot:", rotation,
) )
def renderEnv( def renderEnv(self, show=False, curves=True, spacing=False,
self, show=False, curves=True, spacing=False, arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False,
arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False, iEpisode=None, iStep=None,
iEpisode=None, iStep=None, iSelectedAgent=None, action_dict=None):
iSelectedAgent=None, action_dict=None):
""" """
Draw the environment using matplotlib. Draw the environment using matplotlib.
Draw into the figure if provided. Draw into the figure if provided.
...@@ -683,7 +682,6 @@ class RenderTool(object): ...@@ -683,7 +682,6 @@ class RenderTool(object):
self.gl.pause(0.00001) self.gl.pause(0.00001)
return return
def _draw_square(self, center, size, color, opacity=255, layer=0): def _draw_square(self, center, size, color, opacity=255, layer=0):
...@@ -725,10 +723,9 @@ class RenderTool(object): ...@@ -725,10 +723,9 @@ class RenderTool(object):
gP0 = array([gX1, gY1, gZ1]) gP0 = array([gX1, gY1, gZ1])
def renderEnv2( def renderEnv2(self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True,
self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True, sRailColor="gray", sRailColor="gray", frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
frames=False, iEpisode=None, iStep=None, iSelectedAgent=None, action_dict=dict()):
action_dict=dict()):
""" """
Draw the environment using matplotlib. Draw the environment using matplotlib.
Draw into the figure if provided. Draw into the figure if provided.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment