diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index ad160e4e8770b9994052563c4584a7211e8da3bc..0887c0ca59d1609d27b37aa318022750924f4ea3 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -14,6 +14,7 @@ from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.grid.grid_utils import IntVector2D
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
@@ -218,6 +219,9 @@ class RailEnv(Environment):
 
         self.valid_positions = None
 
+        # global numpy array of agents position, True means that there is an agent at that cell
+        self.agent_positions: np.ndarray = np.full((height, width), False)
+
     def _seed(self, seed=None):
         self.np_random, seed = seeding.np_random(seed)
         return [seed]
@@ -243,7 +247,7 @@ class RailEnv(Environment):
         agent = self.agents[handle]
         if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position):
             agent.status = RailAgentStatus.ACTIVE
-            agent.position = agent.initial_position
+            self._set_agent_to_initial_position(agent, agent.initial_position)
 
     def restart_agents(self):
         """ Reset the agents to their starting positions defined in agents_static
@@ -340,6 +344,8 @@ class RailEnv(Environment):
             else:
                 self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
 
+        self.agent_positions = np.full((self.height, self.width), False)
+
         self.restart_agents()
 
         if activate_agents:
@@ -523,7 +529,7 @@ class RailEnv(Environment):
             if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
                           RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
                 agent.status = RailAgentStatus.ACTIVE
-                agent.position = agent.initial_position
+                self._set_agent_to_initial_position(agent, agent.initial_position)
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
                 return
             else:
@@ -618,7 +624,7 @@ class RailEnv(Environment):
                 assert new_cell_valid
                 assert transition_valid
                 if cell_free:
-                    agent.position = new_position
+                    self._move_agent_to_new_position(agent, new_position)
                     agent.direction = new_direction
                     agent.speed_data['position_fraction'] = 0.0
 
@@ -627,16 +633,54 @@ class RailEnv(Environment):
                 agent.status = RailAgentStatus.DONE
                 self.dones[i_agent] = True
                 agent.moving = False
-
-                if self.remove_agents_at_target:
-                    agent.position = None
-                    agent.status = RailAgentStatus.DONE_REMOVED
+                self._remove_agent_from_scene(agent)
             else:
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
         else:
             # step penalty if not moving (stopped now or before)
             self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
 
+    def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
+        """
+        Sets the agent to its initial position. Updates the agent object and the position
+        of the agent inside the global agent_position numpy array
+
+        Parameters
+        -------
+        agent: EnvAgent object
+        new_position: IntVector2D
+        """
+        agent.position = new_position
+        self.agent_positions[agent.position] = True
+
+    def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
+        """
+        Move the agent to the a new position. Updates the agent object and the position
+        of the agent inside the global agent_position numpy array
+
+        Parameters
+        -------
+        agent: EnvAgent object
+        new_position: IntVector2D
+        """
+        agent.position = new_position
+        self.agent_positions[agent.old_position] = False
+        self.agent_positions[agent.position] = True
+
+    def _remove_agent_from_scene(self, agent: EnvAgent):
+        """
+        Remove the agent from the scene. Updates the agent object and the position
+        of the agent inside the global agent_position numpy array
+
+        Parameters
+        -------
+        agent: EnvAgent object
+        """
+        self.agent_positions[agent.position] = False
+        if self.remove_agents_at_target:
+            agent.position = None
+            agent.status = RailAgentStatus.DONE_REMOVED
+
     def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
         """
 
@@ -673,12 +717,18 @@ class RailEnv(Environment):
                 (*agent.position, agent.direction),
                 new_direction)
 
-        # Check the new position is not the same as any of the existing agent positions
-        # (including itself, for simplicity, since it is moving)
-        cell_free = self.cell_free(new_position)
+
+        # only call cell_free() if new cell is inside the scene
+        if new_cell_valid:
+            # Check the new position is not the same as any of the existing agent positions
+            # (including itself, for simplicity, since it is moving)
+            cell_free = self.cell_free(new_position)
+        else:
+            # if new cell is outside of scene -> cell_free is False
+            cell_free = False
         return cell_free, new_cell_valid, new_direction, new_position, transition_valid
 
-    def cell_free(self, position):
+    def cell_free(self, position: IntVector2D) -> bool:
         """
         Utility to check if a cell is free
 
@@ -692,9 +742,7 @@ class RailEnv(Environment):
             is the cell free or not?
 
         """
-        agent_positions = [agent.position for agent in self.agents if agent.position is not None]
-        ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1))
-        return ret
+        return not self.agent_positions[position]
 
     def check_action(self, agent: EnvAgent, action: RailEnvActions):
         """
@@ -790,7 +838,7 @@ class RailEnv(Environment):
     def set_full_state_msg(self, msg_data):
         """
         Sets environment state with msgdata object passed as argument
-        
+
         Parameters
         -------
         msg_data: msgpack object
@@ -809,7 +857,7 @@ class RailEnv(Environment):
     def set_full_state_dist_msg(self, msg_data):
         """
         Sets environment grid state and distance map with msgdata object passed as argument
-        
+
         Parameters
         -------
         msg_data: msgpack object