diff --git a/env_data/tests/Test_2_Level_0.pkl b/env_data/tests/Test_2_Level_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..076496ad365c3556248de1c762525c18ea098fbf Binary files /dev/null and b/env_data/tests/Test_2_Level_0.pkl differ diff --git a/notebooks/render-episode.ipynb b/notebooks/render-episode.ipynb index 94eba8de9e7ca9a84d4a3ffa4739f9a1aac58796..3068d46d8daa155219c16ae497144879f5429675 100644 --- a/notebooks/render-episode.ipynb +++ b/notebooks/render-episode.ipynb @@ -45,7 +45,8 @@ "outputs": [], "source": [ "import pandas as pd\n", - "import numpy as np" + "import numpy as np\n", + "import matplotlib.pyplot as plt" ] }, { @@ -60,7 +61,23 @@ "source": [ "import PIL\n", "from flatland.utils.rendertools import RenderTool\n", - "import imageio" + "import imageio\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "PU5GkH271guD" + }, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "from IPython.core import display \n", + "display.display(display.HTML(\"<style>.container { width:95% !important; }</style>\"))" ] }, { @@ -102,7 +119,9 @@ "outputs": [], "source": [ "# ENV FILE PATH\n", - "env_file = \"Test_20_Level_0.pkl\"\n" + "#env_file = \"Test_20_Level_0.pkl\"\n", + "#env_file = \"../../evaluation_visualization/round2/or-0827/Test_23/Level_1.pkl\"\n", + "#env_file = \"../../evaluation_visualization/round2/rl-0827/Test_23/Level_1.pkl\"" ] }, { @@ -111,7 +130,10 @@ "metadata": {}, "outputs": [], "source": [ - "%ls" + "if os.path.exists(\"../env_data\"):\n", + " env_file = \"../env_data/tests/Test_2_Level_0.pkl\"\n", + "else:\n", + " env_file = \"./env_data/tests/Test_2_Level_0.pkl\"" ] }, { @@ -124,9 +146,14 @@ }, "outputs": [], "source": [ - "from IPython.display import clear_output\n", - "from IPython.core import display \n", - "display.display(display.HTML(\"<style>.container { width:95% !important; }</style>\"))" + "import pickle\n", + "\n", + "from flatland.envs.rail_env import RailEnv\n", + "from flatland.envs.rail_generators import sparse_rail_generator\n", + "from flatland.envs.schedule_generators import sparse_schedule_generator\n", + "from flatland.envs.malfunction_generators import malfunction_from_file, no_malfunction_generator\n", + "from flatland.envs.rail_generators import rail_from_file\n", + "from flatland.envs.schedule_generators import schedule_from_file" ] }, { @@ -139,15 +166,6 @@ }, "outputs": [], "source": [ - "import pickle\n", - "\n", - "from flatland.envs.rail_env import RailEnv\n", - "from flatland.envs.rail_generators import sparse_rail_generator\n", - "from flatland.envs.schedule_generators import sparse_schedule_generator\n", - "from flatland.envs.malfunction_generators import malfunction_from_file, no_malfunction_generator\n", - "from flatland.envs.rail_generators import rail_from_file\n", - "from flatland.envs.schedule_generators import schedule_from_file\n", - "\n", "with open(env_file, \"rb\") as fIn:\n", " env_dict = pickle.load(fIn)\n", "\n", @@ -195,15 +213,6 @@ "PIL.Image.fromarray(aImg)\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#dummy= env.reset()" - ] - }, { "cell_type": "code", "execution_count": null, @@ -226,15 +235,6 @@ " for oAg in env.agents], columns=lCols)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#env_dict[\"episode\"]" - ] - }, { "cell_type": "code", "execution_count": null, @@ -244,33 +244,6 @@ "pd.DataFrame([ vars(oAg) for oAg in env.agents])" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "genSched = schedule_from_file(env_file)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "oSched = genSched(env.rail.grid, 3, )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "oSched" - ] - }, { "cell_type": "code", "execution_count": null, @@ -312,73 +285,6 @@ "episode_states = env_dict['episode']" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#[ (i, l) for i,l in enumerate(zip(episode_states, expert_actions)) ][:3]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "id": "IXQxfUXF2U33", - "outputId": "2a94ffec-c6b7-4cc3-d779-8c430d66918d" - }, - "outputs": [], - "source": [ - "#dAct = expert_actions[1]\n", - "#print(dAct)\n", - "#env.step(dAct)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from gym.utils import seeding" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "np_random, seed2 = seeding.np_random(123)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i in range(10000):\n", - " oMF = env.malfunction_generator(env.agents[0], np_random) \n", - " if oMF.num_broken_steps > 0:\n", - " print(i, oMF)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\"generate\" in dir(env.malfunction_generator)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -424,6 +330,7 @@ "\n", "step = 0\n", "all_done = False\n", + "print(\"Processing episode steps:\")\n", "while not all_done and step < max_steps:\n", " print(step, end=\", \")\n", " \"\"\"\n", @@ -462,20 +369,21 @@ "\n", " \n", " # Force agent states from the recorded states\n", - " for idx, agent in enumerate(env.agents):\n", - " #print(episode_states[step][idx])\n", - " rcPos = episode_states[step][idx][0:2]\n", - " #print(idx, rcPos)\n", - " if rcPos == [0,0]:\n", - " agent.position = None\n", - " else:\n", - " agent.position = (*rcPos,) # episode_states[step][idx][0], episode_states[step][idx][1]#, episode_states[step][idx][2]\n", - " \n", - " agent.malfunction_data[\"malfunction\"] = episode_states[step][idx][3]\n", - " agent.direction = int(episode_states[step][idx][2])\n", + " if False:\n", + " for idx, agent in enumerate(env.agents):\n", + " #print(episode_states[step][idx])\n", + " rcPos = episode_states[step][idx][0:2]\n", + " #print(idx, rcPos)\n", + " if rcPos == [0,0]:\n", + " agent.position = None\n", + " else:\n", + " agent.position = (*rcPos,) # episode_states[step][idx][0], episode_states[step][idx][1]#, episode_states[step][idx][2]\n", + "\n", + " agent.malfunction_data[\"malfunction\"] = episode_states[step][idx][3]\n", + " agent.direction = int(episode_states[step][idx][2])\n", "\n", - " agent.old_direction = int(episode_states[step-1][idx][2])\n", - " agent.old_position = episode_states[step-1][idx][:2]\n", + " agent.old_direction = int(episode_states[step-1][idx][2])\n", + " agent.old_position = episode_states[step-1][idx][:2]\n", "\n", " statuses = []\n", " for a in range(n_agents):\n", @@ -542,7 +450,7 @@ "outputs": [], "source": [ "sfImg = env_file.replace(\"pkl\", \"gif\")\n", - "imageio.mimsave(sfImg, [d[\"image\"] for d in frames])" + "imageio.mimsave(sfImg, [d[\"image\"] for d in frames], subrectangles=True)" ] }, { @@ -606,18 +514,19 @@ " display.display(frame['image'])\n", " #print(frame['statuses'])\n", "\n", - "slider = widgets.FloatSlider(value=0, min=0, max=max_steps, step=1)\n", - "interact(plot_func, frame_idx = slider)\n", + "if True:\n", + " slider = widgets.FloatSlider(value=0, min=0, max=max_steps, step=1)\n", + " interact(plot_func, frame_idx = slider)\n", "\n", - "play = Play(\n", - " max=max_steps,\n", - " value=0,\n", - " step=1,\n", - " interval=250\n", - ")\n", + " play = Play(\n", + " max=max_steps,\n", + " value=0,\n", + " step=1,\n", + " interval=250\n", + " )\n", "\n", - "widgets.link((play, 'value'), (slider, 'value'))\n", - "widgets.VBox([play])" + " widgets.link((play, 'value'), (slider, 'value'))\n", + " widgets.VBox([play])" ] }, { @@ -629,6 +538,34 @@ "id": "j9_J2f6K64Jb" }, "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g3Ep = np.array(episode_states)\n", + "np.sum(g3Ep[:,:,3] > 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(np.sum(g3Ep[:,:,3]>0, axis=1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [] } ],