Skip to content
Snippets Groups Projects
Commit 50d67dce authored by hagrid67's avatar hagrid67
Browse files

cleaned up agent_close_following notebook, added an assert on frozen agent states

parent ff1a2416
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import List, Tuple from typing import List, Tuple
import graphviz as gv
class MotionCheck(object): class MotionCheck(object):
""" Class to find chains of agents which are "colliding" with a stopped agent. """ Class to find chains of agents which are "colliding" with a stopped agent.
...@@ -180,10 +181,15 @@ class MotionCheck(object): ...@@ -180,10 +181,15 @@ class MotionCheck(object):
def render(omc:MotionCheck): def render(omc:MotionCheck, horizontal=True):
oAG = nx.drawing.nx_agraph.to_agraph(omc.G) oAG = nx.drawing.nx_agraph.to_agraph(omc.G)
oAG.layout("dot") oAG.layout("dot")
return oAG.draw(format="png") sDot = oAG.to_string()
if horizontal:
sDot = sDot.replace('{', '{ rankdir="LR" ')
#return oAG.draw(format="png")
# This returns a graphviz object which implements __repr_svg
return gv.Source(sDot)
class ChainTestEnv(object): class ChainTestEnv(object):
""" Just for testing agent chains """ Just for testing agent chains
......
...@@ -19,30 +19,39 @@ from flatland.utils import env_edit_utils as eeu ...@@ -19,30 +19,39 @@ from flatland.utils import env_edit_utils as eeu
class Behaviour(): class Behaviour():
def __init__(self, env): def __init__(self, env):
self.env = env self.env = env
self.nAg = len(env.agents)
def getActions(self): def getActions(self):
return {} return {}
class AlwaysForward(Behaviour):
class AlwaysForward(): def getActions(self):
pass return { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) }
class EnvCanvas(): class EnvCanvas():
def __init__(self, env): def __init__(self, env, behaviour:Behaviour=None):
self.env = env self.env = env
self.iStep = 0
if behaviour is None:
behaviour = AlwaysForward(env)
self.behaviour = behaviour
self.oRT = RenderTool(env, show_debug=True) self.oRT = RenderTool(env, show_debug=True)
self.render()
self.oCan = canvas.Canvas(size=(600,300)) self.oCan = canvas.Canvas(size=(600,300))
self.oCan.put_image_data(self.oRT.get_image()) self.render()
def render(self): def render(self):
self.oRT.render_env(show_rowcols=True, show_inactive_agents=True, show_observations=False) self.oRT.render_env(show_rowcols=True, show_inactive_agents=True, show_observations=False)
self.oCan.put_image_data(self.oRT.get_image())
def step(self):
dAction = self.behaviour.getActions()
self.env.step(dAction)
def show(self): def show(self):
self.render() self.render()
self.oCan.put_image_data(self.oRT.get_image())
display.display(self.oCan) display.display(self.oCan)
source diff could not be displayed: it is too large. Options to address this: view the blob.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment