diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 053229e56680284998fd2fb7896c7d5246cee30d..e11243046256a28f04913f40ef7ef29539e90c39 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -367,15 +367,16 @@ class RailEnv(Environment):
             for i_agent in range(self.get_num_agents()):
                 self.set_agent_active(i_agent)
 
-        # See if agents are already broken
+        # Induce malfunctions
         self._malfunction(self.mean_malfunction_rate)
 
-        for i_agent, agent in enumerate(self.agents):
-            initial_malfunction = self._agent_malfunction(i_agent)
-
-            if initial_malfunction:
+        for agent in self.agents:
+            if agent.malfunction_data["malfunction"] > 0:
                 agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
 
+        # Fix agents that finished their malfunciton
+        self._fix_agents()
+
         self.num_resets += 1
         self._elapsed_steps = 0
 
@@ -398,26 +399,27 @@ class RailEnv(Environment):
         observation_dict: Dict = self._get_observations()
         return observation_dict, info_dict
 
-    def _agent_malfunction(self, i_agent) -> bool:
+    def _fix_agents(self):
         """
-        Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before).
+        Updates agent malfunction variables and fixes broken agents
         """
-        agent = self.agents[i_agent]
+        for agent in self.agents:
 
-        # Reduce number of malfunction steps left
-        if agent.malfunction_data['malfunction'] > 0:
-            agent.malfunction_data['malfunction'] -= 1
-            return True
+            # Ignore agents that OK
+            if agent.malfunction_data['fixed']:
+                continue
 
-        # Ignore agents that OK
-        if agent.malfunction_data['fixed']:
-            return False
+            # Reduce number of malfunction steps left
+            if agent.malfunction_data['malfunction'] > 1:
+                agent.malfunction_data['malfunction'] -= 1
+                continue
 
-        # Restart agents at the end of their malfunction
-        agent.malfunction_data['fixed'] = True
-        if 'moving_before_malfunction' in agent.malfunction_data:
-            self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']
-        return False
+            # Restart agents at the end of their malfunction
+            agent.malfunction_data['malfunction'] -= 1
+            agent.malfunction_data['fixed'] = True
+            if 'moving_before_malfunction' in agent.malfunction_data:
+                agent.moving = agent.malfunction_data['moving_before_malfunction']
+                continue
 
     def _malfunction(self, rate):
         """
@@ -434,7 +436,7 @@ class RailEnv(Environment):
             # TODO: Do we want to guarantee that we have the desired rate or are we happy with lower rates?
             if breaking_agent.malfunction_data['malfunction'] < 1:
                 num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
-                                                          self.max_number_of_steps_broken + 1)
+                                                          self.max_number_of_steps_broken + 1) + 1
                 breaking_agent.malfunction_data['malfunction'] = num_broken_steps
                 breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
                 breaking_agent.malfunction_data['fixed'] = False
@@ -479,7 +481,7 @@ class RailEnv(Environment):
         }
         have_all_agents_ended = True  # boolean flag to check if all agents are done
 
-        # Evoke the malfunction generator
+        # Induce malfunctions
         self._malfunction(self.mean_malfunction_rate)
 
         for i_agent, agent in enumerate(self.agents):
@@ -498,6 +500,9 @@ class RailEnv(Environment):
             info_dict["speed"][i_agent] = agent.speed_data['speed']
             info_dict["status"][i_agent] = agent.status
 
+        # Fix agents that finished their malfunction
+        self._fix_agents()
+
         # Check for end of episode + set global reward to all rewards!
         if have_all_agents_ended:
             self.dones["__all__"] = True
@@ -542,12 +547,9 @@ class RailEnv(Environment):
         agent.old_direction = agent.direction
         agent.old_position = agent.position
 
-        # is the agent malfunctioning?
-        malfunction = self._agent_malfunction(i_agent)
-
         # if agent is broken, actions are ignored and agent does not move.
         # full step penalty in this case
-        if malfunction:
+        if agent.malfunction_data['malfunction'] > 0:
             self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
             return
 
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index b783fe7ab6e18e1ac532eef2fc63e496d6c3bbb3..68cd6f495c2d9ff099f228e230ef92c1b5e700fa 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -110,7 +110,7 @@ def test_malfunction_process():
         total_down_time += env.agents[0].malfunction_data['malfunction']
 
     # Check that the appropriate number of malfunctions is achieved
-    assert env.agents[0].malfunction_data['nr_malfunctions'] == 28, "Actual {}".format(
+    assert env.agents[0].malfunction_data['nr_malfunctions'] == 22, "Actual {}".format(
         env.agents[0].malfunction_data['nr_malfunctions'])
 
     # Check that malfunctioning data was standing around
@@ -140,17 +140,17 @@ def test_malfunction_process_statistically():
 
     env.agents[0].target = (0, 0)
     # Next line only for test generation
-    #agent_malfunction_list = [[] for i in range(20)]
-    agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0],
-     [0, 0, 0, 0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [0, 0, 5, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [], [], [], [], [], [], [], [], [], []]
+    #agent_malfunction_list = [[] for i in range(10)]
+    agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+     [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3],
+     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
+     [0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+     [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
+     [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
+     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2],
+     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0],
+     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
 
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}
@@ -188,17 +188,25 @@ def test_malfunction_before_entry():
     # Test initial malfunction values for all agents
     # we want some agents to be malfuncitoning already and some to be working
     # we want different next_malfunction values for the agents
-
-    for a in range(10):
-
-        print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
+    assert env.agents[0].malfunction_data['malfunction'] == 0
+    assert env.agents[1].malfunction_data['malfunction'] == 0
+    assert env.agents[2].malfunction_data['malfunction'] == 0
+    assert env.agents[3].malfunction_data['malfunction'] == 0
+    assert env.agents[4].malfunction_data['malfunction'] == 0
+    assert env.agents[5].malfunction_data['malfunction'] == 10
+    assert env.agents[6].malfunction_data['malfunction'] == 0
+    assert env.agents[7].malfunction_data['malfunction'] == 0
+    assert env.agents[8].malfunction_data['malfunction'] == 0
+    assert env.agents[9].malfunction_data['malfunction'] == 0
+    #for a in range(10):
+    #   print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
 
 
 
 
 def test_malfunction_values_and_behavior():
     """
-    Test that the next malfunction occurs when desired.
+    Test the malfunction counts down as desired
     Returns
     -------
 
@@ -207,7 +215,7 @@ def test_malfunction_values_and_behavior():
 
     rail, rail_map = make_simple_rail2()
     action_dict: Dict[int, RailEnvActions] = {}
-    stochastic_data = {'malfunction_rate': 0.01,
+    stochastic_data = {'malfunction_rate': 0.001,
                        'min_duration': 10,
                        'max_duration': 10}
     env = RailEnv(width=25,
@@ -223,7 +231,7 @@ def test_malfunction_values_and_behavior():
     env.reset(False, False, activate_agents=True, random_seed=10)
 
     # Assertions
-    assert_list = [8, 7, 6, 5, 4, 3, 2, 1, 0, 9, 8, 7, 6, 5, 4]
+    assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
     print("[")
     for time_step in range(15):
         # Move in the env
@@ -233,8 +241,7 @@ def test_malfunction_values_and_behavior():
 
 
 def test_initial_malfunction():
-    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
-                       'malfunction_rate': 100,  # Rate of malfunction occurence
+    stochastic_data = {'malfunction_rate': 1000,  # Rate of malfunction occurence
                        'min_duration': 2,  # Minimal duration of malfunction
                        'max_duration': 5  # Max duration of malfunction
                        }
@@ -278,7 +285,7 @@ def test_initial_malfunction():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=1,
-                reward=env.step_penalty * 1.0
+                reward=env.step_penalty
 
             ),  # malfunctioning ends: starting and running at speed 1.0
             Replay(
@@ -293,7 +300,7 @@ def test_initial_malfunction():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
-                reward=env.step_penalty * 1.0  # running at speed 1.0
+                reward=env.step_penalty  # running at speed 1.0
             )
         ],
         speed=env.agents[0].speed_data['speed'],
@@ -341,7 +348,7 @@ def test_initial_malfunction_stop_moving():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
-                malfunction=3,
+                malfunction=2,
                 reward=env.step_penalty,  # full step penalty when stopped
                 status=RailAgentStatus.ACTIVE
             ),
@@ -352,7 +359,7 @@ def test_initial_malfunction_stop_moving():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.STOP_MOVING,
-                malfunction=2,
+                malfunction=1,
                 reward=env.step_penalty,  # full step penalty while stopped
                 status=RailAgentStatus.ACTIVE
             ),
@@ -361,7 +368,7 @@ def test_initial_malfunction_stop_moving():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
-                malfunction=1,
+                malfunction=0,
                 reward=env.step_penalty,  # full step penalty while stopped
                 status=RailAgentStatus.ACTIVE
             ),
@@ -429,7 +436,7 @@ def test_initial_malfunction_do_nothing():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
-                malfunction=3,
+                malfunction=2,
                 reward=env.step_penalty,  # full step penalty while malfunctioning
                 status=RailAgentStatus.ACTIVE
             ),
@@ -440,7 +447,7 @@ def test_initial_malfunction_do_nothing():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
-                malfunction=2,
+                malfunction=1,
                 reward=env.step_penalty,  # full step penalty while stopped
                 status=RailAgentStatus.ACTIVE
             ),
@@ -449,7 +456,7 @@ def test_initial_malfunction_do_nothing():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
-                malfunction=1,
+                malfunction=0,
                 reward=env.step_penalty,  # full step penalty while stopped
                 status=RailAgentStatus.ACTIVE
             ),
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 80656cbb490c39fc7352327faa9bf5859a1123e9..ff4948d629747a3394644b971105178ec1ac4523 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -119,8 +119,9 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
                 # We also set next malfunction to infitiy to avoid interference with our tests
                 agent.malfunction_data['malfunction'] = replay.set_malfunction
                 agent.malfunction_data['moving_before_malfunction'] = agent.moving
+                agent.malfunction_data['fixed'] = False
             _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
-        print(step)
+        print(step, agent.moving, agent.malfunction_data['fixed'], agent.malfunction_data['malfunction'])
         _, rewards_dict, _, info_dict = env.step(action_dict)
         if rendering:
             renderer.render_env(show=True, show_observations=True)