From 3e8a81946f099776f2e7e4df8493be07920129cc Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 28 Aug 2019 13:41:20 +0200
Subject: [PATCH] merge #141 renamed agent_generator* to schedule_generator*

---
 RLLib_training/RailEnvRLLibWrapper.py                | 9 ++++++++-
 scoring/utils/misc_utils.py                          | 4 ++--
 sequential_agent/run_test.py                         | 4 ++--
 torch_training/multi_agent_inference.py              | 2 +-
 torch_training/multi_agent_training.py               | 8 ++++----
 torch_training/multi_agent_two_time_step_training.py | 6 +++---
 torch_training/render_agent_behavior.py              | 4 ++--
 torch_training/training_navigation.py                | 4 ++--
 utils/misc_utils.py                                  | 4 ++--
 9 files changed, 26 insertions(+), 19 deletions(-)

diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
index eb9bbd6..f82cd42 100644
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ b/RLLib_training/RailEnvRLLibWrapper.py
@@ -4,6 +4,7 @@ from ray.rllib.utils.seed import seed as set_seed
 
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator, random_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
 
 
 class RailEnvRLLibWrapper(MultiAgentEnv):
@@ -25,19 +26,25 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
                                                          min_dist=config['min_dist'],
                                                          nr_extra=config['nr_extra'],
                                                          seed=config['seed'] * (1 + vector_index))
+            self.schedule_generator = complex_schedule_generator()
 
         elif config['rail_generator'] == "random_rail_generator":
             self.rail_generator = random_rail_generator()
+            self.schedule_generator = random_schedule_generator()
         elif config['rail_generator'] == "load_env":
             self.predefined_env = True
             self.rail_generator = random_rail_generator()
+            self.schedule_generator = random_schedule_generator()
         else:
             raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}')
 
         set_seed(config['seed'] * (1 + vector_index))
         self.env = RailEnv(width=config["width"], height=config["height"],
                            number_of_agents=config["number_of_agents"],
-                           obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator)
+                           obs_builder_object=config['obs_builder'],
+                           rail_generator=self.rail_generator,
+                           schedule_generator=self.schedule_generator
+                           )
 
         if self.predefined_env:
             self.env.load_resource('torch_training.railway', 'complex_scene.pkl')
diff --git a/scoring/utils/misc_utils.py b/scoring/utils/misc_utils.py
index e6f9195..dee5f47 100644
--- a/scoring/utils/misc_utils.py
+++ b/scoring/utils/misc_utils.py
@@ -7,7 +7,7 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator, rail_from_file
-from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 # Time factor to test the max time allowed for an env.
@@ -113,7 +113,7 @@ def create_testfiles(parameters, test_nr=0, nr_trials_per_test=100):
                   rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist,
                                                         max_dist=99999,
                                                         seed=parameters[3]),
-                  agent_generator=complex_rail_generator_agents_placer(),
+                  schedule_generator=complex_schedule_generator(),
                   obs_builder_object=TreeObsForRailEnv(max_depth=2),
                   number_of_agents=parameters[2])
     printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py
index 31b21e8..ecba34a 100644
--- a/sequential_agent/run_test.py
+++ b/sequential_agent/run_test.py
@@ -4,7 +4,7 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
-from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from sequential_agent.simple_order_agent import OrderedAgent
 
@@ -31,7 +31,7 @@ env = RailEnv(width=x_dim,
               rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
                                                     max_dist=99999,
                                                     seed=0),
-              agent_generator=complex_rail_generator_agents_placer(),
+              schedule_generator=complex_schedule_generator(),
               obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()),
               number_of_agents=n_agents)
 env.reset(True, True)
diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index 6e62687..66a37ad 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -41,7 +41,7 @@ env = RailEnv(width=x_dim,
               rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
                                                     max_dist=99999,
                                                     seed=0),
-              agent_generator=complex_rail_generator_agents_placer(),
+              schedule_generator=complex_schedule_generator(),
               obs_builder_object=observation_helper,
               number_of_agents=n_agents)
 env.reset(True, True)
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 1e545ec..3a4c431 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -16,7 +16,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 # Import Flatland/ Observations and Predictors
-from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
+from flatland.envs.schedule_generators import complex_schedule_generator
 from torch_training.dueling_double_dqn import Agent
 from utils.observation_utils import normalize_observation
 
@@ -57,7 +57,7 @@ def main(argv):
                   rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
                                                         max_dist=99999,
                                                         seed=0),
-                  agent_generator=complex_rail_generator_agents_placer(),
+                  schedule_generator=complex_schedule_generator(),
                   obs_builder_object=observation_helper,
                   number_of_agents=n_agents)
     env.reset(True, True)
@@ -115,9 +115,9 @@ def main(argv):
                           rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
                                                                 max_dist=99999,
                                                                 seed=0),
-                          agent_generator=complex_rail_generator_agents_placer(),
+                          schedule_generator=complex_schedule_generator(),
                           obs_builder_object=observation_helper,
-                          number_of_agents=n_agents)
+                          number_of_agents=n_agents)f
 
             # Adjust the parameters according to the new env.
             max_steps = int((env.height + env.width))
diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py
index d0aafa3..08cd84c 100644
--- a/torch_training/multi_agent_two_time_step_training.py
+++ b/torch_training/multi_agent_two_time_step_training.py
@@ -16,7 +16,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 # Import Flatland/ Observations and Predictors
-from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
+from flatland.envs.schedule_generators import complex_schedule_generator
 from torch_training.dueling_double_dqn import Agent
 from utils.observation_utils import norm_obs_clip, split_tree
 
@@ -53,7 +53,7 @@ def main(argv):
                   rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
                                                         max_dist=99999,
                                                         seed=0),
-                  agent_generator=complex_rail_generator_agents_placer(),
+                  schedule_generator=complex_schedule_generator(),
                   obs_builder_object=observation_helper,
                   number_of_agents=n_agents)
     env.reset(True, True)
@@ -111,7 +111,7 @@ def main(argv):
                           rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
                                                                 max_dist=99999,
                                                                 seed=0),
-                          agent_generator=complex_rail_generator_agents_placer(),
+                          schedule_generator=complex_schedule_generator(),
                           obs_builder_object=TreeObsForRailEnv(max_depth=3,
                                                                predictor=ShortestPathPredictorForRailEnv()),
                           number_of_agents=n_agents)
diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index 3b48d6f..3882cc9 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -10,7 +10,7 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
-from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from torch_training.dueling_double_dqn import Agent
 from utils.observation_utils import norm_obs_clip, split_tree
@@ -38,7 +38,7 @@ env = RailEnv(width=x_dim,
               rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
                                                     max_dist=99999,
                                                     seed=0),
-              agent_generator=complex_rail_generator_agents_placer(),
+              schedule_generator=complex_schedule_generator(),
               obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
               number_of_agents=n_agents)
 env.reset(True, True)
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 623beb5..fb03432 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -11,7 +11,7 @@ from dueling_double_dqn import Agent
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
-from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
+from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from utils.observation_utils import norm_obs_clip, split_tree
 
@@ -45,7 +45,7 @@ def main(argv):
                   rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
                                                         max_dist=99999,
                                                         seed=0),
-                  agent_generator=complex_rail_generator_agents_placer(),
+                  schedule_generator=complex_schedule_generator(),
                   obs_builder_object=observation_builder,
                   number_of_agents=n_agents)
     env.reset(True, True)
diff --git a/utils/misc_utils.py b/utils/misc_utils.py
index 19c704b..09b315c 100644
--- a/utils/misc_utils.py
+++ b/utils/misc_utils.py
@@ -8,7 +8,7 @@ from line_profiler import LineProfiler
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
-from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
+from flatland.envs.schedule_generators import complex_schedule_generator
 from utils.observation_utils import norm_obs_clip, split_tree
 
 
@@ -89,7 +89,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
                   rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist,
                                                         max_dist=99999,
                                                         seed=parameters[3]),
-                  agent_generator=complex_rail_generator_agents_placer(),
+                  schedule_generator=complex_schedule_generator(),
                   obs_builder_object=GlobalObsForRailEnv(),
                   number_of_agents=parameters[2])
     max_steps = int(3 * (env.height + env.width))
-- 
GitLab