From 60b98c1dd7b2efe94b65db4d2500e5f713df5de2 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Mon, 16 Sep 2019 14:25:21 +0200
Subject: [PATCH] Refactoring: delete new_position from TreeObsForRailEnv and
 use flatland.core.grid.grid4utils.get_new_position instead

---
 examples/custom_observation_example.py   |  3 ++-
 examples/debugging_example_DELETE.py     |  3 ++-
 flatland/envs/observations.py            | 22 +++++-----------------
 flatland/envs/rail_env.py                |  2 +-
 tests/test_flatland_envs_observations.py |  3 ++-
 tests/test_flatland_malfunction.py       |  3 ++-
 6 files changed, 14 insertions(+), 22 deletions(-)

diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
index 03b0ff17..36e5305d 100644
--- a/examples/custom_observation_example.py
+++ b/examples/custom_observation_example.py
@@ -4,6 +4,7 @@ import time
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -79,7 +80,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self.new_position(agent.position, direction)
+                    new_position = get_new_position(agent.position, direction)
                     min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index e3683d89..1f0f89de 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -3,6 +3,7 @@ import time
 
 import numpy as np
 
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -47,7 +48,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self.new_position(agent.position, direction)
+                    new_position = get_new_position(agent.position, direction)
                     min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index f479e078..baa378f1 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -7,6 +7,7 @@ import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
 
 
@@ -58,19 +59,6 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         self.agents_previous_reset = agents
 
-    def new_position(self, position, movement):
-        """
-        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
-        """
-        if movement == Grid4TransitionsEnum.NORTH:
-            return (position[0] - 1, position[1])
-        elif movement == Grid4TransitionsEnum.EAST:
-            return (position[0], position[1] + 1)
-        elif movement == Grid4TransitionsEnum.SOUTH:
-            return (position[0] + 1, position[1])
-        elif movement == Grid4TransitionsEnum.WEST:
-            return (position[0], position[1] - 1)
-
     def get_many(self, handles=None):
         """
         Called whenever an observation has to be computed for the `env' environment, for each agent with handle
@@ -194,7 +182,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
-                new_cell = self.new_position(agent.position, branch_direction)
+                new_cell = get_new_position(agent.position, branch_direction)
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, 1, 1)
                 observation = observation + branch_observation
@@ -367,7 +355,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     exploring = True
                     # convert one-hot encoding to 0,1,2,3
                     direction = np.argmax(cell_transitions)
-                    position = self.new_position(position, direction)
+                    position = get_new_position(position, direction)
                     num_steps += 1
                     tot_dist += 1
             elif num_transitions > 0:
@@ -440,7 +428,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                  (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
-                new_cell = self.new_position(position, (branch_direction + 2) % 4)
+                new_cell = get_new_position(position, (branch_direction + 2) % 4)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           (branch_direction + 2) % 4,
@@ -450,7 +438,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if len(branch_visited) != 0:
                     visited = visited.union(branch_visited)
             elif last_is_switch and possible_transitions[branch_direction]:
-                new_cell = self.new_position(position, branch_direction)
+                new_cell = get_new_position(position, branch_direction)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           branch_direction,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 10770210..f6c0c818 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -704,7 +704,7 @@ class RailEnv(Environment):
             possible_directions = [(enforce_target_direction + 2) % 4]
 
         for neigh_direction in possible_directions:
-            new_cell = self.obs_builder.new_position(position, neigh_direction)
+            new_cell = get_new_position(position, neigh_direction)
 
             if new_cell[0] >= 0 and new_cell[0] < self.height and new_cell[1] >= 0 and new_cell[1] < self.width:
 
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index eb056012..5d35bfb1 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -4,6 +4,7 @@
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -51,7 +52,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
         shortest_distance = np.inf
 
         for exit_direction in range(4):
-            neighbour = obs_builder.new_position(agent.position, exit_direction)
+            neighbour = get_new_position(agent.position, exit_direction)
 
             if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width:
                 desired_movement_from_new_cell = (exit_direction + 2) % 4
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index a0f97c3e..d1f487ce 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -1,5 +1,6 @@
 import numpy as np
 
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
@@ -40,7 +41,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self.new_position(agent.position, direction)
+                    new_position = get_new_position(agent.position, direction)
                     min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
-- 
GitLab