From be3c58b0f4bc15c03689f09a23051780d508e9cf Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Sat, 15 Aug 2020 18:01:50 +0100
Subject: [PATCH] updated test-collision to use jupyter_utils for env, agent
 behaviour, etc

---
 flatland/utils/jupyter_utils.py | 45 ++++++++++++++-
 notebooks/test-collision.ipynb  | 99 +++++----------------------------
 2 files changed, 59 insertions(+), 85 deletions(-)

diff --git a/flatland/utils/jupyter_utils.py b/flatland/utils/jupyter_utils.py
index f7538fb4..6dd47501 100644
--- a/flatland/utils/jupyter_utils.py
+++ b/flatland/utils/jupyter_utils.py
@@ -14,7 +14,7 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.persistence import RailEnvPersister
 from flatland.utils.rendertools import RenderTool
 from flatland.utils import env_edit_utils as eeu
-
+from typing import List, NamedTuple
 
 class Behaviour():
     def __init__(self, env):
@@ -28,6 +28,49 @@ class AlwaysForward(Behaviour):
     def getActions(self):
         return { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) }
 
+AgentPause = NamedTuple("AgentPause", 
+    [
+        ("iAg", int),
+        ("iPauseAt", int),
+        ("iPauseFor", int)
+    ])
+
+class ForwardWithPause(Behaviour):
+    def __init__(self, env, lPauses:List[AgentPause]):
+        self.env = env
+        self.nAg = len(env.agents)
+        self.lPauses = lPauses
+        self.dAgPaused = {}
+
+    def getActions(self):
+        iStep = self.env._elapsed_steps + 1  # add one because this is called before step()
+
+        # new pauses starting this step
+        lNewPauses = [ tPause for tPause in self.lPauses if tPause.iPauseAt == iStep ]
+
+        # copy across the agent index and pause length
+        for pause in lNewPauses:
+            self.dAgPaused[pause.iAg] = pause.iPauseFor
+
+        # default action is move forward
+        dAction = { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) }
+
+        # overwrite paused agents with stop
+        for iAg in self.dAgPaused:
+            dAction[iAg] = RailEnvActions.STOP_MOVING
+        
+        # decrement the counters for each pause, and remove any expired pauses.
+        lFinished = []
+        for iAg in self.dAgPaused:
+            self.dAgPaused[iAg] -= 1
+            if self.dAgPaused[iAg] <= 0:
+                lFinished.append(iAg)
+        
+        for iAg in lFinished:
+            self.dAgPaused.pop(iAg, None)
+        
+        return dAction
+
 
 class EnvCanvas():
 
diff --git a/notebooks/test-collision.ipynb b/notebooks/test-collision.ipynb
index 6a0e0502..e72b3b42 100644
--- a/notebooks/test-collision.ipynb
+++ b/notebooks/test-collision.ipynb
@@ -67,89 +67,11 @@
    "cell_type": "code",
    "execution_count": 4,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "UCF: False\n"
-     ]
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "05c64c7cfb3e4dba9057c0d0787223d9",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Canvas(layout=Layout(height='300px', width='600px'), size=(600, 300))"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
-   "source": [
-    "nAg=2\n",
-    "bUCF=False\n",
-    "bReverseStart=False\n",
-    "env, envModel = eeu.makeTestEnv(sName=\"merging_spurs\", nAg=nAg, bUCF=bUCF)\n",
-    "oRT = RenderTool(env, show_debug=True)\n",
-    "oRT.render_env(show_rowcols=True,  show_inactive_agents=True, show_observations=False)\n",
-    "oCan = canvas.Canvas(size=(600,300))\n",
-    "oCan.put_image_data(oRT.get_image())\n",
-    "print(f\"UCF: {bUCF}\")\n",
-    "display.display(oCan)\n",
-    "\n",
-    "\n",
-    "iPauseStep=10\n",
-    "iPauseLen = 5\n",
-    "\n",
-    "for iStep in range(25):\n",
-    "    \n",
-    "    if bReverseStart:\n",
-    "        iAgentStart = max((nAg - 2 - iStep*2), 0)\n",
-    "    else:\n",
-    "        iAgentStart = 0\n",
-    "    dActions = { i:int(RailEnvActions.MOVE_FORWARD) for i in range(iAgentStart, len(env.agents)) }\n",
-    "    \n",
-    "    if (iStep >= iPauseStep) and (iStep < iPauseStep + iPauseLen):\n",
-    "        dActions[0] = RailEnvActions.STOP_MOVING\n",
-    "        \n",
-    "    #print(dActions)\n",
-    "    \n",
-    "    env.step(dActions)\n",
-    "    \n",
-    "    oRT.render_env(show_rowcols=True,  show_inactive_agents=True, show_observations=False)\n",
-    "    \n",
-    "    aImg = oRT.get_image()\n",
-    "    pilImg = PIL.Image.fromarray(aImg)\n",
-    "    oCan.put_image_data(aImg)\n",
-    "    #with open(f\"tmp-images/img-{iStep:03d}.png\", \"wb\") as fImg:\n",
-    "    #    pilImg.save(fImg)\n",
-    "    \n",
-    "    time.sleep(0.05)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "#!ffmpeg -i ./tmp-images/img-%03d.png -vcodec mpeg4 -filter:v \"setpts=8.0*PTS\" -y movie_nucf_stop.mp4"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 8,
-   "metadata": {},
    "outputs": [
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "47677058c61e4c71bf23d56e80bc88d3",
+       "model_id": "a21b1fd47f49490c97b8972281a325f4",
        "version_major": 2,
        "version_minor": 0
       },
@@ -162,19 +84,28 @@
     }
    ],
    "source": [
-    "oEC = ju.EnvCanvas(env)\n",
+    "env, envModel = eeu.makeTestEnv(\"merging_spurs\", nAg=2, bUCF=True)\n",
+    "behaviour = ju.ForwardWithPause(env, [ju.AgentPause(0, 10, 5)])\n",
+    "oEC = ju.EnvCanvas(env, behaviour)\n",
     "env.reset(regenerate_rail=False)\n",
-    "oEC.show()"
+    "oEC.show()\n",
+    "for i in range(25):\n",
+    "    oEC.step()\n",
+    "    oEC.render()\n",
+    "    time.sleep(0.1)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 41,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
-    "oEC.step()\n",
-    "oEC.render()"
+    "dAgStateExpected = {0: (7, 15, 2), 1: (6, 15, 2)}\n",
+    "dAgState={}\n",
+    "for iAg, ag in enumerate(env.agents):\n",
+    "    dAgState[iAg] = (*ag.position, ag.direction)\n",
+    "assert dAgState == dAgStateExpected"
    ]
   }
  ],
-- 
GitLab