diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py
index 7a89add9d3628bc2b509c862c86c2cb9110ce66d..1e80021c4c3c2f06e6d5c897692b814a5d063025 100644
--- a/examples/introduction_flatland_2_1.py
+++ b/examples/introduction_flatland_2_1.py
@@ -1,6 +1,8 @@
 # In Flatland you can use custom observation builders and predicitors
 # Observation builders generate the observation needed by the controller
 # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
+import time
+
 from flatland.envs.observations import GlobalObsForRailEnv
 # First of all we import the Flatland rail environment
 from flatland.envs.rail_env import RailEnv
@@ -26,8 +28,8 @@ from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 # Here we use the sparse_rail_generator with the following parameters
 
 width = 100  # With of map
-height = 100  # Height of ap
-nr_trains = 10  # Number of trains that have an assigned task in the env
+height = 100  # Height of map
+nr_trains = 50  # Number of trains that have an assigned task in the env
 cities_in_map = 20  # Number of cities where agents can start or end
 seed = 14  # Random seed
 grid_distribution_of_cities = False  # Type of city distribution, if False cities are randomly placed
@@ -151,14 +153,14 @@ for agent_idx, agent in enumerate(env.agents):
 # If multiple agents want to enter the same cell at the same time the lower index agent will enter first.
 
 # Let's check if there are any agents with the same start location
-agents_with_same_start = []
+agents_with_same_start = set()
 print("\n The following agents have the same initial position:")
 print("=====================================================")
 for agent_idx, agent in enumerate(env.agents):
     for agent_2_idx, agent2 in enumerate(env.agents):
         if agent_idx != agent_2_idx and agent.initial_position == agent2.initial_position:
             print("Agent {} as the same initial position as agent {}".format(agent_idx, agent_2_idx))
-            agents_with_same_start.append(agent_idx)
+            agents_with_same_start.add(agent_idx)
 
 # Lets try to enter with all of these agents at the same time
 action_dict = dict()
@@ -246,8 +248,11 @@ for step in range(500):
 
     # Environment step which returns the observations for all agents, their corresponding
     # reward and whether their are done
+    start_time = time.time()
     next_obs, all_rewards, done, _ = env.step(action_dict)
-    env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
+    end_time = time.time()
+    print(end_time - start_time)
+    # env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
     frame_step += 1
     # Update replay buffer and train agent
     for a in range(env.get_num_agents()):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 87285812b877416e64aea4370b96ab649df7d6a2..d576641632bfe1d0ae20baf485afdb2aca1d640b 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -308,7 +308,9 @@ class RailEnv(Environment):
             # A proportion of agent in the environment will receive a positive malfunction rate
             if self.np_random.rand() < self.proportion_malfunctioning_trains:
                 agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
-
+                next_breakdown = int(
+                    self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate']))
+                agent.malfunction_data['next_malfunction'] = next_breakdown
             agent.malfunction_data['malfunction'] = 0
 
             initial_malfunction = self._agent_malfunction(i_agent)
@@ -346,7 +348,7 @@ class RailEnv(Environment):
         """
         agent = self.agents[i_agent]
 
-        # Decrease counter for next event only if agent is currently not broken
+        # Decrease counter for next event only if agent is currently not broken and agent has a malfunction rate
         if agent.malfunction_data['malfunction_rate'] >= 1 and agent.malfunction_data['next_malfunction'] > 0 and \
             agent.malfunction_data['malfunction'] < 1:
             agent.malfunction_data['next_malfunction'] -= 1
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 8008e6e2ea4f5aabe98da7a5bff833714361c66a..e9b5a15dade4fd30d6718886eed02a483172f159 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -126,7 +126,7 @@ def test_malfunction_process():
         env.agents[0].malfunction_data['nr_malfunctions'])
 
     # Check that 20 stops where performed
-    assert agent_halts == 20
+    assert agent_halts == 21
 
     # Check that malfunctioning data was standing around
     assert total_down_time > 0
@@ -155,16 +155,16 @@ def test_malfunction_process_statistically():
 
     env.agents[0].target = (0, 0)
     nb_malfunction = 0
-    agent_malfunction_list = [[6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0],
-                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1],
-                              [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3],
-                              [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0],
-                              [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6],
-                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6],
-                              [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2],
-                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5],
-                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0],
-                              [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3]]
+    agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0],
+                              [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5],
+                              [0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4],
+                              [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4],
+                              [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0],
+                              [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3],
+                              [0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5],
+                              [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0],
+                              [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1],
+                              [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0]]
 
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}
@@ -175,6 +175,7 @@ def test_malfunction_process_statistically():
             # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
             assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
         env.step(action_dict)
+    # print(agent_malfunction_list)
 
 
 def test_malfunction_before_entry():
@@ -230,14 +231,13 @@ def test_malfunction_before_entry():
     assert env.agents[8].malfunction_data['malfunction'] == 2
     assert env.agents[9].malfunction_data['malfunction'] == 2
 
-    #for a in range(env.get_num_agents()):
+    # for a in range(env.get_num_agents()):
     #    print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,
     #                                                                               env.agents[a].malfunction_data[
     #                                                                                   'malfunction']))
 
 
 def test_initial_malfunction():
-
     stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                        'malfunction_rate': 100,  # Rate of malfunction occurence
                        'min_duration': 2,  # Minimal duration of malfunction
@@ -410,7 +410,6 @@ def test_initial_malfunction_do_nothing():
 
     rail, rail_map = make_simple_rail2()
 
-
     env = RailEnv(width=25,
                   height=30,
                   rail_generator=rail_from_grid_transition_map(rail),