diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 3bbc1ab092a92d479f91934a871cbbe2aaa30583..3cb38c16c3ea1cf223f9b6c32a06e5faf96c0d52 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -436,11 +436,6 @@ class RailEnv(Environment):
 
         self._elapsed_steps += 1
 
-        # Reset the step rewards
-        self.rewards_dict = dict()
-        for i_agent in range(self.get_num_agents()):
-            self.rewards_dict[i_agent] = 0
-
         # If we're done, set reward and info_dict and step() is done.
         if self.dones["__all__"]:
             self.rewards_dict = {}
@@ -450,39 +445,50 @@ class RailEnv(Environment):
                 "speed" : {},
                 "status" : {},
             }
-            for i_agent in range(self.get_num_agents()):
+            for i_agent, agent in enumerate(self.agents):
                 self.rewards_dict[i_agent] = self.global_reward
                 info_dict["action_required"][i_agent] = False
                 info_dict["malfunction"][i_agent] = 0
                 info_dict["speed"][i_agent] = 0
-                info_dict["status"][i_agent] = self.agents[i_agent].status
+                info_dict["status"][i_agent] = agent.status
 
             return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
-        # Perform step on all agents
-        for i_agent in range(self.get_num_agents()):
+        # Reset the step rewards
+        self.rewards_dict = dict()
+        info_dict = {
+            "action_required" : {},
+            "malfunction" : {},
+            "speed" : {},
+            "status" : {},
+        }
+        have_all_agents_ended = True # boolean flag to check if all agents are done
+        for i_agent, agent in enumerate(self.agents):
+            # Reset the step rewards
+            self.rewards_dict[i_agent] = 0
+
+            # Perform step on the agent
             self._step_agent(i_agent, action_dict_.get(i_agent))
 
+            # manage the boolean flag to check if all agents are indeed done (or done_removed)
+            have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
+
+            # Build info dict
+            info_dict["action_required"][i_agent] = \
+                (agent.status == RailAgentStatus.READY_TO_DEPART or (
+                agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
+                                                                        rtol=1e-03)))
+            info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
+            info_dict["speed"][i_agent] = agent.speed_data['speed']
+            info_dict["status"][i_agent] = agent.status
+
         # Check for end of episode + set global reward to all rewards!
-        if np.all([agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED] for agent in self.agents]):
+        if have_all_agents_ended:
             self.dones["__all__"] = True
             self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
         if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
             self.dones["__all__"] = True
 
-        info_dict = {
-            'action_required': {
-                i: (agent.status == RailAgentStatus.READY_TO_DEPART or (
-                    agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
-                                                                          rtol=1e-03)))
-                for i, agent in enumerate(self.agents)},
-            'malfunction': {
-                i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
-            },
-            'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())},
-            'status': {i: agent.status for i, agent in enumerate(self.agents)}
-        }
-
         return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
     def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):