From be3c58b0f4bc15c03689f09a23051780d508e9cf Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Sat, 15 Aug 2020 18:01:50 +0100 Subject: [PATCH] updated test-collision to use jupyter_utils for env, agent behaviour, etc --- flatland/utils/jupyter_utils.py | 45 ++++++++++++++- notebooks/test-collision.ipynb | 99 +++++---------------------------- 2 files changed, 59 insertions(+), 85 deletions(-) diff --git a/flatland/utils/jupyter_utils.py b/flatland/utils/jupyter_utils.py index f7538fb4..6dd47501 100644 --- a/flatland/utils/jupyter_utils.py +++ b/flatland/utils/jupyter_utils.py @@ -14,7 +14,7 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.persistence import RailEnvPersister from flatland.utils.rendertools import RenderTool from flatland.utils import env_edit_utils as eeu - +from typing import List, NamedTuple class Behaviour(): def __init__(self, env): @@ -28,6 +28,49 @@ class AlwaysForward(Behaviour): def getActions(self): return { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) } +AgentPause = NamedTuple("AgentPause", + [ + ("iAg", int), + ("iPauseAt", int), + ("iPauseFor", int) + ]) + +class ForwardWithPause(Behaviour): + def __init__(self, env, lPauses:List[AgentPause]): + self.env = env + self.nAg = len(env.agents) + self.lPauses = lPauses + self.dAgPaused = {} + + def getActions(self): + iStep = self.env._elapsed_steps + 1 # add one because this is called before step() + + # new pauses starting this step + lNewPauses = [ tPause for tPause in self.lPauses if tPause.iPauseAt == iStep ] + + # copy across the agent index and pause length + for pause in lNewPauses: + self.dAgPaused[pause.iAg] = pause.iPauseFor + + # default action is move forward + dAction = { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) } + + # overwrite paused agents with stop + for iAg in self.dAgPaused: + dAction[iAg] = RailEnvActions.STOP_MOVING + + # decrement the counters for each pause, and remove any expired pauses. + lFinished = [] + for iAg in self.dAgPaused: + self.dAgPaused[iAg] -= 1 + if self.dAgPaused[iAg] <= 0: + lFinished.append(iAg) + + for iAg in lFinished: + self.dAgPaused.pop(iAg, None) + + return dAction + class EnvCanvas(): diff --git a/notebooks/test-collision.ipynb b/notebooks/test-collision.ipynb index 6a0e0502..e72b3b42 100644 --- a/notebooks/test-collision.ipynb +++ b/notebooks/test-collision.ipynb @@ -67,89 +67,11 @@ "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "UCF: False\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "05c64c7cfb3e4dba9057c0d0787223d9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Canvas(layout=Layout(height='300px', width='600px'), size=(600, 300))" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "nAg=2\n", - "bUCF=False\n", - "bReverseStart=False\n", - "env, envModel = eeu.makeTestEnv(sName=\"merging_spurs\", nAg=nAg, bUCF=bUCF)\n", - "oRT = RenderTool(env, show_debug=True)\n", - "oRT.render_env(show_rowcols=True, show_inactive_agents=True, show_observations=False)\n", - "oCan = canvas.Canvas(size=(600,300))\n", - "oCan.put_image_data(oRT.get_image())\n", - "print(f\"UCF: {bUCF}\")\n", - "display.display(oCan)\n", - "\n", - "\n", - "iPauseStep=10\n", - "iPauseLen = 5\n", - "\n", - "for iStep in range(25):\n", - " \n", - " if bReverseStart:\n", - " iAgentStart = max((nAg - 2 - iStep*2), 0)\n", - " else:\n", - " iAgentStart = 0\n", - " dActions = { i:int(RailEnvActions.MOVE_FORWARD) for i in range(iAgentStart, len(env.agents)) }\n", - " \n", - " if (iStep >= iPauseStep) and (iStep < iPauseStep + iPauseLen):\n", - " dActions[0] = RailEnvActions.STOP_MOVING\n", - " \n", - " #print(dActions)\n", - " \n", - " env.step(dActions)\n", - " \n", - " oRT.render_env(show_rowcols=True, show_inactive_agents=True, show_observations=False)\n", - " \n", - " aImg = oRT.get_image()\n", - " pilImg = PIL.Image.fromarray(aImg)\n", - " oCan.put_image_data(aImg)\n", - " #with open(f\"tmp-images/img-{iStep:03d}.png\", \"wb\") as fImg:\n", - " # pilImg.save(fImg)\n", - " \n", - " time.sleep(0.05)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "#!ffmpeg -i ./tmp-images/img-%03d.png -vcodec mpeg4 -filter:v \"setpts=8.0*PTS\" -y movie_nucf_stop.mp4" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "47677058c61e4c71bf23d56e80bc88d3", + "model_id": "a21b1fd47f49490c97b8972281a325f4", "version_major": 2, "version_minor": 0 }, @@ -162,19 +84,28 @@ } ], "source": [ - "oEC = ju.EnvCanvas(env)\n", + "env, envModel = eeu.makeTestEnv(\"merging_spurs\", nAg=2, bUCF=True)\n", + "behaviour = ju.ForwardWithPause(env, [ju.AgentPause(0, 10, 5)])\n", + "oEC = ju.EnvCanvas(env, behaviour)\n", "env.reset(regenerate_rail=False)\n", - "oEC.show()" + "oEC.show()\n", + "for i in range(25):\n", + " oEC.step()\n", + " oEC.render()\n", + " time.sleep(0.1)" ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "oEC.step()\n", - "oEC.render()" + "dAgStateExpected = {0: (7, 15, 2), 1: (6, 15, 2)}\n", + "dAgState={}\n", + "for iAg, ag in enumerate(env.agents):\n", + " dAgState[iAg] = (*ag.position, ag.direction)\n", + "assert dAgState == dAgStateExpected" ] } ], -- GitLab