diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 07fdd8d2b26485de888af81240f9c9f8c6d0533d..ee586e2071f7df26e7e17f553fbe9dc6867597e4 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -2,7 +2,7 @@ import time
 
 import numpy as np
 
-from flatland.envs.malfunction_generators import malfunction_from_params
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
 from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -16,12 +16,10 @@ np.random.seed(1)
 # Training on simple small tasks is the best way to get familiar with the environment
 
 # Use a the malfunction generator to break agents from time to time
-stochastic_data = {'prop_malfunction': 0.3,  # Percentage of defective agents
-                   'malfunction_rate': 30,  # Rate of malfunction occurence
-                   'min_duration': 3,  # Minimal duration of malfunction
-                   'max_duration': 20  # Max duration of malfunction
-                   }
-
+stochastic_data = MalfunctionParameters(malfunction_rate=30,  # Rate of malfunction occurence
+                                        min_duration=3,  # Minimal duration of malfunction
+                                        max_duration=20  # Max duration of malfunction
+                                        )
 # Custom observation builder
 TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 
diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py
index 6c3313d801e31fb9fa32605c59c331198d1865d8..fa48acb23c9bf7e5f7c57df54abc733bc75ed2e7 100644
--- a/examples/introduction_flatland_2_1.py
+++ b/examples/introduction_flatland_2_1.py
@@ -3,7 +3,7 @@ import numpy as np
 # In Flatland you can use custom observation builders and predicitors
 # Observation builders generate the observation needed by the controller
 # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
-from flatland.envs.malfunction_generators import malfunction_from_params
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
 from flatland.envs.observations import GlobalObsForRailEnv
 # First of all we import the Flatland rail environment
 from flatland.envs.rail_env import RailEnv
@@ -62,11 +62,10 @@ schedule_generator = sparse_schedule_generator(speed_ration_map)
 # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
 # during an episode.
 
-stochastic_data = {'malfunction_rate': 12000,  # Rate of malfunction occurence of single agent
-                   'min_duration': 15,  # Minimal duration of malfunction
-                   'max_duration': 50  # Max duration of malfunction
-                   }
-
+stochastic_data = MalfunctionParameters(malfunction_rate=10000,  # Rate of malfunction occurence
+                                        min_duration=15,  # Minimal duration of malfunction
+                                        max_duration=50  # Max duration of malfunction
+                                        )
 # Custom observation builder without predictor
 observation_builder = GlobalObsForRailEnv()
 
@@ -256,7 +255,7 @@ for step in range(500):
     next_obs, all_rewards, done, _ = env.step(action_dict)
 
     env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
-    # env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step))
+    env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step))
     frame_step += 1
     # Update replay buffer and train agent
     for a in range(env.get_num_agents()):
diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index c877f9824526889a7d355d198f0422956324ad72..f6d0c78f9f37b9d179a2c42776d1b7411600d263 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -6,10 +6,12 @@ import msgpack
 import numpy as np
 from numpy.random.mtrand import RandomState
 
-from flatland.envs.agent_utils import EnvAgent
+from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
 
 Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
-MalfunctionGenerator = Callable[[EnvAgent], Optional[Malfunction]]
+MalfunctionParameters = NamedTuple('MalfunctionParameters',
+                                   [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
+MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]]
 MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
                                     [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
 
@@ -36,7 +38,7 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
 
     Returns
     -------
-    Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+    generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
     """
     with open(filename, "rb") as file_in:
         load_data = file_in.read()
@@ -57,7 +59,7 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
         min_number_of_steps_broken = 0
         max_number_of_steps_broken = 0
 
-    def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]:
+    def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
         """
         Generate malfunctions for agents
         Parameters
@@ -69,6 +71,11 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
         -------
         int: Number of time steps an agent is broken
         """
+
+        # Dummy reset function as we don't implement specific seeding here
+        if reset:
+            return Malfunction(0)
+
         if agent.malfunction_data['malfunction'] < 1:
             if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
                 num_broken_steps = np_random.randint(min_number_of_steps_broken,
@@ -80,26 +87,27 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
                                              max_number_of_steps_broken)
 
 
-def malfunction_from_params(parameters: dict) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
+def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
     """
     Utility to load malfunction from parameters
 
     Parameters
     ----------
-    parameters containing
-    malfunction_rate : float how many time steps it takes for a sinlge agent befor it breaks
-    min_duration : int minimal duration of a failure
-    max_number_of_steps_broken : int maximal duration of a failure
+
+    parameters : contains all the parameters of the malfunction
+        malfunction_rate : float how many time steps it takes for a sinlge agent befor it breaks
+        min_duration : int minimal duration of a failure
+        max_number_of_steps_broken : int maximal duration of a failure
 
     Returns
     -------
-    Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+    generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
     """
-    mean_malfunction_rate = parameters['malfunction_rate']
-    min_number_of_steps_broken = parameters['min_duration']
-    max_number_of_steps_broken = parameters['max_duration']
+    mean_malfunction_rate = parameters.malfunction_rate
+    min_number_of_steps_broken = parameters.min_duration
+    max_number_of_steps_broken = parameters.max_duration
 
-    def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]:
+    def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
         """
         Generate malfunctions for agents
         Parameters
@@ -111,6 +119,11 @@ def malfunction_from_params(parameters: dict) -> Tuple[MalfunctionGenerator, Mal
         -------
         int: Number of time steps an agent is broken
         """
+
+        # Dummy reset function as we don't implement specific seeding here
+        if reset:
+            return Malfunction(0)
+
         if agent.malfunction_data['malfunction'] < 1:
             if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
                 num_broken_steps = np_random.randint(min_number_of_steps_broken,
@@ -124,15 +137,15 @@ def malfunction_from_params(parameters: dict) -> Tuple[MalfunctionGenerator, Mal
 
 def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
     """
-    Utility to load malfunction from parameters
+    Malfunction generator which generates no malfunctions
 
     Parameters
     ----------
-    input_file : Pickle file generated by env.save() or editor
+    Nothing
 
     Returns
     -------
-    Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+    generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
     """
     # Mean malfunction in number of time steps
     mean_malfunction_rate = 0.
@@ -141,8 +154,68 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess
     min_number_of_steps_broken = 0
     max_number_of_steps_broken = 0
 
-    def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]:
+    def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
         return Malfunction(0)
 
     return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
                                              max_number_of_steps_broken)
+
+
+def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[
+    MalfunctionGenerator, MalfunctionProcessData]:
+    """
+    Malfunction generator which guarantees exactly one malfunction during an episode of an ACTIVE agent.
+
+    Parameters
+    ----------
+    malfunction_duration: The duration of the single malfunction
+
+    Returns
+    -------
+    generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+    """
+    # Mean malfunction in number of time steps
+    mean_malfunction_rate = 0.
+
+    # Uniform distribution parameters for malfunction duration
+    min_number_of_steps_broken = 0
+    max_number_of_steps_broken = 0
+
+    # Keep track of the total number of malfunctions in the env
+    global_nr_malfunctions = 0
+
+    # Malfunction calls per agent
+    malfunction_calls = dict()
+
+    def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
+        # We use the global variable to assure only a single malfunction in the env
+        nonlocal global_nr_malfunctions
+        nonlocal malfunction_calls
+
+        # Reset malfunciton generator
+        if reset:
+            nonlocal global_nr_malfunctions
+            nonlocal malfunction_calls
+            global_nr_malfunctions = 0
+            malfunction_calls = dict()
+            return Malfunction(0)
+
+        # No more malfunctions if we already had one, ignore all updates
+        if global_nr_malfunctions > 0:
+            return Malfunction(0)
+
+        # Update number of calls per agent
+        if agent.handle in malfunction_calls:
+            malfunction_calls[agent.handle] += 1
+        else:
+            malfunction_calls[agent.handle] = 1
+
+        # Break an agent that is active at the time of the malfunction
+        if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction:
+            global_nr_malfunctions += 1
+            return Malfunction(malfunction_duration)
+        else:
+            return Malfunction(0)
+
+    return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
+                                             max_number_of_steps_broken)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cab96ab0a6ac35db74ff74499c334867980d0a0a..5ec4db3c27749327e539b94543758aa31c69c707 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -370,6 +370,9 @@ class RailEnv(Environment):
         self.obs_builder.reset()
         self.distance_map.reset(self.agents, self.rail)
 
+        # Reset the malfunction generator
+        self.malfunction_generator(reset=True)
+
         info_dict: Dict = {
             'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
             'malfunction': {
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index df398a21583c7c24bc4cb5e1a08e7a517ea3483c..2d3fbd42d353e57b8d510c5bc3e0ef8118bdecaf 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -8,7 +8,7 @@ from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import RailAgentStatus
-from flatland.envs.malfunction_generators import malfunction_from_params
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.schedule_generators import random_schedule_generator
@@ -67,9 +67,10 @@ class SingleAgentNavigationObs(ObservationBuilder):
 
 def test_malfunction_process():
     # Set fixed malfunction duration for this test
-    stochastic_data = {'malfunction_rate': 1,
-                       'min_duration': 3,
-                       'max_duration': 3}
+    stochastic_data = MalfunctionParameters(malfunction_rate=1,  # Rate of malfunction occurence
+                                            min_duration=3,  # Minimal duration of malfunction
+                                            max_duration=3  # Max duration of malfunction
+                                            )
 
     rail, rail_map = make_simple_rail2()
 
@@ -120,9 +121,10 @@ def test_malfunction_process():
 def test_malfunction_process_statistically():
     """Tests hat malfunctions are produced by stochastic_data!"""
     # Set fixed malfunction duration for this test
-    stochastic_data = {'malfunction_rate': 5,
-                       'min_duration': 5,
-                       'max_duration': 5}
+    stochastic_data = MalfunctionParameters(malfunction_rate=5,  # Rate of malfunction occurence
+                                            min_duration=5,  # Minimal duration of malfunction
+                                            max_duration=5  # Max duration of malfunction
+                                            )
 
     rail, rail_map = make_simple_rail2()
 
@@ -166,9 +168,10 @@ def test_malfunction_process_statistically():
 def test_malfunction_before_entry():
     """Tests that malfunctions are working properly for agents before entering the environment!"""
     # Set fixed malfunction duration for this test
-    stochastic_data = {'malfunction_rate': 2,
-                       'min_duration': 10,
-                       'max_duration': 10}
+    stochastic_data = MalfunctionParameters(malfunction_rate=2,  # Rate of malfunction occurence
+                                            min_duration=10,  # Minimal duration of malfunction
+                                            max_duration=10  # Max duration of malfunction
+                                            )
 
     rail, rail_map = make_simple_rail2()
 
@@ -212,9 +215,10 @@ def test_malfunction_values_and_behavior():
 
     rail, rail_map = make_simple_rail2()
     action_dict: Dict[int, RailEnvActions] = {}
-    stochastic_data = {'malfunction_rate': 0.001,
-                       'min_duration': 10,
-                       'max_duration': 10}
+    stochastic_data = MalfunctionParameters(malfunction_rate=0.001,  # Rate of malfunction occurence
+                                            min_duration=10,  # Minimal duration of malfunction
+                                            max_duration=10  # Max duration of malfunction
+                                            )
     env = RailEnv(width=25,
                   height=30,
                   rail_generator=rail_from_grid_transition_map(rail),
@@ -237,10 +241,10 @@ def test_malfunction_values_and_behavior():
 
 
 def test_initial_malfunction():
-    stochastic_data = {'malfunction_rate': 1000,  # Rate of malfunction occurence
-                       'min_duration': 2,  # Minimal duration of malfunction
-                       'max_duration': 5  # Max duration of malfunction
-                       }
+    stochastic_data = MalfunctionParameters(malfunction_rate=1000,  # Rate of malfunction occurence
+                                            min_duration=2,  # Minimal duration of malfunction
+                                            max_duration=5  # Max duration of malfunction
+                                            )
 
     rail, rail_map = make_simple_rail2()
 
@@ -308,12 +312,6 @@ def test_initial_malfunction():
 
 
 def test_initial_malfunction_stop_moving():
-    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
-                       'malfunction_rate': 70,  # Rate of malfunction occurence
-                       'min_duration': 2,  # Minimal duration of malfunction
-                       'max_duration': 5  # Max duration of malfunction
-                       }
-
     rail, rail_map = make_simple_rail2()
 
     env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
@@ -394,11 +392,10 @@ def test_initial_malfunction_do_nothing():
     random.seed(0)
     np.random.seed(0)
 
-    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
-                       'malfunction_rate': 70,  # Rate of malfunction occurence
-                       'min_duration': 2,  # Minimal duration of malfunction
-                       'max_duration': 5  # Max duration of malfunction
-                       }
+    stochastic_data = MalfunctionParameters(malfunction_rate=70,  # Rate of malfunction occurence
+                                            min_duration=2,  # Minimal duration of malfunction
+                                            max_duration=5  # Max duration of malfunction
+                                            )
 
     rail, rail_map = make_simple_rail2()
 
@@ -479,10 +476,6 @@ def test_initial_malfunction_do_nothing():
 def tests_random_interference_from_outside():
     """Tests that malfunctions are produced by stochastic_data!"""
     # Set fixed malfunction duration for this test
-    stochastic_data = {'malfunction_rate': 1,
-                       'min_duration': 10,
-                       'max_duration': 10}
-
     rail, rail_map = make_simple_rail2()
     env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                   schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
@@ -537,9 +530,6 @@ def test_last_malfunction_step():
     """
 
     # Set fixed malfunction duration for this test
-    stochastic_data = {'malfunction_rate': 5,
-                       'min_duration': 4,
-                       'max_duration': 4}
 
     rail, rail_map = make_simple_rail2()
 
diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py
index 075edc139b6786933a32c915998c0fe56cb7a76c..51839babe563943a609492bcad64243d36105b5c 100644
--- a/tests/test_malfunction_generators.py
+++ b/tests/test_malfunction_generators.py
@@ -1,8 +1,5 @@
-import numpy as np
-
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file
+from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file, \
+    single_malfunction_generator, MalfunctionParameters
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.schedule_generators import random_schedule_generator
@@ -16,11 +13,10 @@ def test_malfanction_from_params():
     -------
 
     """
-    stochastic_data = {'malfunction_rate': 1000,  # Rate of malfunction occurence
-                       'min_duration': 2,  # Minimal duration of malfunction
-                       'max_duration': 5  # Max duration of malfunction
-                       }
-
+    stochastic_data = MalfunctionParameters(malfunction_rate=1000,  # Rate of malfunction occurence
+                                            min_duration=2,  # Minimal duration of malfunction
+                                            max_duration=5  # Max duration of malfunction
+                                            )
     rail, rail_map = make_simple_rail2()
 
     env = RailEnv(width=25,
@@ -43,10 +39,10 @@ def test_malfanction_to_and_from_file():
     -------
 
     """
-    stochastic_data = {'malfunction_rate': 1000,  # Rate of malfunction occurence
-                       'min_duration': 2,  # Minimal duration of malfunction
-                       'max_duration': 5  # Max duration of malfunction
-                       }
+    stochastic_data = MalfunctionParameters(malfunction_rate=1000,  # Rate of malfunction occurence
+                                            min_duration=2,  # Minimal duration of malfunction
+                                            max_duration=5  # Max duration of malfunction
+                                            )
 
     rail, rail_map = make_simple_rail2()
 
@@ -62,17 +58,50 @@ def test_malfanction_to_and_from_file():
 
     malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl")
     env2 = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=10,
-                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
-                  )
+                   height=30,
+                   rail_generator=rail_from_grid_transition_map(rail),
+                   schedule_generator=random_schedule_generator(),
+                   number_of_agents=10,
+                   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
+                   )
 
     env2.reset()
 
-    assert env2.malfunction_process_data ==  env.malfunction_process_data
+    assert env2.malfunction_process_data == env.malfunction_process_data
     assert env2.malfunction_process_data.malfunction_rate == 1000
     assert env2.malfunction_process_data.min_duration == 2
     assert env2.malfunction_process_data.max_duration == 5
 
+
+def test_single_malfunction_generator():
+    """
+    Test single malfunction generator
+    Returns
+    -------
+
+    """
+
+    rail, rail_map = make_simple_rail2()
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=10,
+                  malfunction_generator_and_process_data=single_malfunction_generator(earlierst_malfunction=10,
+                                                                                      malfunction_duration=5)
+                  )
+    for test in range(10):
+        env.reset()
+        action_dict = dict()
+        tot_malfunctions = 0
+        print(test)
+        for i in range(10):
+            for agent in env.agents:
+                # Go forward all the time
+                action_dict[agent.handle] = RailEnvActions(2)
+
+            env.step(action_dict)
+        for agent in env.agents:
+            # Go forward all the time
+            tot_malfunctions += agent.malfunction_data['nr_malfunctions']
+        assert tot_malfunctions == 1