diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 4f016f2b3d4ee3e1918fbf47c9492f2c84228999..2281282977d8c9d972f13526efa9d96abaf84a52 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -238,7 +238,7 @@ class RailEnv(Environment):
             agent.speed_data['position_fraction'] = 0.0
             agent.malfunction_data['malfunction'] = 0
 
-            self._agent_stopped(i_agent)
+            self._agent_malfunction(agent)
 
         self.num_resets += 1
         self._elapsed_steps = 0
@@ -253,29 +253,29 @@ class RailEnv(Environment):
         # Return the new observation vectors for each agent
         return self._get_observations()
 
-    def _agent_stopped(self, i_agent):
+    def _agent_malfunction(self, agent):
         # Decrease counter for next event
-        self.agents[i_agent].malfunction_data['next_malfunction'] -= 1
+        agent.malfunction_data['next_malfunction'] -= 1
 
-        # Only agents that have a positive rate for malfunctions are considered
-        if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0 >= self.agents[i_agent].malfunction_data[
+        # Only agents that have a positive rate for malfunctions and are not currently broken are considered
+        if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[
             'malfunction']:
 
             # If counter has come to zero --> Agent has malfunction
             # set next malfunction time and duration of current malfunction
-            if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0:
+            if agent.malfunction_data['next_malfunction'] <= 0:
                 # Increase number of malfunctions
-                self.agents[i_agent].malfunction_data['nr_malfunctions'] += 1
+                agent.malfunction_data['nr_malfunctions'] += 1
 
                 # Next malfunction in number of stops
                 next_breakdown = int(
-                    np.random.exponential(scale=self.agents[i_agent].malfunction_data['malfunction_rate']))
-                self.agents[i_agent].malfunction_data['next_malfunction'] = next_breakdown
+                    np.random.exponential(scale=agent.malfunction_data['malfunction_rate']))
+                agent.malfunction_data['next_malfunction'] = next_breakdown
 
                 # Duration of current malfunction
                 num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
                                                      self.max_number_of_steps_broken + 1) + 1
-                self.agents[i_agent].malfunction_data['malfunction'] = num_broken_steps
+                agent.malfunction_data['malfunction'] = num_broken_steps
 
     def step(self, action_dict_):
         self._elapsed_steps += 1
@@ -306,6 +306,9 @@ class RailEnv(Environment):
             agent.old_direction = agent.direction
             agent.old_position = agent.position
 
+            # Check if agent breaks at this step
+            self._agent_malfunction(agent)
+
             if self.dones[i_agent]:  # this agent has already completed...
                 continue
 
@@ -341,7 +344,6 @@ class RailEnv(Environment):
                 # Only allow halting an agent on entering new cells.
                 agent.moving = False
                 self.rewards_dict[i_agent] += stop_penalty
-                self._agent_stopped(i_agent)
 
             if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
                 # Allow agent to start with any forward or direction action
@@ -385,8 +387,6 @@ class RailEnv(Environment):
                                 self.rewards_dict[i_agent] += invalid_action_penalty
                                 self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
                                 self.rewards_dict[i_agent] += stop_penalty
-                                if agent.moving:
-                                    self._agent_stopped(i_agent)
                                 agent.moving = False
                                 continue
                         else:
@@ -394,8 +394,6 @@ class RailEnv(Environment):
                             self.rewards_dict[i_agent] += invalid_action_penalty
                             self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
                             self.rewards_dict[i_agent] += stop_penalty
-                            if agent.moving:
-                                self._agent_stopped(i_agent)
                             agent.moving = False
                             continue
 
@@ -416,14 +414,11 @@ class RailEnv(Environment):
                     agent.speed_data['position_fraction'] = 0.0
                 else:
                     # If the agent cannot move due to any reason, we set its state to not moving
-                    if agent.moving:
-                        self._agent_stopped(i_agent)
                     agent.moving = False
 
             if np.equal(agent.position, agent.target).all():
                 self.dones[i_agent] = True
                 agent.moving = False
-                # Do not call self._agent_stopped, as the agent has terminated its task
             else:
                 self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
 
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 2e82d212687b7ef21832ea778ec2dd1552db6b36..67dcd25c0769e542fd9a03502c2a8c1b29333b2b 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -53,13 +53,13 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
 def test_malfunction_process():
     # Set fixed malfunction duration for this test
     stochastic_data = {'prop_malfunction': 1.,
-                       'malfunction_rate': 5,
+                       'malfunction_rate': 1000,
                        'min_duration': 3,
                        'max_duration': 3}
     np.random.seed(5)
 
-    env = RailEnv(width=14,
-                  height=14,
+    env = RailEnv(width=20,
+                  height=20,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                         seed=0),
                   number_of_agents=2,
@@ -82,17 +82,17 @@ def test_malfunction_process():
 
         if step % 5 == 0:
             # Stop the agent and set it to be malfunctioning
-            actions[0] = 4
+            env.agents[0].malfunction_data['malfunction'] = -1
             env.agents[0].malfunction_data['next_malfunction'] = 0
             agent_halts += 1
 
+        obs, all_rewards, done, _ = env.step(actions)
+
         if env.agents[0].malfunction_data['malfunction'] > 0:
             agent_malfunctioning = True
         else:
             agent_malfunctioning = False
 
-        obs, all_rewards, done, _ = env.step(actions)
-
         if agent_malfunctioning:
             # Check that agent is not moving while malfunctioning
             assert agent_old_position == env.agents[0].position
@@ -101,7 +101,7 @@ def test_malfunction_process():
         total_down_time += env.agents[0].malfunction_data['malfunction']
 
     # Check that the appropriate number of malfunctions is achieved
-    assert env.agents[0].malfunction_data['nr_malfunctions'] == 5
+    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21
 
     # Check that 20 stops where performed
     assert agent_halts == 20