diff --git a/examples/demo.py b/examples/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..c77dc814b23e4402cc8cd52fc7e20c9528f4bd6b --- /dev/null +++ b/examples/demo.py @@ -0,0 +1,214 @@ +import os +import random +from collections import deque + +import numpy as np +import torch + +from flatland.baselines.dueling_double_dqn import Agent +from flatland.envs.generators import complex_rail_generator +# from flatland.envs.generators import rail_from_list_of_saved_GridTransitionMap_generator +from flatland.envs.generators import random_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool + +# ensure that every demo run behave constantly equal +random.seed(1) +np.random.seed(1) + + +class Scenario_Generator: + @staticmethod + def generate_random_scenario(number_of_agents=3): + # Example generate a rail given a manual specification, + # a map of tuples (cell_type, rotation) + transition_probability = [15, # empty cell - Case 0 + 5, # Case 1 - straight + 5, # Case 2 - simple switch + 1, # Case 3 - diamond crossing + 1, # Case 4 - single slip + 1, # Case 5 - double slip + 1, # Case 6 - symmetrical + 0, # Case 7 - dead end + 1, # Case 1b (8) - simple turn right + 1, # Case 1c (9) - simple turn left + 1] # Case 2b (10) - simple switch mirrored + + # Example generate a random rail + + env = RailEnv(width=20, + height=20, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=number_of_agents) + + return env + + @staticmethod + def generate_complex_scenario(number_of_agents=3): + env = RailEnv(width=15, + height=15, + rail_generator=complex_rail_generator(nr_start_goal=6, nr_extra=30, min_dist=10, max_dist=99999, seed=0), + number_of_agents=number_of_agents) + + return env + + @staticmethod + def load_scenario(filename, number_of_agents=3): + env = RailEnv(width=2 * (1 + number_of_agents), + height=1 + number_of_agents) + + """ + env = RailEnv(width=20, + height=20, + rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( + [filename]), + number_of_agents=number_of_agents) + """ + if os.path.exists(filename): + print("load file: ", filename) + env.load(filename) + env.reset(False, False) + else: + print("File does not exist:", filename, " Working directory: ", os.getcwd()) + + return env + + +def max_lt(seq, val): + """ + Return greatest item in seq for which item < val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + max = 0 + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: + max = seq[idx] + idx -= 1 + return max + + +def min_lt(seq, val): + """ + Return smallest item in seq for which item > val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + min = np.inf + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] > val and seq[idx] < min: + min = seq[idx] + idx -= 1 + return min + + +def norm_obs_clip(obs, clip_min=-1, clip_max=1): + """ + This function returns the difference between min and max value of an observation + :param obs: Observation that should be normalized + :param clip_min: min value where observation will be clipped + :param clip_max: max value where observation will be clipped + :return: returnes normalized and clipped observatoin + """ + max_obs = max(1, max_lt(obs, 1000)) + min_obs = max(0, min_lt(obs, 0)) + if max_obs == min_obs: + return np.clip(np.array(obs) / max_obs, clip_min, clip_max) + norm = np.abs(max_obs - min_obs) + if norm == 0: + norm = 1. + return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) + + +class Demo: + + def __init__(self, env): + self.env = env + self.create_renderer() + self.load_agent() + + def load_agent(self): + self.state_size = 105 * 2 + self.action_size = 4 + self.agent = Agent(self.state_size, self.action_size, "FC", 0) + self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) + + def create_renderer(self): + self.renderer = RenderTool(self.env, gl="QT") + handle = self.env.get_agent_handles() + return handle + + def run_demo(self, max_nbr_of_steps=100): + action_dict = dict() + time_obs = deque(maxlen=2) + action_prob = [0] * 4 + agent_obs = [None] * self.env.get_num_agents() + agent_next_obs = [None] * self.env.get_num_agents() + + # Reset environment + obs = self.env.reset(False, False) + + for a in range(self.env.get_num_agents()): + data, distance = self.env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) + + data = norm_obs_clip(data) + distance = norm_obs_clip(distance) + obs[a] = np.concatenate((data, distance)) + + for i in range(2): + time_obs.append(obs) + + # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) + for a in range(self.env.get_num_agents()): + agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) + + for step in range(max_nbr_of_steps): + self.renderer.renderEnv(show=True) + + # print(step) + # Action + for a in range(self.env.get_num_agents()): + action = self.agent.act(agent_obs[a]) + action_prob[action] += 1 + action_dict.update({a: action}) + + # Environment step + next_obs, all_rewards, done, _ = self.env.step(action_dict) + for a in range(self.env.get_num_agents()): + data, distance = self.env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, + current_depth=0) + data = norm_obs_clip(data) + distance = norm_obs_clip(distance) + next_obs[a] = np.concatenate((data, distance)) + + # Update replay buffer and train agent + for a in range(self.env.get_num_agents()): + agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) + + time_obs.append(next_obs) + + agent_obs = agent_next_obs.copy() + if done['__all__']: + break + + +if False: + demo_000 = Demo(Scenario_Generator.generate_random_scenario()) + demo_000.run_demo() + demo_000 = None + + demo_001 = Demo(Scenario_Generator.generate_complex_scenario()) + demo_001.run_demo() + demo_001 = None + +demo_000 = Demo(Scenario_Generator.load_scenario('../env-data/railway/example_network_000.pkl')) +demo_000.run_demo() +demo_000 = None + +demo_001 = Demo(Scenario_Generator.load_scenario('../env-data/railway/example_network_001.pkl')) +demo_001.run_demo() +demo_001 = None + +demo_002 = Demo(Scenario_Generator.load_scenario('../env-data/railway/example_network_002.pkl')) +demo_002.run_demo() +demo_002 = None diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 118ebf4d3ea122519a09c8b7b5a00c53964063a1..74e7526caa93ca8a1821eb5b2a47576231eb95c3 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -353,6 +353,8 @@ class RailEnv(Environment): self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape + self.rail.height = self.height + self.rail.width = self.width # self.agents = [None] * self.get_num_agents() self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) diff --git a/notebooks/Editor2.ipynb b/notebooks/Editor2.ipynb index c15cf09255ba4b9eff04b45803258b9e3fd0f71d..5dcfd5595bbc7d15e876397e01f9c369abb91c48 100644 --- a/notebooks/Editor2.ipynb +++ b/notebooks/Editor2.ipynb @@ -9,9 +9,18 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 25, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -19,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -54,24 +63,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 28, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cpu\n" - ] - } - ], + "outputs": [], "source": [ "from flatland.utils.editor import EditorMVC, EditorModel, View, Controller" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -105,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 30, "metadata": { "scrolled": false }, @@ -113,7 +114,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2f565b9a339d4be6915d9f6de3a5ef91", + "model_id": "47af532101994c36a053e16a9b31dcd6", "version_major": 2, "version_minor": 0 }, @@ -131,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 31, "metadata": { "scrolled": false }, @@ -139,7 +140,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3e78f1eb4f1f45468e94ef4eeb307d47", + "model_id": "949dc7440647445e82dd1ca0f250e5ca", "version_major": 2, "version_minor": 0 }, @@ -158,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -167,7 +168,7 @@ "(0, 0)" ] }, - "execution_count": 8, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -194,7 +195,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.5" + "version": "3.6.8" }, "latex_envs": { "LaTeX_envs_menu_present": true,