diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py
index dd17c285519cb35ee5d11a3ed1731f3e33a45c33..29fafe6436de2e04a780b34bd3eddba8a4533355 100644
--- a/examples/introduction_flatland_2_1.py
+++ b/examples/introduction_flatland_2_1.py
@@ -7,6 +7,7 @@ from flatland.envs.malfunction_generators import malfunction_from_params
 from flatland.envs.observations import GlobalObsForRailEnv
 # First of all we import the Flatland rail environment
 from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_env import RailEnvActions
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 # We also include a renderer because we want to visualize what is going on in the environment
@@ -28,8 +29,8 @@ from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 # The railway infrastructure can be build using any of the provided generators in env/rail_generators.py
 # Here we use the sparse_rail_generator with the following parameters
 
-width = 16*7  # With of map
-height = 9*7  # Height of map
+width = 16 * 7  # With of map
+height = 9 * 7  # Height of map
 nr_trains = 20  # Number of trains that have an assigned task in the env
 cities_in_map = 20  # Number of cities where agents can start or end
 seed = 14  # Random seed
@@ -104,7 +105,8 @@ class RandomAgent:
         :param state: input is the observation of the agent
         :return: returns an action
         """
-        return np.random.choice([1, 2, 3, 4]) # [Left, Forward, Right, Stop]
+        return np.random.choice([RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT,
+                                 RailEnvActions.STOP_MOVING])
 
     def step(self, memories):
         """
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index f5b84a7d11910e1bf11c93a4bef5a955ad734ad3..2bb9677aab020560aeb28aad97edfa23efebe9bf 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -40,7 +40,7 @@ class EnvAgentStatic(object):
     malfunction_data = attrib(
         default=Factory(
             lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
-                          'moving_before_malfunction': False, 'fixed': True})))
+                          'moving_before_malfunction': False})))
 
     status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
     position = attrib(default=None, type=Optional[Tuple[int, int]])
@@ -65,8 +65,7 @@ class EnvAgentStatic(object):
                                       'malfunction_rate': schedule.agent_malfunction_rates[
                                           i] if schedule.agent_malfunction_rates is not None else 0.,
                                       'next_malfunction': 0,
-                                      'nr_malfunctions': 0,
-                                      'fixed': True})
+                                      'nr_malfunctions': 0})
 
         return list(starmap(EnvAgentStatic, zip(schedule.agent_positions,
                                                 schedule.agent_directions,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 6da778ed40e917833c9460e9e08a8e4a516e5611..29638c6346f34ad39dbe1ad21f2bc7299a5e43ea 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -362,15 +362,17 @@ class RailEnv(Environment):
             for i_agent in range(self.get_num_agents()):
                 self.set_agent_active(i_agent)
 
-        # Induce malfunctions
-        self._malfunction(self.mean_malfunction_rate)
+
 
         for agent in self.agents:
+            # Induce malfunctions
+            self._break_agent(self.mean_malfunction_rate, agent)
+
             if agent.malfunction_data["malfunction"] > 0:
                 agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
 
-        # Fix agents that finished their malfunciton
-        self._fix_agents()
+            # Fix agents that finished their malfunction
+            self._fix_agent(agent)
 
         self.num_resets += 1
         self._elapsed_steps = 0
@@ -394,64 +396,48 @@ class RailEnv(Environment):
         observation_dict: Dict = self._get_observations()
         return observation_dict, info_dict
 
-    def _fix_agents(self):
+    def _fix_agent(self, agent):
         """
         Updates agent malfunction variables and fixes broken agents
-        """
-        for agent in self.agents:
 
-            # Ignore agents that OK
-            if agent.malfunction_data['fixed']:
-                continue
+        Parameters
+        ----------
+        agent
+        """
 
-            # Reduce number of malfunction steps left
-            if agent.malfunction_data['malfunction'] > 1:
-                agent.malfunction_data['malfunction'] -= 1
-                continue
+        # Ignore agents that are OK
+        if self._is_ok(agent):
+            return
 
-            # Restart agents at the end of their malfunction
+        # Reduce number of malfunction steps left
+        if agent.malfunction_data['malfunction'] > 1:
             agent.malfunction_data['malfunction'] -= 1
-            agent.malfunction_data['fixed'] = True
-            if 'moving_before_malfunction' in agent.malfunction_data:
-                agent.moving = agent.malfunction_data['moving_before_malfunction']
-                continue
+            return
 
-    def _malfunction(self, rate):
-        """
-        Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run
+        # Restart agents at the end of their malfunction
+        agent.malfunction_data['malfunction'] -= 1
+        if 'moving_before_malfunction' in agent.malfunction_data:
+            agent.moving = agent.malfunction_data['moving_before_malfunction']
+            return
 
+    def _break_agent(self, rate, agent):
         """
-        if self.np_random.rand() < self._malfunction_prob(rate, len(self.active_agents)):
-            # Select only from agents that are not done yet
-            breaking_agent_idx = self.np_random.choice(self.active_agents)
-            breaking_agent = self.agents[breaking_agent_idx]
+        Malfunction generator that breaks agents at a given rate.
 
-            # We assume that less then half of the active agents should be broken at MOST.
-            # Therefore we only try that many times before ignoring the malfunction
-
-            tries = 0
-            max_tries = 0.5 * len(self.active_agents)
-
-            # Look for a functioning active agent
-            while breaking_agent.malfunction_data['malfunction'] > 0 and tries < max_tries:
-                breaking_agent_idx = self.np_random.choice(self.active_agents)
-                breaking_agent = self.agents[breaking_agent_idx]
-                tries += 1
+        Parameters
+        ----------
+        agent
 
-            # If we did not manage to find a functioning agent among the active ones skip this malfunction
-            if tries < max_tries:
-                # Because we update agents in the same step as we break them we add one to the duration of the
-                # malfunction
+        """
+        if agent.malfunction_data['malfunction'] < 1:
+            if self.np_random.rand() < self._malfunction_prob(rate):
                 num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
-                                                          self.max_number_of_steps_broken + 1) + 1
-                breaking_agent.malfunction_data['malfunction'] = num_broken_steps
-                breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
-                breaking_agent.malfunction_data['fixed'] = False
-                breaking_agent.malfunction_data['nr_malfunctions'] += 1
+                                                              self.max_number_of_steps_broken + 1) + 1
+                agent.malfunction_data['malfunction'] = num_broken_steps
+                agent.malfunction_data['moving_before_malfunction'] = agent.moving
+                agent.malfunction_data['nr_malfunctions'] += 1
+        return
 
-                return
-
-            return
 
     def step(self, action_dict_: Dict[int, RailEnvActions]):
         """
@@ -492,13 +478,15 @@ class RailEnv(Environment):
         }
         have_all_agents_ended = True  # boolean flag to check if all agents are done
 
-        # Induce malfunctions
-        self._malfunction(self.mean_malfunction_rate)
+
 
         for i_agent, agent in enumerate(self.agents):
             # Reset the step rewards
             self.rewards_dict[i_agent] = 0
 
+            # Induce malfunction before we do a step, thus a broken agent can't move in this step
+            self._break_agent(self.mean_malfunction_rate, agent)
+
             # Perform step on the agent
             self._step_agent(i_agent, action_dict_.get(i_agent))
 
@@ -511,8 +499,8 @@ class RailEnv(Environment):
             info_dict["speed"][i_agent] = agent.speed_data['speed']
             info_dict["status"][i_agent] = agent.status
 
-        # Fix agents that finished their malfunction
-        self._fix_agents()
+            # Fix agents that finished their malfunction such that they can perfom an action in the next step
+            self._fix_agent(agent)
 
         # Check for end of episode + set global reward to all rewards!
         if have_all_agents_ended:
@@ -986,7 +974,7 @@ class RailEnv(Environment):
         x = - np.log(1 - u) * rate
         return x
 
-    def _malfunction_prob(self, rate, n_agents):
+    def _malfunction_prob(self, rate):
         """
         Probability that an agent break given the number of agents an the probability of a sinlge agent to break
         :param rate:
@@ -995,4 +983,48 @@ class RailEnv(Environment):
         if rate <= 0:
             return 0.
         else:
-            return 1 - np.exp(- (1 / rate) * (n_agents))
+            return 1 - np.exp(- (1 / rate))
+
+    def _draw_malfunctioning_agent(self, tries):
+        """
+        Function to determin what agent will be breaking.
+        It only looks at active and non-broken agents.
+        After a number of steps it gives up the search after breaking agents and ignores malfunciton
+
+        Parameters
+        ----------
+        tries: How many times we tried to find an agent
+
+        Returns
+        -------
+        agent that is breaking
+        """
+        # Select only from active agents
+        breaking_agent_idx = self.np_random.choice(self.active_agents)
+        breaking_agent = self.agents[breaking_agent_idx]
+        # We assume that at least half of the agents should still be working
+        if tries > 0.5 * len(self.active_agents):
+            return None
+
+        # If agent is already broken look for a new one
+        elif breaking_agent.malfunction_data['malfunction'] > 0:
+            return self._draw_malfunctioning_agent(tries + 1)
+
+        # Return agent to be broken
+        else:
+            return breaking_agent
+
+    def _is_ok(self, agent):
+        """
+        Check if an agent is ok, meaning it can move and is not malfuncitoinig
+        Parameters
+        ----------
+        agent
+
+        Returns
+        -------
+        True if agent is ok, False otherwise
+
+        """
+        return agent.malfunction_data['malfunction'] < 1
+
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 14f9b6c0295306448d27a83444e3dd6496485cf1..b2c1ca1162e476ff6e2f4fc3f8489428af23535e 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -106,7 +106,7 @@ def test_malfunction_process():
         total_down_time += env.agents[0].malfunction_data['malfunction']
 
     # Check that the appropriate number of malfunctions is achieved
-    assert env.agents[0].malfunction_data['nr_malfunctions'] == 22, "Actual {}".format(
+    assert env.agents[0].malfunction_data['nr_malfunctions'] == 23, "Actual {}".format(
         env.agents[0].malfunction_data['nr_malfunctions'])
 
     # Check that malfunctioning data was standing around
@@ -132,16 +132,16 @@ def test_malfunction_process_statistically():
     env.agents[0].target = (0, 0)
     # Next line only for test generation
     #agent_malfunction_list = [[] for i in range(10)]
-    agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
-     [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0],
-     [0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0],
-     [0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 5, 4],
-     [0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0],
-     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2],
-     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5],
-     [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3],
-     [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
+    agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
+     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
+     [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
+     [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
+     [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
+     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+     [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
+     [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
+     [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+     [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
 
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}
@@ -158,7 +158,7 @@ def test_malfunction_process_statistically():
 def test_malfunction_before_entry():
     """Tests that malfunctions are working properly for agents before entering the environment!"""
     # Set fixed malfunction duration for this test
-    stochastic_data = {'malfunction_rate': 0.0001,
+    stochastic_data = {'malfunction_rate': 2,
                        'min_duration': 10,
                        'max_duration': 10}
 
@@ -176,16 +176,17 @@ def test_malfunction_before_entry():
     # we want different next_malfunction values for the agents
     assert env.agents[0].malfunction_data['malfunction'] == 0
     assert env.agents[1].malfunction_data['malfunction'] == 0
-    assert env.agents[2].malfunction_data['malfunction'] == 0
+    assert env.agents[2].malfunction_data['malfunction'] == 10
     assert env.agents[3].malfunction_data['malfunction'] == 0
     assert env.agents[4].malfunction_data['malfunction'] == 0
-    assert env.agents[5].malfunction_data['malfunction'] == 10
+    assert env.agents[5].malfunction_data['malfunction'] == 0
     assert env.agents[6].malfunction_data['malfunction'] == 0
     assert env.agents[7].malfunction_data['malfunction'] == 0
-    assert env.agents[8].malfunction_data['malfunction'] == 0
-    assert env.agents[9].malfunction_data['malfunction'] == 0
-    # for a in range(10):
-    #   print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
+    assert env.agents[8].malfunction_data['malfunction'] == 10
+    assert env.agents[9].malfunction_data['malfunction'] == 10
+
+    #for a in range(10):
+    #  print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
 
 
 def test_malfunction_values_and_behavior():