diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index ab0e14879a354993331519315d7f1fbc423f2af2..4cd1d2d4683e9e945318c914c1e90ea7f2abeb52 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -15,7 +15,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import IntVector2D
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
+from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env_action import RailEnvActions
 
@@ -475,6 +475,34 @@ class RailEnv(Environment):
 
         return
 
+    def _handle_end_reward(self, agent: EnvAgent) -> int:
+        '''
+        Handles end-of-episode reward for a particular agent.
+
+        Parameters
+        ----------
+        agent : EnvAgent
+        '''
+        reward = None
+        # agent done? (arrival_time is not None)
+        if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
+            # if agent arrived earlier or on time = 0
+            # if agent arrived later = -ve reward based on how late
+            reward = min(agent.latest_arrival - agent.arrival_time, 0)
+
+        # Agents not done (arrival_time is None)
+        else:
+            # CANCELLED check (never departed)
+            if (agent.status == RailAgentStatus.READY_TO_DEPART):
+                reward = -1 * self.cancellation_factor * \
+                    (agent.get_travel_time_on_shortest_path(self.distance_map) + 0)  # 0 replaced with buffer
+
+            # Departed but never reached
+            if (agent.status == RailAgentStatus.ACTIVE):
+                reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
+        
+        return reward
+
     def step(self, action_dict_: Dict[int, RailEnvActions]):
         """
         Updates rewards for the agents at a step.
@@ -564,26 +592,8 @@ class RailEnv(Environment):
             for i_agent, agent in enumerate(self.agents):
                 
                 # agent done? (arrival_time is not None)
-                if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
-                    
-                    # if agent arrived earlier or on time = 0
-                    # if agent arrived later = -ve reward based on how late
-                    reward = min(agent.latest_arrival - agent.arrival_time, 0)
-                    self.rewards_dict[i_agent] += reward
-                
-                # Agents not done (arrival_time is None)
-                else:
-                    
-                    # CANCELLED check (never departed)
-                    if (agent.status == RailAgentStatus.READY_TO_DEPART):
-                        reward = -1 * self.cancellation_factor * \
-                            (agent.get_travel_time_on_shortest_path(self.distance_map) + 0) # 0 replaced with buffer
-                        self.rewards_dict[i_agent] += reward
-
-                    # Departed but never reached
-                    if (agent.status == RailAgentStatus.ACTIVE):
-                        reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
-                        self.rewards_dict[i_agent] += reward
+                reward = self._handle_end_reward(agent)
+                self.rewards_dict[i_agent] += reward
                 
                 self.dones[i_agent] = True