Skip to content
Snippets Groups Projects
Commit be3c58b0 authored by hagrid67's avatar hagrid67
Browse files

updated test-collision to use jupyter_utils for env, agent behaviour, etc

parent 296aecf7
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,7 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions ...@@ -14,7 +14,7 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.persistence import RailEnvPersister from flatland.envs.persistence import RailEnvPersister
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from flatland.utils import env_edit_utils as eeu from flatland.utils import env_edit_utils as eeu
from typing import List, NamedTuple
class Behaviour(): class Behaviour():
def __init__(self, env): def __init__(self, env):
...@@ -28,6 +28,49 @@ class AlwaysForward(Behaviour): ...@@ -28,6 +28,49 @@ class AlwaysForward(Behaviour):
def getActions(self): def getActions(self):
return { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) } return { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) }
AgentPause = NamedTuple("AgentPause",
[
("iAg", int),
("iPauseAt", int),
("iPauseFor", int)
])
class ForwardWithPause(Behaviour):
def __init__(self, env, lPauses:List[AgentPause]):
self.env = env
self.nAg = len(env.agents)
self.lPauses = lPauses
self.dAgPaused = {}
def getActions(self):
iStep = self.env._elapsed_steps + 1 # add one because this is called before step()
# new pauses starting this step
lNewPauses = [ tPause for tPause in self.lPauses if tPause.iPauseAt == iStep ]
# copy across the agent index and pause length
for pause in lNewPauses:
self.dAgPaused[pause.iAg] = pause.iPauseFor
# default action is move forward
dAction = { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) }
# overwrite paused agents with stop
for iAg in self.dAgPaused:
dAction[iAg] = RailEnvActions.STOP_MOVING
# decrement the counters for each pause, and remove any expired pauses.
lFinished = []
for iAg in self.dAgPaused:
self.dAgPaused[iAg] -= 1
if self.dAgPaused[iAg] <= 0:
lFinished.append(iAg)
for iAg in lFinished:
self.dAgPaused.pop(iAg, None)
return dAction
class EnvCanvas(): class EnvCanvas():
......
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Test Collisions # Test Collisions
A visual test of a "rear-shunt" collision, to ensure that the agent does not get marked as collided permananently. A visual test of a "rear-shunt" collision, to ensure that the agent does not get marked as collided permananently.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%load_ext autoreload %load_ext autoreload
%autoreload 2 %autoreload 2
from IPython.core import display from IPython.core import display
display.display(display.HTML("<style>.container { width:95% !important; }</style>")) display.display(display.HTML("<style>.container { width:95% !important; }</style>"))
``` ```
%% Output %% Output
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import PIL import PIL
from IPython import display from IPython import display
from ipycanvas import canvas from ipycanvas import canvas
import time import time
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from flatland.envs import malfunction_generators as malgen from flatland.envs import malfunction_generators as malgen
from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgent
#from flatland.envs import sparse_rail_gen as spgen #from flatland.envs import sparse_rail_gen as spgen
from flatland.envs import rail_generators as rail_gen from flatland.envs import rail_generators as rail_gen
from flatland.envs import agent_chains as ac from flatland.envs import agent_chains as ac
from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.persistence import RailEnvPersister from flatland.envs.persistence import RailEnvPersister
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from flatland.utils import env_edit_utils as eeu from flatland.utils import env_edit_utils as eeu
from flatland.utils import jupyter_utils as ju from flatland.utils import jupyter_utils as ju
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
nAg=2 env, envModel = eeu.makeTestEnv("merging_spurs", nAg=2, bUCF=True)
bUCF=False behaviour = ju.ForwardWithPause(env, [ju.AgentPause(0, 10, 5)])
bReverseStart=False oEC = ju.EnvCanvas(env, behaviour)
env, envModel = eeu.makeTestEnv(sName="merging_spurs", nAg=nAg, bUCF=bUCF)
oRT = RenderTool(env, show_debug=True)
oRT.render_env(show_rowcols=True, show_inactive_agents=True, show_observations=False)
oCan = canvas.Canvas(size=(600,300))
oCan.put_image_data(oRT.get_image())
print(f"UCF: {bUCF}")
display.display(oCan)
iPauseStep=10
iPauseLen = 5
for iStep in range(25):
if bReverseStart:
iAgentStart = max((nAg - 2 - iStep*2), 0)
else:
iAgentStart = 0
dActions = { i:int(RailEnvActions.MOVE_FORWARD) for i in range(iAgentStart, len(env.agents)) }
if (iStep >= iPauseStep) and (iStep < iPauseStep + iPauseLen):
dActions[0] = RailEnvActions.STOP_MOVING
#print(dActions)
env.step(dActions)
oRT.render_env(show_rowcols=True, show_inactive_agents=True, show_observations=False)
aImg = oRT.get_image()
pilImg = PIL.Image.fromarray(aImg)
oCan.put_image_data(aImg)
#with open(f"tmp-images/img-{iStep:03d}.png", "wb") as fImg:
# pilImg.save(fImg)
time.sleep(0.05)
```
%% Output
UCF: False
%% Cell type:code id: tags:
``` python
#!ffmpeg -i ./tmp-images/img-%03d.png -vcodec mpeg4 -filter:v "setpts=8.0*PTS" -y movie_nucf_stop.mp4
```
%% Cell type:code id: tags:
``` python
oEC = ju.EnvCanvas(env)
env.reset(regenerate_rail=False) env.reset(regenerate_rail=False)
oEC.show() oEC.show()
for i in range(25):
oEC.step()
oEC.render()
time.sleep(0.1)
``` ```
%% Output %% Output
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
oEC.step() dAgStateExpected = {0: (7, 15, 2), 1: (6, 15, 2)}
oEC.render() dAgState={}
for iAg, ag in enumerate(env.agents):
dAgState[iAg] = (*ag.position, ag.direction)
assert dAgState == dAgStateExpected
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment