From ec3d7d639a39b9bbe72ae0114e9b9af87c467deb Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 22 Oct 2019 15:22:02 +0200 Subject: [PATCH] RailEnv default value in constructor remove_agents_at_target=True --- flatland/envs/rail_env.py | 2 +- tests/test_flaltland_rail_agent_status.py | 1 + tests/test_flatland_envs_observations.py | 1 + tests/test_flatland_envs_sparse_rail_generator.py | 6 ++++-- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index e4b0abd4..1dc24aff 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -116,7 +116,7 @@ class RailEnv(Environment): number_of_agents=1, obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), stochastic_data=None, - remove_agents_at_target=False, + remove_agents_at_target=True, random_seed=1 ): """ diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index 987be396..e70012f8 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -22,6 +22,7 @@ def test_initial_status(): schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + remove_agents_at_target=False ) env.reset() set_penalties_for_replay(env) diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 8f91088a..f4256364 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -182,6 +182,7 @@ def test_reward_function_waiting(rendering=False): schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + remove_agents_at_target=False ) obs_builder: TreeObsForRailEnv = env.obs_builder # initialize agents_static diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index a07c1a59..6222cc27 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1343,7 +1343,8 @@ def test_rail_env_action_required_info(): ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv()) + obs_builder_object=GlobalObsForRailEnv(), + remove_agents_at_target=False) env_always_action.reset() np.random.seed(0) random.seed(0) @@ -1358,7 +1359,8 @@ def test_rail_env_action_required_info(): ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv()) + obs_builder_object=GlobalObsForRailEnv(), + remove_agents_at_target=False) env_only_if_action_required.reset() env_renderer = RenderTool(env_always_action, gl="PILSVG", ) -- GitLab