From 9623a8d66b413d364d752ada3c2b3387ebde7f00 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Sat, 5 Sep 2020 19:36:46 +0100
Subject: [PATCH] made the env loading more forgiving (in persistence) to allow
 the loading of old envs, and msgpack files with .pkl extension.  Also give
 default 100 max_episode_steps where this is missing.  Add test-saved-envs
 notebook to test loading all old envs.

---
 flatland/envs/agent_utils.py         |  18 ++-
 flatland/envs/persistence.py         |  14 +-
 flatland/envs/schedule_generators.py |   6 +
 flatland/utils/env_edit_utils.py     |  28 +++-
 flatland/utils/jupyter_utils.py      |  24 ++-
 notebooks/run_all_notebooks.py       |  49 +++++--
 notebooks/test-saved-envs.ipynb      | 210 +++++++++++++++++++++++++++
 7 files changed, 326 insertions(+), 23 deletions(-)
 create mode 100644 notebooks/test-saved-envs.ipynb

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index d0f9d941..ef0701ac 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -119,8 +119,20 @@ class EnvAgent:
     def load_legacy_static_agent(cls, static_agents_data: Tuple):
         agents = []
         for i, static_agent in enumerate(static_agents_data):
-            agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
-                             direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
-                             speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
+            if len(static_agent) >= 6:
+                agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
+                                direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
+                                speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
+            else:
+                agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
+                                direction=static_agent[1], target=static_agent[2], 
+                                moving=False,
+                                speed_data={"speed":1., "position_fraction":0., "transition_action_on_cell_exit":0.},
+                                malfunction_data={
+                                            'malfunction': 0,
+                                            'nr_malfunctions': 0,
+                                            'moving_before_malfunction': False
+                                        },
+                                handle=i)
             agents.append(agent)
         return agents
diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py
index 8ed7ec8a..1b0f05f1 100644
--- a/flatland/envs/persistence.py
+++ b/flatland/envs/persistence.py
@@ -2,6 +2,7 @@
 
 import pickle
 import msgpack
+import msgpack_numpy
 import numpy as np
 
 from flatland.envs import rail_env 
@@ -22,6 +23,7 @@ from flatland.envs import malfunction_generators as mal_gen
 from flatland.envs import rail_generators as rail_gen
 from flatland.envs import schedule_generators as sched_gen
 
+msgpack_numpy.patch()
 
 class RailEnvPersister(object):
 
@@ -111,9 +113,13 @@ class RailEnvPersister(object):
 
         env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
 
+        llGrid = env_dict["grid"]
+        height = len(llGrid)
+        width = len(llGrid[0])
 
         # TODO: inefficient - each one of these generators loads the complete env file.
-        env = rail_env.RailEnv(width=1, height=1,
+        env = rail_env.RailEnv(#width=1, height=1,
+                width=width, height=height,
                 rail_generator=rail_gen.rail_from_file(filename, 
                     load_from_package=load_from_package),
                 schedule_generator=sched_gen.schedule_from_file(filename,
@@ -141,7 +147,11 @@ class RailEnvPersister(object):
         if filename.endswith("mpk"):
             env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8")
         elif filename.endswith("pkl"):
-            env_dict = pickle.loads(load_data)
+            try:
+                env_dict = pickle.loads(load_data)
+            except ValueError:
+                print("pickle failed to load file:", filename, " trying msgpack (deprecated)...")
+                env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8")
         else:
             print(f"filename {filename} must end with either pkl or mpk")
             env_dict = {}
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index 7d5656dd..c1789435 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -323,6 +323,10 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
         env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
 
         max_episode_steps = env_dict.get("max_episode_steps", 0)
+        if (max_episode_steps==0):
+            print("This env file has no max_episode_steps (deprecated) - setting to 100")
+            max_episode_steps = 100
+            
         agents = env_dict["agents"]
 
         #print("schedule generator from_file - agents: ", agents)
@@ -335,6 +339,8 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
         agents_direction = [a.initial_direction for a in agents]
         agents_target = [a.target for a in agents]
         agents_speed = [a.speed_data['speed'] for a in agents]
+
+        # Malfunctions from here are not used.  They have their own generator.
         #agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
 
         return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
diff --git a/flatland/utils/env_edit_utils.py b/flatland/utils/env_edit_utils.py
index b1a40174..bf6aa32a 100644
--- a/flatland/utils/env_edit_utils.py
+++ b/flatland/utils/env_edit_utils.py
@@ -70,9 +70,7 @@ def makeEnv2(nAg=2, shape=(20,10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDi
     return env, envModel
 
 
-def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
-
-    ddEnvSpecs = {
+ddEnvSpecs = {
         # opposing stations with single alternative path
         "single_alternative":{
             "llrcPaths":  [
@@ -116,10 +114,34 @@ def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
             "lrcStarts": [(1,3)],
             "lrcTargs": [(2,1)],
             "liDirs":  [1]
+            },
+
+        # two loops
+        "loop_with_loops": {
+            "llrcPaths": [
+                # big outer loop Row 1, 8; Col 1, 15
+                [(1,1), (1,15), (8, 15), (8,1), (1,1), (1,3)],
+                # alternative 1
+                [(1,3), (1,5), (3,5), (3,10), (1, 10), (1, 12)],
+                # alternative 2
+                [(8,3), (8,5), (6,5), (6,10), (8, 10), (8, 12)],
+                
+                ],
+            
+            # list of row,col of agent start cells
+            "lrcStarts": [(1,3), (8, 3)],
+            # list of row,col of targets
+            "lrcTargs": [(8,2), (1,2)],
+            # list of initial directions
+            "liDirs":  [1, 1], 
             }
 
         }
     
+
+def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
+    global ddEnvSpecs
+    
     dSpec = ddEnvSpecs[sName]
 
     return makeEnv2(nAg=nAg, bUCF=bUCF, **dSpec)
diff --git a/flatland/utils/jupyter_utils.py b/flatland/utils/jupyter_utils.py
index 6dd47501..3b7bc3e0 100644
--- a/flatland/utils/jupyter_utils.py
+++ b/flatland/utils/jupyter_utils.py
@@ -11,6 +11,7 @@ from flatland.envs.agent_utils import EnvAgent
 from flatland.envs import rail_generators as rail_gen
 from flatland.envs import agent_chains as ac
 from flatland.envs.rail_env import RailEnv, RailEnvActions
+
 from flatland.envs.persistence import RailEnvPersister
 from flatland.utils.rendertools import RenderTool
 from flatland.utils import env_edit_utils as eeu
@@ -71,6 +72,27 @@ class ForwardWithPause(Behaviour):
         
         return dAction
 
+class Deterministic(Behaviour):
+    def __init__(self, env, dAg_lActions):
+        super().__init__(env)
+        self.dAg_lActions = dAg_lActions
+    
+    def getActions(self):
+        iStep = self.env._elapsed_steps
+        
+        dAg_Action = {}
+        for iAg, lActions in self.dAg_lActions.items():
+            if iStep < len(lActions):
+                iAct = lActions[iStep]
+            else:
+                iAct = RailEnvActions.DO_NOTHING
+            dAg_Action[iAg] = iAct
+        #print(iStep, dAg_Action[0])
+        return dAg_Action
+
+
+
+
 
 class EnvCanvas():
 
@@ -86,7 +108,7 @@ class EnvCanvas():
         self.render()
 
     def render(self):
-        self.oRT.render_env(show_rowcols=True,  show_inactive_agents=True, show_observations=False)
+        self.oRT.render_env(show_rowcols=True,  show_inactive_agents=False, show_observations=False)
         self.oCan.put_image_data(self.oRT.get_image())
 
     def step(self):
diff --git a/notebooks/run_all_notebooks.py b/notebooks/run_all_notebooks.py
index 6facfa1a..98ed6a06 100644
--- a/notebooks/run_all_notebooks.py
+++ b/notebooks/run_all_notebooks.py
@@ -5,6 +5,7 @@ from subprocess import Popen, PIPE
 import importlib_resources
 import pkg_resources
 from importlib_resources import path
+import importlib_resources as ir
 from ipython_genutils.py3compat import string_types, bytes_to_str
 
 
@@ -38,17 +39,37 @@ def run_python(parameters, ignore_return_code=False, stdin=None):
     return stdout.decode('utf8', 'replace'), stderr.decode('utf8', 'replace')
 
 
-for entry in [entry for entry in importlib_resources.contents('notebooks') if
-              not pkg_resources.resource_isdir('notebooks', entry)
-              and entry.endswith(".ipynb")
-              ]:
-    print("*****************************************************************")
-    print("Converting and running {}".format(entry))
-    print("*****************************************************************")
-
-    with path('notebooks', entry) as file_in:
-        out, err = run_python(" -m jupyter nbconvert --execute --to notebook --inplace " + str(file_in))
-        sys.stderr.write(err)
-        sys.stderr.flush()
-        sys.stdout.write(out)
-        sys.stdout.flush()
+def main():
+
+    # If the file notebooks-list exists, use it as a definitive list of notebooks to run
+    # This in effect ignores any local notebooks you might be working on, so you can run tox
+    # without them causing the notebooks task / testenv to fail.
+    if importlib_resources.is_resource("notebooks", "notebook-list"):
+        print("Using the notebooks-list file to designate which notebooks to run")
+        lsNB = [
+            sLine for sLine in ir.read_text("notebooks", "notebook-list").split("\n") 
+            if len(sLine) > 3 and not sLine.startswith("#")
+            ]
+    else:
+        lsNB = [
+            entry for entry in importlib_resources.contents('notebooks') if
+                not pkg_resources.resource_isdir('notebooks', entry)
+                and entry.endswith(".ipynb")
+                ]
+
+    print("Running notebooks:", " ".join(lsNB))
+
+    for entry in lsNB:
+        print("*****************************************************************")
+        print("Converting and running {}".format(entry))
+        print("*****************************************************************")
+
+        with path('notebooks', entry) as file_in:
+            out, err = run_python(" -m jupyter nbconvert --execute --to notebook --inplace " + str(file_in))
+            sys.stderr.write(err)
+            sys.stderr.flush()
+            sys.stdout.write(out)
+            sys.stdout.flush()
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/notebooks/test-saved-envs.ipynb b/notebooks/test-saved-envs.ipynb
new file mode 100644
index 00000000..7815981f
--- /dev/null
+++ b/notebooks/test-saved-envs.ipynb
@@ -0,0 +1,210 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Load some (old) env files to check they work\n",
+    "This notebook just loads some old env files, renders them, and runs a few steps.\n",
+    "This is just a sanity check that these old envs will still load.\n",
+    "Many of them use deprecated data formats so it's just so that we can avoid deleting them for now, and so new participants are not confused by us shipping env files which don't work..."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from IPython import display\n",
+    "display.display(display.HTML(\"<style>.container { width:95% !important; }</style>\"))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import PIL\n",
+    "import glob\n",
+    "import pickle\n",
+    "import msgpack"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from flatland.envs.persistence import RailEnvPersister"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from flatland.utils import env_edit_utils as eeu\n",
+    "from flatland.utils import jupyter_utils as ju"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "lsDirs = [ \"../env_data/railway\", \"../env_data/tests\"]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "lsFiles = []\n",
+    "for sDir in lsDirs:\n",
+    "    for sExt in [\"mpk\", \"pkl\"]:\n",
+    "        lsFiles += glob.glob(sDir + \"/*\" + sExt)\n",
+    "lsFiles"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for sFile in lsFiles:\n",
+    "    try:\n",
+    "        with open(sFile, \"rb\") as fIn:\n",
+    "            env_dict = pickle.load(fIn)\n",
+    "        print(\"pickle:\", sFile)\n",
+    "    except ValueError as oErr:\n",
+    "        try:\n",
+    "            with open(sFile, \"rb\") as fIn:\n",
+    "                env_dict = msgpack.load(fIn)\n",
+    "            print(\"msgpack: \", sFile)\n",
+    "        except ValueError as oErr:\n",
+    "            print(\"msgpack failed: \", sFile)\n",
+    "            \n",
+    "        "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "for sFile in lsFiles:\n",
+    "    print(\"Loading: \", sFile)\n",
+    "    env, env_dict = RailEnvPersister.load_new(sFile)\n",
+    "    env.reset()\n",
+    "    oCanvas = ju.EnvCanvas(env, ju.AlwaysForward(env))\n",
+    "    oCanvas.show()\n",
+    "    for iStep in range(10):\n",
+    "        oCanvas.step()\n",
+    "        oCanvas.render()"
+   ]
+  }
+ ],
+ "metadata": {
+  "hide_input": false,
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.7"
+  },
+  "latex_envs": {
+   "LaTeX_envs_menu_present": true,
+   "autoclose": false,
+   "autocomplete": true,
+   "bibliofile": "biblio.bib",
+   "cite_by": "apalike",
+   "current_citInitial": 1,
+   "eqLabelWithNumbers": true,
+   "eqNumInitial": 1,
+   "hotkeys": {
+    "equation": "Ctrl-E",
+    "itemize": "Ctrl-I"
+   },
+   "labels_anchors": false,
+   "latex_user_defs": false,
+   "report_style_numbering": false,
+   "user_envs_cfg": false
+  },
+  "toc": {
+   "base_numbering": 1,
+   "nav_menu": {},
+   "number_sections": true,
+   "sideBar": true,
+   "skip_h1_title": false,
+   "title_cell": "Table of Contents",
+   "title_sidebar": "Contents",
+   "toc_cell": false,
+   "toc_position": {},
+   "toc_section_display": true,
+   "toc_window_display": false
+  },
+  "varInspector": {
+   "cols": {
+    "lenName": 16,
+    "lenType": 16,
+    "lenVar": 40
+   },
+   "kernels_config": {
+    "python": {
+     "delete_cmd_postfix": "",
+     "delete_cmd_prefix": "del ",
+     "library": "var_list.py",
+     "varRefreshCmd": "print(var_dic_list())"
+    },
+    "r": {
+     "delete_cmd_postfix": ") ",
+     "delete_cmd_prefix": "rm(",
+     "library": "var_list.r",
+     "varRefreshCmd": "cat(var_dic_list()) "
+    }
+   },
+   "types_to_exclude": [
+    "module",
+    "function",
+    "builtin_function_or_method",
+    "instance",
+    "_Feature"
+   ],
+   "window_display": false
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
-- 
GitLab