diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cb4da95a81758625f1f45ecb535e47d834fd5005..c141bb89c9a064026ea118b5bed1e85e65dcaae9 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -14,6 +14,7 @@ from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.grid.grid_utils import IntVector2D
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
@@ -216,6 +217,7 @@ class RailEnv(Environment):
         # Reset environment
 
         self.valid_positions = None
+        self.agent_positions: np.ndarray = np.full((height, width), False)
 
     def _seed(self, seed=None):
         self.np_random, seed = seeding.np_random(seed)
@@ -242,7 +244,7 @@ class RailEnv(Environment):
         agent = self.agents[handle]
         if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position):
             agent.status = RailAgentStatus.ACTIVE
-            agent.position = agent.initial_position
+            self._set_agent_to_initial_position(agent, agent.initial_position)
 
     def restart_agents(self):
         """ Reset the agents to their starting positions defined in agents_static
@@ -339,6 +341,8 @@ class RailEnv(Environment):
             else:
                 self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
 
+        self.agent_positions = np.full((self.height, self.width), False)
+
         self.restart_agents()
 
         if activate_agents:
@@ -515,7 +519,7 @@ class RailEnv(Environment):
             if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
                           RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
                 agent.status = RailAgentStatus.ACTIVE
-                agent.position = agent.initial_position
+                self._set_agent_to_initial_position(agent, agent.initial_position)
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
                 return
             else:
@@ -610,7 +614,7 @@ class RailEnv(Environment):
                 assert new_cell_valid
                 assert transition_valid
                 if cell_free:
-                    agent.position = new_position
+                    self._move_agent_to_new_position(agent, new_position)
                     agent.direction = new_direction
                     agent.speed_data['position_fraction'] = 0.0
 
@@ -619,16 +623,28 @@ class RailEnv(Environment):
                 agent.status = RailAgentStatus.DONE
                 self.dones[i_agent] = True
                 agent.moving = False
-
-                if self.remove_agents_at_target:
-                    agent.position = None
-                    agent.status = RailAgentStatus.DONE_REMOVED
+                self._remove_agent_from_scene(agent)
             else:
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
         else:
             # step penalty if not moving (stopped now or before)
             self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
 
+    def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
+        agent.position = new_position
+        self.agent_positions[agent.position] = True
+
+    def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
+        agent.position = new_position
+        self.agent_positions[agent.old_position] = False
+        self.agent_positions[agent.position] = True
+
+    def _remove_agent_from_scene(self, agent: EnvAgent):
+        self.agent_positions[agent.position] = False
+        if self.remove_agents_at_target:
+            agent.position = None
+            agent.status = RailAgentStatus.DONE_REMOVED
+
     def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
         """
 
@@ -670,11 +686,8 @@ class RailEnv(Environment):
         cell_free = self.cell_free(new_position)
         return cell_free, new_cell_valid, new_direction, new_position, transition_valid
 
-    def cell_free(self, position):
-
-        agent_positions = [agent.position for agent in self.agents if agent.position is not None]
-        ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1))
-        return ret
+    def cell_free(self, position: IntVector2D) -> bool:
+        return not self.agent_positions[position]
 
     def check_action(self, agent: EnvAgent, action: RailEnvActions):
         """