diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index 0dfafb36447550bc275dadabc6943b50d07e2f4f..086fd9cef348ff8de6b6c358876926804dc673e5 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -253,7 +253,7 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
     min_number_of_steps_broken = parameters.min_duration
     max_number_of_steps_broken = parameters.max_duration
 
-    def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
+    def generator(np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
         """
         Generate malfunctions for agents
         Parameters
@@ -270,11 +270,10 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
         if reset:
             return Malfunction(0)
 
-        if agent.malfunction_data['malfunction'] < 1:
-            if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
-                num_broken_steps = np_random.randint(min_number_of_steps_broken,
-                                                     max_number_of_steps_broken + 1) + 1
-                return Malfunction(num_broken_steps)
+        if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
+            num_broken_steps = np_random.randint(min_number_of_steps_broken,
+                                                    max_number_of_steps_broken + 1)
+            return Malfunction(num_broken_steps)
         return Malfunction(0)
 
     return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index a6a15531bf42a45b17f93fecf82dc5c1d2ef92af..f345648533c7ab554403b7031245139e39e51b75 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -549,7 +549,7 @@ class RailEnv(Environment):
             if agent.malfunction_handler.in_malfunction:
                 movement_allowed = False
             else:
-                movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck
+                movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
 
             # Position can be changed only if other cell is empty
             # And either the speed counter completes or agent is being added to map
diff --git a/flatland/envs/step_utils/action_saver.py b/flatland/envs/step_utils/action_saver.py
index bf61076e1f1136b3eca28bd05a51ba614a9d0eb6..913e9576d923a7e67ff7a498237803df3d9d0a43 100644
--- a/flatland/envs/step_utils/action_saver.py
+++ b/flatland/envs/step_utils/action_saver.py
@@ -17,14 +17,10 @@ class ActionSaver:
         """
         Save the action if all conditions are met
             1. It is a movement based action -> Forward, Left, Right
-            2. Action is not already saved
-            3. Not in a malfunction state 
-            4. Agent is not already done
+            2. Action is not already saved 
+            3. Agent is not already done
         """
-        if action.is_moving_action() and \
-               not self.is_action_saved and \
-               not state.is_malfunction_state() and \
-               not state == TrainState.DONE:
+        if action.is_moving_action() and not self.is_action_saved and not state == TrainState.DONE:
             self.saved_action = action
 
     def clear_saved_action(self):
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index d633351ed3624499aa2e30df9f09031b0b4cf581..b4632f3e37bd5f879349488d8b24a74c4c5d9759 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -190,8 +190,9 @@ def test_malfunction_before_entry():
     # Test initial malfunction values for all agents
     # we want some agents to be malfuncitoning already and some to be working
     # we want different next_malfunction values for the agents
-    assert env.agents[0].malfunction_data['malfunction'] == 0
-    assert env.agents[1].malfunction_data['malfunction'] == 10
+    malfunction_values = [env.malfunction_generator(env.np_random).num_broken_steps for _ in range(1000)]
+    expected_value = (1 - np.exp(-0.5)) * 10
+    assert np.allclose(np.mean(malfunction_values), expected_value, rtol=0.1), "Mean values of malfunction don't match rate"
 
 
 def test_malfunction_values_and_behavior():
@@ -257,7 +258,7 @@ def test_initial_malfunction():
     set_penalties_for_replay(env)
     replay_config = ReplayConfig(
         replay=[
-            Replay(
+            Replay( # 0
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
@@ -265,7 +266,7 @@ def test_initial_malfunction():
                 malfunction=3,
                 reward=env.step_penalty  # full step penalty when malfunctioning
             ),
-            Replay(
+            Replay( # 1
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
@@ -274,7 +275,7 @@ def test_initial_malfunction():
             ),
             # malfunction stops in the next step and we're still at the beginning of the cell
             # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
-            Replay(
+            Replay( # 2
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
@@ -282,14 +283,14 @@ def test_initial_malfunction():
                 reward=env.step_penalty
 
             ),  # malfunctioning ends: starting and running at speed 1.0
-            Replay(
+            Replay( # 3
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
                 reward=env.start_penalty + env.step_penalty * 1.0  # running at speed 1.0
             ),
-            Replay(
+            Replay( # 4
                 position=(3, 3),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
@@ -420,7 +421,7 @@ def test_initial_malfunction_do_nothing():
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=2,
                 reward=env.step_penalty,  # full step penalty while malfunctioning
-                state=TrainState.ACTIVE
+                state=TrainState.MOVING
             ),
             # malfunction stops in the next step and we're still at the beginning of the cell
             # --> if we take action DO_NOTHING, agent should restart without moving
@@ -431,7 +432,7 @@ def test_initial_malfunction_do_nothing():
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=1,
                 reward=env.step_penalty,  # full step penalty while stopped
-                state=TrainState.ACTIVE
+                state=TrainState.MOVING
             ),
             # we haven't started moving yet --> stay here
             Replay(
@@ -440,7 +441,7 @@ def test_initial_malfunction_do_nothing():
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=0,
                 reward=env.step_penalty,  # full step penalty while stopped
-                state=TrainState.ACTIVE
+                state=TrainState.MOVING
             ),
 
             Replay(
@@ -449,7 +450,7 @@ def test_initial_malfunction_do_nothing():
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
                 reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
-                state=TrainState.ACTIVE
+                state=TrainState.MOVING
             ),  # we start to move forward --> should go to next cell now
             Replay(
                 position=(3, 3),
@@ -457,7 +458,7 @@ def test_initial_malfunction_do_nothing():
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
                 reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
-                state=TrainState.ACTIVE
+                state=TrainState.MOVING
             )
         ],
         speed=env.agents[0].speed_counter.speed,
@@ -546,7 +547,7 @@ def test_last_malfunction_step():
     env.reset(False, False)
     for a_idx in range(len(env.agents)):
         env.agents[a_idx].position =  env.agents[a_idx].initial_position
-        env.agents[a_idx].state = TrainState.ACTIVE
+        env.agents[a_idx].state = TrainState.MOVING
     # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
     env.agents[0].malfunction_data['next_malfunction'] = 2
     env.agents[0].malfunction_data['malfunction'] = 0
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 3469c9c82ed3f66ade78f2f335850a9c4809ddc6..391ba535a343c25df03df6fa18f8651543fab06f 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -107,7 +107,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
         for a, test_config in enumerate(test_configs):
             agent: EnvAgent = env.agents[a]
             replay = test_config.replay[step]
-            print(agent.position, replay.position, agent.state, agent.speed_counter)
             _assert(a, agent.position, replay.position, 'position')
             _assert(a, agent.direction, replay.direction, 'direction')
             if replay.state is not None: