diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py
index b832717c76735f7b89768e31df5013af77874c33..5fd0498cee4b89bb85edd5831b6735a118046474 100644
--- a/examples/introduction_flatland_2_1.py
+++ b/examples/introduction_flatland_2_1.py
@@ -59,7 +59,7 @@ schedule_generator = sparse_schedule_generator(speed_ration_map)
 # during an episode.
 
 stochastic_data = {'prop_malfunction': 0.3,  # Percentage of defective agents
-                   'malfunction_rate': 30,  # Rate of malfunction occurence
+                   'malfunction_rate': 50,  # Rate of malfunction occurence
                    'min_duration': 3,  # Minimal duration of malfunction
                    'max_duration': 20  # Max duration of malfunction
                    }
@@ -204,9 +204,8 @@ print("========================================")
 
 for agent_idx, agent in enumerate(env.agents):
     print(
-        "Agent {} will malfunction = {} at a rate of {}, the next malfunction will occur in {} step. Agent OK = {}".format(
-            agent_idx, agent.malfunction_data['malfunction_rate'] > 0, agent.malfunction_data['malfunction_rate'],
-            agent.malfunction_data['next_malfunction'], agent.malfunction_data['malfunction'] < 1))
+        "Agent {} is OK = {}".format(
+            agent_idx, agent.malfunction_data['malfunction'] < 1))
 
 # Now that you have seen these novel concepts that were introduced you will realize that agents don't need to take
 # an action at every time step as it will only change the outcome when actions are chosen at cell entry.
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 13d0fd8989206f2086a358fb0bcc69f369aada3e..ef2d4855d5ba58bbb863d181f5e96415ad08a070 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -39,8 +39,8 @@ class EnvAgentStatic(object):
     # number of time the agent had to stop, since the last time it broke down
     malfunction_data = attrib(
         default=Factory(
-            lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
-                          'moving_before_malfunction': False, 'fixed': False})))
+            lambda: dict({'malfunction': 0, 'nr_malfunctions': 0,
+                          'moving_before_malfunction': False, 'fixed': True})))
 
     status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
     position = attrib(default=None, type=Optional[Tuple[int, int]])
@@ -62,11 +62,8 @@ class EnvAgentStatic(object):
         malfunction_datas = []
         for i in range(len(schedule.agent_positions)):
             malfunction_datas.append({'malfunction': 0,
-                                      'malfunction_rate': schedule.agent_malfunction_rates[
-                                          i] if schedule.agent_malfunction_rates is not None else 0.,
-                                      'next_malfunction': 0,
                                       'nr_malfunctions': 0,
-                                      'fixed': False})
+                                      'fixed': True})
 
         return list(starmap(EnvAgentStatic, zip(schedule.agent_positions,
                                                 schedule.agent_directions,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 9a25f95456c029b7f502b5fe7be7ffe34c7f06e1..3e17a4a4c95aee7a51a0c614e2ff852e3cb5e353 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -8,6 +8,7 @@ from typing import List, NamedTuple, Optional, Dict
 import msgpack
 import msgpack_numpy as m
 import numpy as np
+import random
 from gym.utils import seeding
 
 from flatland.core.env import Environment
@@ -194,20 +195,15 @@ class RailEnv(Environment):
 
         # Stochastic train malfunctioning parameters
         if stochastic_data is not None:
-            prop_malfunction = stochastic_data['prop_malfunction']
             mean_malfunction_rate = stochastic_data['malfunction_rate']
             malfunction_min_duration = stochastic_data['min_duration']
             malfunction_max_duration = stochastic_data['max_duration']
         else:
-            prop_malfunction = 0.
             mean_malfunction_rate = 0.
             malfunction_min_duration = 0.
             malfunction_max_duration = 0.
 
-        # percentage of malfunctioning trains
-        self.proportion_malfunctioning_trains = prop_malfunction
-
-        # Mean malfunction in number of stops
+        # Mean malfunction in number of time steps
         self.mean_malfunction_rate = mean_malfunction_rate
 
         # Uniform distribution parameters for malfunction duration
@@ -219,6 +215,7 @@ class RailEnv(Environment):
 
     def _seed(self, seed=None):
         self.np_random, seed = seeding.np_random(seed)
+        random.seed(seed)
         return [seed]
 
     # no more agent_handles
@@ -344,16 +341,8 @@ class RailEnv(Environment):
         if activate_agents:
             for i_agent in range(self.get_num_agents()):
                 self.set_agent_active(i_agent)
-
+        self._malfunction(self.mean_malfunction_rate)
         for i_agent, agent in enumerate(self.agents):
-            # A proportion of agent in the environment will receive a positive malfunction rate
-            if self.np_random.rand() < self.proportion_malfunctioning_trains:
-                agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
-                next_breakdown = int(
-                    self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate']))
-                agent.malfunction_data['next_malfunction'] = next_breakdown
-            agent.malfunction_data['malfunction'] = 0
-
             initial_malfunction = self._agent_malfunction(i_agent)
 
             if initial_malfunction:
@@ -390,45 +379,39 @@ class RailEnv(Environment):
         """
         agent = self.agents[i_agent]
 
-        # Ignore agents that dont have positive malfunction rate
-        if agent.malfunction_data['malfunction_rate'] < 1:
-            return False
-
-        # Update malfunctioning agents
+        # Reduce number of malfunction steps left
         if agent.malfunction_data['malfunction'] > 0:
             agent.malfunction_data['malfunction'] -= 1
             return True
 
-        if agent.malfunction_data['malfunction'] < 1 and agent.malfunction_data['next_malfunction'] > 0:
-            # Restart fixed agents
-            if not agent.malfunction_data['fixed']:
-                agent.malfunction_data['next_malfunction'] -= 1
-                agent.malfunction_data['fixed'] = True
-                if 'moving_before_malfunction' in agent.malfunction_data:
-                    self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']
-                return False
-            else:
-                # Agent has been running smoothly
-                agent.malfunction_data['next_malfunction'] -= 1
-                return False
-
-        # Break agents that have next_malfunction
-        if agent.malfunction_data['malfunction'] < 1 and agent.malfunction_data['next_malfunction'] < 1:
-            # Increase number of malfunctions
-            agent.malfunction_data['nr_malfunctions'] += 1
-            agent.malfunction_data['fixed'] = False
-
-            # Next malfunction in number of stops
-            next_breakdown = int(
-                self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate']))
-            agent.malfunction_data['next_malfunction'] = max(next_breakdown, 1)
-            # Duration of current malfunction
+        # Ignore agents that OK
+        if agent.malfunction_data['fixed']:
+            return False
+
+        # Restart agents at the end of their malfunction
+        agent.malfunction_data['fixed'] = True
+        if 'moving_before_malfunction' in agent.malfunction_data:
+            self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']
+        return False
+
+
+
+
+    def _malfunction(self, rate) -> bool:
+        """
+        Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run
+
+        """
+        if np.random.random() < self._malfunction_prob(rate):
+            breaking_agent = random.choice(self.agents)
+            while breaking_agent.status == RailAgentStatus.DONE_REMOVED:
+                breaking_agent = random.choice(self.agents)
+
             num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
                                                       self.max_number_of_steps_broken + 1)
-            agent.malfunction_data['malfunction'] = num_broken_steps
-            agent.malfunction_data['moving_before_malfunction'] = agent.moving
-
-            return True
+            breaking_agent.malfunction_data['malfunction'] = num_broken_steps
+            breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
+            breaking_agent.malfunction_data['fixed'] = False
 
 
 
@@ -463,6 +446,9 @@ class RailEnv(Environment):
             "status" : {},
         }
         have_all_agents_ended = True # boolean flag to check if all agents are done
+
+        # Evoke the malfunction generator
+        self._malfunction(self.mean_malfunction_rate)
         for i_agent, agent in enumerate(self.agents):
             # Reset the step rewards
             self.rewards_dict[i_agent] = 0
@@ -824,3 +810,14 @@ class RailEnv(Environment):
         u = self.np_random.rand()
         x = - np.log(1 - u) * rate
         return x
+
+    def _malfunction_prob(self, rate):
+        """
+        Gives the cummulative exponential distribution at point x, with exp decay rate
+        :param rate:
+        :return:
+        """
+        if rate <= 0:
+            return 0.
+        else:
+            return 1 - np.exp(-(1 / rate))