diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index e4b0abd4884136fcabbabf220e492f962ef0fb71..1dc24affd461ce585bceb6083df4feffe74c9b46 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -116,7 +116,7 @@ class RailEnv(Environment): number_of_agents=1, obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), stochastic_data=None, - remove_agents_at_target=False, + remove_agents_at_target=True, random_seed=1 ): """ diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index 987be3967746c7859414e24a452baab7758ee46b..e70012f8e4e77017c7dde4c3f1287e4d3bf72278 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -22,6 +22,7 @@ def test_initial_status(): schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + remove_agents_at_target=False ) env.reset() set_penalties_for_replay(env) diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 8f91088a205c92ec1d79fc9b3909a4c80ca72db5..f425636467ec7cefa0169db006122999b862308a 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -182,6 +182,7 @@ def test_reward_function_waiting(rendering=False): schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + remove_agents_at_target=False ) obs_builder: TreeObsForRailEnv = env.obs_builder # initialize agents_static diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index a07c1a59816ee1bb40ef747b4748f16ee472b684..6222cc2776d0e3741e894c61a95026bc6233e2a8 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1343,7 +1343,8 @@ def test_rail_env_action_required_info(): ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv()) + obs_builder_object=GlobalObsForRailEnv(), + remove_agents_at_target=False) env_always_action.reset() np.random.seed(0) random.seed(0) @@ -1358,7 +1359,8 @@ def test_rail_env_action_required_info(): ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv()) + obs_builder_object=GlobalObsForRailEnv(), + remove_agents_at_target=False) env_only_if_action_required.reset() env_renderer = RenderTool(env_always_action, gl="PILSVG", )