From b8e41dc1a6721009be8ba26f46f6bd99a38d519f Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 1 Sep 2019 09:19:11 -0400 Subject: [PATCH] using new level generator for training and inference --- torch_training/dueling_double_dqn.py | 1 + torch_training/render_agent_behavior.py | 57 +++++++++++++++++++------ torch_training/training_navigation.py | 48 ++++++++++++++++----- 3 files changed, 82 insertions(+), 24 deletions(-) diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py index 4c2f0fa..dd67b4f 100644 --- a/torch_training/dueling_double_dqn.py +++ b/torch_training/dueling_double_dqn.py @@ -20,6 +20,7 @@ double_dqn = True # If using double dqn algorithm input_channels = 5 # Number of Input channels device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") print(device) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 3882cc9..44c02ac 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -9,8 +9,8 @@ import torch_training.Nets from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -27,20 +27,51 @@ x_dim = env.width y_dim = env.height """ -x_dim = np.random.randint(8, 20) -y_dim = np.random.randint(8, 20) -n_agents = np.random.randint(3, 8) -n_goals = n_agents + np.random.randint(0, 3) -min_dist = int(0.75 * min(x_dim, y_dim)) +# Parameters for the Environment +x_dim = 20 +y_dim = 20 +n_agents = 1 +n_goals = 5 +min_dist = 5 + +# We are training an Agent using the Tree Observation with depth 2 +observation_builder = TreeObsForRailEnv(max_depth=2) + +# Use a the malfunction generator to break agents from time to time +stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 20 # Max duration of malfunction + } + +# Custom observation builder +TreeObservation = TreeObsForRailEnv(max_depth=2) + +# Different agent types (trains) with different speeds. +speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, - max_dist=99999, - seed=0), - schedule_generator=complex_schedule_generator(), - obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), - number_of_agents=n_agents) + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=10, # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes + node_radius=2, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=15, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=n_agents, + stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=TreeObservation) env.reset(True, True) observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index fb03432..3110590 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -10,8 +10,8 @@ from dueling_double_dqn import Agent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool from utils.observation_utils import norm_obs_clip, split_tree @@ -30,8 +30,8 @@ def main(argv): np.random.seed(1) # Parameters for the Environment - x_dim = 10 - y_dim = 10 + x_dim = 20 + y_dim = 20 n_agents = 1 n_goals = 5 min_dist = 5 @@ -39,15 +39,41 @@ def main(argv): # We are training an Agent using the Tree Observation with depth 2 observation_builder = TreeObsForRailEnv(max_depth=2) - # Load the Environment + # Use a the malfunction generator to break agents from time to time + stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 20 # Max duration of malfunction + } + + # Custom observation builder + TreeObservation = TreeObsForRailEnv(max_depth=2) + + # Different agent types (trains) with different speeds. + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, - max_dist=99999, - seed=0), - schedule_generator=complex_schedule_generator(), - obs_builder_object=observation_builder, - number_of_agents=n_agents) + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=10, # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes + node_radius=2, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=15, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=n_agents, + stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=TreeObservation) env.reset(True, True) # After training we want to render the results so we also load a renderer -- GitLab