From cced4901185dbb9e440ab916ba611fa324c4e06f Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Thu, 17 Oct 2019 11:46:28 +0200
Subject: [PATCH] add env.reset() to examples and notebooks

---
 examples/Simple_Realistic_Railway_Generator.py      | 3 ++-
 examples/complex_rail_benchmark.py                  | 1 +
 examples/custom_observation_example_01_SimpleObs.py | 1 +
 examples/introduction_flatland_2_1.py               | 1 +
 examples/simple_example_3.py                        | 2 ++
 examples/training_example.py                        | 1 +
 notebooks/Simple_Rendering_Demo.ipynb               | 3 ++-
 notebooks/simple_example_3_manual_control.ipynb     | 1 +
 8 files changed, 11 insertions(+), 2 deletions(-)

diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
index d4f013d5..a7f6a0bf 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -596,7 +596,8 @@ for itrials in range(100):
                   number_of_agents=1000,
                   obs_builder_object=GlobalObsForRailEnv())
 
-    # reset to initialize agents_static
+    env.reset()
+
     env_renderer = RenderTool(env, gl="PILSVG", screen_width=1400, screen_height=1000)
     cnt = 0
     while cnt < 10:
diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py
index aaa5d286..ecbbe8b4 100644
--- a/examples/complex_rail_benchmark.py
+++ b/examples/complex_rail_benchmark.py
@@ -18,6 +18,7 @@ def run_benchmark():
                   rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
                   schedule_generator=complex_schedule_generator(),
                   number_of_agents=5)
+    env.reset()
 
     n_trials = 20
     action_dict = dict()
diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py
index 8c128867..600b8f09 100644
--- a/examples/custom_observation_example_01_SimpleObs.py
+++ b/examples/custom_observation_example_01_SimpleObs.py
@@ -33,6 +33,7 @@ def main():
                   rail_generator=random_rail_generator(),
                   number_of_agents=3,
                   obs_builder_object=SimpleObs())
+    env.reset()
 
     # Print the observation vector for each agents
     obs, all_rewards, done, _ = env.step({0: 0})
diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py
index df7a4e2d..2db53a40 100644
--- a/examples/introduction_flatland_2_1.py
+++ b/examples/introduction_flatland_2_1.py
@@ -80,6 +80,7 @@ env = RailEnv(width=width,
               obs_builder_object=observation_builder,
               remove_agents_at_target=True  # Removes agents at the end of their journey to make space for others
               )
+env.reset()
 
 # Initiate the renderer
 env_renderer = RenderTool(env, gl="PILSVG",
diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index f294279d..ccbe8682 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -18,6 +18,8 @@ env = RailEnv(width=7,
               number_of_agents=2,
               obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
+env.reset()
+
 # Print the observation vector for agent 0
 obs, all_rewards, done, _ = env.step({0: 0})
 for i in range(env.get_num_agents()):
diff --git a/examples/training_example.py b/examples/training_example.py
index 9d663171..2ce2ad1a 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -20,6 +20,7 @@ env = RailEnv(width=20,
               schedule_generator=complex_schedule_generator(),
               obs_builder_object=TreeObservation,
               number_of_agents=3)
+env.reset()
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 
diff --git a/notebooks/Simple_Rendering_Demo.ipynb b/notebooks/Simple_Rendering_Demo.ipynb
index 2084ee46..818b2130 100644
--- a/notebooks/Simple_Rendering_Demo.ipynb
+++ b/notebooks/Simple_Rendering_Demo.ipynb
@@ -81,7 +81,8 @@
     "                  height=10,\n",
     "                  rail_generator=fnMethod,\n",
     "                  number_of_agents=nAgents,\n",
-    "                  obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))"
+    "                  obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))\n",
+    "env.reset()"
    ]
   },
   {
diff --git a/notebooks/simple_example_3_manual_control.ipynb b/notebooks/simple_example_3_manual_control.ipynb
index cb2b3777..3b29a3e6 100644
--- a/notebooks/simple_example_3_manual_control.ipynb
+++ b/notebooks/simple_example_3_manual_control.ipynb
@@ -60,6 +60,7 @@
     "              rail_generator=random_rail_generator(),\n",
     "              number_of_agents=2,\n",
     "              obs_builder_object=TreeObsForRailEnv(max_depth=2))\n",
+    "env.reset()\n",
     "\n",
     "# Print the observation vector for agent 0\n",
     "obs, all_rewards, done, _ = env.step({0: 0})\n",
-- 
GitLab