From 94feb03955515de836cbb75e90b690ac4c2ca7bd Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Tue, 14 Sep 2021 17:31:26 +0530
Subject: [PATCH] position update allowed on cell exit and stopped state

---
 flatland/envs/rail_env.py                     | 25 +++++++++++++------
 .../envs/step_utils/action_preprocessing.py   | 10 ++++----
 .../envs/step_utils/malfunction_handler.py    | 12 ++++++---
 flatland/envs/step_utils/state_machine.py     | 17 ++++++++++---
 4 files changed, 44 insertions(+), 20 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index e7f736f5..ce5a09b3 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -431,7 +431,7 @@ class RailEnv(Environment):
             * Block all actions when in waiting state
             * Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
         """
-        action = action_preprocessing.preprocess_raw_action(action, agent.state)
+        action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action)
         action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
 
         # Try moving actions on current position
@@ -440,7 +440,6 @@ class RailEnv(Environment):
             current_position, current_direction = agent.initial_position, agent.initial_direction
         
         action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
-
         return action
     
     def clear_rewards_dict(self):
@@ -513,6 +512,9 @@ class RailEnv(Environment):
             # Save moving actions in not already saved
             agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
 
+            # Train's next position can change if current stopped in a fractional speed or train is at cell's exit
+            position_update_allowed = (agent.speed_counter.is_cell_exit or agent.state == TrainState.STOPPED)
+
             # Calculate new position
             # Add agent to the map if not on it yet
             if agent.position is None and agent.action_saver.is_action_saved:
@@ -520,7 +522,7 @@ class RailEnv(Environment):
                 new_direction = agent.initial_direction
                 
             # If movement is allowed apply saved action independent of other agents
-            elif agent.action_saver.is_action_saved:
+            elif agent.action_saver.is_action_saved and position_update_allowed:
                 saved_action = agent.action_saver.saved_action
                 # Apply action independent of other agents and get temporary new position and direction
                 new_position, new_direction  = self.apply_action_independent(saved_action, 
@@ -557,7 +559,7 @@ class RailEnv(Environment):
                (agent.speed_counter.is_cell_exit or agent.position is None):
                 agent.position = agent_transition_data.position
                 agent.direction = agent_transition_data.direction
-                
+
             preprocessed_action = agent_transition_data.preprocessed_action
 
             ## Update states
@@ -565,9 +567,8 @@ class RailEnv(Environment):
             agent.state_machine.set_transition_signals(state_transition_signals)
             agent.state_machine.step()
 
-            if agent.state.is_on_map_state() and agent.position is None:
-                raise ValueError("Agent ID {} Agent State {} not matching with Agent Position {} ".format(
-                                    agent.handle, str(agent.state), str(agent.position) ))
+            # Off map or on map state and position should match
+            state_position_sync_check(agent.state, agent.position, agent.handle)
 
             # Handle done state actions, optionally remove agents
             self.handle_done_state(agent)
@@ -583,7 +584,7 @@ class RailEnv(Environment):
             agent.malfunction_handler.update_counter()
 
             # Clear old action when starting in new cell
-            if agent.speed_counter.is_cell_entry:
+            if agent.speed_counter.is_cell_entry and agent.position is not None:
                 agent.action_saver.clear_saved_action()
         
         # Check if episode has ended and update rewards and dones
@@ -687,3 +688,11 @@ def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
         return False
     else:
         return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
+
+def state_position_sync_check(state, position, i_agent):
+    if state.is_on_map_state() and position is None:
+        raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format(
+                        i_agent, str(state), str(position) ))
+    elif state.is_off_map_state() and position is not None:
+        raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format(
+                        i_agent, str(state), str(position) ))
diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py
index a397054c..47f06e2c 100644
--- a/flatland/envs/step_utils/action_preprocessing.py
+++ b/flatland/envs/step_utils/action_preprocessing.py
@@ -11,9 +11,11 @@ def process_illegal_action(action: RailEnvActions):
 		return RailEnvActions(action)
 
 
-def process_do_nothing(state: TrainState):
+def process_do_nothing(state: TrainState, saved_action: RailEnvActions):
     if state == TrainState.MOVING:
         action = RailEnvActions.MOVE_FORWARD
+    elif saved_action:
+        action = saved_action
     else:
         action = RailEnvActions.STOP_MOVING
     return action
@@ -34,7 +36,7 @@ def preprocess_action_when_waiting(action, state):
     return action
 
 
-def preprocess_raw_action(action, state):
+def preprocess_raw_action(action, state, saved_action):
     """
     Preprocesses actions to handle different situations of usage of action based on context
         - DO_NOTHING is converted to FORWARD if train is moving
@@ -43,7 +45,7 @@ def preprocess_raw_action(action, state):
     action = process_illegal_action(action)
 
     if action == RailEnvActions.DO_NOTHING:
-        action = process_do_nothing(state)
+        action = process_do_nothing(state, saved_action)
 
     return action
 
@@ -55,6 +57,4 @@ def preprocess_moving_action(action, rail, position, direction):
     if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
         action = process_left_right(action, rail, position, direction)
 
-    if not check_valid_action(action, rail, position, direction):
-        action = RailEnvActions.STOP_MOVING
     return action
\ No newline at end of file
diff --git a/flatland/envs/step_utils/malfunction_handler.py b/flatland/envs/step_utils/malfunction_handler.py
index 2ba72643..bf1f188f 100644
--- a/flatland/envs/step_utils/malfunction_handler.py
+++ b/flatland/envs/step_utils/malfunction_handler.py
@@ -10,6 +10,7 @@ def get_number_of_steps_to_break(malfunction_generator, np_random):
 class MalfunctionHandler:
     def __init__(self):
         self._malfunction_down_counter = 0
+        self.num_malfunctions = 0
     
     @property
     def in_malfunction(self):
@@ -33,6 +34,7 @@ class MalfunctionHandler:
         # Only set new malfunction value if old malfunction is completed
         if self._malfunction_down_counter == 0:
             self._malfunction_down_counter = val
+            self.num_malfunctions += 1
 
     def generate_malfunction(self, malfunction_generator, np_random):
         num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
@@ -44,16 +46,20 @@ class MalfunctionHandler:
 
     def __repr__(self):
         return f"malfunction_down_counter: {self._malfunction_down_counter} \
-                in_malfunction: {self.in_malfunction}"
+                in_malfunction: {self.in_malfunction} \
+                num_malfunctions: {self.num_malfunctions}"
 
     def to_dict(self):
-        return {"malfunction_down_counter": self._malfunction_down_counter}
+        return {"malfunction_down_counter": self._malfunction_down_counter,
+                "num_malfunctions": self.num_malfunctions}
     
     def from_dict(self, load_dict):
         self._malfunction_down_counter = load_dict['malfunction_down_counter']
+        self.num_malfunctions = load_dict['num_malfunctions']
 
     def __eq__(self, other):
-        return self._malfunction_down_counter == other._malfunction_down_counter
+        return self._malfunction_down_counter == other._malfunction_down_counter and \
+               self.num_malfunctions == other.num_malfunctions
 
     
 
diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py
index 78c9883f..58b028b6 100644
--- a/flatland/envs/step_utils/state_machine.py
+++ b/flatland/envs/step_utils/state_machine.py
@@ -31,12 +31,21 @@ class TrainStateMachine:
     
     def _handle_malfunction_off_map(self):
         if self.st_signals.malfunction_counter_complete:
+            
             if self.st_signals.earliest_departure_reached:
-                self.next_state = TrainState.READY_TO_DEPART
+
+                if self.st_signals.valid_movement_action_given:
+                    self.next_state = TrainState.MOVING
+                elif self.st_signals.stop_action_given:
+                    self.next_state = TrainState.STOPPED
+                else:
+                    self.next_state = TrainState.READY_TO_DEPART
+                    
             else:
-                self.next_state = TrainState.STOPPED
+                self.next_state = TrainState.WAITING
+
         else:
-            self.next_state = TrainState.WAITING
+            self.next_state = TrainState.MALFUNCTION_OFF_MAP
     
     def _handle_moving(self):
         if self.st_signals.in_malfunction:
@@ -61,7 +70,7 @@ class TrainStateMachine:
            self.st_signals.valid_movement_action_given:
             self.next_state = TrainState.MOVING
         elif self.st_signals.malfunction_counter_complete and \
-             (self.st_signals.stop_action_given or self.st_signals.movement_conflict):
+                (self.st_signals.stop_action_given or self.st_signals.movement_conflict):
              self.next_state = TrainState.STOPPED
         else:
             self.next_state = TrainState.MALFUNCTION
-- 
GitLab