diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 0467fcd6b0086981c3dcd30d4bb25072361e96cb..435571c8216026b02e7ec4633ace2dcd85c15e4e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -360,8 +360,6 @@ class RailEnv(Environment): # Perform step on all agents for i_agent in range(self.get_num_agents()): - if self._elapsed_steps - 1 == 3 and i_agent == 0: - a = 5 self._step_agent(i_agent, action_dict_.get(i_agent)) # Check for end of episode + set global reward to all rewards! diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 473629e2115299dadbf2e8cea9d52c765ad1c32f..801c50629b6b13b7d3cc2c0e4b9739ce56e6cedf 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator -from test_utils import Replay, ReplayConfig, run_replay_config +from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay class SingleAgentNavigationObs(TreeObsForRailEnv): @@ -187,7 +187,7 @@ def test_initial_malfunction(): number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) - + set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ Replay( @@ -270,7 +270,7 @@ def test_initial_malfunction_stop_moving(): number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) - + set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ Replay( @@ -363,7 +363,7 @@ def test_initial_malfunction_do_nothing(): number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) - + set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[Replay( position=(28, 5), diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index f2fd3613d922ecd5dc23bd6a0649936f7838fbd6..57c0430d630615d02785545d0b54f3e9980c54d4 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -6,9 +6,8 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator -from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail -from test_utils import ReplayConfig, Replay, run_replay_config +from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay np.random.seed(1) @@ -106,9 +105,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - # env.start_penalty = 13 - # env.stop_penalty = 19 - + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ Replay( @@ -208,7 +205,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - + set_penalties_for_replay(env) test_configs = [ ReplayConfig( replay=[ @@ -391,15 +388,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - # initialize agents_static - env.reset() - - # reset to set agents from agents_static - env.reset(False, False) - - if rendering: - renderer = RenderTool(env, gl="PILSVG") - + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ Replay( diff --git a/tests/test_utils.py b/tests/test_utils.py index 88d669fa220e8d65f8e667532c39e0f0e6ad7a69..5821cd05604aa79b0b55a73136161b49e847f01e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -26,6 +26,13 @@ class ReplayConfig(object): speed = attrib(type=float) +# ensure that env is working correctly with start/stop/invalidaction penalty different from 0 +def set_penalties_for_replay(env: RailEnv): + env.step_penalty = 13 + env.stop_penalty = 19 + env.invalid_action_penalty = 29 + + def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False): """ Runs the replay configs and checks assertions. @@ -54,6 +61,7 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: info_dict = { 'action_required': [True for _ in test_configs] } + for step in range(len(test_configs[0].replay)): if step == 0: for a, test_config in enumerate(test_configs):