Skip to content
Snippets Groups Projects
Commit 9623a8d6 authored by hagrid67's avatar hagrid67
Browse files

made the env loading more forgiving (in persistence) to allow the loading of...

made the env loading more forgiving (in persistence) to allow the loading of old envs, and msgpack files with .pkl extension.  Also give default 100 max_episode_steps where this is missing.  Add test-saved-envs notebook to test loading all old envs.
parent 67433db1
No related branches found
No related tags found
No related merge requests found
...@@ -119,8 +119,20 @@ class EnvAgent: ...@@ -119,8 +119,20 @@ class EnvAgent:
def load_legacy_static_agent(cls, static_agents_data: Tuple): def load_legacy_static_agent(cls, static_agents_data: Tuple):
agents = [] agents = []
for i, static_agent in enumerate(static_agents_data): for i, static_agent in enumerate(static_agents_data):
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1], if len(static_agent) >= 6:
direction=static_agent[1], target=static_agent[2], moving=static_agent[3], agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i) direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
else:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
moving=False,
speed_data={"speed":1., "position_fraction":0., "transition_action_on_cell_exit":0.},
malfunction_data={
'malfunction': 0,
'nr_malfunctions': 0,
'moving_before_malfunction': False
},
handle=i)
agents.append(agent) agents.append(agent)
return agents return agents
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import pickle import pickle
import msgpack import msgpack
import msgpack_numpy
import numpy as np import numpy as np
from flatland.envs import rail_env from flatland.envs import rail_env
...@@ -22,6 +23,7 @@ from flatland.envs import malfunction_generators as mal_gen ...@@ -22,6 +23,7 @@ from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen from flatland.envs import rail_generators as rail_gen
from flatland.envs import schedule_generators as sched_gen from flatland.envs import schedule_generators as sched_gen
msgpack_numpy.patch()
class RailEnvPersister(object): class RailEnvPersister(object):
...@@ -111,9 +113,13 @@ class RailEnvPersister(object): ...@@ -111,9 +113,13 @@ class RailEnvPersister(object):
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package) env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
llGrid = env_dict["grid"]
height = len(llGrid)
width = len(llGrid[0])
# TODO: inefficient - each one of these generators loads the complete env file. # TODO: inefficient - each one of these generators loads the complete env file.
env = rail_env.RailEnv(width=1, height=1, env = rail_env.RailEnv(#width=1, height=1,
width=width, height=height,
rail_generator=rail_gen.rail_from_file(filename, rail_generator=rail_gen.rail_from_file(filename,
load_from_package=load_from_package), load_from_package=load_from_package),
schedule_generator=sched_gen.schedule_from_file(filename, schedule_generator=sched_gen.schedule_from_file(filename,
...@@ -141,7 +147,11 @@ class RailEnvPersister(object): ...@@ -141,7 +147,11 @@ class RailEnvPersister(object):
if filename.endswith("mpk"): if filename.endswith("mpk"):
env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8") env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8")
elif filename.endswith("pkl"): elif filename.endswith("pkl"):
env_dict = pickle.loads(load_data) try:
env_dict = pickle.loads(load_data)
except ValueError:
print("pickle failed to load file:", filename, " trying msgpack (deprecated)...")
env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8")
else: else:
print(f"filename {filename} must end with either pkl or mpk") print(f"filename {filename} must end with either pkl or mpk")
env_dict = {} env_dict = {}
......
...@@ -323,6 +323,10 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: ...@@ -323,6 +323,10 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package) env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
max_episode_steps = env_dict.get("max_episode_steps", 0) max_episode_steps = env_dict.get("max_episode_steps", 0)
if (max_episode_steps==0):
print("This env file has no max_episode_steps (deprecated) - setting to 100")
max_episode_steps = 100
agents = env_dict["agents"] agents = env_dict["agents"]
#print("schedule generator from_file - agents: ", agents) #print("schedule generator from_file - agents: ", agents)
...@@ -335,6 +339,8 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: ...@@ -335,6 +339,8 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
agents_direction = [a.initial_direction for a in agents] agents_direction = [a.initial_direction for a in agents]
agents_target = [a.target for a in agents] agents_target = [a.target for a in agents]
agents_speed = [a.speed_data['speed'] for a in agents] agents_speed = [a.speed_data['speed'] for a in agents]
# Malfunctions from here are not used. They have their own generator.
#agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents] #agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
return Schedule(agent_positions=agents_position, agent_directions=agents_direction, return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
......
...@@ -70,9 +70,7 @@ def makeEnv2(nAg=2, shape=(20,10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDi ...@@ -70,9 +70,7 @@ def makeEnv2(nAg=2, shape=(20,10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDi
return env, envModel return env, envModel
def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True): ddEnvSpecs = {
ddEnvSpecs = {
# opposing stations with single alternative path # opposing stations with single alternative path
"single_alternative":{ "single_alternative":{
"llrcPaths": [ "llrcPaths": [
...@@ -116,10 +114,34 @@ def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True): ...@@ -116,10 +114,34 @@ def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
"lrcStarts": [(1,3)], "lrcStarts": [(1,3)],
"lrcTargs": [(2,1)], "lrcTargs": [(2,1)],
"liDirs": [1] "liDirs": [1]
},
# two loops
"loop_with_loops": {
"llrcPaths": [
# big outer loop Row 1, 8; Col 1, 15
[(1,1), (1,15), (8, 15), (8,1), (1,1), (1,3)],
# alternative 1
[(1,3), (1,5), (3,5), (3,10), (1, 10), (1, 12)],
# alternative 2
[(8,3), (8,5), (6,5), (6,10), (8, 10), (8, 12)],
],
# list of row,col of agent start cells
"lrcStarts": [(1,3), (8, 3)],
# list of row,col of targets
"lrcTargs": [(8,2), (1,2)],
# list of initial directions
"liDirs": [1, 1],
} }
} }
def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
global ddEnvSpecs
dSpec = ddEnvSpecs[sName] dSpec = ddEnvSpecs[sName]
return makeEnv2(nAg=nAg, bUCF=bUCF, **dSpec) return makeEnv2(nAg=nAg, bUCF=bUCF, **dSpec)
......
...@@ -11,6 +11,7 @@ from flatland.envs.agent_utils import EnvAgent ...@@ -11,6 +11,7 @@ from flatland.envs.agent_utils import EnvAgent
from flatland.envs import rail_generators as rail_gen from flatland.envs import rail_generators as rail_gen
from flatland.envs import agent_chains as ac from flatland.envs import agent_chains as ac
from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.persistence import RailEnvPersister from flatland.envs.persistence import RailEnvPersister
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from flatland.utils import env_edit_utils as eeu from flatland.utils import env_edit_utils as eeu
...@@ -71,6 +72,27 @@ class ForwardWithPause(Behaviour): ...@@ -71,6 +72,27 @@ class ForwardWithPause(Behaviour):
return dAction return dAction
class Deterministic(Behaviour):
def __init__(self, env, dAg_lActions):
super().__init__(env)
self.dAg_lActions = dAg_lActions
def getActions(self):
iStep = self.env._elapsed_steps
dAg_Action = {}
for iAg, lActions in self.dAg_lActions.items():
if iStep < len(lActions):
iAct = lActions[iStep]
else:
iAct = RailEnvActions.DO_NOTHING
dAg_Action[iAg] = iAct
#print(iStep, dAg_Action[0])
return dAg_Action
class EnvCanvas(): class EnvCanvas():
...@@ -86,7 +108,7 @@ class EnvCanvas(): ...@@ -86,7 +108,7 @@ class EnvCanvas():
self.render() self.render()
def render(self): def render(self):
self.oRT.render_env(show_rowcols=True, show_inactive_agents=True, show_observations=False) self.oRT.render_env(show_rowcols=True, show_inactive_agents=False, show_observations=False)
self.oCan.put_image_data(self.oRT.get_image()) self.oCan.put_image_data(self.oRT.get_image())
def step(self): def step(self):
......
...@@ -5,6 +5,7 @@ from subprocess import Popen, PIPE ...@@ -5,6 +5,7 @@ from subprocess import Popen, PIPE
import importlib_resources import importlib_resources
import pkg_resources import pkg_resources
from importlib_resources import path from importlib_resources import path
import importlib_resources as ir
from ipython_genutils.py3compat import string_types, bytes_to_str from ipython_genutils.py3compat import string_types, bytes_to_str
...@@ -38,17 +39,37 @@ def run_python(parameters, ignore_return_code=False, stdin=None): ...@@ -38,17 +39,37 @@ def run_python(parameters, ignore_return_code=False, stdin=None):
return stdout.decode('utf8', 'replace'), stderr.decode('utf8', 'replace') return stdout.decode('utf8', 'replace'), stderr.decode('utf8', 'replace')
for entry in [entry for entry in importlib_resources.contents('notebooks') if def main():
not pkg_resources.resource_isdir('notebooks', entry)
and entry.endswith(".ipynb") # If the file notebooks-list exists, use it as a definitive list of notebooks to run
]: # This in effect ignores any local notebooks you might be working on, so you can run tox
print("*****************************************************************") # without them causing the notebooks task / testenv to fail.
print("Converting and running {}".format(entry)) if importlib_resources.is_resource("notebooks", "notebook-list"):
print("*****************************************************************") print("Using the notebooks-list file to designate which notebooks to run")
lsNB = [
with path('notebooks', entry) as file_in: sLine for sLine in ir.read_text("notebooks", "notebook-list").split("\n")
out, err = run_python(" -m jupyter nbconvert --execute --to notebook --inplace " + str(file_in)) if len(sLine) > 3 and not sLine.startswith("#")
sys.stderr.write(err) ]
sys.stderr.flush() else:
sys.stdout.write(out) lsNB = [
sys.stdout.flush() entry for entry in importlib_resources.contents('notebooks') if
not pkg_resources.resource_isdir('notebooks', entry)
and entry.endswith(".ipynb")
]
print("Running notebooks:", " ".join(lsNB))
for entry in lsNB:
print("*****************************************************************")
print("Converting and running {}".format(entry))
print("*****************************************************************")
with path('notebooks', entry) as file_in:
out, err = run_python(" -m jupyter nbconvert --execute --to notebook --inplace " + str(file_in))
sys.stderr.write(err)
sys.stderr.flush()
sys.stdout.write(out)
sys.stdout.flush()
if __name__ == "__main__":
main()
\ No newline at end of file
%% Cell type:markdown id: tags:
## Load some (old) env files to check they work
This notebook just loads some old env files, renders them, and runs a few steps.
This is just a sanity check that these old envs will still load.
Many of them use deprecated data formats so it's just so that we can avoid deleting them for now, and so new participants are not confused by us shipping env files which don't work...
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
from IPython import display
display.display(display.HTML("<style>.container { width:95% !important; }</style>"))
```
%% Cell type:code id: tags:
``` python
import PIL
import glob
import pickle
import msgpack
```
%% Cell type:code id: tags:
``` python
from flatland.envs.persistence import RailEnvPersister
```
%% Cell type:code id: tags:
``` python
from flatland.utils import env_edit_utils as eeu
from flatland.utils import jupyter_utils as ju
```
%% Cell type:code id: tags:
``` python
lsDirs = [ "../env_data/railway", "../env_data/tests"]
```
%% Cell type:code id: tags:
``` python
lsFiles = []
for sDir in lsDirs:
for sExt in ["mpk", "pkl"]:
lsFiles += glob.glob(sDir + "/*" + sExt)
lsFiles
```
%% Cell type:code id: tags:
``` python
for sFile in lsFiles:
try:
with open(sFile, "rb") as fIn:
env_dict = pickle.load(fIn)
print("pickle:", sFile)
except ValueError as oErr:
try:
with open(sFile, "rb") as fIn:
env_dict = msgpack.load(fIn)
print("msgpack: ", sFile)
except ValueError as oErr:
print("msgpack failed: ", sFile)
```
%% Cell type:code id: tags:
``` python
for sFile in lsFiles:
print("Loading: ", sFile)
env, env_dict = RailEnvPersister.load_new(sFile)
env.reset()
oCanvas = ju.EnvCanvas(env, ju.AlwaysForward(env))
oCanvas.show()
for iStep in range(10):
oCanvas.step()
oCanvas.render()
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment