From 60b98c1dd7b2efe94b65db4d2500e5f713df5de2 Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Mon, 16 Sep 2019 14:25:21 +0200 Subject: [PATCH] Refactoring: delete new_position from TreeObsForRailEnv and use flatland.core.grid.grid4utils.get_new_position instead --- examples/custom_observation_example.py | 3 ++- examples/debugging_example_DELETE.py | 3 ++- flatland/envs/observations.py | 22 +++++----------------- flatland/envs/rail_env.py | 2 +- tests/test_flatland_envs_observations.py | 3 ++- tests/test_flatland_malfunction.py | 3 ++- 6 files changed, 14 insertions(+), 22 deletions(-) diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 03b0ff17..36e5305d 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -4,6 +4,7 @@ import time import numpy as np from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -79,7 +80,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = self.new_position(agent.position, direction) + new_position = get_new_position(agent.position, direction) min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index e3683d89..1f0f89de 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -3,6 +3,7 @@ import time import numpy as np +from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator @@ -47,7 +48,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = self.new_position(agent.position, direction) + new_position = get_new_position(agent.position, direction) min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index f479e078..baa378f1 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -7,6 +7,7 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position @@ -58,19 +59,6 @@ class TreeObsForRailEnv(ObservationBuilder): self.agents_previous_reset = agents - def new_position(self, position, movement): - """ - Utility function that converts a compass movement over a 2D grid to new positions (r, c). - """ - if movement == Grid4TransitionsEnum.NORTH: - return (position[0] - 1, position[1]) - elif movement == Grid4TransitionsEnum.EAST: - return (position[0], position[1] + 1) - elif movement == Grid4TransitionsEnum.SOUTH: - return (position[0] + 1, position[1]) - elif movement == Grid4TransitionsEnum.WEST: - return (position[0], position[1] - 1) - def get_many(self, handles=None): """ Called whenever an observation has to be computed for the `env' environment, for each agent with handle @@ -194,7 +182,7 @@ class TreeObsForRailEnv(ObservationBuilder): for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: - new_cell = self.new_position(agent.position, branch_direction) + new_cell = get_new_position(agent.position, branch_direction) branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) observation = observation + branch_observation @@ -367,7 +355,7 @@ class TreeObsForRailEnv(ObservationBuilder): exploring = True # convert one-hot encoding to 0,1,2,3 direction = np.argmax(cell_transitions) - position = self.new_position(position, direction) + position = get_new_position(position, direction) num_steps += 1 tot_dist += 1 elif num_transitions > 0: @@ -440,7 +428,7 @@ class TreeObsForRailEnv(ObservationBuilder): (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back - new_cell = self.new_position(position, (branch_direction + 2) % 4) + new_cell = get_new_position(position, (branch_direction + 2) % 4) branch_observation, branch_visited = self._explore_branch(handle, new_cell, (branch_direction + 2) % 4, @@ -450,7 +438,7 @@ class TreeObsForRailEnv(ObservationBuilder): if len(branch_visited) != 0: visited = visited.union(branch_visited) elif last_is_switch and possible_transitions[branch_direction]: - new_cell = self.new_position(position, branch_direction) + new_cell = get_new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 10770210..f6c0c818 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -704,7 +704,7 @@ class RailEnv(Environment): possible_directions = [(enforce_target_direction + 2) % 4] for neigh_direction in possible_directions: - new_cell = self.obs_builder.new_position(position, neigh_direction) + new_cell = get_new_position(position, neigh_direction) if new_cell[0] >= 0 and new_cell[0] < self.height and new_cell[1] >= 0 and new_cell[1] < self.width: diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index eb056012..5d35bfb1 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -4,6 +4,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -51,7 +52,7 @@ def _step_along_shortest_path(env, obs_builder, rail): shortest_distance = np.inf for exit_direction in range(4): - neighbour = obs_builder.new_position(agent.position, exit_direction) + neighbour = get_new_position(agent.position, exit_direction) if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width: desired_movement_from_new_cell = (exit_direction + 2) % 4 diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index a0f97c3e..d1f487ce 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -1,5 +1,6 @@ import numpy as np +from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator @@ -40,7 +41,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = self.new_position(agent.position, direction) + new_position = get_new_position(agent.position, direction) min_distances.append(self.env.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) -- GitLab