diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 0887c0ca59d1609d27b37aa318022750924f4ea3..22bd21625fd4cc0fb50c592bf4533e60ced00546 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -280,6 +280,24 @@ class RailEnv(Environment):
         alpha = 2
         return int(timedelay_factor * alpha * (width + height + ratio_nr_agents_to_nr_cities))
 
+    def action_required(self, agent):
+        """
+        Check if an agent needs to provide an action
+
+        Parameters
+        ----------
+        agent: RailEnvAgent
+        Agent we want to check
+
+        Returns
+        -------
+        True: Agent needs to provide an action
+        False: Agent cannot provide an action
+        """
+        return (agent.status == RailAgentStatus.READY_TO_DEPART or (
+            agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
+                                                                  rtol=1e-03)))
+
     def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False,
               random_seed: bool = None) -> (Dict, Dict):
         """
@@ -339,8 +357,8 @@ class RailEnv(Environment):
             if agents_hints and 'city_orientations' in agents_hints:
                 ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations'])
                 self._max_episode_steps = self.compute_max_episode_steps(
-                                                    width=self.width, height=self.height,
-                                                    ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities)
+                    width=self.width, height=self.height,
+                    ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities)
             else:
                 self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
 
@@ -377,10 +395,7 @@ class RailEnv(Environment):
         self.distance_map.reset(self.agents, self.rail)
 
         info_dict: Dict = {
-            'action_required': {
-                i: (agent.status == RailAgentStatus.READY_TO_DEPART or (
-                    agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0))
-                for i, agent in enumerate(self.agents)},
+            'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
             'malfunction': {
                 i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
             },
@@ -454,10 +469,10 @@ class RailEnv(Environment):
         if self.dones["__all__"]:
             self.rewards_dict = {}
             info_dict = {
-                "action_required" : {},
-                "malfunction" : {},
-                "speed" : {},
-                "status" : {},
+                "action_required": {},
+                "malfunction": {},
+                "speed": {},
+                "status": {},
             }
             for i_agent, agent in enumerate(self.agents):
                 self.rewards_dict[i_agent] = self.global_reward
@@ -471,12 +486,12 @@ class RailEnv(Environment):
         # Reset the step rewards
         self.rewards_dict = dict()
         info_dict = {
-            "action_required" : {},
-            "malfunction" : {},
-            "speed" : {},
-            "status" : {},
+            "action_required": {},
+            "malfunction": {},
+            "speed": {},
+            "status": {},
         }
-        have_all_agents_ended = True # boolean flag to check if all agents are done
+        have_all_agents_ended = True  # boolean flag to check if all agents are done
         for i_agent, agent in enumerate(self.agents):
             # Reset the step rewards
             self.rewards_dict[i_agent] = 0
@@ -488,10 +503,7 @@ class RailEnv(Environment):
             have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
 
             # Build info dict
-            info_dict["action_required"][i_agent] = \
-                (agent.status == RailAgentStatus.READY_TO_DEPART or (
-                agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
-                                                                        rtol=1e-03)))
+            info_dict["action_required"][i_agent] = self.action_required(agent)
             info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
             info_dict["speed"][i_agent] = agent.speed_data['speed']
             info_dict["status"][i_agent] = agent.status