diff --git a/examples/training_example.py b/examples/training_example.py
index 17484ad7422e327ad5b50200f6e1726f19a43594..6910461327c778ff52824165032641ece019cf7a 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -75,13 +75,11 @@ for trials in range(1, n_trials + 1):
 
     score = 0
     # Run episode
-    mean_malfunction_interval = []
     for step in range(100):
         # Chose an action for each agent in the environment
         for a in range(env.get_num_agents()):
             action = agent.act(obs[a])
             action_dict.update({a: action})
-
         # Environment step which returns the observations for all agents, their corresponding
         # reward and whether their are done
         next_obs, all_rewards, done, _ = env.step(action_dict)
@@ -95,5 +93,4 @@ for trials in range(1, n_trials + 1):
         obs = next_obs.copy()
         if done['__all__']:
             break
-    print(np.mean(mean_malfunction_interval))
     print('Episode Nr. {}\t Score = {}'.format(trials, score))
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 27a7a380a23d0a1bd6352e06b00e8e63deafb71a..4c4070088c59499f885c16db68976c163ec91001 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -26,7 +26,8 @@ class EnvAgentStatic(object):
     # if broken>0, the agent's actions are ignored for 'broken' steps
     # number of time the agent had to stop, since the last time it broke down
     malfunction_data = attrib(
-        default=Factory(lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0})))
+        default=Factory(
+            lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0})))
 
     @classmethod
     def from_lists(cls, positions, directions, targets, speeds=None):
@@ -40,18 +41,19 @@ class EnvAgentStatic(object):
 
         # TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set
         # some as broken?
-        broken_datas = []
+        malfunction_datas = []
         for i in range(len(positions)):
-            broken_datas.append({'malfunction': 0,
+            malfunction_datas.append({'malfunction': 0,
                                  'malfunction_rate': 0,
-                                 'next_malfunction': 0})
+                                      'next_malfunction': 0,
+                                      'nr_malfunctions': 0})
 
         return list(starmap(EnvAgentStatic, zip(positions,
                                                 directions,
                                                 targets,
                                                 [False] * len(positions),
                                                 speed_datas,
-                                                broken_datas)))
+                                                malfunction_datas)))
 
     def to_list(self):
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cfd8dad4e80f1d4ce3783a594de3c890191ce0f6..d62c689aa9f6c788b54474b884d4101d98fb4ff0 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -94,7 +94,8 @@ class RailEnv(Environment):
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
-                 max_episode_steps=None
+                 max_episode_steps=None,
+                 stochastic_data=None
                  ):
         """
         Environment init.
@@ -158,12 +159,26 @@ class RailEnv(Environment):
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
 
         # Stochastic train malfunctioning parameters
-        self.proportion_malfunctioning_trains = 0.1  # percentage of malfunctioning trains
-        self.mean_malfunction_rate = 5  # Average malfunction in number of stops
+        if stochastic_data is not None:
+            prop_malfunction = stochastic_data['prop_malfunction']
+            mean_malfunction_rate = stochastic_data['malfunction_rate']
+            malfunction_min_duration = stochastic_data['min_duration']
+            malfunction_max_duration = stochastic_data['max_duration']
+        else:
+            prop_malfunction = 0.
+            mean_malfunction_rate = 0.
+            malfunction_min_duration = 0.
+            malfunction_max_duration = 0.
+
+        # percentage of malfunctioning trains
+        self.proportion_malfunctioning_trains = prop_malfunction
+
+        # Mean malfunction in number of stops
+        self.mean_malfunction_rate = mean_malfunction_rate
 
         # Uniform distribution parameters for malfunction duration
-        self.min_number_of_steps_broken = 4
-        self.max_number_of_steps_broken = 10
+        self.min_number_of_steps_broken = malfunction_min_duration
+        self.max_number_of_steps_broken = malfunction_max_duration
 
         # Rest environment
         self.reset()
@@ -217,8 +232,9 @@ class RailEnv(Environment):
             agent = self.agents[i_agent]
 
             # A proportion of agent in the environment will receive a positive malfunction rate
-            if np.random.random() >= self.proportion_malfunctioning_trains:
+            if np.random.random() < self.proportion_malfunctioning_trains:
                 agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
+
             agent.speed_data['position_fraction'] = 0.0
             agent.malfunction_data['malfunction'] = 0
 
@@ -236,21 +252,23 @@ class RailEnv(Environment):
         return self._get_observations()
 
     def _agent_stopped(self, i_agent):
-        # Make sure agent is stopped
-        self.agents[i_agent].moving = False
+        # Decrease counter for next event
+        self.agents[i_agent].malfunction_data['next_malfunction'] -= 1
 
         # Only agents that have a positive rate for malfunctions are considered
-        if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0:
-
-            # Decrease counter for next event
-            self.agents[i_agent].malfunction_data['next_malfunction'] -= 1
-
-            # If counter has come to zero, set next malfunction time and duration of current malfunction
+        if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0 >= self.agents[i_agent].malfunction_data[
+            'malfunction']:
 
+            # If counter has come to zero --> Agent has malfunction
+            # set next malfunction time and duration of current malfunction
             if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0:
+                # Increase number of malfunctions
+                self.agents[i_agent].malfunction_data['nr_malfunctions'] += 1
+
                 # Next malfunction in number of stops
-                self.agents[i_agent].malfunction_data['next_malfunction'] = int(np.random.exponential(
-                    scale=self.agents[i_agent].malfunction_data['malfunction_rate']))
+                next_breakdown = int(
+                    np.random.exponential(scale=self.agents[i_agent].malfunction_data['malfunction_rate']))
+                self.agents[i_agent].malfunction_data['next_malfunction'] = next_breakdown
 
                 # Duration of current malfunction
                 num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
@@ -286,9 +304,6 @@ class RailEnv(Environment):
             agent.old_direction = agent.direction
             agent.old_position = agent.position
 
-            if agent.malfunction_data['malfunction'] > 0:
-                agent.malfunction_data['malfunction'] -= 1
-
             if self.dones[i_agent]:  # this agent has already completed...
                 continue
 
@@ -298,8 +313,16 @@ class RailEnv(Environment):
 
             # The train is broken
             if agent.malfunction_data['malfunction'] > 0:
+                agent.malfunction_data['malfunction'] -= 1
+
+                # Broken agents are stopped
+                self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
+                self.agents[i_agent].moving = False
                 action_dict[i_agent] = RailEnvActions.DO_NOTHING
 
+                # Nothing left to do with broken agent
+                continue
+
             if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
                 print('ERROR: illegal action=', action_dict[i_agent],
                       'for agent with index=', i_agent,
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
new file mode 100644
index 0000000000000000000000000000000000000000..91c551db60f9d71d7aa0774ea8b6aaf42af3e35b
--- /dev/null
+++ b/tests/test_flatland_malfunction.py
@@ -0,0 +1,88 @@
+import numpy as np
+
+from flatland.envs.generators import complex_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.rail_env import RailEnv
+
+
+class SingleAgentNavigationObs(TreeObsForRailEnv):
+    """
+    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
+    the minimum distances from each grid node to each agent's target.
+
+    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
+    for each agent (Left, Forward, Right) lead to the shortest path to its target.
+    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
+    will be [1, 0, 0].
+    """
+
+    def __init__(self):
+        super().__init__(max_depth=0)
+        self.observation_space = [3]
+
+    def reset(self):
+        # Recompute the distance map, if the environment has changed.
+        super().reset()
+
+    def get(self, handle):
+        agent = self.env.agents[handle]
+
+        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
+        num_transitions = np.count_nonzero(possible_transitions)
+
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right], relative to the current orientation
+        # If only one transition is possible, the forward branch is aligned with it.
+        if num_transitions == 1:
+            observation = [0, 1, 0]
+        else:
+            min_distances = []
+            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
+                if possible_transitions[direction]:
+                    new_position = self._new_position(agent.position, direction)
+                    min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
+                else:
+                    min_distances.append(np.inf)
+
+            observation = [0, 0, 0]
+            observation[np.argmin(min_distances)] = 1
+
+        return observation
+
+
+def test_malfunction_process():
+    stochastic_data = {'prop_malfunction': 1.,
+                       'malfunction_rate': 5,
+                       'min_duration': 3,
+                       'max_duration': 10}
+    np.random.seed(5)
+
+    env = RailEnv(width=14,
+                  height=14,
+                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
+                                                        seed=0),
+                  number_of_agents=2,
+                  obs_builder_object=SingleAgentNavigationObs(),
+                  stochastic_data=stochastic_data)
+
+    obs = env.reset()
+    agent_halts = 0
+    for step in range(100):
+        actions = {}
+        for i in range(len(obs)):
+            actions[i] = np.argmax(obs[i]) + 1
+
+        if step % 5 == 0:
+            actions[0] = 4
+            agent_halts += 1
+
+        obs, all_rewards, done, _ = env.step(actions)
+
+        if done["__all__"]:
+            break
+
+    # Check that the agents breaks twice
+    assert env.agents[0].malfunction_data['nr_malfunctions'] == 2
+
+    # Check that 7 stops where performed
+    assert agent_halts == 7