From 5f80e96d7f79a96c1964a5005076c6e2c1e0533d Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Wed, 27 Oct 2021 15:26:55 +0530
Subject: [PATCH] add new tests for malfunction off map

---
 flatland/envs/rail_env.py                |  2 +-
 tests/test_flatland_states_edge_cases.py | 83 ++++++++++++++++++++++++
 2 files changed, 84 insertions(+), 1 deletion(-)
 create mode 100644 tests/test_flatland_states_edge_cases.py

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 1a064c99..43797883 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -440,7 +440,7 @@ class RailEnv(Environment):
         action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
 
         # Check transitions, bounts for executing the action in the given position and directon
-        if not check_valid_action(action, self.rail, current_position, current_direction):
+        if action.is_moving_action() and not check_valid_action(action, self.rail, current_position, current_direction):
             action = RailEnvActions.STOP_MOVING
 
         return action
diff --git a/tests/test_flatland_states_edge_cases.py b/tests/test_flatland_states_edge_cases.py
new file mode 100644
index 00000000..c0d53838
--- /dev/null
+++ b/tests/test_flatland_states_edge_cases.py
@@ -0,0 +1,83 @@
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.line_generators import sparse_line_generator
+from flatland.utils.simple_rail import make_simple_rail
+from flatland.envs.step_utils.states import TrainState
+
+def test_return_to_ready_to_depart():
+    """
+    When going from ready to depart to malfunction off map, if do nothing is provided, should return to ready to depart
+    """
+    stochastic_data = MalfunctionParameters(malfunction_rate=0,  # Rate of malfunction occurence
+                                        min_duration=0,  # Minimal duration of malfunction
+                                        max_duration=0  # Max duration of malfunction
+                                        )
+
+    rail, _, optionals = make_simple_rail()
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
+                  line_generator=sparse_line_generator(seed=10),
+                  number_of_agents=1,
+                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                  )
+    
+    env.reset(False, False, random_seed=10)
+    env._max_episode_steps = 100
+
+    for _ in range(3):
+        env.step({0: RailEnvActions.DO_NOTHING})
+
+    env.agents[0].malfunction_handler._set_malfunction_down_counter(2)
+    env.step({0: RailEnvActions.DO_NOTHING})
+
+    assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP
+
+    for _ in range(2):
+        env.step({0: RailEnvActions.DO_NOTHING})
+
+    
+    assert env.agents[0].state == TrainState.READY_TO_DEPART
+
+def test_ready_to_depart_to_stopped():
+    """
+    When going from ready to depart to malfunction off map, if stopped is provided, should go to stopped
+    """
+    stochastic_data = MalfunctionParameters(malfunction_rate=0,  # Rate of malfunction occurence
+                                        min_duration=0,  # Minimal duration of malfunction
+                                        max_duration=0  # Max duration of malfunction
+                                        )
+
+    rail, _, optionals = make_simple_rail()
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
+                  line_generator=sparse_line_generator(seed=10),
+                  number_of_agents=1,
+                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                  )
+    
+    env.reset(False, False, random_seed=10)
+    env._max_episode_steps = 100
+
+    for _ in range(3):
+        env.step({0: RailEnvActions.STOP_MOVING})
+
+    assert env.agents[0].state == TrainState.READY_TO_DEPART
+
+    env.agents[0].malfunction_handler._set_malfunction_down_counter(2)
+    env.step({0: RailEnvActions.STOP_MOVING})
+
+    assert env.agents[0].state == TrainState.MALFUNCTION_OFF_MAP
+
+    for _ in range(2):
+        env.step({0: RailEnvActions.STOP_MOVING})
+
+    
+    assert env.agents[0].state == TrainState.STOPPED
\ No newline at end of file
-- 
GitLab