From cced4901185dbb9e440ab916ba611fa324c4e06f Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Thu, 17 Oct 2019 11:46:28 +0200 Subject: [PATCH] add env.reset() to examples and notebooks --- examples/Simple_Realistic_Railway_Generator.py | 3 ++- examples/complex_rail_benchmark.py | 1 + examples/custom_observation_example_01_SimpleObs.py | 1 + examples/introduction_flatland_2_1.py | 1 + examples/simple_example_3.py | 2 ++ examples/training_example.py | 1 + notebooks/Simple_Rendering_Demo.ipynb | 3 ++- notebooks/simple_example_3_manual_control.ipynb | 1 + 8 files changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index d4f013d5..a7f6a0bf 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -596,7 +596,8 @@ for itrials in range(100): number_of_agents=1000, obs_builder_object=GlobalObsForRailEnv()) - # reset to initialize agents_static + env.reset() + env_renderer = RenderTool(env, gl="PILSVG", screen_width=1400, screen_height=1000) cnt = 0 while cnt < 10: diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index aaa5d286..ecbbe8b4 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -18,6 +18,7 @@ def run_benchmark(): rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), schedule_generator=complex_schedule_generator(), number_of_agents=5) + env.reset() n_trials = 20 action_dict = dict() diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py index 8c128867..600b8f09 100644 --- a/examples/custom_observation_example_01_SimpleObs.py +++ b/examples/custom_observation_example_01_SimpleObs.py @@ -33,6 +33,7 @@ def main(): rail_generator=random_rail_generator(), number_of_agents=3, obs_builder_object=SimpleObs()) + env.reset() # Print the observation vector for each agents obs, all_rewards, done, _ = env.step({0: 0}) diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index df7a4e2d..2db53a40 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -80,6 +80,7 @@ env = RailEnv(width=width, obs_builder_object=observation_builder, remove_agents_at_target=True # Removes agents at the end of their journey to make space for others ) +env.reset() # Initiate the renderer env_renderer = RenderTool(env, gl="PILSVG", diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index f294279d..ccbe8682 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -18,6 +18,8 @@ env = RailEnv(width=7, number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) +env.reset() + # Print the observation vector for agent 0 obs, all_rewards, done, _ = env.step({0: 0}) for i in range(env.get_num_agents()): diff --git a/examples/training_example.py b/examples/training_example.py index 9d663171..2ce2ad1a 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -20,6 +20,7 @@ env = RailEnv(width=20, schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObservation, number_of_agents=3) +env.reset() env_renderer = RenderTool(env, gl="PILSVG", ) diff --git a/notebooks/Simple_Rendering_Demo.ipynb b/notebooks/Simple_Rendering_Demo.ipynb index 2084ee46..818b2130 100644 --- a/notebooks/Simple_Rendering_Demo.ipynb +++ b/notebooks/Simple_Rendering_Demo.ipynb @@ -81,7 +81,8 @@ " height=10,\n", " rail_generator=fnMethod,\n", " number_of_agents=nAgents,\n", - " obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))" + " obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))\n", + "env.reset()" ] }, { diff --git a/notebooks/simple_example_3_manual_control.ipynb b/notebooks/simple_example_3_manual_control.ipynb index cb2b3777..3b29a3e6 100644 --- a/notebooks/simple_example_3_manual_control.ipynb +++ b/notebooks/simple_example_3_manual_control.ipynb @@ -60,6 +60,7 @@ " rail_generator=random_rail_generator(),\n", " number_of_agents=2,\n", " obs_builder_object=TreeObsForRailEnv(max_depth=2))\n", + "env.reset()\n", "\n", "# Print the observation vector for agent 0\n", "obs, all_rewards, done, _ = env.step({0: 0})\n", -- GitLab