Skip to content
Snippets Groups Projects
Commit b044ae56 authored by adrian_egli2's avatar adrian_egli2
Browse files

Jupyter notebooks are ready for flatland3.

But they should be overworked. Might just one example per use case would be much more clear. TODO - cleanup
parent f4bc62a8
No related branches found
No related tags found
No related merge requests found
from typing import List, NamedTuple
import PIL
from IPython import display from IPython import display
from ipycanvas import canvas from ipycanvas import canvas
import time from flatland.envs.rail_env import RailEnvActions
from flatland.envs import malfunction_generators as malgen
from flatland.envs.agent_utils import EnvAgent
#from flatland.envs import sparse_rail_gen as spgen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import agent_chains as ac
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.persistence import RailEnvPersister
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from flatland.utils import env_edit_utils as eeu
from typing import List, NamedTuple
class Behaviour(): class Behaviour():
def __init__(self, env): def __init__(self, env):
...@@ -24,10 +13,12 @@ class Behaviour(): ...@@ -24,10 +13,12 @@ class Behaviour():
def getActions(self): def getActions(self):
return {} return {}
class AlwaysForward(Behaviour): class AlwaysForward(Behaviour):
def getActions(self): def getActions(self):
return { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) } return {i: RailEnvActions.MOVE_FORWARD for i in range(self.nAg)}
class DelayedStartForward(AlwaysForward): class DelayedStartForward(AlwaysForward):
def __init__(self, env, nStartDelay=2): def __init__(self, env, nStartDelay=2):
...@@ -37,17 +28,19 @@ class DelayedStartForward(AlwaysForward): ...@@ -37,17 +28,19 @@ class DelayedStartForward(AlwaysForward):
def getActions(self): def getActions(self):
iStep = self.env._elapsed_steps + 1 iStep = self.env._elapsed_steps + 1
nAgentsMoving = min(self.nAg, iStep // self.nStartDelay) nAgentsMoving = min(self.nAg, iStep // self.nStartDelay)
return { i:RailEnvActions.MOVE_FORWARD for i in range(nAgentsMoving) } return {i: RailEnvActions.MOVE_FORWARD for i in range(nAgentsMoving)}
AgentPause = NamedTuple("AgentPause",
[
("iAg", int),
("iPauseAt", int),
("iPauseFor", int)
])
AgentPause = NamedTuple("AgentPause",
[
("iAg", int),
("iPauseAt", int),
("iPauseFor", int)
])
class ForwardWithPause(Behaviour): class ForwardWithPause(Behaviour):
def __init__(self, env, lPauses:List[AgentPause]): def __init__(self, env, lPauses: List[AgentPause]):
self.env = env self.env = env
self.nAg = len(env.agents) self.nAg = len(env.agents)
self.lPauses = lPauses self.lPauses = lPauses
...@@ -57,39 +50,40 @@ class ForwardWithPause(Behaviour): ...@@ -57,39 +50,40 @@ class ForwardWithPause(Behaviour):
iStep = self.env._elapsed_steps + 1 # add one because this is called before step() iStep = self.env._elapsed_steps + 1 # add one because this is called before step()
# new pauses starting this step # new pauses starting this step
lNewPauses = [ tPause for tPause in self.lPauses if tPause.iPauseAt == iStep ] lNewPauses = [tPause for tPause in self.lPauses if tPause.iPauseAt == iStep]
# copy across the agent index and pause length # copy across the agent index and pause length
for pause in lNewPauses: for pause in lNewPauses:
self.dAgPaused[pause.iAg] = pause.iPauseFor self.dAgPaused[pause.iAg] = pause.iPauseFor
# default action is move forward # default action is move forward
dAction = { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) } dAction = {i: RailEnvActions.MOVE_FORWARD for i in range(self.nAg)}
# overwrite paused agents with stop # overwrite paused agents with stop
for iAg in self.dAgPaused: for iAg in self.dAgPaused:
dAction[iAg] = RailEnvActions.STOP_MOVING dAction[iAg] = RailEnvActions.STOP_MOVING
# decrement the counters for each pause, and remove any expired pauses. # decrement the counters for each pause, and remove any expired pauses.
lFinished = [] lFinished = []
for iAg in self.dAgPaused: for iAg in self.dAgPaused:
self.dAgPaused[iAg] -= 1 self.dAgPaused[iAg] -= 1
if self.dAgPaused[iAg] <= 0: if self.dAgPaused[iAg] <= 0:
lFinished.append(iAg) lFinished.append(iAg)
for iAg in lFinished: for iAg in lFinished:
self.dAgPaused.pop(iAg, None) self.dAgPaused.pop(iAg, None)
return dAction return dAction
class Deterministic(Behaviour): class Deterministic(Behaviour):
def __init__(self, env, dAg_lActions): def __init__(self, env, dAg_lActions):
super().__init__(env) super().__init__(env)
self.dAg_lActions = dAg_lActions self.dAg_lActions = dAg_lActions
def getActions(self): def getActions(self):
iStep = self.env._elapsed_steps iStep = self.env._elapsed_steps
dAg_Action = {} dAg_Action = {}
for iAg, lActions in self.dAg_lActions.items(): for iAg, lActions in self.dAg_lActions.items():
if iStep < len(lActions): if iStep < len(lActions):
...@@ -97,16 +91,13 @@ class Deterministic(Behaviour): ...@@ -97,16 +91,13 @@ class Deterministic(Behaviour):
else: else:
iAct = RailEnvActions.DO_NOTHING iAct = RailEnvActions.DO_NOTHING
dAg_Action[iAg] = iAct dAg_Action[iAg] = iAct
#print(iStep, dAg_Action[0]) # print(iStep, dAg_Action[0])
return dAg_Action return dAg_Action
class EnvCanvas(): class EnvCanvas():
def __init__(self, env, behaviour:Behaviour=None): def __init__(self, env, behaviour: Behaviour = None):
self.env = env self.env = env
self.iStep = 0 self.iStep = 0
if behaviour is None: if behaviour is None:
...@@ -114,11 +105,11 @@ class EnvCanvas(): ...@@ -114,11 +105,11 @@ class EnvCanvas():
self.behaviour = behaviour self.behaviour = behaviour
self.oRT = RenderTool(env, show_debug=True) self.oRT = RenderTool(env, show_debug=True)
self.oCan = canvas.Canvas(size=(600,300)) self.oCan = canvas.Canvas(size=(600, 300))
self.render() self.render()
def render(self): def render(self):
self.oRT.render_env(show_rowcols=True, show_inactive_agents=False, show_observations=False) self.oRT.render_env(show_rowcols=True, show_inactive_agents=False, show_observations=False)
self.oCan.put_image_data(self.oRT.get_image()) self.oCan.put_image_data(self.oRT.get_image())
def step(self): def step(self):
...@@ -128,5 +119,3 @@ class EnvCanvas(): ...@@ -128,5 +119,3 @@ class EnvCanvas():
def show(self): def show(self):
self.render() self.render()
display.display(self.oCan) display.display(self.oCan)
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Simple Animation Demo # Simple Animation Demo
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%load_ext autoreload %load_ext autoreload
%autoreload 2 %autoreload 2
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from flatland.utils.rendertools import RenderTool
import ipycanvas
import time import time
from IPython import display
from ipycanvas import canvas
from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions as rea from flatland.envs.rail_env import RailEnvActions as rea
from flatland.envs.persistence import RailEnvPersister from flatland.envs.persistence import RailEnvPersister
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
env, env_dict = RailEnvPersister.load_new("complex_scene_2.pkl", load_from_package="env_data.railway") env, env_dict = RailEnvPersister.load_new("complex_scene_2.pkl", load_from_package="env_data.railway")
_ = env.reset() _ = env.reset()
env._max_episode_steps = 100 env._max_episode_steps = 100
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
oRT = RenderTool(env, gl="PILSVG") oRT = RenderTool(env, gl="PILSVG", jupyter=False, show_debug=True)
oRT.render_env(show_observations=False,show_predictions=False)
```
%% Cell type:code id: tags:
``` python
image_arr = oRT.get_image() image_arr = oRT.get_image()
canvas = ipycanvas.Canvas() oCanvas = canvas.Canvas()
canvas.put_image_data(image_arr) oCanvas.put_image_data(image_arr)
display(canvas) display.display(oCanvas)
done={"__all__":False} done={"__all__":False}
while not done["__all__"]: while not done["__all__"]:
actions = {} actions = {}
for agent_handle, agents in enumerate(env.agents): for agent_handle, agents in enumerate(env.agents):
actions.update({agent_handle:rea.MOVE_FORWARD}) actions.update({agent_handle:rea.MOVE_FORWARD})
obs, rew, done, info = env.step(actions) obs, rew, done, info = env.step(actions)
oRT.render_env(show_observations=False,show_predictions=False) oRT.render_env(show_observations=False,show_predictions=False)
gIm = oRT.get_image() gIm = oRT.get_image()
canvas.put_image_data(gIm) oCanvas.put_image_data(gIm)
time.sleep(0.1) time.sleep(0.1)
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment