From 4eb2d5ffa10c0ec8ad53f3c25d86ada62cdc01cc Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Thu, 23 May 2019 09:25:12 +0200 Subject: [PATCH] demo.py added to show real-evn --- env-data/railway/example_network_000.pkl | Bin 172 -> 180 bytes env-data/railway/example_network_001.pkl | Bin 210 -> 218 bytes env-data/railway/example_network_002.pkl | Bin 274 -> 282 bytes examples/demo.py | 214 +++++++++++++++++++++++ flatland/envs/rail_env.py | 2 + notebooks/Editor2.ipynb | 33 ++-- 6 files changed, 237 insertions(+), 12 deletions(-) create mode 100644 examples/demo.py diff --git a/env-data/railway/example_network_000.pkl b/env-data/railway/example_network_000.pkl index dbf868829b5aea9d046a3e69622edadaf6dd5ed1..280688c2629331621ab2ea80b4b096226464e653 100644 GIT binary patch delta 24 gcmZ3(xP@`TV!?@%C$TUvOk(9+mYANJS5iCy0Bef~xc~qF delta 16 YcmdnOxQ21UV$KQ664O)jN{S}{064=2=Kufz diff --git a/env-data/railway/example_network_001.pkl b/env-data/railway/example_network_001.pkl index e9c396fa45565f646a6aca5735ffb769d9db26ee..801f95149dec6eb4d47fd14e36d30f2541480188 100644 GIT binary patch delta 24 gcmcb_c#CnuNx_MeCowWGOyXczmYANJS5iCy0Cdy|A^-pY delta 16 Ycmcb`c!_bsNzMt&64O)jN{S}{06y6VSpWb4 diff --git a/env-data/railway/example_network_002.pkl b/env-data/railway/example_network_002.pkl index a598ca94bcd1193778ba1111be4371a14b3cae7c..898d54ebeb823e48790d4661ffe75a6940cd0712 100644 GIT binary patch delta 25 hcmbQlG>d5iC!^rR$&)xa8746>EK5vJ%_}LM003Z82o3-M delta 17 YcmbQmG>K^gCnM*CWr^vjc_qaY05YluJOBUy diff --git a/examples/demo.py b/examples/demo.py new file mode 100644 index 0000000..8f1638e --- /dev/null +++ b/examples/demo.py @@ -0,0 +1,214 @@ +import os +import random +from collections import deque + +import numpy as np +import torch + +from flatland.baselines.dueling_double_dqn import Agent +from flatland.envs.generators import complex_rail_generator +# from flatland.envs.generators import rail_from_list_of_saved_GridTransitionMap_generator +from flatland.envs.generators import random_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool + +# ensure that every demo run behave constantly equal +random.seed(1) +np.random.seed(1) + + +class Scenario_Generator: + @staticmethod + def generate_random_scenario(number_of_agents=3): + # Example generate a rail given a manual specification, + # a map of tuples (cell_type, rotation) + transition_probability = [15, # empty cell - Case 0 + 5, # Case 1 - straight + 5, # Case 2 - simple switch + 1, # Case 3 - diamond crossing + 1, # Case 4 - single slip + 1, # Case 5 - double slip + 1, # Case 6 - symmetrical + 0, # Case 7 - dead end + 1, # Case 1b (8) - simple turn right + 1, # Case 1c (9) - simple turn left + 1] # Case 2b (10) - simple switch mirrored + + # Example generate a random rail + + env = RailEnv(width=20, + height=20, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=number_of_agents) + + return env + + @staticmethod + def generate_complex_scenario(number_of_agents=3): + env = RailEnv(width=15, + height=15, + rail_generator=complex_rail_generator(nr_start_goal=6, nr_extra=30, min_dist=10, max_dist=99999, seed=0), + number_of_agents=number_of_agents) + + return env + + @staticmethod + def load_scenario(filename, number_of_agents=3): + env = RailEnv(width=2 * (1 + number_of_agents), + height=1 + number_of_agents) + + """ + env = RailEnv(width=20, + height=20, + rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( + [filename]), + number_of_agents=number_of_agents) + """ + if os.path.exists(filename): + print("load file: ", filename) + env.load(filename) + env.reset(False, False) + else: + print("File does not exist:", filename, " Working directory: ", os.getcwd()) + + return env + + +def max_lt(seq, val): + """ + Return greatest item in seq for which item < val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + max = 0 + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: + max = seq[idx] + idx -= 1 + return max + + +def min_lt(seq, val): + """ + Return smallest item in seq for which item > val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + min = np.inf + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] > val and seq[idx] < min: + min = seq[idx] + idx -= 1 + return min + + +def norm_obs_clip(obs, clip_min=-1, clip_max=1): + """ + This function returns the difference between min and max value of an observation + :param obs: Observation that should be normalized + :param clip_min: min value where observation will be clipped + :param clip_max: max value where observation will be clipped + :return: returnes normalized and clipped observatoin + """ + max_obs = max(1, max_lt(obs, 1000)) + min_obs = max(0, min_lt(obs, 0)) + if max_obs == min_obs: + return np.clip(np.array(obs) / max_obs, clip_min, clip_max) + norm = np.abs(max_obs - min_obs) + if norm == 0: + norm = 1. + return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) + + +class Demo: + + def __init__(self, env): + self.env = env + self.create_renderer() + self.load_agent() + + def load_agent(self): + self.state_size = 105 * 2 + self.action_size = 4 + self.agent = Agent(self.state_size, self.action_size, "FC", 0) + self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) + + def create_renderer(self): + self.renderer = RenderTool(self.env, gl="QT") + handle = self.env.get_agent_handles() + return handle + + def run_demo(self, max_nbr_of_steps=100): + action_dict = dict() + time_obs = deque(maxlen=2) + action_prob = [0] * 4 + agent_obs = [None] * self.env.get_num_agents() + agent_next_obs = [None] * self.env.get_num_agents() + + # Reset environment + obs = self.env.reset(False, False) + + for a in range(self.env.get_num_agents()): + data, distance = self.env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) + + data = norm_obs_clip(data) + distance = norm_obs_clip(distance) + obs[a] = np.concatenate((data, distance)) + + for i in range(2): + time_obs.append(obs) + + # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) + for a in range(self.env.get_num_agents()): + agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) + + for step in range(max_nbr_of_steps): + self.renderer.renderEnv(show=True) + + # print(step) + # Action + for a in range(self.env.get_num_agents()): + action = self.agent.act(agent_obs[a]) + action_prob[action] += 1 + action_dict.update({a: action}) + + # Environment step + next_obs, all_rewards, done, _ = self.env.step(action_dict) + for a in range(self.env.get_num_agents()): + data, distance = self.env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, + current_depth=0) + data = norm_obs_clip(data) + distance = norm_obs_clip(distance) + next_obs[a] = np.concatenate((data, distance)) + + # Update replay buffer and train agent + for a in range(self.env.get_num_agents()): + agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) + + time_obs.append(next_obs) + + agent_obs = agent_next_obs.copy() + if done['__all__']: + break + + +if False: + demo_000 = Demo(Scenario_Generator.generate_random_scenario()) + demo_000.run_demo() + demo_000 = None + + demo_001 = Demo(Scenario_Generator.generate_complex_scenario()) + demo_001.run_demo() + demo_001 = None + +demo_001 = Demo(Scenario_Generator.load_scenario('../env-data/railway/example_network_001.pkl')) +demo_001.run_demo() +demo_001 = None + +demo_002 = Demo(Scenario_Generator.load_scenario('../env-data/railway/example_network_002.pkl')) +demo_002.run_demo() +demo_002 = None + +demo_003 = Demo(Scenario_Generator.load_scenario('../env-data/railway/example_network_003.pkl')) +demo_003.run_demo() +demo_003 = None diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 118ebf4..74e7526 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -353,6 +353,8 @@ class RailEnv(Environment): self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape + self.rail.height = self.height + self.rail.width = self.width # self.agents = [None] * self.get_num_agents() self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) diff --git a/notebooks/Editor2.ipynb b/notebooks/Editor2.ipynb index 4e9a3d7..5dcfd55 100644 --- a/notebooks/Editor2.ipynb +++ b/notebooks/Editor2.ipynb @@ -9,9 +9,18 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 25, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -19,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -54,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -63,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -97,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 30, "metadata": { "scrolled": false }, @@ -105,7 +114,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7c89d2a7999f41e0b2ee1f79b4fa3df0", + "model_id": "47af532101994c36a053e16a9b31dcd6", "version_major": 2, "version_minor": 0 }, @@ -123,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 31, "metadata": { "scrolled": false }, @@ -131,7 +140,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2d0119cf2c704437bec328b1d19dd741", + "model_id": "949dc7440647445e82dd1ca0f250e5ca", "version_major": 2, "version_minor": 0 }, @@ -150,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -159,7 +168,7 @@ "(0, 0)" ] }, - "execution_count": 8, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } -- GitLab