From 722b0928d9d1f9d29cf36e2e004e74e3b6ccacb5 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 19 Sep 2019 16:43:29 +0200 Subject: [PATCH] #168 #163 multispeed and penalty testing --- flatland/envs/rail_env.py | 2 -- tests/test_flatland_malfunction.py | 8 ++++---- tests/test_multi_speed.py | 19 ++++--------------- tests/test_utils.py | 8 ++++++++ 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 0467fcd6..435571c8 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -360,8 +360,6 @@ class RailEnv(Environment): # Perform step on all agents for i_agent in range(self.get_num_agents()): - if self._elapsed_steps - 1 == 3 and i_agent == 0: - a = 5 self._step_agent(i_agent, action_dict_.get(i_agent)) # Check for end of episode + set global reward to all rewards! diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 473629e2..801c5062 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator -from test_utils import Replay, ReplayConfig, run_replay_config +from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay class SingleAgentNavigationObs(TreeObsForRailEnv): @@ -187,7 +187,7 @@ def test_initial_malfunction(): number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) - + set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ Replay( @@ -270,7 +270,7 @@ def test_initial_malfunction_stop_moving(): number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) - + set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ Replay( @@ -363,7 +363,7 @@ def test_initial_malfunction_do_nothing(): number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) - + set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[Replay( position=(28, 5), diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index f2fd3613..57c0430d 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -6,9 +6,8 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator -from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail -from test_utils import ReplayConfig, Replay, run_replay_config +from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay np.random.seed(1) @@ -106,9 +105,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - # env.start_penalty = 13 - # env.stop_penalty = 19 - + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ Replay( @@ -208,7 +205,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - + set_penalties_for_replay(env) test_configs = [ ReplayConfig( replay=[ @@ -391,15 +388,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - # initialize agents_static - env.reset() - - # reset to set agents from agents_static - env.reset(False, False) - - if rendering: - renderer = RenderTool(env, gl="PILSVG") - + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ Replay( diff --git a/tests/test_utils.py b/tests/test_utils.py index 88d669fa..5821cd05 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -26,6 +26,13 @@ class ReplayConfig(object): speed = attrib(type=float) +# ensure that env is working correctly with start/stop/invalidaction penalty different from 0 +def set_penalties_for_replay(env: RailEnv): + env.step_penalty = 13 + env.stop_penalty = 19 + env.invalid_action_penalty = 29 + + def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False): """ Runs the replay configs and checks assertions. @@ -54,6 +61,7 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: info_dict = { 'action_required': [True for _ in test_configs] } + for step in range(len(test_configs[0].replay)): if step == 0: for a, test_config in enumerate(test_configs): -- GitLab