From 53d7dcc13422fe6e19a9704568d0cec6cf08d246 Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Thu, 9 Sep 2021 20:07:44 +0530
Subject: [PATCH] minor refactors

---
 flatland/envs/rail_env.py                     | 51 +++++++++++--------
 .../envs/step_utils/action_preprocessing.py   |  2 +-
 flatland/envs/step_utils/state_machine.py     |  1 -
 flatland/envs/step_utils/transition_utils.py  |  4 +-
 4 files changed, 33 insertions(+), 25 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 60860d65..4181482b 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -443,6 +443,25 @@ class RailEnv(Environment):
         
         return reward
 
+    def preprocess_action(self, action, agent):
+        """
+        Preprocess the provided action
+            * Change to DO_NOTHING if illegal action
+            * Block all actions when in waiting state
+            * Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
+        """
+        action = action_preprocessing.preprocess_raw_action(action, agent.state)
+        action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
+
+        # Try moving actions on current position
+        current_position, current_direction = agent.position, agent.direction
+        if current_position is None: # Agent not added on map yet
+            current_position, current_direction = agent.initial_position, agent.initial_direction
+        
+        action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
+
+        return action
+
     def step(self, action_dict_: Dict[int, RailEnvActions]):
         """
         Updates rewards for the agents at a step.
@@ -475,59 +494,51 @@ class RailEnv(Environment):
             "speed": {},
             "status": {},
         }
-        have_all_agents_ended = True  # boolean flag to check if all agents are done
 
         self.motionCheck = ac.MotionCheck()  # reset the motion check
 
         temp_transition_data = {}
         
-        for i_agent, agent in enumerate(self.agents): # TODO: Important - Do not use i_agent like this, use agent.handle if needed
+        for agent in self.agents:
+            i_agent = agent.handle
             # Generate malfunction
             agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
 
             # Get action for the agent
             action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
-            # TODO: Add the bottom stuff to separate function(s)
-
-            # Preprocess action
-            action = action_preprocessing.preprocess_raw_action(action, agent.state)
-            action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
 
-            # Try moving actions on current position
-            current_position, current_direction = agent.position, agent.direction
-            agent_not_on_map = current_position is None
-            if agent_not_on_map: # Agent not added on map yet
-                current_position, current_direction = agent.initial_position, agent.initial_direction
-            action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
+            preprocessed_action = self.preprocess_action(action, agent)
 
             # Save moving actions in not already saved
-            agent.action_saver.save_action_if_allowed(action, agent.state)
+            agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
 
             # Calculate new position
             # Add agent to the map if not on it yet
-            if agent_not_on_map and agent.action_saver.is_action_saved:
+            if agent.position is None and agent.action_saver.is_action_saved:
                 new_position = agent.initial_position
                 new_direction = agent.initial_direction
-                preprocessed_action = action
                 
             # When cell exit occurs apply saved action independent of other agents
             elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
                 saved_action = agent.action_saver.saved_action
                 # Apply action independent of other agents and get temporary new position and direction
-                pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
-                new_position, new_direction = pd
+                new_position, new_direction  = self.apply_action_independent(saved_action, 
+                                                                             self.rail, 
+                                                                             agent.position, 
+                                                                             agent.direction)
                 preprocessed_action = saved_action
             else:
                 new_position, new_direction = agent.position, agent.direction
-                preprocessed_action = action
 
             temp_transition_data[i_agent] = AgentTransitionData(position=new_position,
                                                                 direction=new_direction,
                                                                 preprocessed_action=preprocessed_action)
+            
+            # This is for checking conflicts of agents trying to occupy same cell                                                    
             self.motionCheck.addAgent(i_agent, agent.position, new_position)
 
         # Find conflicts
-        # TODO : Important - Modify conflicted positions and select one of them randomly to go to new position
+
         self.motionCheck.find_conflicts()
         
         for agent in self.agents:
diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py
index 4da43c16..a397054c 100644
--- a/flatland/envs/step_utils/action_preprocessing.py
+++ b/flatland/envs/step_utils/action_preprocessing.py
@@ -55,6 +55,6 @@ def preprocess_moving_action(action, rail, position, direction):
     if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
         action = process_left_right(action, rail, position, direction)
 
-    if not check_valid_action(action, rail, position, direction): # TODO: Dipam - Not sure if this is needed
+    if not check_valid_action(action, rail, position, direction):
         action = RailEnvActions.STOP_MOVING
     return action
\ No newline at end of file
diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py
index 6d0b9f40..47b553a8 100644
--- a/flatland/envs/step_utils/state_machine.py
+++ b/flatland/envs/step_utils/state_machine.py
@@ -1,4 +1,3 @@
-from attr import s
 from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
 
 class TrainStateMachine:
diff --git a/flatland/envs/step_utils/transition_utils.py b/flatland/envs/step_utils/transition_utils.py
index 157db5ac..c84d6c59 100644
--- a/flatland/envs/step_utils/transition_utils.py
+++ b/flatland/envs/step_utils/transition_utils.py
@@ -71,9 +71,7 @@ def check_action_on_agent(action, rail, position, direction):
 
     # If transition validity hasn't been checked yet.
     if transition_valid is None:
-        transition_valid = rail.get_transition( # TODO: Dipam - Read this one
-            (*position, direction),
-            new_direction)
+        transition_valid = rail.get_transition( (*position, direction), new_direction)
 
     return new_cell_valid, new_direction, new_position, transition_valid
 
-- 
GitLab