diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index fffe7ff786a32a6796af9667f1dfb9a3eb92ce9c..632caeea7e416895d36ce845e19917c2cc94d76d 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -2,21 +2,19 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint
 import numpy as np
 
 from enum import IntEnum
+from flatland.envs.step_utils.states import TrainState
 from itertools import starmap
 from typing import Tuple, Optional, NamedTuple, List
 
 from attr import attr, attrs, attrib, Factory
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.timetable_utils import Line
-
-class RailAgentStatus(IntEnum):
-    WAITING = 0
-    READY_TO_DEPART = 1  # not in grid yet (position is None) -> prediction as if it were at initial position
-    ACTIVE = 2  # in grid (position is not None), not done -> prediction is remaining path
-    DONE = 3  # in grid (position is not None), but done -> prediction is stay at target forever
-    DONE_REMOVED = 4  # removed from grid (position is None) -> prediction is None
+from flatland.envs.schedule_utils import Schedule
 
+from flatland.envs.step_utils.action_saver import ActionSaver
+from flatland.envs.step_utils.speed_counter import SpeedCounter
+from flatland.envs.step_utils.state_machine import TrainStateMachine
+from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
 
 Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
                              ('initial_direction', Grid4TransitionsEnum),
@@ -28,11 +26,16 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
                              ('speed_data', dict),
                              ('malfunction_data', dict),
                              ('handle', int),
-                             ('status', RailAgentStatus),
                              ('position', Tuple[int, int]),
                              ('arrival_time', int),
                              ('old_direction', Grid4TransitionsEnum),
-                             ('old_position', Tuple[int, int])])
+                             ('old_position', Tuple[int, int]),
+                             ('speed_counter', SpeedCounter),
+                             ('action_saver', ActionSaver),
+                             ('state', TrainState),
+                             ('state_machine', TrainStateMachine),
+                             ('malfunction_handler', MalfunctionHandler),
+                             ])
 
 
 @attrs
@@ -65,7 +68,15 @@ class EnvAgent:
     handle = attrib(default=None)
     # INIT TILL HERE IN _from_line()
 
-    status = attrib(default=RailAgentStatus.WAITING, type=RailAgentStatus)
+    # Env step facelift
+    speed_counter = attrib(default = None, type=SpeedCounter)
+    action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
+    state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , 
+                           type=TrainStateMachine)
+    malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
+    
+    state = attrib(default=TrainState.WAITING, type=TrainState)
+
     position = attrib(default=None, type=Optional[Tuple[int, int]])
 
     # NEW : EnvAgent Reward Handling
@@ -75,6 +86,7 @@ class EnvAgent:
     old_direction = attrib(default=None)
     old_position = attrib(default=None)
 
+
     def reset(self):
         """
         Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
@@ -82,14 +94,6 @@ class EnvAgent:
         self.position = None
         # TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
         self.direction = self.initial_direction
-
-        if (self.earliest_departure == 0):
-            self.status = RailAgentStatus.READY_TO_DEPART
-        else:
-            self.status = RailAgentStatus.WAITING
-            
-        self.arrival_time = None
-
         self.old_position = None
         self.old_direction = None
         self.moving = False
@@ -103,48 +107,42 @@ class EnvAgent:
         self.malfunction_data['nr_malfunctions'] = 0
         self.malfunction_data['moving_before_malfunction'] = False
 
-    # NEW : Callables
-    def get_shortest_path(self, distance_map) -> List[Waypoint]:
-        from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
-        return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
-        
-    def get_travel_time_on_shortest_path(self, distance_map) -> int:
-        shortest_path = self.get_shortest_path(distance_map)
-        if shortest_path is not None:
-            distance = len(shortest_path)
-        else:
-            distance = 0
-        speed = self.speed_data['speed']
-        return int(np.ceil(distance / speed))
-
-    def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
-        return self.latest_arrival - elapsed_steps
-
-    def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
-        '''
-        +ve if arrival time is projected before latest arrival
-        -ve if arrival time is projected after latest arrival
-        '''
-        return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
-               self.get_travel_time_on_shortest_path(distance_map)
+        self.action_saver.clear_saved_action()
+        self.speed_counter.reset_counter()
+        self.state_machine.reset()
 
     def to_agent(self) -> Agent:
-        return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction, 
-                     direction=self.direction, target=self.target, moving=self.moving, earliest_departure=self.earliest_departure, 
-                     latest_arrival=self.latest_arrival, speed_data=self.speed_data, malfunction_data=self.malfunction_data, 
-                     handle=self.handle, status=self.status, position=self.position, arrival_time=self.arrival_time, 
-                     old_direction=self.old_direction, old_position=self.old_position)
+        return Agent(initial_position=self.initial_position, 
+                     initial_direction=self.initial_direction,
+                     direction=self.direction,
+                     target=self.target,
+                     moving=self.moving,
+                     earliest_departure=self.earliest_departure, 
+                     latest_arrival=self.latest_arrival, 
+                     speed_data=self.speed_data,
+                     malfunction_data=self.malfunction_data, 
+                     handle=self.handle, 
+                     state=self.state,
+                     position=self.position, 
+                     old_direction=self.old_direction, 
+                     old_position=self.old_position,
+                     speed_counter=self.speed_counter,
+                     action_saver=self.action_saver,
+                     state_machine=self.state_machine,
+                     malfunction_handler=self.malfunction_handler)
 
     @classmethod
     def from_line(cls, line: Line):
         """ Create a list of EnvAgent from lists of positions, directions and targets
         """
         speed_datas = []
-
-        for i in range(len(line.agent_positions)):
+        speed_counters = []
+        for i in range(len(schedule.agent_positions)):
+            speed = schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0
             speed_datas.append({'position_fraction': 0.0,
-                                'speed': line.agent_speeds[i] if line.agent_speeds is not None else 1.0,
+                                'speed': speed,
                                 'transition_action_on_cellexit': 0})
+            speed_counters.append( SpeedCounter(speed=speed) )
 
         malfunction_datas = []
         for i in range(len(line.agent_positions)):
@@ -153,17 +151,19 @@ class EnvAgent:
                                           i] if line.agent_malfunction_rates is not None else 0.,
                                       'next_malfunction': 0,
                                       'nr_malfunctions': 0})
-
-        return list(starmap(EnvAgent, zip(line.agent_positions,
-                                          line.agent_directions,
-                                          line.agent_directions,
-                                          line.agent_targets, 
-                                          [False] * len(line.agent_positions), 
-                                          [None] * len(line.agent_positions), # earliest_departure
-                                          [None] * len(line.agent_positions), # latest_arrival
+        
+        return list(starmap(EnvAgent, zip(schedule.agent_positions,  # TODO : Dipam - Really want to change this way of loading agents
+                                          schedule.agent_directions,
+                                          schedule.agent_directions,
+                                          schedule.agent_targets, 
+                                          [False] * len(schedule.agent_positions), 
+                                          [None] * len(schedule.agent_positions), # earliest_departure
+                                          [None] * len(schedule.agent_positions), # latest_arrival
                                           speed_datas,
                                           malfunction_datas,
-                                          range(len(line.agent_positions)))))
+                                          range(len(schedule.agent_positions)),
+                                          speed_counters,
+                                          )))
 
     @classmethod
     def load_legacy_static_agent(cls, static_agents_data: Tuple):
diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index 0d27913d6f27fb5df301960655d90baa42ef1ac0..2fecddf1a0abb954637992c371c1fc1053417a78 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -18,7 +18,7 @@ MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
 Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
 
 # Why is the return value Optional?  We always return a Malfunction.
-MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]]
+MalfunctionGenerator = Callable[[RandomState, bool], Malfunction]
 
 def _malfunction_prob(rate: float) -> float:
     """
@@ -42,21 +42,14 @@ class ParamMalfunctionGen(object):
         #self.max_number_of_steps_broken = parameters.max_duration
         self.MFP = parameters
 
-    def generate(self,
-        agent: EnvAgent = None,
-        np_random: RandomState = None,
-        reset=False) -> Optional[Malfunction]:
-      
-        # Dummy reset function as we don't implement specific seeding here
-        if reset:
-            return Malfunction(0)
+    def generate(self, np_random: RandomState) -> Malfunction:
 
-        if agent.malfunction_data['malfunction'] < 1:
-            if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
-                num_broken_steps = np_random.randint(self.MFP.min_duration,
-                                                     self.MFP.max_duration + 1) + 1
-                return Malfunction(num_broken_steps)
-        return Malfunction(0)
+        if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
+            num_broken_steps = np_random.randint(self.MFP.min_duration,
+                                                    self.MFP.max_duration + 1) + 1
+        else:
+            num_broken_steps = 0
+        return Malfunction(num_broken_steps)
 
     def get_process_data(self):
         return MalfunctionProcessData(*self.MFP)
@@ -103,7 +96,7 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess
     min_number_of_steps_broken = 0
     max_number_of_steps_broken = 0
 
-    def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
+    def generator(np_random: RandomState = None) -> Malfunction:
         return Malfunction(0)
 
     return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
@@ -162,7 +155,7 @@ def single_malfunction_generator(earlierst_malfunction: int, malfunction_duratio
             malfunction_calls[agent.handle] = 1
 
         # Break an agent that is active at the time of the malfunction
-        if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction:
+        if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
             global_nr_malfunctions += 1
             return Malfunction(malfunction_duration)
         else:
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 3d4da7b93da1705678e310b6c81c9076542db089..1dc332d9480298020ff2d63c9677f5dd0631bf6b 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -7,13 +7,15 @@ from enum import IntEnum
 from typing import List, NamedTuple, Optional, Dict, Tuple
 
 import numpy as np
+from numpy.lib.shape_base import vsplit
+from numpy.testing._private.utils import import_nose
 
 
 from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.core.grid.grid_utils import IntVector2D
+from flatland.core.grid.grid_utils import IntVector2D, position_to_coordinate
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
@@ -37,37 +39,23 @@ from gym.utils import seeding
 # from flatland.envs.line_generators import random_line_generator, LineGenerator
 
 
+# NEW : Imports 
+from flatland.envs.schedule_time_generators import schedule_time_generator
+from flatland.envs.step_utils.states import TrainState
+from flatland.envs.step_utils.transition_utils import check_action
+
+# Env Step Facelift imports
+from flatland.envs.step_utils.action_preprocessing import preprocess_raw_action, preprocess_moving_action, preprocess_action_when_waiting
 
 # Adrian Egli performance fix (the fast methods brings more than 50%)
 def fast_isclose(a, b, rtol):
     return (a < (b + rtol)) or (a < (b - rtol))
 
-
-def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool:
-    return (
-        max(min_value[0], min(position[0], max_value[0])),
-        max(min_value[1], min(position[1], max_value[1]))
-    )
-
-
-def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
-    if possible_transitions[0] == 1:
-        return 0
-    if possible_transitions[1] == 1:
-        return 1
-    if possible_transitions[2] == 1:
-        return 2
-    return 3
-
-
 def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
-    return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
-
-
-def fast_count_nonzero(possible_transitions: (int, int, int, int)):
-    return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
-
-
+    if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None
+        return False
+    else:
+        return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
 
 class RailEnv(Environment):
     """
@@ -255,6 +243,8 @@ class RailEnv(Environment):
         self.close_following = close_following  # use close following logic
         self.motionCheck = ac.MotionCheck()
 
+        self.agent_helpers = {}
+
     def _seed(self, seed=None):
         self.np_random, seed = seeding.np_random(seed)
         random.seed(seed)
@@ -379,15 +369,18 @@ class RailEnv(Environment):
         # Reset agents to initial states
         self.reset_agents()
 
-        for agent in self.agents:
-            # Induce malfunctions
-            self._break_agent(agent)
+        # for agent in self.agents:
+        #     # Induce malfunctions
+        #     if activate_agents:
+        #         self.set_agent_active(agent)
 
-            if agent.malfunction_data["malfunction"] > 0:
-                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
+        #     self._break_agent(agent)
 
-            # Fix agents that finished their malfunction
-            self._fix_agent_after_malfunction(agent)
+        #     if agent.malfunction_data["malfunction"] > 0:
+        #         agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
+
+        #     # Fix agents that finished their malfunction
+        #     self._fix_agent_after_malfunction(agent)
 
         self.num_resets += 1
         self._elapsed_steps = 0
@@ -398,12 +391,6 @@ class RailEnv(Environment):
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
 
-        # Reset the malfunction generator
-        if "generate" in dir(self.malfunction_generator):
-            self.malfunction_generator.generate(reset=True)
-        else:
-            self.malfunction_generator(reset=True)
-
         # Empty the episode store of agent positions
         self.cur_episode = []
 
@@ -418,52 +405,25 @@ class RailEnv(Environment):
         # Return the new observation vectors for each agent
         observation_dict: Dict = self._get_observations()
         return observation_dict, info_dict
-
-    def _fix_agent_after_malfunction(self, agent: EnvAgent):
-        """
-        Updates agent malfunction variables and fixes broken agents
-
-        Parameters
-        ----------
-        agent
-        """
-
-        # Ignore agents that are OK
-        if self._is_agent_ok(agent):
-            return
-
-        # Reduce number of malfunction steps left
-        if agent.malfunction_data['malfunction'] > 1:
-            agent.malfunction_data['malfunction'] -= 1
-            return
-
-        # Restart agents at the end of their malfunction
-        agent.malfunction_data['malfunction'] -= 1
-        if 'moving_before_malfunction' in agent.malfunction_data:
-            agent.moving = agent.malfunction_data['moving_before_malfunction']
-            return
-
-    def _break_agent(self, agent: EnvAgent):
-        """
-        Malfunction generator that breaks agents at a given rate.
-
-        Parameters
-        ----------
-        agent
-
-        """
-
-        if "generate" in dir(self.malfunction_generator):
-            malfunction: mal_gen.Malfunction = self.malfunction_generator.generate(agent, self.np_random)
+    
+    def apply_action_independent(self, action, rail, position, direction):
+        if RailEnvActions.is_moving_action(action):
+            new_direction, _ = check_action(action, position, direction, rail)
+            new_position = get_new_position(position, new_direction)
         else:
-            malfunction: mal_gen.Malfunction = self.malfunction_generator(agent, self.np_random)
-
-        if malfunction.num_broken_steps > 0:
-            agent.malfunction_data['malfunction'] = malfunction.num_broken_steps
-            agent.malfunction_data['moving_before_malfunction'] = agent.moving
-            agent.malfunction_data['nr_malfunctions'] += 1
-
-        return
+            new_position, new_direction = position, direction
+        return new_position, direction
+    
+    def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
+        st_signals = {}
+        
+        st_signals['malfunction_onset'] = agent.malfunction_handler.in_malfunction
+        st_signals['malfunction_counter_complete'] = agent.malfunction_handler.malfunction_counter_complete
+        st_signals['earliest_departure_reached'] = self._elapsed_steps >= agent.earliest_departure
+        st_signals['stop_action_given'] = (preprocessed_action == RailEnvActions.STOP_MOVING)
+        st_signals['valid_movement_action_given'] = RailEnvActions.is_moving_action(preprocessed_action)
+        st_signals['target_reached'] = fast_position_equal(agent.position, agent.target)
+        st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
 
     def _handle_end_reward(self, agent: EnvAgent) -> int:
         '''
@@ -497,16 +457,26 @@ class RailEnv(Environment):
         """
         Updates rewards for the agents at a step.
 
-        Parameters
-        ----------
-        action_dict_ : Dict[int,RailEnvActions]
-
-        """
+    def step(self, action_dict):
         self._elapsed_steps += 1
 
         # If we're done, set reward and info_dict and step() is done.
-        if self.dones["__all__"]:
-            raise Exception("Episode is done, cannot call step()")
+        if self.dones["__all__"]: # TODO: Move boilerplate to different function
+            self.rewards_dict = {}
+            info_dict = {
+                "action_required": {},
+                "malfunction": {},
+                "speed": {},
+                "status": {},
+            }
+            for i_agent, agent in enumerate(self.agents):
+                self.rewards_dict[i_agent] = self.global_reward
+                info_dict["action_required"][i_agent] = False
+                info_dict["malfunction"][i_agent] = 0
+                info_dict["speed"][i_agent] = 0
+                info_dict["status"][i_agent] = agent.status
+
+            return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
         # Reset the step rewards
         self.rewards_dict = dict()
@@ -520,407 +490,96 @@ class RailEnv(Environment):
 
         self.motionCheck = ac.MotionCheck()  # reset the motion check
 
-        if not self.close_following:
-            for i_agent, agent in enumerate(self.agents):
-                # Reset the step rewards
-                self.rewards_dict[i_agent] = 0
-
-                # Induce malfunction before we do a step, thus a broken agent can't move in this step
-                self._break_agent(agent)
-
-                # Perform step on the agent
-                self._step_agent(i_agent, action_dict_.get(i_agent))
-
-                # manage the boolean flag to check if all agents are indeed done (or done_removed)
-                have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
-
-                # Build info dict
-                info_dict["action_required"][i_agent] = self.action_required(agent)
-                info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
-                info_dict["speed"][i_agent] = agent.speed_data['speed']
-                info_dict["status"][i_agent] = agent.status
-
-                # Fix agents that finished their malfunction such that they can perform an action in the next step
-                self._fix_agent_after_malfunction(agent)
-
-
-        else:
-            for i_agent, agent in enumerate(self.agents):
-                # Reset the step rewards
-                self.rewards_dict[i_agent] = 0
-
-                # Induce malfunction before we do a step, thus a broken agent can't move in this step
-                self._break_agent(agent)
-
-                # Perform step on the agent
-                self._step_agent_cf(i_agent, action_dict_.get(i_agent))
-
-            # second loop: check for collisions / conflicts
-            self.motionCheck.find_conflicts()
-
-            # third loop: update positions
-            for i_agent, agent in enumerate(self.agents):
-                self._step_agent2_cf(i_agent)
-
-                # manage the boolean flag to check if all agents are indeed done (or done_removed)
-                have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
-
-                # Build info dict
-                info_dict["action_required"][i_agent] = self.action_required(agent)
-                info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
-                info_dict["speed"][i_agent] = agent.speed_data['speed']
-                info_dict["status"][i_agent] = agent.status
-
-                # Fix agents that finished their malfunction such that they can perform an action in the next step
-                self._fix_agent_after_malfunction(agent)
-
+        temp_saved_data = {} # TODO : Change name
         
-        # NEW : REW: (END)
-        if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \
-            or have_all_agents_ended :
-            
-            for i_agent, agent in enumerate(self.agents):
-                
-                reward = self._handle_end_reward(agent)
-                self.rewards_dict[i_agent] += reward
+        for i_agent, agent in enumerate(self.agents): # TODO: Important - Do not use i_agent like this, use agent.handle if needed
+            # Generate malfunction
+            agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
+
+            # Get action for the agent
+            action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING)
+            # TODO: Add the bottom stuff to separate function(s)
+
+            # Preprocess action
+            action = preprocess_raw_action(action, agent.state)
+            action = preprocess_action_when_waiting(action, agent.state)
+
+            # Try moving actions on current position
+            current_position, current_direction = agent.position, agent.direction
+            agent_not_on_map = current_position is None
+            if agent_not_on_map: # Agent not added on map yet
+                current_position, current_direction = agent.initial_position, agent.initial_direction
+            action = preprocess_moving_action(action, agent.state, self.rail, current_position, current_direction)
+
+            # Save moving actions in not already saved
+            agent.action_saver.save_action_if_allowed(action, agent.state)
+
+            # Calculate new position
+            # Add agent to the map if not on it yet
+            if agent_not_on_map and agent.action_saver.is_action_saved:
+                temp_new_position = agent.initial_position
+                temp_new_direction = agent.initial_direction
                 
-                self.dones[i_agent] = True
+            # When cell exit occurs apply saved action independent of other agents
+            elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
+                saved_action = agent.action_saver.saved_action
+                # Apply action independent of other agents and get temporary new position and direction
+                temp_pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
+                temp_new_position, temp_new_direction = temp_pd
+            else:
+                temp_new_position, temp_new_direction = agent.position, agent.direction
+
+            # TODO: Saving temporary positon shouldn't be needed if recheck of position is not needed later (see TAG#1)
+            temp_saved_data[i_agent] = temp_new_position, temp_new_direction, action
+            self.motionCheck.addAgent(i_agent, agent.position, temp_new_position)
 
-            self.dones["__all__"] = True
+        # Find conflicts
+        # TODO : Important - Modify conflicted positions and select one of them randomly to go to new position
+        self.motionCheck.find_conflicts()
         
+        for agent in self.agents:
+            i_agent = agent.handle
 
-        if self.record_steps:
-            self.record_timestep(action_dict_)
+            ## Update positions
+            movement_allowed, _ = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck
+            # TODO : Important : Original code rechecks the next position here again - not sure why? TAG#1
+            preprocessed_action = temp_saved_data[i_agent][2] # TODO : Important : Make this namedtuple or class
 
-        return self._get_observations(), self.rewards_dict, self.dones, info_dict
+            # TODO : Looks like a hacky conditionm, improve the handling
+            if agent.malfunction_handler.in_malfunction:
+                movement_allowed = False
 
-    def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):
-        """
-        Performs a step and step, start and stop penalty on a single agent in the following sub steps:
-        - malfunction
-        - action handling if at the beginning of cell
-        - movement
-
-        Parameters
-        ----------
-        i_agent : int
-        action_dict_ : Dict[int,RailEnvActions]
-
-        """
-        agent = self.agents[i_agent]
-        if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]:  # this agent has already completed...
-            return
-
-        # agent gets active by a MOVE_* action and if c
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
-            initial_cell_free = self.cell_free(agent.initial_position)
-            is_action_starting = action in [
-                RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]
-
-            if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
-                          RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
-                agent.status = RailAgentStatus.ACTIVE
-                self._set_agent_to_initial_position(agent, agent.initial_position)
-                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-                return
+            if movement_allowed:
+                final_new_position, final_new_direction = temp_saved_data[i_agent][:2] # TODO : Important : Make this namedtuple or class
             else:
-                # TODO: Here we need to check for the departure time in future releases with full schedules
-                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-                return
-
-        agent.old_direction = agent.direction
-        agent.old_position = agent.position
-
-        # if agent is broken, actions are ignored and agent does not move.
-        # full step penalty in this case
-        if agent.malfunction_data['malfunction'] > 0:
-            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-            return
-
-        # Is the agent at the beginning of the cell? Then, it can take an action.
-        # As long as the agent is malfunctioning or stopped at the beginning of the cell,
-        # different actions may be taken!
-        if fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
-            # No action has been supplied for this agent -> set DO_NOTHING as default
-            if action is None:
-                action = RailEnvActions.DO_NOTHING
-
-            if action < 0 or action > len(RailEnvActions):
-                print('ERROR: illegal action=', action,
-                      'for agent with index=', i_agent,
-                      '"DO NOTHING" will be executed instead')
-                action = RailEnvActions.DO_NOTHING
-
-            if action == RailEnvActions.DO_NOTHING and agent.moving:
-                # Keep moving
-                action = RailEnvActions.MOVE_FORWARD
-
-            if action == RailEnvActions.STOP_MOVING and agent.moving:
-                # Only allow halting an agent on entering new cells.
-                agent.moving = False
-                self.rewards_dict[i_agent] += self.stop_penalty
-
-            if not agent.moving and not (
-                action == RailEnvActions.DO_NOTHING or
-                action == RailEnvActions.STOP_MOVING):
-                # Allow agent to start with any forward or direction action
-                agent.moving = True
-                self.rewards_dict[i_agent] += self.start_penalty
-
-            # Store the action if action is moving
-            # If not moving, the action will be stored when the agent starts moving again.
-            if agent.moving:
-                _action_stored = False
-                _, new_cell_valid, new_direction, new_position, transition_valid = \
-                    self._check_action_on_agent(action, agent)
-
-                if all([new_cell_valid, transition_valid]):
-                    agent.speed_data['transition_action_on_cellexit'] = action
-                    _action_stored = True
-                else:
-                    # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
-                    # try to keep moving forward!
-                    if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
-                        _, new_cell_valid, new_direction, new_position, transition_valid = \
-                            self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
-
-                        if all([new_cell_valid, transition_valid]):
-                            agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
-                            _action_stored = True
-
-                if not _action_stored:
-                    # If the agent cannot move due to an invalid transition, we set its state to not moving
-                    self.rewards_dict[i_agent] += self.invalid_action_penalty
-                    self.rewards_dict[i_agent] += self.stop_penalty
-                    agent.moving = False
-
-        # Now perform a movement.
-        # If agent.moving, increment the position_fraction by the speed of the agent
-        # If the new position fraction is >= 1, reset to 0, and perform the stored
-        #   transition_action_on_cellexit if the cell is free.
-        if agent.moving:
-            agent.speed_data['position_fraction'] += agent.speed_data['speed']
-            if agent.speed_data['position_fraction'] > 1.0 or fast_isclose(agent.speed_data['position_fraction'], 1.0,
-                                                                           rtol=1e-03):
-                # Perform stored action to transition to the next cell as soon as cell is free
-                # Notice that we've already checked new_cell_valid and transition valid when we stored the action,
-                # so we only have to check cell_free now!
-
-                # Traditional check that next cell is free
-                # cell and transition validity was checked when we stored transition_action_on_cellexit!
-                cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
-                    agent.speed_data['transition_action_on_cellexit'], agent)
-
-                # N.B. validity of new_cell and transition should have been verified before the action was stored!
-                assert new_cell_valid
-                assert transition_valid
-                if cell_free:
-                    self._move_agent_to_new_position(agent, new_position)
-                    agent.direction = new_direction
-                    agent.speed_data['position_fraction'] = 0.0
-
-            # has the agent reached its target?
-            if np.equal(agent.position, agent.target).all():
-                agent.status = RailAgentStatus.DONE
-                self.dones[i_agent] = True
-                self.active_agents.remove(i_agent)
-                agent.moving = False
-                self._remove_agent_from_scene(agent)
-            else:
-                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-        else:
-            # step penalty if not moving (stopped now or before)
-            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-
-    def _step_agent_cf(self, i_agent, action: Optional[RailEnvActions] = None):
-        """ "close following" version of step_agent.
-        """
-        agent = self.agents[i_agent]
-        if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]:  # this agent has already completed...
-            return
-
-        # NEW : STEP: WAITING > WAITING or WAITING > READY_TO_DEPART
-        if (agent.status == RailAgentStatus.WAITING):
-            if ( self._elapsed_steps >= agent.earliest_departure ):
-                agent.status = RailAgentStatus.READY_TO_DEPART
-            self.motionCheck.addAgent(i_agent, None, None)
-            return
-
-        # agent gets active by a MOVE_* action and if c
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
-            is_action_starting = action in [
-                RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]
-
-            if is_action_starting:  # agent is trying to start
-                self.motionCheck.addAgent(i_agent, None, agent.initial_position)
-            else:  # agent wants to remain unstarted
-                self.motionCheck.addAgent(i_agent, None, None)
-            return
-
-        agent.old_direction = agent.direction
-        agent.old_position = agent.position
-
-        # if agent is broken, actions are ignored and agent does not move.
-        # full step penalty in this case
-        # TODO: this means that deadlocked agents which suffer a malfunction are marked as 
-        # stopped rather than deadlocked.
-        if agent.malfunction_data['malfunction'] > 0:
-            self.motionCheck.addAgent(i_agent, agent.position, agent.position)
-            # agent will get penalty in step_agent2_cf
-            # self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-            return
-
-        # Is the agent at the beginning of the cell? Then, it can take an action.
-        # As long as the agent is malfunctioning or stopped at the beginning of the cell,
-        # different actions may be taken!
-        if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
-            # No action has been supplied for this agent -> set DO_NOTHING as default
-            if action is None:
-                action = RailEnvActions.DO_NOTHING
-
-            if action < 0 or action > len(RailEnvActions):
-                print('ERROR: illegal action=', action,
-                      'for agent with index=', i_agent,
-                      '"DO NOTHING" will be executed instead')
-                action = RailEnvActions.DO_NOTHING
-
-            if action == RailEnvActions.DO_NOTHING and agent.moving:
-                # Keep moving
-                action = RailEnvActions.MOVE_FORWARD
-
-            if action == RailEnvActions.STOP_MOVING and agent.moving:
-                # Only allow halting an agent on entering new cells.
-                agent.moving = False
-                self.rewards_dict[i_agent] += self.stop_penalty
-
-            if not agent.moving and not (
-                action == RailEnvActions.DO_NOTHING or
-                action == RailEnvActions.STOP_MOVING):
-                # Allow agent to start with any forward or direction action
-                agent.moving = True
-                self.rewards_dict[i_agent] += self.start_penalty
-
-            # Store the action if action is moving
-            # If not moving, the action will be stored when the agent starts moving again.
-            new_position = None
-            if agent.moving:
-                _action_stored = False
-                _, new_cell_valid, new_direction, new_position, transition_valid = \
-                    self._check_action_on_agent(action, agent)
-
-                if all([new_cell_valid, transition_valid]):
-                    agent.speed_data['transition_action_on_cellexit'] = action
-                    _action_stored = True
-                else:
-                    # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
-                    # try to keep moving forward!
-                    if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
-                        _, new_cell_valid, new_direction, new_position, transition_valid = \
-                            self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
-
-                        if all([new_cell_valid, transition_valid]):
-                            agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
-                            _action_stored = True
-
-                if not _action_stored:
-                    # If the agent cannot move due to an invalid transition, we set its state to not moving
-                    self.rewards_dict[i_agent] += self.invalid_action_penalty
-                    self.rewards_dict[i_agent] += self.stop_penalty
-                    agent.moving = False
-                    self.motionCheck.addAgent(i_agent, agent.position, agent.position)
-                    return
-
-            if new_position is None:
-                self.motionCheck.addAgent(i_agent, agent.position, agent.position)
-                if agent.moving:
-                    print("Agent", i_agent, "new_pos none, but moving")
-
-        # Check the pos_frac position fraction
-        if agent.moving:
-            agent.speed_data['position_fraction'] += agent.speed_data['speed']
-            if agent.speed_data['position_fraction'] > 0.999:
-                stored_action = agent.speed_data["transition_action_on_cellexit"]
-
-                # find the next cell using the stored action
-                _, new_cell_valid, new_direction, new_position, transition_valid = \
-                    self._check_action_on_agent(stored_action, agent)
-
-                # if it's valid, record it as the new position
-                if all([new_cell_valid, transition_valid]):
-                    self.motionCheck.addAgent(i_agent, agent.position, new_position)
-                else:  # if the action wasn't valid then record the agent as stationary
-                    self.motionCheck.addAgent(i_agent, agent.position, agent.position)
-            else:  # This agent hasn't yet crossed the cell
-                self.motionCheck.addAgent(i_agent, agent.position, agent.position)
-
-    def _step_agent2_cf(self, i_agent):
-        agent = self.agents[i_agent]
-
-        # NEW : REW: (WAITING) no reward during WAITING...
-        if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED, RailAgentStatus.WAITING]:
-            return
-
-        (move, rc_next) = self.motionCheck.check_motion(i_agent, agent.position)
-
-        if agent.position is not None:
-            sbTrans = format(self.rail.grid[agent.position], "016b")
-            trans_block = sbTrans[agent.direction * 4: agent.direction * 4 + 4]
-            if (trans_block == "0000"):
-                print (i_agent, agent.position, agent.direction, sbTrans, trans_block)
-
-        # if agent cannot enter env, then we should have move=False
-
-        if move:
-            if agent.position is None:  # agent is entering the env
-                # print(i_agent, "writing new pos ", rc_next, " into agent position (None)")
-                agent.position = rc_next
-                agent.status = RailAgentStatus.ACTIVE
-                agent.speed_data['position_fraction'] = 0.0
-
-            else:  # normal agent move
-                cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
-                    agent.speed_data['transition_action_on_cellexit'], agent)
-
-                if not all([transition_valid, new_cell_valid]):
-                    print(f"ERRROR: step_agent2 invalid transition ag {i_agent} dir {new_direction} pos {agent.position} next {rc_next}")
-
-                if new_position != rc_next:
-                    print(f"ERROR: agent {i_agent} new_pos {new_position} != rc_next {rc_next}  " + 
-                          f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" +
-                          f"stored action: {agent.speed_data['transition_action_on_cellexit']}")
-
-                sbTrans = format(self.rail.grid[agent.position], "016b")
-                trans_block = sbTrans[agent.direction * 4: agent.direction * 4 + 4]
-                if (trans_block == "0000"):
-                    print ("ERROR: ", i_agent, agent.position, agent.direction, sbTrans, trans_block)
-
-                agent.position = rc_next
-                agent.direction = new_direction
-                agent.speed_data['position_fraction'] = 0.0
-
-            # NEW : STEP: Check DONE  before / after LA & Check if RUNNING before / after LA
-            # has the agent reached its target?
-            if np.equal(agent.position, agent.target).all():
-                # arrived before or after Latest Arrival
-                agent.status = RailAgentStatus.DONE
-                self.dones[i_agent] = True
-                self.active_agents.remove(i_agent)
-                agent.moving = False
-                agent.arrival_time = self._elapsed_steps
-                self._remove_agent_from_scene(agent)
-
-            else: # not reached its target and moving
-                # running before Latest Arrival
-                if (self._elapsed_steps <= agent.latest_arrival):
-                    self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-                else: # running after Latest Arrival
-                    self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] # + # NEGATIVE REWARD? per step?
-        else:
-            # stopped (!move) before Latest Arrival
-            if (self._elapsed_steps <= agent.latest_arrival):
-                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-            else:  # stopped (!move) after Latest Arrival
-                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']  # + # NEGATIVE REWARD? per step?
+                final_new_position = agent.position
+                final_new_direction = agent.direction
+            agent.position = final_new_position
+            agent.direction = final_new_direction
+
+            ## Update states
+            state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
+            agent.state_machine.set_transition_signals(state_transition_signals)
+            agent.state_machine.step()
+            agent.state = agent.state_machine.state # TODO : Make this a property instead?
+
+            # TODO : Important : Looks like a hacky condiition, improve the handling
+            if agent.state == TrainState.DONE:
+                agent.position = None
+
+            ## Update rewards
+            # self.update_rewards(i_agent, agent, rail)
+
+            ## Update counters (malfunction and speed)
+            agent.speed_counter.update_counter(agent.state)
+            agent.malfunction_handler.update_counter()
+
+            # Clear old action when starting in new cell
+            if agent.speed_counter.is_cell_entry:
+                agent.action_saver.clear_saved_action()
+        
+        self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Remove this
+        return self._get_observations(), self.rewards_dict, self.dones, info_dict # TODO : Will need changes?
 
     def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
         """
@@ -965,52 +624,6 @@ class RailEnv(Environment):
             agent.old_position = None
             agent.status = RailAgentStatus.DONE_REMOVED
 
-    def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
-        """
-
-        Parameters
-        ----------
-        action : RailEnvActions
-        agent : EnvAgent
-
-        Returns
-        -------
-        bool
-            Is it a legal move?
-            1) transition allows the new_direction in the cell,
-            2) the new cell is not empty (case 0),
-            3) the cell is free, i.e., no agent is currently in that cell
-
-
-        """
-        # compute number of possible transitions in the current
-        # cell used to check for invalid actions
-        new_direction, transition_valid = self.check_action(agent, action)
-        new_position = get_new_position(agent.position, new_direction)
-
-        new_cell_valid = (
-            fast_position_equal(  # Check the new position is still in the grid
-                new_position,
-                fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
-            and  # check the new position has some transitions (ie is not an empty cell)
-            self.rail.get_full_transitions(*new_position) > 0)
-
-        # If transition validity hasn't been checked yet.
-        if transition_valid is None:
-            transition_valid = self.rail.get_transition(
-                (*agent.position, agent.direction),
-                new_direction)
-
-        # only call cell_free() if new cell is inside the scene
-        if new_cell_valid:
-            # Check the new position is not the same as any of the existing agent positions
-            # (including itself, for simplicity, since it is moving)
-            cell_free = self.cell_free(new_position)
-        else:
-            # if new cell is outside of scene -> cell_free is False
-            cell_free = False
-        return cell_free, new_cell_valid, new_direction, new_position, transition_valid
-
     def record_timestep(self, dActions):
         ''' Record the positions and orientations of all agents in memory, in the cur_episode
         '''
@@ -1034,62 +647,6 @@ class RailEnv(Environment):
         self.cur_episode.append(list_agents_state)
         self.list_actions.append(dActions)
 
-    def cell_free(self, position: IntVector2D) -> bool:
-        """
-        Utility to check if a cell is free
-
-        Parameters:
-        --------
-        position : Tuple[int, int]
-
-        Returns
-        -------
-        bool
-            is the cell free or not?
-
-        """
-        return self.agent_positions[position] == -1
-
-    def check_action(self, agent: EnvAgent, action: RailEnvActions):
-        """
-
-        Parameters
-        ----------
-        agent : EnvAgent
-        action : RailEnvActions
-
-        Returns
-        -------
-        Tuple[Grid4TransitionsEnum,Tuple[int,int]]
-
-
-
-        """
-        transition_valid = None
-        possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
-        num_transitions = fast_count_nonzero(possible_transitions)
-
-        new_direction = agent.direction
-        if action == RailEnvActions.MOVE_LEFT:
-            new_direction = agent.direction - 1
-            if num_transitions <= 1:
-                transition_valid = False
-
-        elif action == RailEnvActions.MOVE_RIGHT:
-            new_direction = agent.direction + 1
-            if num_transitions <= 1:
-                transition_valid = False
-
-        new_direction %= 4
-
-        if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1:
-            # - dead-end, straight line or curved line;
-            # new_direction will be the only valid transition
-            # - take only available transition
-            new_direction = fast_argmax(possible_transitions)
-            transition_valid = True
-        return new_direction, transition_valid
-
     def _get_observations(self):
         """
         Utility which returns the observations for an agent with respect to environment
@@ -1140,7 +697,7 @@ class RailEnv(Environment):
         True if agent is ok, False otherwise
 
         """
-        return agent.malfunction_data['malfunction'] < 1
+        return agent.malfunction_handler.in_malfunction
 
     def save(self, filename):
         print("deprecated call to env.save() - pls call RailEnvPersister.save()")
diff --git a/flatland/envs/rail_env_action.py b/flatland/envs/rail_env_action.py
index 6fcc175e7f7f63653153f8841ec3ba398876d4a1..a25cc8f0f37233f76b921ffc62c83818e8e7bb9b 100644
--- a/flatland/envs/rail_env_action.py
+++ b/flatland/envs/rail_env_action.py
@@ -19,6 +19,10 @@ class RailEnvActions(IntEnum):
             4: 'S',
         }[a]
 
+    @staticmethod
+    def is_moving_action(action):
+        return action in [1,2,3]
+
 
 RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
 RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8ad1d797d47dc9495089c576fc16fc507548adf
--- /dev/null
+++ b/flatland/envs/step_utils/action_preprocessing.py
@@ -0,0 +1,61 @@
+from flatland.core.grid.grid_utils import position_to_coordinate
+from flatland.envs.agent_utils import TrainState
+from flatland.envs.rail_env_action import RailEnvActions
+from flatland.envs.step_utils.transition_utils import check_valid_action
+
+
+def process_illegal_action(action: RailEnvActions):
+	# TODO - Dipam : This check is kind of weird, change this
+	if action is None or action not in RailEnvActions._value2member_map_: 
+		return RailEnvActions.DO_NOTHING
+	else:
+		return action
+
+
+def process_do_nothing(state: TrainState):
+    if state == TrainState.MOVING:
+        action = RailEnvActions.MOVE_FORWARD
+    else:
+        action = RailEnvActions.STOP_MOVING
+    return action
+
+
+def process_left_right(action, state, rail, position, direction):
+    if not check_valid_action(action, state, rail, position, direction):
+        action = RailEnvActions.MOVE_FORWARD
+    return action
+
+
+def preprocess_action_when_waiting(action, state):
+    """
+    Set action to DO_NOTHING if in waiting state
+    """
+    if state == TrainState.WAITING:
+        action = RailEnvActions.DO_NOTHING
+    return action
+
+
+def preprocess_raw_action(action, state):
+    """
+    Preprocesses actions to handle different situations of usage of action based on context
+        - DO_NOTHING is converted to FORWARD if train is moving
+        - DO_NOTHING is converted to STOP_MOVING if train is moving
+    """
+    action = process_illegal_action(action)
+
+    if action == RailEnvActions.DO_NOTHING:
+        action = process_do_nothing(state)
+
+    return action
+
+def preprocess_moving_action(action, state, rail, position, direction):
+    """
+    LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving
+    FORWARD is converted to STOP_MOVING if leading to dead end?
+    """
+    if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
+        action = process_left_right(action, rail, position, direction)
+
+    if not check_valid_action(action, rail, position, direction): # TODO: Dipam - Not sure if this is needed
+        action = RailEnvActions.STOP_MOVING
+    return action
\ No newline at end of file
diff --git a/flatland/envs/step_utils/action_saver.py b/flatland/envs/step_utils/action_saver.py
new file mode 100644
index 0000000000000000000000000000000000000000..56f7465af77de4a88ce6d010593bca92c8280759
--- /dev/null
+++ b/flatland/envs/step_utils/action_saver.py
@@ -0,0 +1,25 @@
+from flatland.envs.rail_env_action import RailEnvActions
+from flatland.envs.step_utils.states import TrainState
+
+class ActionSaver:
+    def __init__(self):
+        self.saved_action = None
+
+    @property
+    def is_action_saved(self):
+        return self.saved_action is not None
+    
+    def __repr__(self):
+        return f"is_action_saved: {self.is_action_saved}, saved_action: {self.saved_action}"
+
+
+    def save_action_if_allowed(self, action, state):
+        if not self.is_action_saved and \
+            RailEnvActions.is_moving_action(action) and \
+            not TrainState.is_malfunction_state(state):
+            self.saved_action = action
+
+    def clear_saved_action(self):
+        self.saved_action = None
+
+
diff --git a/flatland/envs/step_utils/malfunction_handler.py b/flatland/envs/step_utils/malfunction_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d2d4169e0b0f46b172b358f84a26e5832749969
--- /dev/null
+++ b/flatland/envs/step_utils/malfunction_handler.py
@@ -0,0 +1,47 @@
+
+def get_number_of_steps_to_break(malfunction_generator, np_random):
+    if hasattr(malfunction_generator, "generate"):
+        malfunction = malfunction_generator.generate(np_random)
+    else:
+        malfunction = malfunction_generator(np_random)
+
+    return malfunction.num_broken_steps
+
+class MalfunctionHandler:
+    def __init__(self):
+        self._malfunction_down_counter = 0
+    
+    @property
+    def in_malfunction(self):
+        return self._malfunction_down_counter > 0
+    
+    @property
+    def malfunction_counter_complete(self):
+        return self._malfunction_down_counter == 0
+
+    @property
+    def malfunction_down_counter(self):
+        return self._malfunction_down_counter
+
+    @malfunction_down_counter.setter
+    def malfunction_down_counter(self, val):
+        self._set_malfunction_down_counter(val)
+
+    def _set_malfunction_down_counter(self, val):
+        if val < 0:
+            raise ValueError("Cannot set a negative value to malfunction down counter")
+        self._malfunction_down_counter = val
+
+    def generate_malfunction(self, malfunction_generator, np_random):
+        num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
+        self._set_malfunction_down_counter(num_broken_steps)
+
+    def update_counter(self):
+        if self._malfunction_down_counter > 0:
+            self._malfunction_down_counter -= 1
+
+
+    
+
+    
+
diff --git a/flatland/envs/step_utils/speed_counter.py b/flatland/envs/step_utils/speed_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bde9c20f98b1b7ed26ad4a8ba3d5791786bd84f
--- /dev/null
+++ b/flatland/envs/step_utils/speed_counter.py
@@ -0,0 +1,31 @@
+import numpy as np
+from flatland.envs.step_utils.states import TrainState
+
+class SpeedCounter:
+    def __init__(self, speed):
+        self.speed = speed
+        self.max_count = int(1/speed)
+
+    def update_counter(self, state):
+        if state == TrainState.MOVING:
+            self.counter += 1
+            self.counter = self.counter % self.max_count
+
+    def __repr__(self):
+        return f"speed: {self.speed} \
+                 max_count: {self.max_count} \
+                 is_cell_entry: {self.is_cell_entry} \
+                 is_cell_exit: {self.is_cell_exit} \
+                 counter: {self.counter}"
+
+    def reset_counter(self):
+        self.counter = 0
+
+    @property
+    def is_cell_entry(self):
+        return self.counter == 0
+
+    @property
+    def is_cell_exit(self):
+        return self.counter == self.max_count - 1
+
diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42a829d2018c3c540ddd0f0e8c249530333abef
--- /dev/null
+++ b/flatland/envs/step_utils/state_machine.py
@@ -0,0 +1,140 @@
+from attr import s
+from flatland.envs.step_utils.states import TrainState
+
+class TrainStateMachine:
+    def __init__(self, initial_state=TrainState.WAITING):
+        self._initial_state = initial_state
+        self._state = initial_state
+        self.st_signals = {} # State Transition Signals # TODO: Make this namedtuple
+        self.next_state = None
+    
+    def _handle_waiting(self):
+        """" Waiting state goes to ready to depart when earliest departure is reached"""
+        # TODO: Important - The malfunction handling is not like proper state machine 
+        #                   Both transition signals can happen at the same time
+        #                   Atleast mention it in the diagram
+        if self.st_signals['malfunction_onset']:  
+            self.next_state = TrainState.MALFUNCTION_OFF_MAP
+        elif self.st_signals['earliest_departure_reached']:
+            self.next_state = TrainState.READY_TO_DEPART
+        else:
+            self.next_state = TrainState.WAITING
+
+    def _handle_ready_to_depart(self):
+        """ Can only go to MOVING if a valid action is provided """
+        if self.st_signals['malfunction_onset']:  
+            self.next_state = TrainState.MALFUNCTION_OFF_MAP
+        elif self.st_signals['valid_movement_action_given']:
+            self.next_state = TrainState.MOVING
+        else:
+            self.next_state = TrainState.READY_TO_DEPART
+    
+    def _handle_malfunction_off_map(self):
+        if self.st_signals['malfunction_counter_complete']:
+            if self.st_signals['earliest_departure_reached']:
+                self.next_state = TrainState.READY_TO_DEPART
+            else:
+                self.next_state = TrainState.STOPPED
+        else:
+            self.next_state = TrainState.WAITING
+    
+    def _handle_moving(self):
+        if self.st_signals['malfunction_onset']:
+            self.next_state = TrainState.MALFUNCTION
+        elif self.st_signals['target_reached']:
+            self.next_state = TrainState.DONE
+        elif self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']:
+            self.next_state = TrainState.STOPPED
+        else:
+            self.next_state = TrainState.MOVING
+    
+    def _handle_stopped(self):
+        if self.st_signals['malfunction_onset']:
+            self.next_state = TrainState.MALFUNCTION
+        elif self.st_signals['valid_movement_action_given']:
+            self.next_state = TrainState.MOVING
+        else:
+            self.next_state = TrainState.STOPPED
+    
+    def _handle_malfunction(self):
+        if self.st_signals['malfunction_counter_complete'] and \
+           self.st_signals['valid_movement_action_given']:
+            self.next_state = TrainState.MOVING
+        elif self.st_signals['malfunction_counter_complete'] and \
+             (self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']):
+             self.next_state = TrainState.STOPPED
+        else:
+            self.next_state = TrainState.MALFUNCTION
+
+    def _handle_done(self):
+        """" Done state is terminal """
+        self.next_state = TrainState.DONE
+
+    def calculate_next_state(self, current_state):
+
+        # _Handle the current state
+        if current_state == TrainState.WAITING:
+            self._handle_waiting()
+
+        elif current_state == TrainState.READY_TO_DEPART:
+            self._handle_ready_to_depart()
+        
+        elif current_state == TrainState.MALFUNCTION_OFF_MAP:
+            self._handle_malfunction_off_map()
+
+        elif current_state == TrainState.MOVING:
+            self._handle_moving()
+
+        elif current_state == TrainState.STOPPED:
+            self._handle_stopped()
+
+        elif current_state == TrainState.MALFUNCTION:
+            self._handle_malfunction()
+
+        elif current_state == TrainState.DONE:
+            self._handle_done()
+
+        else:
+            raise ValueError(f"Got unexpected state {current_state}")
+
+    def step(self):
+        """ Steps the state machine to the next state """
+
+        current_state = self._state
+
+        # Clear next state
+        self.clear_next_state()
+
+        # Handle current state to get next_state
+        self.calculate_next_state(current_state)
+
+        # Set next state
+        self.set_state(self.next_state)
+
+
+    def clear_next_state(self):
+        self.next_state = None
+
+    def set_state(self, state):
+        if not TrainState.check_valid_state(state):
+            raise ValueError(f"Cannot set invalid state {state}")
+        self._state = state
+
+    def reset(self):
+        self._state = self._initial_state
+        self.st_signals = {}
+        self.clear_next_state()
+
+    @property
+    def state(self):
+        return self._state
+    
+    @property
+    def state_transition_signals(self):
+        return self.st_signals
+    
+    def set_transition_signals(self, state_transition_signals):
+        self.st_signals = state_transition_signals # TODO: Important: Check all keys are present and if not raise error
+
+
+        
diff --git a/flatland/envs/step_utils/states.py b/flatland/envs/step_utils/states.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c991b667fdfcc90086492fe80d43ff4d45ddce1
--- /dev/null
+++ b/flatland/envs/step_utils/states.py
@@ -0,0 +1,21 @@
+from enum import IntEnum
+
+class TrainState(IntEnum):
+    WAITING = 0
+    READY_TO_DEPART = 1
+    MALFUNCTION_OFF_MAP = 2
+    MOVING = 3
+    STOPPED = 4
+    MALFUNCTION = 5
+    DONE = 6
+
+    @classmethod
+    def check_valid_state(cls, state):
+        return state in cls._value2member_map_
+
+    @staticmethod
+    def is_malfunction_state(state):
+        return state in [2, 5] # TODO: Can this be done with names instead?
+    
+
+
diff --git a/flatland/envs/step_utils/transition_utils.py b/flatland/envs/step_utils/transition_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d58d21e2b023087bd432d0df40703f33e69e797
--- /dev/null
+++ b/flatland/envs/step_utils/transition_utils.py
@@ -0,0 +1,101 @@
+from typing import Tuple
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.rail_env_action import RailEnvActions
+
+
+def check_action(action, position, direction, rail):
+    """
+
+    Parameters
+    ----------
+    agent : EnvAgent
+    action : RailEnvActions
+
+    Returns
+    -------
+    Tuple[Grid4TransitionsEnum,Tuple[int,int]]
+
+
+
+    """
+    transition_valid = None
+    possible_transitions = rail.get_transitions(*position, direction)
+    num_transitions = fast_count_nonzero(possible_transitions)
+	
+    new_direction = direction
+    if action == RailEnvActions.MOVE_LEFT:
+        new_direction = direction - 1
+        if num_transitions <= 1:
+            transition_valid = False
+
+    elif action == RailEnvActions.MOVE_RIGHT:
+        new_direction = direction + 1
+        if num_transitions <= 1:
+            transition_valid = False
+
+    new_direction %= 4  # Dipam : Why?
+
+    if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1:
+        # - dead-end, straight line or curved line;
+        # new_direction will be the only valid transition
+        # - take only available transition
+        new_direction = fast_argmax(possible_transitions)
+        transition_valid = True
+    return new_direction, transition_valid
+
+
+def check_action_on_agent(action, rail, position, direction):
+    """
+    Parameters
+    ----------
+    action : RailEnvActions
+    agent : EnvAgent
+
+    Returns
+    -------
+    bool
+        Is it a legal move?
+        1) transition allows the new_direction in the cell,
+        2) the new cell is not empty (case 0),
+        3) the cell is free, i.e., no agent is currently in that cell
+
+
+    """
+    # compute number of possible transitions in the current
+    # cell used to check for invalid actions
+    new_direction, transition_valid = check_action(action, position, direction, rail)
+    new_position = get_new_position(position, new_direction)
+
+    cell_inside_grid = check_bounds(new_position, rail.height, rail.width)
+    cell_not_empty = rail.get_full_transitions(*new_position) > 0
+    new_cell_valid = cell_inside_grid and cell_not_empty
+
+    # If transition validity hasn't been checked yet.
+    if transition_valid is None:
+        transition_valid = rail.get_transition( # TODO: Dipam - Read this one
+            (*position, direction),
+            new_direction)
+
+    return new_cell_valid, new_direction, new_position, transition_valid
+
+
+def check_valid_action(action, rail, position, direction):
+	new_cell_valid, _, _, transition_valid = check_action_on_agent(action, rail, position, direction)
+	action_is_valid = new_cell_valid and transition_valid
+	return action_is_valid
+
+def fast_argmax(possible_transitions: Tuple[int, int, int, int]) -> bool:
+    if possible_transitions[0] == 1:
+        return 0
+    if possible_transitions[1] == 1:
+        return 1
+    if possible_transitions[2] == 1:
+        return 2
+    return 3
+
+def fast_count_nonzero(possible_transitions: Tuple[int, int, int, int]):
+    return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
+
+def check_bounds(position, height, width):
+    return position[0] >= 0 and position[1] >= 0 and position[0] < height and position[1] < width
+ 
\ No newline at end of file
diff --git a/tests/test_env_step_utils.py b/tests/test_env_step_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..739d3d06e7271d2ce54ded07de54957d96c08022
--- /dev/null
+++ b/tests/test_env_step_utils.py
@@ -0,0 +1,61 @@
+import numpy as np
+import numpy as np
+import os
+
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen
+
+from flatland.envs.observations import GlobalObsForRailEnv
+# First of all we import the Flatland rail environment
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_env import RailEnvActions
+from flatland.envs.rail_generators import sparse_rail_generator
+#from flatland.envs.sparse_rail_gen import SparseRailGen
+from flatland.envs.schedule_generators import sparse_schedule_generator
+
+
+def get_small_two_agent_env():
+    """Generates a simple 2 city 2 train env returns it after reset"""
+    width = 30  # With of map
+    height = 15  # Height of map
+    nr_trains = 2  # Number of trains that have an assigned task in the env
+    cities_in_map = 2 # Number of cities where agents can start or end
+    seed = 42  # Random seed
+    grid_distribution_of_cities = False  # Type of city distribution, if False cities are randomly placed
+    max_rails_between_cities = 2  # Max number of tracks allowed between cities. This is number of entry point to a city
+    max_rail_in_cities = 6  # Max number of parallel tracks within a city, representing a realistic trainstation
+
+    rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
+                                        seed=seed,
+                                        grid_mode=grid_distribution_of_cities,
+                                        max_rails_between_cities=max_rails_between_cities,
+                                        max_rail_pairs_in_city=max_rail_in_cities//2,
+                                        )
+    speed_ration_map = {1.: 0.25,  # Fast passenger train
+                    1. / 2.: 0.25,  # Fast freight train
+                    1. / 3.: 0.25,  # Slow commuter train
+                    1. / 4.: 0.25}  # Slow freight train
+
+    schedule_generator = sparse_schedule_generator(speed_ration_map)
+
+
+    stochastic_data = MalfunctionParameters(malfunction_rate=1/10000,  # Rate of malfunction occurence
+                                        min_duration=15,  # Minimal duration of malfunction
+                                        max_duration=50  # Max duration of malfunction
+                                        )
+
+    observation_builder = GlobalObsForRailEnv()
+
+    env = RailEnv(width=width,
+                height=height,
+                rail_generator=rail_generator,
+                schedule_generator=schedule_generator,
+                number_of_agents=nr_trains,
+                obs_builder_object=observation_builder,
+                #malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                malfunction_generator=ParamMalfunctionGen(stochastic_data),
+                remove_agents_at_target=True,
+                random_seed=seed)
+
+    env.reset()
+
+    return env
\ No newline at end of file
diff --git a/tests/test_state_machine.py b/tests/test_state_machine.py
new file mode 100644
index 0000000000000000000000000000000000000000..266a8f86589b6033ea67523cab0b31b72ac9d32d
--- /dev/null
+++ b/tests/test_state_machine.py
@@ -0,0 +1,115 @@
+from test_env_step_utils import get_small_two_agent_env
+from flatland.envs.rail_env_action import RailEnvActions
+from flatland.envs.step_utils.states import TrainState
+from flatland.envs.malfunction_generators import Malfunction
+
+class NoMalfunctionGenerator:
+    def generate(self, np_random):
+        return Malfunction(0)
+
+class AlwaysThreeStepMalfunction:
+    def generate(self, np_random):
+        return Malfunction(3)
+
+def test_waiting_no_transition():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 0
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed-1):
+        env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+        assert env.agents[i_agent].state == TrainState.WAITING
+    
+    
+def test_waiting_to_ready_to_depart():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 0
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING})
+    assert env.agents[i_agent].state == TrainState.READY_TO_DEPART
+
+
+def test_ready_to_depart_to_moving():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 0
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING})
+
+    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+    assert env.agents[i_agent].state == TrainState.MOVING
+
+def test_moving_to_stopped():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 0
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING})
+
+    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+    env.step({i_agent: RailEnvActions.STOP_MOVING})
+    assert env.agents[i_agent].state == TrainState.STOPPED
+
+def test_stopped_to_moving():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 0
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING})
+
+    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+    env.step({i_agent: RailEnvActions.STOP_MOVING})
+    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+    assert env.agents[i_agent].state == TrainState.MOVING
+
+def test_moving_to_done():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 1
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING})
+
+    for _ in range(50):
+        env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+    assert env.agents[i_agent].state == TrainState.DONE
+
+def test_waiting_to_malfunction():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = AlwaysThreeStepMalfunction()
+    i_agent = 1
+    env.step({i_agent: RailEnvActions.DO_NOTHING})
+    assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
+
+
+def test_ready_to_depart_to_malfunction_off_map():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 1
+    env.step({i_agent: RailEnvActions.DO_NOTHING})
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart
+        
+    env.malfunction_generator = AlwaysThreeStepMalfunction()
+    env.step({i_agent: RailEnvActions.DO_NOTHING})
+    assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
+
+
+def test_malfunction_off_map_to_waiting():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 1
+    env.step({i_agent: RailEnvActions.DO_NOTHING})
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart
+        
+    env.malfunction_generator = AlwaysThreeStepMalfunction()
+    env.step({i_agent: RailEnvActions.DO_NOTHING})
+    assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
\ No newline at end of file