diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 03b0ff17f23d4a1046237082bfdcb946de2f87fb..36e5305d5a5a78c600cee0ea6845c341ac2b2e6d 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -4,6 +4,7 @@ import time import numpy as np from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -79,7 +80,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = self.new_position(agent.position, direction) + new_position = get_new_position(agent.position, direction) min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index e3683d893f4feb9979d86f3ca3507100d86d1813..1f0f89dee63026c08e6ab2b0adb7e33d069ae652 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -3,6 +3,7 @@ import time import numpy as np +from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator @@ -47,7 +48,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = self.new_position(agent.position, direction) + new_position = get_new_position(agent.position, direction) min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index f479e07890fc9c53538cc4f95acac9559202354f..baa378f1eaa148fb6546e7b207a12ce350403dfe 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -7,6 +7,7 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position @@ -58,19 +59,6 @@ class TreeObsForRailEnv(ObservationBuilder): self.agents_previous_reset = agents - def new_position(self, position, movement): - """ - Utility function that converts a compass movement over a 2D grid to new positions (r, c). - """ - if movement == Grid4TransitionsEnum.NORTH: - return (position[0] - 1, position[1]) - elif movement == Grid4TransitionsEnum.EAST: - return (position[0], position[1] + 1) - elif movement == Grid4TransitionsEnum.SOUTH: - return (position[0] + 1, position[1]) - elif movement == Grid4TransitionsEnum.WEST: - return (position[0], position[1] - 1) - def get_many(self, handles=None): """ Called whenever an observation has to be computed for the `env' environment, for each agent with handle @@ -194,7 +182,7 @@ class TreeObsForRailEnv(ObservationBuilder): for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: - new_cell = self.new_position(agent.position, branch_direction) + new_cell = get_new_position(agent.position, branch_direction) branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) observation = observation + branch_observation @@ -367,7 +355,7 @@ class TreeObsForRailEnv(ObservationBuilder): exploring = True # convert one-hot encoding to 0,1,2,3 direction = np.argmax(cell_transitions) - position = self.new_position(position, direction) + position = get_new_position(position, direction) num_steps += 1 tot_dist += 1 elif num_transitions > 0: @@ -440,7 +428,7 @@ class TreeObsForRailEnv(ObservationBuilder): (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back - new_cell = self.new_position(position, (branch_direction + 2) % 4) + new_cell = get_new_position(position, (branch_direction + 2) % 4) branch_observation, branch_visited = self._explore_branch(handle, new_cell, (branch_direction + 2) % 4, @@ -450,7 +438,7 @@ class TreeObsForRailEnv(ObservationBuilder): if len(branch_visited) != 0: visited = visited.union(branch_visited) elif last_is_switch and possible_transitions[branch_direction]: - new_cell = self.new_position(position, branch_direction) + new_cell = get_new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 10770210b073bf86e1c813895948be3354bb291a..f6c0c818a2a2fce3a8a422220643ae67b0797bf9 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -704,7 +704,7 @@ class RailEnv(Environment): possible_directions = [(enforce_target_direction + 2) % 4] for neigh_direction in possible_directions: - new_cell = self.obs_builder.new_position(position, neigh_direction) + new_cell = get_new_position(position, neigh_direction) if new_cell[0] >= 0 and new_cell[0] < self.height and new_cell[1] >= 0 and new_cell[1] < self.width: diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index eb056012c59f3e136b583911f86b3952d4435eae..5d35bfb1f45edbc04748c6a8bbab0ac3e0fbea2a 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -4,6 +4,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -51,7 +52,7 @@ def _step_along_shortest_path(env, obs_builder, rail): shortest_distance = np.inf for exit_direction in range(4): - neighbour = obs_builder.new_position(agent.position, exit_direction) + neighbour = get_new_position(agent.position, exit_direction) if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width: desired_movement_from_new_cell = (exit_direction + 2) % 4 diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index a0f97c3e7d3248b22cbc5228ddf49417b146bc67..d1f487ce5e68066ba9db4f791a84ec8bc13c8416 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -1,5 +1,6 @@ import numpy as np +from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator @@ -40,7 +41,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = self.new_position(agent.position, direction) + new_position = get_new_position(agent.position, direction) min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf)