Commit 9623a8d6 authored by hagrid67's avatar hagrid67

made the env loading more forgiving (in persistence) to allow the loading of...

made the env loading more forgiving (in persistence) to allow the loading of old envs, and msgpack files with .pkl extension.  Also give default 100 max_episode_steps where this is missing.  Add test-saved-envs notebook to test loading all old envs.
parent 67433db1
Pipeline #5283 failed with stages
in 35 minutes and 11 seconds
......@@ -119,8 +119,20 @@ class EnvAgent:
def load_legacy_static_agent(cls, static_agents_data: Tuple):
agents = []
for i, static_agent in enumerate(static_agents_data):
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
if len(static_agent) >= 6:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
else:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
moving=False,
speed_data={"speed":1., "position_fraction":0., "transition_action_on_cell_exit":0.},
malfunction_data={
'malfunction': 0,
'nr_malfunctions': 0,
'moving_before_malfunction': False
},
handle=i)
agents.append(agent)
return agents
......@@ -2,6 +2,7 @@
import pickle
import msgpack
import msgpack_numpy
import numpy as np
from flatland.envs import rail_env
......@@ -22,6 +23,7 @@ from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import schedule_generators as sched_gen
msgpack_numpy.patch()
class RailEnvPersister(object):
......@@ -111,9 +113,13 @@ class RailEnvPersister(object):
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
llGrid = env_dict["grid"]
height = len(llGrid)
width = len(llGrid[0])
# TODO: inefficient - each one of these generators loads the complete env file.
env = rail_env.RailEnv(width=1, height=1,
env = rail_env.RailEnv(#width=1, height=1,
width=width, height=height,
rail_generator=rail_gen.rail_from_file(filename,
load_from_package=load_from_package),
schedule_generator=sched_gen.schedule_from_file(filename,
......@@ -141,7 +147,11 @@ class RailEnvPersister(object):
if filename.endswith("mpk"):
env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8")
elif filename.endswith("pkl"):
env_dict = pickle.loads(load_data)
try:
env_dict = pickle.loads(load_data)
except ValueError:
print("pickle failed to load file:", filename, " trying msgpack (deprecated)...")
env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8")
else:
print(f"filename {filename} must end with either pkl or mpk")
env_dict = {}
......
......@@ -323,6 +323,10 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
max_episode_steps = env_dict.get("max_episode_steps", 0)
if (max_episode_steps==0):
print("This env file has no max_episode_steps (deprecated) - setting to 100")
max_episode_steps = 100
agents = env_dict["agents"]
#print("schedule generator from_file - agents: ", agents)
......@@ -335,6 +339,8 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
agents_direction = [a.initial_direction for a in agents]
agents_target = [a.target for a in agents]
agents_speed = [a.speed_data['speed'] for a in agents]
# Malfunctions from here are not used. They have their own generator.
#agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
......
......@@ -70,9 +70,7 @@ def makeEnv2(nAg=2, shape=(20,10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDi
return env, envModel
def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
ddEnvSpecs = {
ddEnvSpecs = {
# opposing stations with single alternative path
"single_alternative":{
"llrcPaths": [
......@@ -116,10 +114,34 @@ def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
"lrcStarts": [(1,3)],
"lrcTargs": [(2,1)],
"liDirs": [1]
},
# two loops
"loop_with_loops": {
"llrcPaths": [
# big outer loop Row 1, 8; Col 1, 15
[(1,1), (1,15), (8, 15), (8,1), (1,1), (1,3)],
# alternative 1
[(1,3), (1,5), (3,5), (3,10), (1, 10), (1, 12)],
# alternative 2
[(8,3), (8,5), (6,5), (6,10), (8, 10), (8, 12)],
],
# list of row,col of agent start cells
"lrcStarts": [(1,3), (8, 3)],
# list of row,col of targets
"lrcTargs": [(8,2), (1,2)],
# list of initial directions
"liDirs": [1, 1],
}
}
def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
global ddEnvSpecs
dSpec = ddEnvSpecs[sName]
return makeEnv2(nAg=nAg, bUCF=bUCF, **dSpec)
......
......@@ -11,6 +11,7 @@ from flatland.envs.agent_utils import EnvAgent
from flatland.envs import rail_generators as rail_gen
from flatland.envs import agent_chains as ac
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.persistence import RailEnvPersister
from flatland.utils.rendertools import RenderTool
from flatland.utils import env_edit_utils as eeu
......@@ -71,6 +72,27 @@ class ForwardWithPause(Behaviour):
return dAction
class Deterministic(Behaviour):
def __init__(self, env, dAg_lActions):
super().__init__(env)
self.dAg_lActions = dAg_lActions
def getActions(self):
iStep = self.env._elapsed_steps
dAg_Action = {}
for iAg, lActions in self.dAg_lActions.items():
if iStep < len(lActions):
iAct = lActions[iStep]
else:
iAct = RailEnvActions.DO_NOTHING
dAg_Action[iAg] = iAct
#print(iStep, dAg_Action[0])
return dAg_Action
class EnvCanvas():
......@@ -86,7 +108,7 @@ class EnvCanvas():
self.render()
def render(self):
self.oRT.render_env(show_rowcols=True, show_inactive_agents=True, show_observations=False)
self.oRT.render_env(show_rowcols=True, show_inactive_agents=False, show_observations=False)
self.oCan.put_image_data(self.oRT.get_image())
def step(self):
......
......@@ -5,6 +5,7 @@ from subprocess import Popen, PIPE
import importlib_resources
import pkg_resources
from importlib_resources import path
import importlib_resources as ir
from ipython_genutils.py3compat import string_types, bytes_to_str
......@@ -38,17 +39,37 @@ def run_python(parameters, ignore_return_code=False, stdin=None):
return stdout.decode('utf8', 'replace'), stderr.decode('utf8', 'replace')
for entry in [entry for entry in importlib_resources.contents('notebooks') if
not pkg_resources.resource_isdir('notebooks', entry)
and entry.endswith(".ipynb")
]:
print("*****************************************************************")
print("Converting and running {}".format(entry))
print("*****************************************************************")
with path('notebooks', entry) as file_in:
out, err = run_python(" -m jupyter nbconvert --execute --to notebook --inplace " + str(file_in))
sys.stderr.write(err)
sys.stderr.flush()
sys.stdout.write(out)
sys.stdout.flush()
def main():
# If the file notebooks-list exists, use it as a definitive list of notebooks to run
# This in effect ignores any local notebooks you might be working on, so you can run tox
# without them causing the notebooks task / testenv to fail.
if importlib_resources.is_resource("notebooks", "notebook-list"):
print("Using the notebooks-list file to designate which notebooks to run")
lsNB = [
sLine for sLine in ir.read_text("notebooks", "notebook-list").split("\n")
if len(sLine) > 3 and not sLine.startswith("#")
]
else:
lsNB = [
entry for entry in importlib_resources.contents('notebooks') if
not pkg_resources.resource_isdir('notebooks', entry)
and entry.endswith(".ipynb")
]
print("Running notebooks:", " ".join(lsNB))
for entry in lsNB:
print("*****************************************************************")
print("Converting and running {}".format(entry))
print("*****************************************************************")
with path('notebooks', entry) as file_in:
out, err = run_python(" -m jupyter nbconvert --execute --to notebook --inplace " + str(file_in))
sys.stderr.write(err)
sys.stderr.flush()
sys.stdout.write(out)
sys.stdout.flush()
if __name__ == "__main__":
main()
\ No newline at end of file
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load some (old) env files to check they work\n",
"This notebook just loads some old env files, renders them, and runs a few steps.\n",
"This is just a sanity check that these old envs will still load.\n",
"Many of them use deprecated data formats so it's just so that we can avoid deleting them for now, and so new participants are not confused by us shipping env files which don't work..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython import display\n",
"display.display(display.HTML(\"<style>.container { width:95% !important; }</style>\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import PIL\n",
"import glob\n",
"import pickle\n",
"import msgpack"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from flatland.envs.persistence import RailEnvPersister"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from flatland.utils import env_edit_utils as eeu\n",
"from flatland.utils import jupyter_utils as ju"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lsDirs = [ \"../env_data/railway\", \"../env_data/tests\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lsFiles = []\n",
"for sDir in lsDirs:\n",
" for sExt in [\"mpk\", \"pkl\"]:\n",
" lsFiles += glob.glob(sDir + \"/*\" + sExt)\n",
"lsFiles"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for sFile in lsFiles:\n",
" try:\n",
" with open(sFile, \"rb\") as fIn:\n",
" env_dict = pickle.load(fIn)\n",
" print(\"pickle:\", sFile)\n",
" except ValueError as oErr:\n",
" try:\n",
" with open(sFile, \"rb\") as fIn:\n",
" env_dict = msgpack.load(fIn)\n",
" print(\"msgpack: \", sFile)\n",
" except ValueError as oErr:\n",
" print(\"msgpack failed: \", sFile)\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"for sFile in lsFiles:\n",
" print(\"Loading: \", sFile)\n",
" env, env_dict = RailEnvPersister.load_new(sFile)\n",
" env.reset()\n",
" oCanvas = ju.EnvCanvas(env, ju.AlwaysForward(env))\n",
" oCanvas.show()\n",
" for iStep in range(10):\n",
" oCanvas.step()\n",
" oCanvas.render()"
]
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
"autoclose": false,
"autocomplete": true,
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 1,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"labels_anchors": false,
"latex_user_defs": false,
"report_style_numbering": false,
"user_envs_cfg": false
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment