Skip to content
Snippets Groups Projects
Commit 80189222 authored by Erik Nygren's avatar Erik Nygren
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland

parents 14b63f9d 53422dfe
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import random import random
from collections import deque from collections import deque
import time
import numpy as np import numpy as np
import torch import torch
...@@ -165,6 +166,8 @@ class Demo: ...@@ -165,6 +166,8 @@ class Demo:
for step in range(max_nbr_of_steps): for step in range(max_nbr_of_steps):
self.renderer.renderEnv(show=True) self.renderer.renderEnv(show=True)
time.sleep(.2)
# print(step) # print(step)
# Action # Action
for a in range(self.env.get_num_agents()): for a in range(self.env.get_num_agents()):
......
import random import random
import time
from collections import deque from collections import deque
import numpy as np import numpy as np
...@@ -193,15 +192,18 @@ for trials in range(1, n_trials + 1): ...@@ -193,15 +192,18 @@ for trials in range(1, n_trials + 1):
scores.append(np.mean(scores_window)) scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window))) dones_list.append((np.mean(done_window)))
print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%' +
env.get_num_agents(), '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
trials, env.get_num_agents(),
np.mean(scores_window), trials,
100 * np.mean(done_window), np.mean(scores_window),
eps, action_prob / np.sum(action_prob)), end=" ") 100 * np.mean(done_window),
eps, action_prob / np.sum(action_prob)), end=" ")
if trials % 100 == 0: if trials % 100 == 0:
print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( print(
'\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.get_num_agents(), env.get_num_agents(),
trials, trials,
np.mean(scores_window), np.mean(scores_window),
......
from recordtype import recordtype import time
from collections import deque
import numpy as np
from numpy import array
# import xarray as xr # import xarray as xr
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import time import numpy as np
from collections import deque from numpy import array
from flatland.utils.render_qt import QTGL, QTSVG from recordtype import recordtype
from flatland.utils.graphics_pil import PILGL
from flatland.utils.graphics_layer import GraphicsLayer from flatland.utils.graphics_layer import GraphicsLayer
from flatland.utils.graphics_pil import PILGL
from flatland.utils.render_qt import QTGL, QTSVG
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
...@@ -409,13 +410,12 @@ class RenderTool(object): ...@@ -409,13 +410,12 @@ class RenderTool(object):
color=sColor color=sColor
) )
def drawTrans2( def drawTrans2(self,
self, xyLine, xyCentre,
xyLine, xyCentre, rotation, bDeadEnd=False,
rotation, bDeadEnd=False, sColor="gray",
sColor="gray", bArrow=True,
bArrow=True, spacing=0.1):
spacing=0.1):
""" """
gLine is a numpy 2d array of points, gLine is a numpy 2d array of points,
in the plotting space / coords. in the plotting space / coords.
...@@ -501,7 +501,7 @@ class RenderTool(object): ...@@ -501,7 +501,7 @@ class RenderTool(object):
for visited_cell in observation_dict[agent]: for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2]) cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100) self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False): def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
...@@ -604,11 +604,10 @@ class RenderTool(object): ...@@ -604,11 +604,10 @@ class RenderTool(object):
"rot:", rotation, "rot:", rotation,
) )
def renderEnv( def renderEnv(self, show=False, curves=True, spacing=False,
self, show=False, curves=True, spacing=False, arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False,
arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False, iEpisode=None, iStep=None,
iEpisode=None, iStep=None, iSelectedAgent=None, action_dict=None):
iSelectedAgent=None, action_dict=None):
""" """
Draw the environment using matplotlib. Draw the environment using matplotlib.
Draw into the figure if provided. Draw into the figure if provided.
...@@ -683,7 +682,6 @@ class RenderTool(object): ...@@ -683,7 +682,6 @@ class RenderTool(object):
self.gl.pause(0.00001) self.gl.pause(0.00001)
return return
def _draw_square(self, center, size, color, opacity=255, layer=0): def _draw_square(self, center, size, color, opacity=255, layer=0):
...@@ -725,10 +723,9 @@ class RenderTool(object): ...@@ -725,10 +723,9 @@ class RenderTool(object):
gP0 = array([gX1, gY1, gZ1]) gP0 = array([gX1, gY1, gZ1])
def renderEnv2( def renderEnv2(self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True,
self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True, sRailColor="gray", sRailColor="gray", frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
frames=False, iEpisode=None, iStep=None, iSelectedAgent=None, action_dict=dict()):
action_dict=dict()):
""" """
Draw the environment using matplotlib. Draw the environment using matplotlib.
Draw into the figure if provided. Draw into the figure if provided.
......
...@@ -105,7 +105,7 @@ class Zug(object): ...@@ -105,7 +105,7 @@ class Zug(object):
class Track(object): class Track(object):
def __init__(self): def __init__(self):
dFiles = { dFiles = {
"": "Background_#91D1DD.svg", "": "Background_#9CCB89.svg",
"WE": "Gleis_Deadend.svg", "WE": "Gleis_Deadend.svg",
"WW EE NN SS": "Gleis_Diamond_Crossing.svg", "WW EE NN SS": "Gleis_Diamond_Crossing.svg",
"WW EE": "Gleis_horizontal.svg", "WW EE": "Gleis_horizontal.svg",
...@@ -132,7 +132,7 @@ class Track(object): ...@@ -132,7 +132,7 @@ class Track(object):
lDirs = list("NESW") lDirs = list("NESW")
svgBG = SVG("./svg/Background_#91D1DD.svg") svgBG = SVG("./svg/Background_#9CCB89.svg")
for sTrans, sFile in dFiles.items(): for sTrans, sFile in dFiles.items():
svg = SVG("./svg/" + sFile) svg = SVG("./svg/" + sFile)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment