diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 9230659e500c0b292e3525eceb9280f77878c1a8..91c6d72faee99df7b2a6242de246f55bd5953f17 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -1,8 +1,6 @@
 from flatland.envs.rail_trainrun_data_structures import Waypoint
 import numpy as np
 
-from enum import IntEnum
-
 from itertools import starmap
 from typing import Tuple, Optional, NamedTuple, List
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index a140da06c29916d1f05b01cb2df0bb437f27a5d6..e859173469fa60f81e86c0f6fad74d82905e21ad 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -94,7 +94,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.location_has_agent_ready_to_depart = {}
 
         for _agent in self.env.agents:
-            if not TrainState.off_map_state(_agent.state) and \
+            if not _agent.state.is_off_map_state() and \
                 _agent.position:
                 self.location_has_agent[tuple(_agent.position)] = 1
                 self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
@@ -103,7 +103,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     'malfunction']
 
             # [NIMISH] WHAT IS THIS
-            if TrainState.off_map_state(_agent.state) and \
+            if _agent.state.is_off_map_state() and \
                 _agent.initial_position:
                     self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
                     self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
@@ -570,9 +570,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
     def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
 
         agent = self.env.agents[handle]
-        if TrainState.off_map_state(agent.state):
+        if agent.state.is_off_map_state():
             agent_virtual_position = agent.initial_position
-        elif TrainState.on_map_state(agent.state):
+        elif agent.state.is_on_map_state():
             agent_virtual_position = agent.position
         elif agent.state == TrainState.DONE:
             agent_virtual_position = agent.target
@@ -608,7 +608,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
                 obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
                 obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
             # fifth channel: all ready to depart on this position
-            if TrainState.off_map_state(other_agent.state):
+            if other_agent.state.is_off_map_state():
                 obs_agents_state[other_agent.initial_position][4] += 1
         return self.rail_obs, obs_agents_state, obs_targets
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 68bb0de2677f64fa8ee2763455c3d5aeef33d203..2f19bddd8aeda0e8fe8bdc886fcf21db646a22c7 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -3,7 +3,6 @@ Definition of the RailEnv environment.
 """
 import random
 # TODO:  _ this is a global method --> utils or remove later
-from enum import IntEnum
 from typing import List, NamedTuple, Optional, Dict, Tuple
 
 import numpy as np
@@ -285,7 +284,7 @@ class RailEnv(Environment):
         False: Agent cannot provide an action
         """
         return agent.state == TrainState.READY_TO_DEPART or \
-               (TrainState.on_map_state(agent.state) and \
+               (agent.state.is_on_map_state() and \
                 fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) )
 
     def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
@@ -406,7 +405,7 @@ class RailEnv(Environment):
         return observation_dict, info_dict
     
     def apply_action_independent(self, action, rail, position, direction):
-        if RailEnvActions.is_moving_action(action):
+        if action.is_moving_action():
             new_direction, _ = check_action(action, position, direction, rail)
             new_position = get_new_position(position, new_direction)
         else:
@@ -420,7 +419,7 @@ class RailEnv(Environment):
         st_signals['malfunction_counter_complete'] = agent.malfunction_handler.malfunction_counter_complete
         st_signals['earliest_departure_reached'] = self._elapsed_steps >= agent.earliest_departure
         st_signals['stop_action_given'] = (preprocessed_action == RailEnvActions.STOP_MOVING)
-        st_signals['valid_movement_action_given'] = RailEnvActions.is_moving_action(preprocessed_action)
+        st_signals['valid_movement_action_given'] = preprocessed_action.is_moving_action()
         st_signals['target_reached'] = fast_position_equal(agent.position, agent.target)
         st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
 
@@ -557,10 +556,6 @@ class RailEnv(Environment):
             else:
                 final_new_position = agent.position
                 final_new_direction = agent.direction
-            # if final_new_position and self.rail.grid[final_new_position] == 0:
-                # import pdb; pdb.set_trace()
-            # if final_new_position and not (final_new_position[0] >= 0 and final_new_position[1] >= 0 and final_new_position[0] < self.rail.height and final_new_position[1] < self.rail.width): # TODO: Remove this
-                # import pdb; pdb.set_trace()
             agent.position = final_new_position
             agent.direction = final_new_direction
 
diff --git a/flatland/envs/rail_env_action.py b/flatland/envs/rail_env_action.py
index a25cc8f0f37233f76b921ffc62c83818e8e7bb9b..f583eb72568f375e3a066cbf355771bb9b60e038 100644
--- a/flatland/envs/rail_env_action.py
+++ b/flatland/envs/rail_env_action.py
@@ -19,9 +19,12 @@ class RailEnvActions(IntEnum):
             4: 'S',
         }[a]
 
-    @staticmethod
-    def is_moving_action(action):
-        return action in [1,2,3]
+    @classmethod
+    def check_valid_action(cls, action):
+        return action in cls._value2member_map_
+
+    def is_moving_action(self):
+        return self.value in [self.MOVE_RIGHT, self.MOVE_LEFT, self.MOVE_FORWARD]
 
 
 RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py
index 7d0632823e46432fb9d13be4af9af06d3bf6f873..e844390f7d4927476525da45196db28893145f7a 100644
--- a/flatland/envs/rail_env_shortest_paths.py
+++ b/flatland/envs/rail_env_shortest_paths.py
@@ -227,9 +227,9 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
     shortest_paths = dict()
 
     def _shortest_path_for_agent(agent):
-        if TrainState.off_map_state(agent.state):
+        if agent.state.is_off_map_state():
             position = agent.initial_position
-        elif TrainState.on_map_state(agent.state):
+        elif agent.state.is_on_map_state():
             position = agent.position
         elif agent.state == TrainState.DONE:
             position = agent.target
diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py
index c777342d6d9cf20167d7f2a9df10a9a2a3a7922a..98c42b15961b515e8c07eb89b521453a581e82f2 100644
--- a/flatland/envs/step_utils/action_preprocessing.py
+++ b/flatland/envs/step_utils/action_preprocessing.py
@@ -5,11 +5,10 @@ from flatland.envs.step_utils.transition_utils import check_valid_action
 
 
 def process_illegal_action(action: RailEnvActions):
-	# TODO - Dipam : This check is kind of weird, change this
-	if action is None or action not in RailEnvActions._value2member_map_: 
+	if not RailEnvActions.check_valid_action(action): 
 		return RailEnvActions.DO_NOTHING
 	else:
-		return action
+		return RailEnvActions(action)
 
 
 def process_do_nothing(state: TrainState):
diff --git a/flatland/envs/step_utils/action_saver.py b/flatland/envs/step_utils/action_saver.py
index 56f7465af77de4a88ce6d010593bca92c8280759..a34778ed48a90a8dede28ba83181877c247deb96 100644
--- a/flatland/envs/step_utils/action_saver.py
+++ b/flatland/envs/step_utils/action_saver.py
@@ -15,8 +15,8 @@ class ActionSaver:
 
     def save_action_if_allowed(self, action, state):
         if not self.is_action_saved and \
-            RailEnvActions.is_moving_action(action) and \
-            not TrainState.is_malfunction_state(state):
+               action.is_moving_action() and \
+               not state.is_malfunction_state():
             self.saved_action = action
 
     def clear_saved_action(self):
diff --git a/flatland/envs/step_utils/states.py b/flatland/envs/step_utils/states.py
index b8040939419e27987a06636c205e2f2cfef45166..4e612ae842fd9da2e206f010a72bca2e0b9f5608 100644
--- a/flatland/envs/step_utils/states.py
+++ b/flatland/envs/step_utils/states.py
@@ -13,17 +13,14 @@ class TrainState(IntEnum):
     def check_valid_state(cls, state):
         return state in cls._value2member_map_
 
-    @staticmethod
-    def is_malfunction_state(state):
-        return state in [2, 5] # TODO: Can this be done with names instead?
+    def is_malfunction_state(self):
+        return self.value in [self.MALFUNCTION, self.MALFUNCTION_OFF_MAP]
 
-    @staticmethod
-    def off_map_state(state):
-        return state in [0, 1, 2]
+    def is_off_map_state(self):
+        return self.value in [self.WAITING, self.READY_TO_DEPART, self.MALFUNCTION_OFF_MAP]
     
-    @staticmethod    
-    def on_map_state(state):
-        return state in [3, 4, 5]
+    def is_on_map_state(self):
+        return self.value in [self.MOVING, self.STOPPED, self.MALFUNCTION]
 
     
 
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 9499c5c4542bcb7c16231c2fd88a8b445ff3bce0..cd765cd19ba0c9510d301ac77a0782bccd6bd6b4 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -743,7 +743,7 @@ class RenderLocal(RenderBase):
                     if show_inactive_agents:
                         show_this_agent = True
                     else:
-                        show_this_agent = TrainState.on_map_state(agent.state)
+                        show_this_agent = agent.state.is_on_map_state()
 
                     if show_this_agent:
                         self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,