diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index f79ab504b4d02dc6e58e52f9ca238be62f5dcb6b..b64716f58502c6b4dc13b246f93f83ba5e88d41c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -24,7 +24,6 @@ from flatland.envs.schedule_generators import random_schedule_generator, Schedul
 
 m.patch()
 
-DEPOT_POSITION = (-10, -10)
 
 class RailEnvActions(IntEnum):
     DO_NOTHING = 0  # implies change of direction in a dead-end!
@@ -110,6 +109,10 @@ class RailEnv(Environment):
     stop_penalty = 0  # penalty for stopping a moving agent
     start_penalty = 0  # penalty for starting a stopped agent
 
+    # Where the agent will be placed to when the reach their target destination
+    # (Remove the agents to free the cell)
+    DEPOT_POSITION = (-10, -10)
+
     def __init__(self,
                  width,
                  height,
@@ -118,7 +121,8 @@ class RailEnv(Environment):
                  number_of_agents=1,
                  obs_builder_object: ObservationBuilder = TreeObsForRailEnv(max_depth=2),
                  max_episode_steps=None,
-                 stochastic_data=None
+                 stochastic_data=None,
+                 remove_agents_at_target=False
                  ):
         """
         Environment init.
@@ -149,6 +153,9 @@ class RailEnv(Environment):
             ObservationBuilder-derived object that takes builds observation
             vectors for each agent.
         max_episode_steps : int or None
+        remove_agents_at_target : bool
+            If remove_agents_at_target is set to true then the agents will be removed by placing to
+            RailEnv.DEPOT_POSITION when the agent has reach it's target position.
         """
         super().__init__()
 
@@ -159,6 +166,8 @@ class RailEnv(Environment):
         self.width = width
         self.height = height
 
+        self.remove_agents_at_target = remove_agents_at_target
+
         self.rewards = [0] * number_of_agents
         self.done = False
         self.obs_builder = obs_builder_object
@@ -503,14 +512,16 @@ class RailEnv(Environment):
             if np.equal(agent.position, agent.target).all():
                 self.dones[i_agent] = True
                 agent.moving = False
-                # TODO: Moving agents to arbitrary position
-                agent.position = DEPOT_POSITION
-                agent.target = DEPOT_POSITION
+
+                if self.remove_agents_at_target:
+                    agent.position = RailEnv.DEPOT_POSITION
+                    agent.target = RailEnv.DEPOT_POSITION
             else:
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
         else:
             # step penalty if not moving (stopped now or before)
             self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+
     def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
         """
 
@@ -549,15 +560,7 @@ class RailEnv(Environment):
 
         # Check the new position is not the same as any of the existing agent positions
         # (including itself, for simplicity, since it is moving)
-        # TODO: Revert to earlier version
-        cell_free = True
-        for agent2 in self.agents:
-            if Vec2dOperations.is_equal(new_position, agent2.position) and not Vec2dOperations.is_equal(agent2.target,
-                                                                                                        agent2.position):
-                cell_free = False
-                break
-
-
+        cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
         return cell_free, new_cell_valid, new_direction, new_position, transition_valid
 
     def check_action(self, agent: EnvAgent, action: RailEnvActions):
@@ -605,7 +608,7 @@ class RailEnv(Environment):
         return self.obs_dict
 
     def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
-        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col))
+        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
 
     def get_full_state_msg(self):
         grid_data = self.rail.grid.tolist()