diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 0a642a4c5b6063a259dee4b8adda3f6948234e6b..6859497f129022877d2174b95a51da076ff3d725 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -372,7 +372,7 @@ class RailEnv(Environment):
         """ Generate State Transitions Signals used in the state machine """
         st_signals = StateTransitionSignals()
         
-        # Malfunction onset - Malfunction starts
+        # Malfunction starts when in_malfunction is set to true
         st_signals.malfunction_onset = agent.malfunction_handler.in_malfunction
 
         # Malfunction counter complete - Malfunction ends next timestep
@@ -563,7 +563,8 @@ class RailEnv(Environment):
             agent.state_machine.step()
 
             if agent.state.is_on_map_state() and agent.position is None:
-                import pdb; pdb.set_trace()
+                raise ValueError("Agent ID {} Agent State {} not matching with Agent Position {} ".format(
+                                    agent.handle, str(agent.state), str(agent.position) ))
 
             # Handle done state actions, optionally remove agents
             self.handle_done_state(agent)
diff --git a/flatland/envs/step_utils/malfunction_handler.py b/flatland/envs/step_utils/malfunction_handler.py
index a45aa02499983ed8d6444ef4cab11dbd8650f1ed..2ba72643f7c28729b99facf702528f6ffd174add 100644
--- a/flatland/envs/step_utils/malfunction_handler.py
+++ b/flatland/envs/step_utils/malfunction_handler.py
@@ -30,7 +30,9 @@ class MalfunctionHandler:
     def _set_malfunction_down_counter(self, val):
         if val < 0:
             raise ValueError("Cannot set a negative value to malfunction down counter")
-        self._malfunction_down_counter = val
+        # Only set new malfunction value if old malfunction is completed
+        if self._malfunction_down_counter == 0:
+            self._malfunction_down_counter = val
 
     def generate_malfunction(self, malfunction_generator, np_random):
         num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
@@ -40,6 +42,10 @@ class MalfunctionHandler:
         if self._malfunction_down_counter > 0:
             self._malfunction_down_counter -= 1
 
+    def __repr__(self):
+        return f"malfunction_down_counter: {self._malfunction_down_counter} \
+                in_malfunction: {self.in_malfunction}"
+
     def to_dict(self):
         return {"malfunction_down_counter": self._malfunction_down_counter}
     
diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py
index d1938f4f5157497c75969b01ea15a072d6a3a31c..78c9883fb85b3a4f967356515cddfc95235346f9 100644
--- a/flatland/envs/step_utils/state_machine.py
+++ b/flatland/envs/step_utils/state_machine.py
@@ -13,7 +13,7 @@ class TrainStateMachine:
         # TODO: Important - The malfunction handling is not like proper state machine 
         #                   Both transition signals can happen at the same time
         #                   Atleast mention it in the diagram
-        if self.st_signals.malfunction_onset:  
+        if self.st_signals.in_malfunction:  
             self.next_state = TrainState.MALFUNCTION_OFF_MAP
         elif self.st_signals.earliest_departure_reached:
             self.next_state = TrainState.READY_TO_DEPART
@@ -22,7 +22,7 @@ class TrainStateMachine:
 
     def _handle_ready_to_depart(self):
         """ Can only go to MOVING if a valid action is provided """
-        if self.st_signals.malfunction_onset:  
+        if self.st_signals.in_malfunction:  
             self.next_state = TrainState.MALFUNCTION_OFF_MAP
         elif self.st_signals.valid_movement_action_given:
             self.next_state = TrainState.MOVING
@@ -39,7 +39,7 @@ class TrainStateMachine:
             self.next_state = TrainState.WAITING
     
     def _handle_moving(self):
-        if self.st_signals.malfunction_onset:
+        if self.st_signals.in_malfunction:
             self.next_state = TrainState.MALFUNCTION
         elif self.st_signals.target_reached:
             self.next_state = TrainState.DONE
@@ -49,7 +49,7 @@ class TrainStateMachine:
             self.next_state = TrainState.MOVING
     
     def _handle_stopped(self):
-        if self.st_signals.malfunction_onset:
+        if self.st_signals.in_malfunction:
             self.next_state = TrainState.MALFUNCTION
         elif self.st_signals.valid_movement_action_given:
             self.next_state = TrainState.MOVING
diff --git a/flatland/envs/step_utils/states.py b/flatland/envs/step_utils/states.py
index d5d85b1bf30c52b770a337a183af52cb7333989b..806113e524112e7aa0a0704ddffce1b8d2db5ffa 100644
--- a/flatland/envs/step_utils/states.py
+++ b/flatland/envs/step_utils/states.py
@@ -27,7 +27,7 @@ class TrainState(IntEnum):
 
 @dataclass(repr=True)
 class StateTransitionSignals:
-    malfunction_onset : bool = False
+    in_malfunction : bool = False
     malfunction_counter_complete : bool = False
     earliest_departure_reached : bool = False
     stop_action_given : bool = False
diff --git a/requirements_dev.txt b/requirements_dev.txt
index 93414562b79e3c0d5e1a77e42b967dc0ea4028fe..51473c19d41ddfbc14507c643758aece381e62e2 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -23,3 +23,4 @@ networkx
 ipycanvas
 graphviz
 imageio
+dataclasses
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 56b4befc9b6d594750dbf87f5db7d8f8a03aa24a..3469c9c82ed3f66ade78f2f335850a9c4809ddc6 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -107,11 +107,9 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
         for a, test_config in enumerate(test_configs):
             agent: EnvAgent = env.agents[a]
             replay = test_config.replay[step]
-
             print(agent.position, replay.position, agent.state, agent.speed_counter)
-            # import pdb; pdb.set_trace()
-            # _assert(a, agent.position, replay.position, 'position')
-            # _assert(a, agent.direction, replay.direction, 'direction')
+            _assert(a, agent.position, replay.position, 'position')
+            _assert(a, agent.direction, replay.direction, 'direction')
             if replay.state is not None:
                 _assert(a, agent.state, replay.state, 'state')
 
@@ -129,10 +127,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
                 # As we force malfunctions on the agents we have to set a positive rate that the env
                 # recognizes the agent as potentially malfuncitoning
                 # We also set next malfunction to infitiy to avoid interference with our tests
-                agent.malfunction_data['malfunction'] = replay.set_malfunction
-                agent.malfunction_data['moving_before_malfunction'] = agent.moving
-                agent.malfunction_data['fixed'] = False
-            # _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
+                env.agents[a].malfunction_handler._set_malfunction_down_counter(replay.set_malfunction)
+            _assert(a, agent.malfunction_handler.malfunction_down_counter, replay.malfunction, 'malfunction')
         print(step)
         _, rewards_dict, _, info_dict = env.step(action_dict)
         if rendering:
@@ -143,8 +139,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
 
             if not skip_reward_check:
                 _assert(a, rewards_dict[a], replay.reward, 'reward')
-    assert False
-
 
 def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator):
     stochastic_data = MalfunctionParameters(malfunction_rate=1000,  # Rate of malfunction occurence