Skip to content
Snippets Groups Projects
render_episode.ipynb 228 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Render Episode\n",
adrian_egli2's avatar
adrian_egli2 committed
    "Render a stored episode.  Env file needs to have \"episode\" and \"action\" keys. \n",
    "- creates a moving gif file of the episode\n",
    "- displays the episode in a widget with a slider for the time steps."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "baXcVq3ii0Cb"
   },
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 153
    },
    "colab_type": "code",
    "id": "eKL0JthzupFg",
    "outputId": "2ec78745-cb78-4426-ee9d-b8acac185910"
   },
   "outputs": [],
   "source": [
    "#!apt -qq install graphviz libgraphviz-dev pkg-config\n",
    "#!pip install -qq git+https://gitlab.aicrowd.com/flatland/flatland.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eSHpLxdt1jmE"
   },
   "outputs": [],
   "source": [
    "import PIL\n",
    "from flatland.utils.rendertools import RenderTool\n",
    "import imageio\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PU5GkH271guD"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>.container { width:95% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import clear_output\n",
    "from IPython.core import display \n",
    "display.display(display.HTML(\"<style>.container { width:95% !important; }</style>\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eSHpLxdt1jmE"
   },
   "outputs": [],
   "source": [
    "def render_env(env_renderer):\n",
    "    ag0= env_renderer.env.agents[0]\n",
    "    #print(\"render_env ag0: \",ag0.position, ag0.direction)\n",
    "    aImage = env_renderer.render_env(show_rowcols=True, return_image=True)\n",
    "    pil_image = PIL.Image.fromarray(aImage)\n",
    "    return pil_image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UeX1h4c0i5e6"
   },
   "source": [
    "# Experiments\n",
    "\n",
    "This has been mostly changed to load envs using `importlib_resources`.  It's getting them from the package \"envdata.tests`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PU5GkH271guD"
   },
adrian_egli2's avatar
adrian_egli2 committed
   "outputs": [],
   "source": [
    "\n",
    "from flatland.envs.rail_env import RailEnv\n",
    "from flatland.envs.rail_generators import sparse_rail_generator\n",
adrian_egli2's avatar
adrian_egli2 committed
    "from flatland.envs.line_generators import sparse_line_generator\n",
    "from flatland.envs.malfunction_generators import malfunction_from_file, no_malfunction_generator\n",
    "from flatland.envs.rail_generators import rail_from_file\n",
adrian_egli2's avatar
adrian_egli2 committed
    "from flatland.envs.rail_env import RailEnvActions\n",
    "from flatland.envs.step_utils.states import TrainState"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flatland.envs.persistence import RailEnvPersister"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pickle failed to load file: complex_scene_2.pkl  trying msgpack (deprecated)...\n",
      "pickle failed to load file: complex_scene_2.pkl  trying msgpack (deprecated)...\n",
      "pickle failed to load file: complex_scene_2.pkl  trying msgpack (deprecated)...\n",
      "This env file has no max_episode_steps (deprecated) - setting to 100\n"
     ]
    }
   ],
   "source": [
    "env, env_dict = RailEnvPersister.load_new(\"complex_scene_2.pkl\", load_from_package=\"env_data.railway\")\n",
    "_ = env.reset()\n",
    "env._max_episode_steps = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\u216993\\.conda\\envs\\flatland3-rl\\lib\\site-packages\\flatland\\utils\\rendertools.py:399: UserWarning: Predictor did not provide any predicted cells to render.                 Observation builder needs to populate: env.dev_obs_dict\n",
      "  Observation builder needs to populate: env.dev_obs_dict\")\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<PIL.Image.Image image mode=RGBA size=704x284 at 0x1FFFCD7CE48>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# the seed has to match that used to record the episode, in order for the malfunctions to match.\n",
    "oRT = RenderTool(env, show_debug=True)\n",
    "aImg = oRT.render_env(show_rowcols=True, return_image=True, show_inactive_agents=True)\n",
    "print(env._max_episode_steps)\n",
    "PIL.Image.fromarray(aImg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>initial_direction</th>\n",
       "      <th>direction</th>\n",
       "      <th>initial_position</th>\n",
       "      <th>position</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(2, 1)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(1, 1)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(10, 4)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(11, 4)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(12, 4)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(13, 4)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(14, 4)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>(18, 48)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>(14, 48)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(13, 5)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>(13, 48)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    initial_direction  direction initial_position position\n",
       "0                   1          1           (2, 1)     None\n",
       "1                   1          1           (1, 1)     None\n",
       "2                   1          1          (10, 4)     None\n",
       "3                   1          1          (11, 4)     None\n",
       "4                   1          1          (12, 4)     None\n",
       "5                   1          1          (13, 4)     None\n",
       "6                   1          1          (14, 4)     None\n",
       "7                   3          3         (18, 48)     None\n",
       "8                   3          3         (14, 48)     None\n",
       "9                   1          1          (13, 5)     None\n",
       "10                  3          3         (13, 48)     None"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loAgs = env_dict[\"agents\"]\n",
    "lCols =  \"initial_direction,direction,initial_position,position\".split(\",\")\n",
    "pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols] \n",
    "              for oAg in loAgs], columns=lCols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>initial_direction</th>\n",
       "      <th>direction</th>\n",
       "      <th>initial_position</th>\n",
       "      <th>position</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(2, 1)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(1, 1)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(10, 4)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(11, 4)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(12, 4)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(13, 4)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(14, 4)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>(18, 48)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>(14, 48)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(13, 5)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>(13, 48)</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    initial_direction  direction initial_position position\n",
       "0                   1          1           (2, 1)     None\n",
       "1                   1          1           (1, 1)     None\n",
       "2                   1          1          (10, 4)     None\n",
       "3                   1          1          (11, 4)     None\n",
       "4                   1          1          (12, 4)     None\n",
       "5                   1          1          (13, 4)     None\n",
       "6                   1          1          (14, 4)     None\n",
       "7                   3          3         (18, 48)     None\n",
       "8                   3          3         (14, 48)     None\n",
       "9                   1          1          (13, 5)     None\n",
       "10                  3          3         (13, 48)     None"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols] \n",
    "              for oAg in env.agents], columns=lCols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>initial_position</th>\n",
       "      <th>initial_direction</th>\n",
       "      <th>direction</th>\n",
       "      <th>target</th>\n",
       "      <th>moving</th>\n",
       "      <th>earliest_departure</th>\n",
       "      <th>latest_arrival</th>\n",
       "      <th>handle</th>\n",
       "      <th>speed_counter</th>\n",
       "      <th>action_saver</th>\n",
       "      <th>state_machine</th>\n",
       "      <th>malfunction_handler</th>\n",
       "      <th>position</th>\n",
       "      <th>arrival_time</th>\n",
       "      <th>old_direction</th>\n",
       "      <th>old_position</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>(2, 1)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(10, 12)</td>\n",
       "      <td>False</td>\n",
       "      <td>64</td>\n",
       "      <td>191</td>\n",
       "      <td>0</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(1, 1)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(19, 48)</td>\n",
       "      <td>False</td>\n",
       "      <td>10</td>\n",
       "      <td>210</td>\n",
       "      <td>1</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>(10, 4)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(13, 46)</td>\n",
       "      <td>False</td>\n",
       "      <td>121</td>\n",
       "      <td>196</td>\n",
       "      <td>2</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>(11, 4)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(14, 46)</td>\n",
       "      <td>False</td>\n",
       "      <td>121</td>\n",
       "      <td>193</td>\n",
       "      <td>3</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>(12, 4)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(12, 42)</td>\n",
       "      <td>False</td>\n",
       "      <td>10</td>\n",
       "      <td>78</td>\n",
       "      <td>4</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>(13, 4)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(11, 42)</td>\n",
       "      <td>False</td>\n",
       "      <td>99</td>\n",
       "      <td>167</td>\n",
       "      <td>5</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>(14, 4)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(13, 15)</td>\n",
       "      <td>False</td>\n",
       "      <td>31</td>\n",
       "      <td>60</td>\n",
       "      <td>6</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>(18, 48)</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>(1, 2)</td>\n",
       "      <td>False</td>\n",
       "      <td>2</td>\n",
       "      <td>199</td>\n",
       "      <td>7</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>(14, 48)</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>(11, 12)</td>\n",
       "      <td>False</td>\n",
       "      <td>83</td>\n",
       "      <td>147</td>\n",
       "      <td>8</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>(13, 5)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(13, 13)</td>\n",
       "      <td>False</td>\n",
       "      <td>28</td>\n",
       "      <td>55</td>\n",
       "      <td>9</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>(13, 48)</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>(12, 12)</td>\n",
       "      <td>False</td>\n",
       "      <td>107</td>\n",
       "      <td>169</td>\n",
       "      <td>10</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   initial_position  initial_direction  direction    target  moving  \\\n",
       "0            (2, 1)                  1          1  (10, 12)   False   \n",
       "1            (1, 1)                  1          1  (19, 48)   False   \n",
       "2           (10, 4)                  1          1  (13, 46)   False   \n",
       "3           (11, 4)                  1          1  (14, 46)   False   \n",
       "4           (12, 4)                  1          1  (12, 42)   False   \n",
       "5           (13, 4)                  1          1  (11, 42)   False   \n",
       "6           (14, 4)                  1          1  (13, 15)   False   \n",
       "7          (18, 48)                  3          3    (1, 2)   False   \n",
       "8          (14, 48)                  3          3  (11, 12)   False   \n",
       "9           (13, 5)                  1          1  (13, 13)   False   \n",
       "10         (13, 48)                  3          3  (12, 12)   False   \n",
       "\n",
       "    earliest_departure  latest_arrival  handle  \\\n",
       "0                   64             191       0   \n",
       "1                   10             210       1   \n",
       "2                  121             196       2   \n",
       "3                  121             193       3   \n",
       "4                   10              78       4   \n",
       "5                   99             167       5   \n",
       "6                   31              60       6   \n",
       "7                    2             199       7   \n",
       "8                   83             147       8   \n",
       "9                   28              55       9   \n",
       "10                 107             169      10   \n",
       "\n",
       "                                        speed_counter  \\\n",
       "0   speed: 1.0                  max_count: 0      ...   \n",
       "1   speed: 1.0                  max_count: 0      ...   \n",
       "2   speed: 1.0                  max_count: 0      ...   \n",
       "3   speed: 1.0                  max_count: 0      ...   \n",
       "4   speed: 1.0                  max_count: 0      ...   \n",
       "5   speed: 1.0                  max_count: 0      ...   \n",
       "6   speed: 1.0                  max_count: 0      ...   \n",
       "7   speed: 1.0                  max_count: 0      ...   \n",
       "8   speed: 1.0                  max_count: 0      ...   \n",
       "9   speed: 1.0                  max_count: 0      ...   \n",
       "10  speed: 1.0                  max_count: 0      ...   \n",
       "\n",
       "                                  action_saver  \\\n",
       "0   is_action_saved: False, saved_action: None   \n",
       "1   is_action_saved: False, saved_action: None   \n",
       "2   is_action_saved: False, saved_action: None   \n",
       "3   is_action_saved: False, saved_action: None   \n",
       "4   is_action_saved: False, saved_action: None   \n",
       "5   is_action_saved: False, saved_action: None   \n",
       "6   is_action_saved: False, saved_action: None   \n",
       "7   is_action_saved: False, saved_action: None   \n",
       "8   is_action_saved: False, saved_action: None   \n",
       "9   is_action_saved: False, saved_action: None   \n",
       "10  is_action_saved: False, saved_action: None   \n",
       "\n",
       "                                        state_machine  \\\n",
       "0   \\n                  state: TrainState.WAITING ...   \n",
       "1   \\n                  state: TrainState.WAITING ...   \n",
       "2   \\n                  state: TrainState.WAITING ...   \n",
       "3   \\n                  state: TrainState.WAITING ...   \n",
       "4   \\n                  state: TrainState.WAITING ...   \n",
       "5   \\n                  state: TrainState.WAITING ...   \n",
       "6   \\n                  state: TrainState.WAITING ...   \n",
       "7   \\n                  state: TrainState.WAITING ...   \n",
       "8   \\n                  state: TrainState.WAITING ...   \n",
       "9   \\n                  state: TrainState.WAITING ...   \n",
       "10  \\n                  state: TrainState.WAITING ...   \n",
       "\n",
       "                                  malfunction_handler position arrival_time  \\\n",
       "0   malfunction_down_counter: 0                 in...     None         None   \n",
       "1   malfunction_down_counter: 0                 in...     None         None   \n",
       "2   malfunction_down_counter: 0                 in...     None         None   \n",
       "3   malfunction_down_counter: 0                 in...     None         None   \n",
       "4   malfunction_down_counter: 0                 in...     None         None   \n",
       "5   malfunction_down_counter: 0                 in...     None         None   \n",
       "6   malfunction_down_counter: 0                 in...     None         None   \n",
       "7   malfunction_down_counter: 0                 in...     None         None   \n",
       "8   malfunction_down_counter: 0                 in...     None         None   \n",
       "9   malfunction_down_counter: 0                 in...     None         None   \n",
       "10  malfunction_down_counter: 0                 in...     None         None   \n",
       "\n",
       "   old_direction old_position  \n",
       "0           None         None  \n",
       "1           None         None  \n",
       "2           None         None  \n",
       "3           None         None  \n",
       "4           None         None  \n",
       "5           None         None  \n",
       "6           None         None  \n",
       "7           None         None  \n",
       "8           None         None  \n",
       "9           None         None  \n",
       "10          None         None  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame([ vars(oAg) for oAg in env.agents])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kkejL06T8xyU"
   },
   "outputs": [],
   "source": [
    "# from persistence.py\n",
    "def get_agent_state(env):\n",
    "    list_agents_state = []\n",
    "    for iAg, oAg in enumerate(env.agents):\n",
    "        # the int cast is to avoid numpy types which may cause problems with msgpack\n",
    "        # in env v2, agents may have position None, before starting\n",
    "        if oAg.position is None:\n",
    "            pos = (0, 0)\n",
    "        else:\n",
    "            pos = (int(oAg.position[0]), int(oAg.position[1]))\n",
    "        # print(\"pos:\", pos, type(pos[0]))\n",
    "        list_agents_state.append(\n",
adrian_egli2's avatar
adrian_egli2 committed
    "            [*pos, int(oAg.direction), oAg.malfunction_handler])\n",
    "      \n",
    "    return list_agents_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
adrian_egli2's avatar
adrian_egli2 committed
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>initial_position</th>\n",
       "      <th>initial_direction</th>\n",
       "      <th>direction</th>\n",
       "      <th>target</th>\n",
       "      <th>moving</th>\n",
       "      <th>earliest_departure</th>\n",
       "      <th>latest_arrival</th>\n",
       "      <th>handle</th>\n",
       "      <th>speed_counter</th>\n",
       "      <th>action_saver</th>\n",
       "      <th>state_machine</th>\n",
       "      <th>malfunction_handler</th>\n",
       "      <th>position</th>\n",
       "      <th>arrival_time</th>\n",
       "      <th>old_direction</th>\n",
       "      <th>old_position</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>(2, 1)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(10, 12)</td>\n",
       "      <td>False</td>\n",
       "      <td>64</td>\n",
       "      <td>191</td>\n",
       "      <td>0</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(1, 1)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(19, 48)</td>\n",
       "      <td>False</td>\n",
       "      <td>10</td>\n",
       "      <td>210</td>\n",
       "      <td>1</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>(10, 4)</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>(13, 46)</td>\n",
       "      <td>False</td>\n",
       "      <td>121</td>\n",
       "      <td>196</td>\n",
       "      <td>2</td>\n",
       "      <td>speed: 1.0                  max_count: 0      ...</td>\n",
       "      <td>is_action_saved: False, saved_action: None</td>\n",
       "      <td>\\n                  state: TrainState.WAITING ...</td>\n",
       "      <td>malfunction_down_counter: 0                 in...</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>(11, 4)</td>\n",
       "      <td>1</td>\n",