diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 2f19bddd8aeda0e8fe8bdc886fcf21db646a22c7..60860d65dddf84f42ee6225f67ff585213a7171f 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -2,25 +2,22 @@
 Definition of the RailEnv environment.
 """
 import random
-# TODO:  _ this is a global method --> utils or remove later
-from typing import List, NamedTuple, Optional, Dict, Tuple
 
-import numpy as np
-from numpy.lib.shape_base import vsplit
-from numpy.testing._private.utils import import_nose
+from typing import List, Optional, Dict, Tuple
 
+import numpy as np
+from gym.utils import seeding
+from dataclasses import dataclass
 
 from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
+from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.core.grid.grid_utils import IntVector2D, position_to_coordinate
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import Agent, EnvAgent
+from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env_action import RailEnvActions
 
-# Need to use circular imports for persistence.
 from flatland.envs import malfunction_generators as mal_gen
 from flatland.envs import rail_generators as rail_gen
 from flatland.envs import line_generators as line_gen
@@ -29,31 +26,11 @@ from flatland.envs import persistence
 from flatland.envs import agent_chains as ac
 
 from flatland.envs.observations import GlobalObsForRailEnv
-from gym.utils import seeding
-
-# Direct import of objects / classes does not work with circular imports.
-# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
-# from flatland.envs.observations import GlobalObsForRailEnv
-# from flatland.envs.rail_generators import random_rail_generator, RailGenerator
-# from flatland.envs.line_generators import random_line_generator, LineGenerator
-
 
 from flatland.envs.timetable_generators import timetable_generator
-from flatland.envs.step_utils.states import TrainState
-from flatland.envs.step_utils.transition_utils import check_action
-
-# Env Step Facelift imports
-from flatland.envs.step_utils.action_preprocessing import preprocess_raw_action, preprocess_moving_action, preprocess_action_when_waiting
-
-# Adrian Egli performance fix (the fast methods brings more than 50%)
-def fast_isclose(a, b, rtol):
-    return (a < (b + rtol)) or (a < (b - rtol))
-
-def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
-    if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None
-        return False
-    else:
-        return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
+from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
+from flatland.envs.step_utils import transition_utils
+from flatland.envs.step_utils import action_preprocessing
 
 class RailEnv(Environment):
     """
@@ -406,22 +383,35 @@ class RailEnv(Environment):
     
     def apply_action_independent(self, action, rail, position, direction):
         if action.is_moving_action():
-            new_direction, _ = check_action(action, position, direction, rail)
+            new_direction, _ = transition_utils.check_action(action, position, direction, rail)
             new_position = get_new_position(position, new_direction)
         else:
             new_position, new_direction = position, direction
         return new_position, direction
     
     def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
-        st_signals = {}
+        st_signals = StateTransitionSignals()
         
-        st_signals['malfunction_onset'] = agent.malfunction_handler.in_malfunction
-        st_signals['malfunction_counter_complete'] = agent.malfunction_handler.malfunction_counter_complete
-        st_signals['earliest_departure_reached'] = self._elapsed_steps >= agent.earliest_departure
-        st_signals['stop_action_given'] = (preprocessed_action == RailEnvActions.STOP_MOVING)
-        st_signals['valid_movement_action_given'] = preprocessed_action.is_moving_action()
-        st_signals['target_reached'] = fast_position_equal(agent.position, agent.target)
-        st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
+        # Malfunction onset - Malfunction starts
+        st_signals.malfunction_onset = agent.malfunction_handler.in_malfunction
+
+        # Malfunction counter complete - Malfunction ends next timestep
+        st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete
+
+        # Earliest departure reached - Train is allowed to move now
+        st_signals.earliest_departure_reached = self._elapsed_steps >= agent.earliest_departure
+
+        # Stop Action Given
+        st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING)
+
+        # Valid Movement action Given
+        st_signals.valid_movement_action_given = preprocessed_action.is_moving_action()
+
+        # Target Reached
+        st_signals.target_reached = fast_position_equal(agent.position, agent.target)
+
+        # Movement conflict - Multiple trains trying to move into same cell
+        st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
 
         return st_signals
 
@@ -489,7 +479,7 @@ class RailEnv(Environment):
 
         self.motionCheck = ac.MotionCheck()  # reset the motion check
 
-        temp_saved_data = {} # TODO : Change name
+        temp_transition_data = {}
         
         for i_agent, agent in enumerate(self.agents): # TODO: Important - Do not use i_agent like this, use agent.handle if needed
             # Generate malfunction
@@ -500,15 +490,15 @@ class RailEnv(Environment):
             # TODO: Add the bottom stuff to separate function(s)
 
             # Preprocess action
-            action = preprocess_raw_action(action, agent.state)
-            action = preprocess_action_when_waiting(action, agent.state)
+            action = action_preprocessing.preprocess_raw_action(action, agent.state)
+            action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
 
             # Try moving actions on current position
             current_position, current_direction = agent.position, agent.direction
             agent_not_on_map = current_position is None
             if agent_not_on_map: # Agent not added on map yet
                 current_position, current_direction = agent.initial_position, agent.initial_direction
-            action = preprocess_moving_action(action, self.rail, current_position, current_direction)
+            action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
 
             # Save moving actions in not already saved
             agent.action_saver.save_action_if_allowed(action, agent.state)
@@ -516,24 +506,25 @@ class RailEnv(Environment):
             # Calculate new position
             # Add agent to the map if not on it yet
             if agent_not_on_map and agent.action_saver.is_action_saved:
-                temp_new_position = agent.initial_position
-                temp_new_direction = agent.initial_direction
+                new_position = agent.initial_position
+                new_direction = agent.initial_direction
                 preprocessed_action = action
                 
             # When cell exit occurs apply saved action independent of other agents
             elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
                 saved_action = agent.action_saver.saved_action
                 # Apply action independent of other agents and get temporary new position and direction
-                temp_pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
-                temp_new_position, temp_new_direction = temp_pd
+                pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
+                new_position, new_direction = pd
                 preprocessed_action = saved_action
             else:
-                temp_new_position, temp_new_direction = agent.position, agent.direction
+                new_position, new_direction = agent.position, agent.direction
                 preprocessed_action = action
 
-            # TODO: Saving temporary positon shouldn't be needed if recheck of position is not needed later (see TAG#1)
-            temp_saved_data[i_agent] = temp_new_position, temp_new_direction, preprocessed_action
-            self.motionCheck.addAgent(i_agent, agent.position, temp_new_position)
+            temp_transition_data[i_agent] = AgentTransitionData(position=new_position,
+                                                                direction=new_direction,
+                                                                preprocessed_action=preprocessed_action)
+            self.motionCheck.addAgent(i_agent, agent.position, new_position)
 
         # Find conflicts
         # TODO : Important - Modify conflicted positions and select one of them randomly to go to new position
@@ -541,23 +532,19 @@ class RailEnv(Environment):
         
         for agent in self.agents:
             i_agent = agent.handle
+            agent_transition_data = temp_transition_data[i_agent]
 
             ## Update positions
-            movement_allowed, _ = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck
-            # TODO : Important : Original code rechecks the next position here again - not sure why? TAG#1
-            preprocessed_action = temp_saved_data[i_agent][2] # TODO : Important : Make this namedtuple or class
-
-            # TODO : Looks like a hacky conditionm, improve the handling
             if agent.malfunction_handler.in_malfunction:
                 movement_allowed = False
+            else:
+                movement_allowed, _ = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck
 
             if movement_allowed:
-                final_new_position, final_new_direction = temp_saved_data[i_agent][:2] # TODO : Important : Make this namedtuple or class
-            else:
-                final_new_position = agent.position
-                final_new_direction = agent.direction
-            agent.position = final_new_position
-            agent.direction = final_new_direction
+                agent.position = agent_transition_data.position
+                agent.direction = agent_transition_data.direction
+            
+            preprocessed_action = agent_transition_data.preprocessed_action
 
             ## Update states
             state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
@@ -565,8 +552,8 @@ class RailEnv(Environment):
             agent.state_machine.step()
             agent.state = agent.state_machine.state # TODO : Make this a property instead?
 
-            # TODO : Important : Looks like a hacky condiition, improve the handling
-            if agent.state == TrainState.DONE:
+            # Remove agent is required
+            if self.remove_agents_at_target and agent.state == TrainState.DONE:
                 agent.position = None
 
             ## Update rewards
@@ -661,3 +648,21 @@ class RailEnv(Environment):
     def save(self, filename):
         print("deprecated call to env.save() - pls call RailEnvPersister.save()")
         persistence.RailEnvPersister.save(self, filename)
+
+@dataclass(repr=True)
+class AgentTransitionData:
+    """ Class for keeping track of temporary agent data for position update """
+    position : Tuple[int, int]
+    direction : Grid4Transitions
+    preprocessed_action : RailEnvActions
+
+
+# Adrian Egli performance fix (the fast methods brings more than 50%)
+def fast_isclose(a, b, rtol):
+    return (a < (b + rtol)) or (a < (b - rtol))
+
+def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
+    if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None
+        return False
+    else:
+        return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
diff --git a/flatland/envs/rail_env_action.py b/flatland/envs/rail_env_action.py
index f583eb72568f375e3a066cbf355771bb9b60e038..8665897f949294a9a1bf50fdc624de7907eca714 100644
--- a/flatland/envs/rail_env_action.py
+++ b/flatland/envs/rail_env_action.py
@@ -20,7 +20,7 @@ class RailEnvActions(IntEnum):
         }[a]
 
     @classmethod
-    def check_valid_action(cls, action):
+    def is_action_valid(cls, action):
         return action in cls._value2member_map_
 
     def is_moving_action(self):
diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py
index 98c42b15961b515e8c07eb89b521453a581e82f2..4da43c1695607785d3e779f1fe119064545fd575 100644
--- a/flatland/envs/step_utils/action_preprocessing.py
+++ b/flatland/envs/step_utils/action_preprocessing.py
@@ -5,7 +5,7 @@ from flatland.envs.step_utils.transition_utils import check_valid_action
 
 
 def process_illegal_action(action: RailEnvActions):
-	if not RailEnvActions.check_valid_action(action): 
+	if not RailEnvActions.is_action_valid(action): 
 		return RailEnvActions.DO_NOTHING
 	else:
 		return RailEnvActions(action)
diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py
index e42a829d2018c3c540ddd0f0e8c249530333abef..6d0b9f406e437c5215e00d983ae9484458ee4455 100644
--- a/flatland/envs/step_utils/state_machine.py
+++ b/flatland/envs/step_utils/state_machine.py
@@ -1,11 +1,11 @@
 from attr import s
-from flatland.envs.step_utils.states import TrainState
+from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
 
 class TrainStateMachine:
     def __init__(self, initial_state=TrainState.WAITING):
         self._initial_state = initial_state
         self._state = initial_state
-        self.st_signals = {} # State Transition Signals # TODO: Make this namedtuple
+        self.st_signals = StateTransitionSignals()
         self.next_state = None
     
     def _handle_waiting(self):
@@ -13,25 +13,25 @@ class TrainStateMachine:
         # TODO: Important - The malfunction handling is not like proper state machine 
         #                   Both transition signals can happen at the same time
         #                   Atleast mention it in the diagram
-        if self.st_signals['malfunction_onset']:  
+        if self.st_signals.malfunction_onset:  
             self.next_state = TrainState.MALFUNCTION_OFF_MAP
-        elif self.st_signals['earliest_departure_reached']:
+        elif self.st_signals.earliest_departure_reached:
             self.next_state = TrainState.READY_TO_DEPART
         else:
             self.next_state = TrainState.WAITING
 
     def _handle_ready_to_depart(self):
         """ Can only go to MOVING if a valid action is provided """
-        if self.st_signals['malfunction_onset']:  
+        if self.st_signals.malfunction_onset:  
             self.next_state = TrainState.MALFUNCTION_OFF_MAP
-        elif self.st_signals['valid_movement_action_given']:
+        elif self.st_signals.valid_movement_action_given:
             self.next_state = TrainState.MOVING
         else:
             self.next_state = TrainState.READY_TO_DEPART
     
     def _handle_malfunction_off_map(self):
-        if self.st_signals['malfunction_counter_complete']:
-            if self.st_signals['earliest_departure_reached']:
+        if self.st_signals.malfunction_counter_complete:
+            if self.st_signals.earliest_departure_reached:
                 self.next_state = TrainState.READY_TO_DEPART
             else:
                 self.next_state = TrainState.STOPPED
@@ -39,29 +39,29 @@ class TrainStateMachine:
             self.next_state = TrainState.WAITING
     
     def _handle_moving(self):
-        if self.st_signals['malfunction_onset']:
+        if self.st_signals.malfunction_onset:
             self.next_state = TrainState.MALFUNCTION
-        elif self.st_signals['target_reached']:
+        elif self.st_signals.target_reached:
             self.next_state = TrainState.DONE
-        elif self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']:
+        elif self.st_signals.stop_action_given or self.st_signals.movement_conflict:
             self.next_state = TrainState.STOPPED
         else:
             self.next_state = TrainState.MOVING
     
     def _handle_stopped(self):
-        if self.st_signals['malfunction_onset']:
+        if self.st_signals.malfunction_onset:
             self.next_state = TrainState.MALFUNCTION
-        elif self.st_signals['valid_movement_action_given']:
+        elif self.st_signals.valid_movement_action_given:
             self.next_state = TrainState.MOVING
         else:
             self.next_state = TrainState.STOPPED
     
     def _handle_malfunction(self):
-        if self.st_signals['malfunction_counter_complete'] and \
-           self.st_signals['valid_movement_action_given']:
+        if self.st_signals.malfunction_counter_complete and \
+           self.st_signals.valid_movement_action_given:
             self.next_state = TrainState.MOVING
-        elif self.st_signals['malfunction_counter_complete'] and \
-             (self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']):
+        elif self.st_signals.malfunction_counter_complete and \
+             (self.st_signals.stop_action_given or self.st_signals.movement_conflict):
              self.next_state = TrainState.STOPPED
         else:
             self.next_state = TrainState.MALFUNCTION
@@ -134,7 +134,7 @@ class TrainStateMachine:
         return self.st_signals
     
     def set_transition_signals(self, state_transition_signals):
-        self.st_signals = state_transition_signals # TODO: Important: Check all keys are present and if not raise error
+        self.st_signals = state_transition_signals
 
 
         
diff --git a/flatland/envs/step_utils/states.py b/flatland/envs/step_utils/states.py
index 4e612ae842fd9da2e206f010a72bca2e0b9f5608..d7cfd2b24a41ccfffc1178b0d0625b5e1d617752 100644
--- a/flatland/envs/step_utils/states.py
+++ b/flatland/envs/step_utils/states.py
@@ -1,5 +1,5 @@
 from enum import IntEnum
-
+from dataclasses import dataclass
 class TrainState(IntEnum):
     WAITING = 0
     READY_TO_DEPART = 1
@@ -22,6 +22,13 @@ class TrainState(IntEnum):
     def is_on_map_state(self):
         return self.value in [self.MOVING, self.STOPPED, self.MALFUNCTION]
 
-    
-
+@dataclass(repr=True)
+class StateTransitionSignals:
+    malfunction_onset : bool = False
+    malfunction_counter_complete : bool = False
+    earliest_departure_reached : bool = False
+    stop_action_given : bool = False
+    valid_movement_action_given : bool = False
+    target_reached : bool = False
+    movement_conflict : bool = False