diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 75b10e8f71b45b078017f0e1ff9061249c594c75..1410556e9701b56c6b2f86e545b3baa7c240022f 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -461,39 +461,32 @@ class RailEnv(Environment):
         action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
 
         return action
+    
+    def clear_rewards_dict(self):
+        """ Reset the step rewards """
 
-    def step(self, action_dict_: Dict[int, RailEnvActions]):
-        """
-        Updates rewards for the agents at a step.
-        """
-        self._elapsed_steps += 1
-
-        # If we're done, set reward and info_dict and step() is done.
-        if self.dones["__all__"]: # TODO: Move boilerplate to different function
-            self.rewards_dict = {}
-            info_dict = {
-                "action_required": {},
-                "malfunction": {},
-                "speed": {},
-                "status": {},
-            }
-            for i_agent, agent in enumerate(self.agents):
-                self.rewards_dict[i_agent] = self.global_reward
-                info_dict["action_required"][i_agent] = False
-                info_dict["malfunction"][i_agent] = 0
-                info_dict["speed"][i_agent] = 0
-                info_dict["status"][i_agent] = agent.status
-
-            return self._get_observations(), self.rewards_dict, self.dones, info_dict
-
-        # Reset the step rewards
         self.rewards_dict = dict()
+
+    def get_info_dict(self): # TODO Important : Update this
         info_dict = {
             "action_required": {},
             "malfunction": {},
             "speed": {},
             "status": {},
         }
+        return info_dict
+
+    def step(self, action_dict_: Dict[int, RailEnvActions]):
+        """
+        Updates rewards for the agents at a step.
+        """
+        self._elapsed_steps += 1
+
+        # Not allowed to step further once done
+        if self.dones["__all__"]:
+            raise Exception("Episode is done, cannot call step()")
+
+        self.clear_rewards_dict()
 
         self.motionCheck = ac.MotionCheck()  # reset the motion check
 
@@ -578,7 +571,7 @@ class RailEnv(Environment):
                 agent.action_saver.clear_saved_action()
         
         self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Rewards - Remove this
-        return self._get_observations(), self.rewards_dict, self.dones, info_dict # TODO : Will need changes?
+        return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() 
 
     def record_timestep(self, dActions):
         ''' Record the positions and orientations of all agents in memory, in the cur_episode