From 4c8f4d40da4f70a8308799794b05a3d6fd3d55a3 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 24 Sep 2019 09:28:56 +0200 Subject: [PATCH] update baselines to master of flatland --- MANIFEST.in | 2 +- .../observation_builders/observations.py | 43 +++++++------------ torch_training/predictors/predictions.py | 5 ++- 3 files changed, 19 insertions(+), 31 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 85154f6..c328629 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -12,4 +12,4 @@ recursive-include tests * recursive-exclude * __pycache__ recursive-exclude * *.py[co] -recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif +recursive-include docs *.rst *.md conf.py Makefile make.bat *.jpg *.png *.gif diff --git a/torch_training/observation_builders/observations.py b/torch_training/observation_builders/observations.py index e3d52d3..b55a5bf 100644 --- a/torch_training/observation_builders/observations.py +++ b/torch_training/observation_builders/observations.py @@ -7,7 +7,7 @@ from collections import deque import numpy as np from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position @@ -86,10 +86,10 @@ class TreeObsForRailEnv(ObservationBuilder): # Fill in the (up to) 4 neighboring nodes # direction is the direction of movement, meaning that at least a possible orientation of an agent - # in cell (row,col) allows a movement in direction `direction' + # in cell (row,col) allows a movement in direction `direction` nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) - # BFS from target `position' to all the reachable nodes in the grid + # BFS from target `position` to all the reachable nodes in the grid # Stop the search if the target position is re-visited, in any direction visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), (position[0], position[1], 3)} @@ -125,12 +125,12 @@ class TreeObsForRailEnv(ObservationBuilder): possible_directions = [0, 1, 2, 3] if enforce_target_direction >= 0: - # The agent must land into the current cell with orientation `enforce_target_direction'. + # The agent must land into the current cell with orientation `enforce_target_direction`. # This is only possible if the agent has arrived from the cell in the opposite direction! possible_directions = [(enforce_target_direction + 2) % 4] for neigh_direction in possible_directions: - new_cell = self._new_position(position, neigh_direction) + new_cell = get_new_position(position, neigh_direction) if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: @@ -138,7 +138,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Check all possible transitions in new_cell for agent_orientation in range(4): - # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? + # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible? is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), desired_movement_from_new_cell) @@ -156,23 +156,10 @@ class TreeObsForRailEnv(ObservationBuilder): return neighbors - def _new_position(self, position, movement): - """ - Utility function that converts a compass movement over a 2D grid to new positions (r, c). - """ - if movement == Grid4TransitionsEnum.NORTH: - return (position[0] - 1, position[1]) - elif movement == Grid4TransitionsEnum.EAST: - return (position[0], position[1] + 1) - elif movement == Grid4TransitionsEnum.SOUTH: - return (position[0] + 1, position[1]) - elif movement == Grid4TransitionsEnum.WEST: - return (position[0], position[1] - 1) - def get_many(self, handles=None): """ - Called whenever an observation has to be computed for the `env' environment, for each agent with handle - in the `handles' list. + Called whenever an observation has to be computed for the `env` environment, for each agent with handle + in the `handles` list. """ if handles is None: @@ -200,7 +187,7 @@ class TreeObsForRailEnv(ObservationBuilder): def get(self, handle): """ - Computes the current observation for agent `handle' in env + Computes the current observation for agent `handle` in env The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). @@ -280,7 +267,7 @@ class TreeObsForRailEnv(ObservationBuilder): for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: - new_cell = self._new_position(agent.position, branch_direction) + new_cell = get_new_position(agent.position, branch_direction) branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) observation = observation + branch_observation @@ -428,11 +415,11 @@ class TreeObsForRailEnv(ObservationBuilder): last_is_dead_end = True if not last_is_dead_end: - # Keep walking through the tree along `direction' + # Keep walking through the tree along `direction` exploring = True # convert one-hot encoding to 0,1,2,3 direction = np.argmax(cell_transitions) - position = self._new_position(position, direction) + position = get_new_position(position, direction) num_steps += 1 tot_dist += 1 elif num_transitions > 0: @@ -447,7 +434,7 @@ class TreeObsForRailEnv(ObservationBuilder): last_is_terminal = True break - # `position' is either a terminal node or a switch + # `position` is either a terminal node or a switch # ############################# # ############################# @@ -499,7 +486,7 @@ class TreeObsForRailEnv(ObservationBuilder): (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back - new_cell = self._new_position(position, (branch_direction + 2) % 4) + new_cell = get_new_position(position, (branch_direction + 2) % 4) branch_observation, branch_visited = self._explore_branch(handle, new_cell, (branch_direction + 2) % 4, @@ -509,7 +496,7 @@ class TreeObsForRailEnv(ObservationBuilder): if len(branch_visited) != 0: visited = visited.union(branch_visited) elif last_is_switch and possible_transitions[branch_direction]: - new_cell = self._new_position(position, branch_direction) + new_cell = get_new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, diff --git a/torch_training/predictors/predictions.py b/torch_training/predictors/predictions.py index 10abcef..4b816ab 100644 --- a/torch_training/predictors/predictions.py +++ b/torch_training/predictors/predictions.py @@ -8,6 +8,7 @@ from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.rail_env import RailEnvActions + class ShortestPathPredictorForRailEnv(PredictionBuilder): """ ShortestPathPredictorForRailEnv object. @@ -25,10 +26,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): Requires distance_map to extract the shortest path. Parameters - ------- + ---------- custom_args: dict - distance_map : dict - handle : int (optional) + handle : int, optional Handle of the agent for which to compute the observation vector. Returns -- GitLab