diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index fc73cd6ce481887477d50905c98d721032a08ff1..cb2fdb7e9e2bad19ce90f70a74e980d9bd7f75cf 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -421,11 +421,21 @@ class RailEnv(Environment):
                 agent.moving = agent.malfunction_data['moving_before_malfunction']
                 continue
 
-    def _draw_malfunctioning_agent(self):
+    def _draw_malfunctioning_agent(self, tries):
         # Select only from active agents
         breaking_agent_idx = self.np_random.choice(self.active_agents)
         breaking_agent = self.agents[breaking_agent_idx]
-        return breaking_agent
+        # We assume that at least half of the agents should still be working
+        if tries > 0.5 * len(self.active_agents):
+            return None
+
+        # If agent is already broken look for a new one
+        elif breaking_agent.malfunction_data['malfunction'] > 0:
+            return self._draw_malfunctioning_agent(tries+1)
+
+        # Return agent to be broken
+        else:
+            return breaking_agent
 
     def _malfunction(self, rate):
         """
@@ -434,30 +444,16 @@ class RailEnv(Environment):
         """
         if self.np_random.rand() < self._malfunction_prob(rate, len(self.active_agents)):
 
-            breaking_agent = self._draw_malfunctioning_agent()
-            # We assume that less then half of the active agents should be broken at MOST.
-            # Therefore we only try that many times before ignoring the malfunction
-            tries = 0
-            max_tries = 0.5 * len(self.active_agents)
-            # Look for a functioning active agent
-            while breaking_agent.malfunction_data['malfunction'] > 0 and tries < max_tries:
-                breaking_agent = self._draw_malfunctioning_agent()
-                tries += 1
-
-            # If we did not manage to find a functioning agent among the active ones skip this malfunction
-            if tries < max_tries:
-                # Because we update agents in the same step as we break them we add one to the duration of the
-                # malfunction
+            breaking_agent = self._draw_malfunctioning_agent(0)
+            if breaking_agent:
                 num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
-                                                          self.max_number_of_steps_broken + 1) + 1
+                                                              self.max_number_of_steps_broken + 1) + 1
                 breaking_agent.malfunction_data['malfunction'] = num_broken_steps
                 breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
                 breaking_agent.malfunction_data['fixed'] = False
                 breaking_agent.malfunction_data['nr_malfunctions'] += 1
+        return
 
-                return
-
-            return
 
     def step(self, action_dict_: Dict[int, RailEnvActions]):
         """