diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 1a064c992803a6e0a3fad509b0632187624d8f83..437978833627874f79e2f42107c4f91c065ff413 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -440,7 +440,7 @@ class RailEnv(Environment): action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) # Check transitions, bounts for executing the action in the given position and directon - if not check_valid_action(action, self.rail, current_position, current_direction): + if action.is_moving_action() and not check_valid_action(action, self.rail, current_position, current_direction): action = RailEnvActions.STOP_MOVING return action diff --git a/tests/test_flatland_states_edge_cases.py b/tests/test_flatland_states_edge_cases.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d5383880dc1df3c2e1e02eca6c535aa3de0c00 --- /dev/null +++ b/tests/test_flatland_states_edge_cases.py @@ -0,0 +1,83 @@ +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.line_generators import sparse_line_generator +from flatland.utils.simple_rail import make_simple_rail +from flatland.envs.step_utils.states import TrainState + +def test_return_to_ready_to_depart(): + """ + When going from ready to depart to malfunction off map, if do nothing is provided, should return to ready to depart + """ + stochastic_data = MalfunctionParameters(malfunction_rate=0, # Rate of malfunction occurence + min_duration=0, # Minimal duration of malfunction + max_duration=0 # Max duration of malfunction + ) + + rail, _, optionals = make_simple_rail() + + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(seed=10), + number_of_agents=1, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + ) + + env.reset(False, False, random_seed=10) + env._max_episode_steps = 100 + + for _ in range(3): + env.step({0: RailEnvActions.DO_NOTHING}) + + env.agents[0].malfunction_handler._set_malfunction_down_counter(2) + env.step({0: RailEnvActions.DO_NOTHING}) + + assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP + + for _ in range(2): + env.step({0: RailEnvActions.DO_NOTHING}) + + + assert env.agents[0].state == TrainState.READY_TO_DEPART + +def test_ready_to_depart_to_stopped(): + """ + When going from ready to depart to malfunction off map, if stopped is provided, should go to stopped + """ + stochastic_data = MalfunctionParameters(malfunction_rate=0, # Rate of malfunction occurence + min_duration=0, # Minimal duration of malfunction + max_duration=0 # Max duration of malfunction + ) + + rail, _, optionals = make_simple_rail() + + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(seed=10), + number_of_agents=1, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + ) + + env.reset(False, False, random_seed=10) + env._max_episode_steps = 100 + + for _ in range(3): + env.step({0: RailEnvActions.STOP_MOVING}) + + assert env.agents[0].state == TrainState.READY_TO_DEPART + + env.agents[0].malfunction_handler._set_malfunction_down_counter(2) + env.step({0: RailEnvActions.STOP_MOVING}) + + assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP + + for _ in range(2): + env.step({0: RailEnvActions.STOP_MOVING}) + + + assert env.agents[0].state == TrainState.STOPPED \ No newline at end of file