From 5f80e96d7f79a96c1964a5005076c6e2c1e0533d Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipam@aicrowd.com> Date: Wed, 27 Oct 2021 15:26:55 +0530 Subject: [PATCH] add new tests for malfunction off map --- flatland/envs/rail_env.py | 2 +- tests/test_flatland_states_edge_cases.py | 83 ++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 tests/test_flatland_states_edge_cases.py diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 1a064c99..43797883 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -440,7 +440,7 @@ class RailEnv(Environment): action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) # Check transitions, bounts for executing the action in the given position and directon - if not check_valid_action(action, self.rail, current_position, current_direction): + if action.is_moving_action() and not check_valid_action(action, self.rail, current_position, current_direction): action = RailEnvActions.STOP_MOVING return action diff --git a/tests/test_flatland_states_edge_cases.py b/tests/test_flatland_states_edge_cases.py new file mode 100644 index 00000000..c0d53838 --- /dev/null +++ b/tests/test_flatland_states_edge_cases.py @@ -0,0 +1,83 @@ +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.line_generators import sparse_line_generator +from flatland.utils.simple_rail import make_simple_rail +from flatland.envs.step_utils.states import TrainState + +def test_return_to_ready_to_depart(): + """ + When going from ready to depart to malfunction off map, if do nothing is provided, should return to ready to depart + """ + stochastic_data = MalfunctionParameters(malfunction_rate=0, # Rate of malfunction occurence + min_duration=0, # Minimal duration of malfunction + max_duration=0 # Max duration of malfunction + ) + + rail, _, optionals = make_simple_rail() + + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(seed=10), + number_of_agents=1, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + ) + + env.reset(False, False, random_seed=10) + env._max_episode_steps = 100 + + for _ in range(3): + env.step({0: RailEnvActions.DO_NOTHING}) + + env.agents[0].malfunction_handler._set_malfunction_down_counter(2) + env.step({0: RailEnvActions.DO_NOTHING}) + + assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP + + for _ in range(2): + env.step({0: RailEnvActions.DO_NOTHING}) + + + assert env.agents[0].state == TrainState.READY_TO_DEPART + +def test_ready_to_depart_to_stopped(): + """ + When going from ready to depart to malfunction off map, if stopped is provided, should go to stopped + """ + stochastic_data = MalfunctionParameters(malfunction_rate=0, # Rate of malfunction occurence + min_duration=0, # Minimal duration of malfunction + max_duration=0 # Max duration of malfunction + ) + + rail, _, optionals = make_simple_rail() + + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(seed=10), + number_of_agents=1, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + ) + + env.reset(False, False, random_seed=10) + env._max_episode_steps = 100 + + for _ in range(3): + env.step({0: RailEnvActions.STOP_MOVING}) + + assert env.agents[0].state == TrainState.READY_TO_DEPART + + env.agents[0].malfunction_handler._set_malfunction_down_counter(2) + env.step({0: RailEnvActions.STOP_MOVING}) + + assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP + + for _ in range(2): + env.step({0: RailEnvActions.STOP_MOVING}) + + + assert env.agents[0].state == TrainState.STOPPED \ No newline at end of file -- GitLab