From ec3d7d639a39b9bbe72ae0114e9b9af87c467deb Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 22 Oct 2019 15:22:02 +0200
Subject: [PATCH] RailEnv default value in constructor
 remove_agents_at_target=True

---
 flatland/envs/rail_env.py                         | 2 +-
 tests/test_flaltland_rail_agent_status.py         | 1 +
 tests/test_flatland_envs_observations.py          | 1 +
 tests/test_flatland_envs_sparse_rail_generator.py | 6 ++++--
 4 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index e4b0abd4..1dc24aff 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -116,7 +116,7 @@ class RailEnv(Environment):
                  number_of_agents=1,
                  obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
                  stochastic_data=None,
-                 remove_agents_at_target=False,
+                 remove_agents_at_target=True,
                  random_seed=1
                  ):
         """
diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py
index 987be396..e70012f8 100644
--- a/tests/test_flaltland_rail_agent_status.py
+++ b/tests/test_flaltland_rail_agent_status.py
@@ -22,6 +22,7 @@ def test_initial_status():
                   schedule_generator=random_schedule_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  remove_agents_at_target=False
                   )
     env.reset()
     set_penalties_for_replay(env)
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 8f91088a..f4256364 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -182,6 +182,7 @@ def test_reward_function_waiting(rendering=False):
                   schedule_generator=random_schedule_generator(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  remove_agents_at_target=False
                   )
     obs_builder: TreeObsForRailEnv = env.obs_builder
     # initialize agents_static
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index a07c1a59..6222cc27 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -1343,7 +1343,8 @@ def test_rail_env_action_required_info():
                                 ),
                                 schedule_generator=sparse_schedule_generator(speed_ration_map),
                                 number_of_agents=10,
-                                obs_builder_object=GlobalObsForRailEnv())
+                                obs_builder_object=GlobalObsForRailEnv(),
+                                remove_agents_at_target=False)
     env_always_action.reset()
     np.random.seed(0)
     random.seed(0)
@@ -1358,7 +1359,8 @@ def test_rail_env_action_required_info():
                                           ),
                                           schedule_generator=sparse_schedule_generator(speed_ration_map),
                                           number_of_agents=10,
-                                          obs_builder_object=GlobalObsForRailEnv())
+                                          obs_builder_object=GlobalObsForRailEnv(),
+                                          remove_agents_at_target=False)
     env_only_if_action_required.reset()
     env_renderer = RenderTool(env_always_action, gl="PILSVG", )
 
-- 
GitLab