From b8e41dc1a6721009be8ba26f46f6bd99a38d519f Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 1 Sep 2019 09:19:11 -0400
Subject: [PATCH] using new level generator for training and inference

---
 torch_training/dueling_double_dqn.py    |  1 +
 torch_training/render_agent_behavior.py | 57 +++++++++++++++++++------
 torch_training/training_navigation.py   | 48 ++++++++++++++++-----
 3 files changed, 82 insertions(+), 24 deletions(-)

diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py
index 4c2f0fa..dd67b4f 100644
--- a/torch_training/dueling_double_dqn.py
+++ b/torch_training/dueling_double_dqn.py
@@ -20,6 +20,7 @@ double_dqn = True  # If using double dqn algorithm
 input_channels = 5  # Number of Input channels
 
 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+device = torch.device("cpu")
 print(device)
 
 
diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index 3882cc9..44c02ac 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -9,8 +9,8 @@ import torch_training.Nets
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_generators import complex_rail_generator
-from flatland.envs.schedule_generators import complex_schedule_generator
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from torch_training.dueling_double_dqn import Agent
 from utils.observation_utils import norm_obs_clip, split_tree
@@ -27,20 +27,51 @@ x_dim = env.width
 y_dim = env.height
 """
 
-x_dim = np.random.randint(8, 20)
-y_dim = np.random.randint(8, 20)
-n_agents = np.random.randint(3, 8)
-n_goals = n_agents + np.random.randint(0, 3)
-min_dist = int(0.75 * min(x_dim, y_dim))
+# Parameters for the Environment
+x_dim = 20
+y_dim = 20
+n_agents = 1
+n_goals = 5
+min_dist = 5
+
+# We are training an Agent using the Tree Observation with depth 2
+observation_builder = TreeObsForRailEnv(max_depth=2)
+
+# Use a the malfunction generator to break agents from time to time
+stochastic_data = {'prop_malfunction': 0.1,  # Percentage of defective agents
+                   'malfunction_rate': 30,  # Rate of malfunction occurence
+                   'min_duration': 3,  # Minimal duration of malfunction
+                   'max_duration': 20  # Max duration of malfunction
+                   }
+
+# Custom observation builder
+TreeObservation = TreeObsForRailEnv(max_depth=2)
+
+# Different agent types (trains) with different speeds.
+speed_ration_map = {1.: 0.25,  # Fast passenger train
+                    1. / 2.: 0.25,  # Fast freight train
+                    1. / 3.: 0.25,  # Slow commuter train
+                    1. / 4.: 0.25}  # Slow freight train
 
 env = RailEnv(width=x_dim,
               height=y_dim,
-              rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
-                                                    max_dist=99999,
-                                                    seed=0),
-              schedule_generator=complex_schedule_generator(),
-              obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
-              number_of_agents=n_agents)
+              rail_generator=sparse_rail_generator(num_cities=5,
+                                                   # Number of cities in map (where train stations are)
+                                                   num_intersections=4,
+                                                   # Number of intersections (no start / target)
+                                                   num_trainstations=10,  # Number of possible start/targets on map
+                                                   min_node_dist=3,  # Minimal distance of nodes
+                                                   node_radius=2,  # Proximity of stations to city center
+                                                   num_neighb=3,
+                                                   # Number of connections to other cities/intersections
+                                                   seed=15,  # Random seed
+                                                   grid_mode=True,
+                                                   enhance_intersection=False
+                                                   ),
+              schedule_generator=sparse_schedule_generator(speed_ration_map),
+              number_of_agents=n_agents,
+              stochastic_data=stochastic_data,  # Malfunction data generator
+              obs_builder_object=TreeObservation)
 env.reset(True, True)
 
 observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index fb03432..3110590 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -10,8 +10,8 @@ from dueling_double_dqn import Agent
 
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_generators import complex_rail_generator
-from flatland.envs.schedule_generators import complex_schedule_generator
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from utils.observation_utils import norm_obs_clip, split_tree
 
@@ -30,8 +30,8 @@ def main(argv):
     np.random.seed(1)
 
     # Parameters for the Environment
-    x_dim = 10
-    y_dim = 10
+    x_dim = 20
+    y_dim = 20
     n_agents = 1
     n_goals = 5
     min_dist = 5
@@ -39,15 +39,41 @@ def main(argv):
     # We are training an Agent using the Tree Observation with depth 2
     observation_builder = TreeObsForRailEnv(max_depth=2)
 
-    # Load the Environment
+    # Use a the malfunction generator to break agents from time to time
+    stochastic_data = {'prop_malfunction': 0.1,  # Percentage of defective agents
+                       'malfunction_rate': 30,  # Rate of malfunction occurence
+                       'min_duration': 3,  # Minimal duration of malfunction
+                       'max_duration': 20  # Max duration of malfunction
+                       }
+
+    # Custom observation builder
+    TreeObservation = TreeObsForRailEnv(max_depth=2)
+
+    # Different agent types (trains) with different speeds.
+    speed_ration_map = {1.: 0.25,  # Fast passenger train
+                        1. / 2.: 0.25,  # Fast freight train
+                        1. / 3.: 0.25,  # Slow commuter train
+                        1. / 4.: 0.25}  # Slow freight train
+
     env = RailEnv(width=x_dim,
                   height=y_dim,
-                  rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
-                                                        max_dist=99999,
-                                                        seed=0),
-                  schedule_generator=complex_schedule_generator(),
-                  obs_builder_object=observation_builder,
-                  number_of_agents=n_agents)
+                  rail_generator=sparse_rail_generator(num_cities=5,
+                                                       # Number of cities in map (where train stations are)
+                                                       num_intersections=4,
+                                                       # Number of intersections (no start / target)
+                                                       num_trainstations=10,  # Number of possible start/targets on map
+                                                       min_node_dist=3,  # Minimal distance of nodes
+                                                       node_radius=2,  # Proximity of stations to city center
+                                                       num_neighb=3,
+                                                       # Number of connections to other cities/intersections
+                                                       seed=15,  # Random seed
+                                                       grid_mode=True,
+                                                       enhance_intersection=False
+                                                       ),
+                  schedule_generator=sparse_schedule_generator(speed_ration_map),
+                  number_of_agents=n_agents,
+                  stochastic_data=stochastic_data,  # Malfunction data generator
+                  obs_builder_object=TreeObservation)
     env.reset(True, True)
 
     # After training we want to render the results so we also load a renderer
-- 
GitLab