Skip to content
Snippets Groups Projects
Commit 80189222 authored by Erik Nygren's avatar Erik Nygren
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland

parents 14b63f9d 53422dfe
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ import os
import random
from collections import deque
import time
import numpy as np
import torch
......@@ -165,6 +166,8 @@ class Demo:
for step in range(max_nbr_of_steps):
self.renderer.renderEnv(show=True)
time.sleep(.2)
# print(step)
# Action
for a in range(self.env.get_num_agents()):
......
import random
import time
from collections import deque
import numpy as np
......@@ -193,15 +192,18 @@ for trials in range(1, n_trials + 1):
scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window)))
print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.get_num_agents(),
trials,
np.mean(scores_window),
100 * np.mean(done_window),
eps, action_prob / np.sum(action_prob)), end=" ")
print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.get_num_agents(),
trials,
np.mean(scores_window),
100 * np.mean(done_window),
eps, action_prob / np.sum(action_prob)), end=" ")
if trials % 100 == 0:
print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
print(
'\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.get_num_agents(),
trials,
np.mean(scores_window),
......
from recordtype import recordtype
import time
from collections import deque
import numpy as np
from numpy import array
# import xarray as xr
import matplotlib.pyplot as plt
import time
from collections import deque
from flatland.utils.render_qt import QTGL, QTSVG
from flatland.utils.graphics_pil import PILGL
import numpy as np
from numpy import array
from recordtype import recordtype
from flatland.utils.graphics_layer import GraphicsLayer
from flatland.utils.graphics_pil import PILGL
from flatland.utils.render_qt import QTGL, QTSVG
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
......@@ -409,13 +410,12 @@ class RenderTool(object):
color=sColor
)
def drawTrans2(
self,
xyLine, xyCentre,
rotation, bDeadEnd=False,
sColor="gray",
bArrow=True,
spacing=0.1):
def drawTrans2(self,
xyLine, xyCentre,
rotation, bDeadEnd=False,
sColor="gray",
bArrow=True,
spacing=0.1):
"""
gLine is a numpy 2d array of points,
in the plotting space / coords.
......@@ -501,7 +501,7 @@ class RenderTool(object):
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
self._draw_square(cell_coord_trans, 1 / (agent+1.1), color, layer=1, opacity=100)
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
......@@ -604,11 +604,10 @@ class RenderTool(object):
"rot:", rotation,
)
def renderEnv(
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False,
iEpisode=None, iStep=None,
iSelectedAgent=None, action_dict=None):
def renderEnv(self, show=False, curves=True, spacing=False,
arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False,
iEpisode=None, iStep=None,
iSelectedAgent=None, action_dict=None):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
......@@ -683,7 +682,6 @@ class RenderTool(object):
self.gl.pause(0.00001)
return
def _draw_square(self, center, size, color, opacity=255, layer=0):
......@@ -725,10 +723,9 @@ class RenderTool(object):
gP0 = array([gX1, gY1, gZ1])
def renderEnv2(
self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True, sRailColor="gray",
frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
action_dict=dict()):
def renderEnv2(self, show=False, curves=True, spacing=False, arrows=False, agents=True, renderobs=True,
sRailColor="gray", frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
action_dict=dict()):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
......
......@@ -105,7 +105,7 @@ class Zug(object):
class Track(object):
def __init__(self):
dFiles = {
"": "Background_#91D1DD.svg",
"": "Background_#9CCB89.svg",
"WE": "Gleis_Deadend.svg",
"WW EE NN SS": "Gleis_Diamond_Crossing.svg",
"WW EE": "Gleis_horizontal.svg",
......@@ -132,7 +132,7 @@ class Track(object):
lDirs = list("NESW")
svgBG = SVG("./svg/Background_#91D1DD.svg")
svgBG = SVG("./svg/Background_#9CCB89.svg")
for sTrans, sFile in dFiles.items():
svg = SVG("./svg/" + sFile)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment