From 4eb2d5ffa10c0ec8ad53f3c25d86ada62cdc01cc Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Thu, 23 May 2019 09:25:12 +0200
Subject: [PATCH] demo.py added to show real-evn

---
 env-data/railway/example_network_000.pkl | Bin 172 -> 180 bytes
 env-data/railway/example_network_001.pkl | Bin 210 -> 218 bytes
 env-data/railway/example_network_002.pkl | Bin 274 -> 282 bytes
 examples/demo.py                         | 214 +++++++++++++++++++++++
 flatland/envs/rail_env.py                |   2 +
 notebooks/Editor2.ipynb                  |  33 ++--
 6 files changed, 237 insertions(+), 12 deletions(-)
 create mode 100644 examples/demo.py

diff --git a/env-data/railway/example_network_000.pkl b/env-data/railway/example_network_000.pkl
index dbf868829b5aea9d046a3e69622edadaf6dd5ed1..280688c2629331621ab2ea80b4b096226464e653 100644
GIT binary patch
delta 24
gcmZ3(xP@`TV!?@%C$TUvOk(9+mYANJS5iCy0Bef~xc~qF

delta 16
YcmdnOxQ21UV$KQ664O)jN{S}{064=2=Kufz

diff --git a/env-data/railway/example_network_001.pkl b/env-data/railway/example_network_001.pkl
index e9c396fa45565f646a6aca5735ffb769d9db26ee..801f95149dec6eb4d47fd14e36d30f2541480188 100644
GIT binary patch
delta 24
gcmcb_c#CnuNx_MeCowWGOyXczmYANJS5iCy0Cdy|A^-pY

delta 16
Ycmcb`c!_bsNzMt&64O)jN{S}{06y6VSpWb4

diff --git a/env-data/railway/example_network_002.pkl b/env-data/railway/example_network_002.pkl
index a598ca94bcd1193778ba1111be4371a14b3cae7c..898d54ebeb823e48790d4661ffe75a6940cd0712 100644
GIT binary patch
delta 25
hcmbQlG>d5iC!^rR$&)xa8746>EK5vJ%_}LM003Z82o3-M

delta 17
YcmbQmG>K^gCnM*CWr^vjc_qaY05YluJOBUy

diff --git a/examples/demo.py b/examples/demo.py
new file mode 100644
index 00000000..8f1638e6
--- /dev/null
+++ b/examples/demo.py
@@ -0,0 +1,214 @@
+import os
+import random
+from collections import deque
+
+import numpy as np
+import torch
+
+from flatland.baselines.dueling_double_dqn import Agent
+from flatland.envs.generators import complex_rail_generator
+# from flatland.envs.generators import rail_from_list_of_saved_GridTransitionMap_generator
+from flatland.envs.generators import random_rail_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool
+
+# ensure that every demo run behave constantly equal
+random.seed(1)
+np.random.seed(1)
+
+
+class Scenario_Generator:
+    @staticmethod
+    def generate_random_scenario(number_of_agents=3):
+        # Example generate a rail given a manual specification,
+        # a map of tuples (cell_type, rotation)
+        transition_probability = [15,  # empty cell - Case 0
+                                  5,  # Case 1 - straight
+                                  5,  # Case 2 - simple switch
+                                  1,  # Case 3 - diamond crossing
+                                  1,  # Case 4 - single slip
+                                  1,  # Case 5 - double slip
+                                  1,  # Case 6 - symmetrical
+                                  0,  # Case 7 - dead end
+                                  1,  # Case 1b (8)  - simple turn right
+                                  1,  # Case 1c (9)  - simple turn left
+                                  1]  # Case 2b (10) - simple switch mirrored
+
+        # Example generate a random rail
+
+        env = RailEnv(width=20,
+                      height=20,
+                      rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
+                      number_of_agents=number_of_agents)
+
+        return env
+
+    @staticmethod
+    def generate_complex_scenario(number_of_agents=3):
+        env = RailEnv(width=15,
+                      height=15,
+                      rail_generator=complex_rail_generator(nr_start_goal=6, nr_extra=30, min_dist=10, max_dist=99999, seed=0),
+                      number_of_agents=number_of_agents)
+
+        return env
+
+    @staticmethod
+    def load_scenario(filename, number_of_agents=3):
+        env = RailEnv(width=2 * (1 + number_of_agents),
+                      height=1 + number_of_agents)
+
+        """
+        env = RailEnv(width=20,
+                      height=20,
+                      rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
+                          [filename]),
+                      number_of_agents=number_of_agents)
+        """
+        if os.path.exists(filename):
+            print("load file: ", filename)
+            env.load(filename)
+            env.reset(False, False)
+        else:
+            print("File does not exist:", filename, " Working directory: ", os.getcwd())
+
+        return env
+
+
+def max_lt(seq, val):
+    """
+    Return greatest item in seq for which item < val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    max = 0
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
+            max = seq[idx]
+        idx -= 1
+    return max
+
+
+def min_lt(seq, val):
+    """
+    Return smallest item in seq for which item > val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    min = np.inf
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] > val and seq[idx] < min:
+            min = seq[idx]
+        idx -= 1
+    return min
+
+
+def norm_obs_clip(obs, clip_min=-1, clip_max=1):
+    """
+    This function returns the difference between min and max value of an observation
+    :param obs: Observation that should be normalized
+    :param clip_min: min value where observation will be clipped
+    :param clip_max: max value where observation will be clipped
+    :return: returnes normalized and clipped observatoin
+    """
+    max_obs = max(1, max_lt(obs, 1000))
+    min_obs = max(0, min_lt(obs, 0))
+    if max_obs == min_obs:
+        return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
+    norm = np.abs(max_obs - min_obs)
+    if norm == 0:
+        norm = 1.
+    return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
+
+
+class Demo:
+
+    def __init__(self, env):
+        self.env = env
+        self.create_renderer()
+        self.load_agent()
+
+    def load_agent(self):
+        self.state_size = 105 * 2
+        self.action_size = 4
+        self.agent = Agent(self.state_size, self.action_size, "FC", 0)
+        self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
+
+    def create_renderer(self):
+        self.renderer = RenderTool(self.env, gl="QT")
+        handle = self.env.get_agent_handles()
+        return handle
+
+    def run_demo(self, max_nbr_of_steps=100):
+        action_dict = dict()
+        time_obs = deque(maxlen=2)
+        action_prob = [0] * 4
+        agent_obs = [None] * self.env.get_num_agents()
+        agent_next_obs = [None] * self.env.get_num_agents()
+
+        # Reset environment
+        obs = self.env.reset(False, False)
+
+        for a in range(self.env.get_num_agents()):
+            data, distance = self.env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
+
+            data = norm_obs_clip(data)
+            distance = norm_obs_clip(distance)
+            obs[a] = np.concatenate((data, distance))
+
+        for i in range(2):
+            time_obs.append(obs)
+
+        # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+        for a in range(self.env.get_num_agents()):
+            agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+
+        for step in range(max_nbr_of_steps):
+            self.renderer.renderEnv(show=True)
+
+            # print(step)
+            # Action
+            for a in range(self.env.get_num_agents()):
+                action = self.agent.act(agent_obs[a])
+                action_prob[action] += 1
+                action_dict.update({a: action})
+
+            # Environment step
+            next_obs, all_rewards, done, _ = self.env.step(action_dict)
+            for a in range(self.env.get_num_agents()):
+                data, distance = self.env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
+                                                                 current_depth=0)
+                data = norm_obs_clip(data)
+                distance = norm_obs_clip(distance)
+                next_obs[a] = np.concatenate((data, distance))
+
+            # Update replay buffer and train agent
+            for a in range(self.env.get_num_agents()):
+                agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+
+            time_obs.append(next_obs)
+
+            agent_obs = agent_next_obs.copy()
+            if done['__all__']:
+                break
+
+
+if False:
+    demo_000 = Demo(Scenario_Generator.generate_random_scenario())
+    demo_000.run_demo()
+    demo_000 = None
+
+    demo_001 = Demo(Scenario_Generator.generate_complex_scenario())
+    demo_001.run_demo()
+    demo_001 = None
+
+demo_001 = Demo(Scenario_Generator.load_scenario('../env-data/railway/example_network_001.pkl'))
+demo_001.run_demo()
+demo_001 = None
+
+demo_002 = Demo(Scenario_Generator.load_scenario('../env-data/railway/example_network_002.pkl'))
+demo_002.run_demo()
+demo_002 = None
+
+demo_003 = Demo(Scenario_Generator.load_scenario('../env-data/railway/example_network_003.pkl'))
+demo_003.run_demo()
+demo_003 = None
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 118ebf4d..74e7526c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -353,6 +353,8 @@ class RailEnv(Environment):
         self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
         # setup with loaded data
         self.height, self.width = self.rail.grid.shape
+        self.rail.height = self.height
+        self.rail.width = self.width
         # self.agents = [None] * self.get_num_agents()
         self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
 
diff --git a/notebooks/Editor2.ipynb b/notebooks/Editor2.ipynb
index 4e9a3d76..5dcfd559 100644
--- a/notebooks/Editor2.ipynb
+++ b/notebooks/Editor2.ipynb
@@ -9,9 +9,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": 25,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "The autoreload extension is already loaded. To reload it, use:\n",
+      "  %reload_ext autoreload\n"
+     ]
+    }
+   ],
    "source": [
     "%load_ext autoreload\n",
     "%autoreload 2"
@@ -19,7 +28,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 26,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -32,7 +41,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 27,
    "metadata": {},
    "outputs": [
     {
@@ -54,7 +63,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 28,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -63,7 +72,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 29,
    "metadata": {},
    "outputs": [
     {
@@ -97,7 +106,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 30,
    "metadata": {
     "scrolled": false
    },
@@ -105,7 +114,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "7c89d2a7999f41e0b2ee1f79b4fa3df0",
+       "model_id": "47af532101994c36a053e16a9b31dcd6",
        "version_major": 2,
        "version_minor": 0
       },
@@ -123,7 +132,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 31,
    "metadata": {
     "scrolled": false
    },
@@ -131,7 +140,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "2d0119cf2c704437bec328b1d19dd741",
+       "model_id": "949dc7440647445e82dd1ca0f250e5ca",
        "version_major": 2,
        "version_minor": 0
       },
@@ -150,7 +159,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 32,
    "metadata": {},
    "outputs": [
     {
@@ -159,7 +168,7 @@
        "(0, 0)"
       ]
      },
-     "execution_count": 8,
+     "execution_count": 32,
      "metadata": {},
      "output_type": "execute_result"
     }
-- 
GitLab