diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
index 03b0ff17f23d4a1046237082bfdcb946de2f87fb..36e5305d5a5a78c600cee0ea6845c341ac2b2e6d 100644
--- a/examples/custom_observation_example.py
+++ b/examples/custom_observation_example.py
@@ -4,6 +4,7 @@ import time
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -79,7 +80,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self.new_position(agent.position, direction)
+                    new_position = get_new_position(agent.position, direction)
                     min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index e3683d893f4feb9979d86f3ca3507100d86d1813..1f0f89dee63026c08e6ab2b0adb7e33d069ae652 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -3,6 +3,7 @@ import time
 
 import numpy as np
 
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -47,7 +48,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self.new_position(agent.position, direction)
+                    new_position = get_new_position(agent.position, direction)
                     min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index f479e07890fc9c53538cc4f95acac9559202354f..baa378f1eaa148fb6546e7b207a12ce350403dfe 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -7,6 +7,7 @@ import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
 
 
@@ -58,19 +59,6 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         self.agents_previous_reset = agents
 
-    def new_position(self, position, movement):
-        """
-        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
-        """
-        if movement == Grid4TransitionsEnum.NORTH:
-            return (position[0] - 1, position[1])
-        elif movement == Grid4TransitionsEnum.EAST:
-            return (position[0], position[1] + 1)
-        elif movement == Grid4TransitionsEnum.SOUTH:
-            return (position[0] + 1, position[1])
-        elif movement == Grid4TransitionsEnum.WEST:
-            return (position[0], position[1] - 1)
-
     def get_many(self, handles=None):
         """
         Called whenever an observation has to be computed for the `env' environment, for each agent with handle
@@ -194,7 +182,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
-                new_cell = self.new_position(agent.position, branch_direction)
+                new_cell = get_new_position(agent.position, branch_direction)
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, 1, 1)
                 observation = observation + branch_observation
@@ -367,7 +355,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     exploring = True
                     # convert one-hot encoding to 0,1,2,3
                     direction = np.argmax(cell_transitions)
-                    position = self.new_position(position, direction)
+                    position = get_new_position(position, direction)
                     num_steps += 1
                     tot_dist += 1
             elif num_transitions > 0:
@@ -440,7 +428,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                  (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
-                new_cell = self.new_position(position, (branch_direction + 2) % 4)
+                new_cell = get_new_position(position, (branch_direction + 2) % 4)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           (branch_direction + 2) % 4,
@@ -450,7 +438,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if len(branch_visited) != 0:
                     visited = visited.union(branch_visited)
             elif last_is_switch and possible_transitions[branch_direction]:
-                new_cell = self.new_position(position, branch_direction)
+                new_cell = get_new_position(position, branch_direction)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           branch_direction,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 10770210b073bf86e1c813895948be3354bb291a..f6c0c818a2a2fce3a8a422220643ae67b0797bf9 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -704,7 +704,7 @@ class RailEnv(Environment):
             possible_directions = [(enforce_target_direction + 2) % 4]
 
         for neigh_direction in possible_directions:
-            new_cell = self.obs_builder.new_position(position, neigh_direction)
+            new_cell = get_new_position(position, neigh_direction)
 
             if new_cell[0] >= 0 and new_cell[0] < self.height and new_cell[1] >= 0 and new_cell[1] < self.width:
 
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index eb056012c59f3e136b583911f86b3952d4435eae..5d35bfb1f45edbc04748c6a8bbab0ac3e0fbea2a 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -4,6 +4,7 @@
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -51,7 +52,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
         shortest_distance = np.inf
 
         for exit_direction in range(4):
-            neighbour = obs_builder.new_position(agent.position, exit_direction)
+            neighbour = get_new_position(agent.position, exit_direction)
 
             if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width:
                 desired_movement_from_new_cell = (exit_direction + 2) % 4
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index a0f97c3e7d3248b22cbc5228ddf49417b146bc67..d1f487ce5e68066ba9db4f791a84ec8bc13c8416 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -1,5 +1,6 @@
 import numpy as np
 
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -40,7 +41,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self.new_position(agent.position, direction)
+                    new_position = get_new_position(agent.position, direction)
                     min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)