From 4c8f4d40da4f70a8308799794b05a3d6fd3d55a3 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Tue, 24 Sep 2019 09:28:56 +0200
Subject: [PATCH] update baselines to master of flatland

---
 MANIFEST.in                                   |  2 +-
 .../observation_builders/observations.py      | 43 +++++++------------
 torch_training/predictors/predictions.py      |  5 ++-
 3 files changed, 19 insertions(+), 31 deletions(-)

diff --git a/MANIFEST.in b/MANIFEST.in
index 85154f6..c328629 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -12,4 +12,4 @@ recursive-include tests *
 recursive-exclude * __pycache__
 recursive-exclude * *.py[co]
 
-recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif
+recursive-include docs *.rst *.md conf.py Makefile make.bat *.jpg *.png *.gif
diff --git a/torch_training/observation_builders/observations.py b/torch_training/observation_builders/observations.py
index e3d52d3..b55a5bf 100644
--- a/torch_training/observation_builders/observations.py
+++ b/torch_training/observation_builders/observations.py
@@ -7,7 +7,7 @@ from collections import deque
 import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
 
 
@@ -86,10 +86,10 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Fill in the (up to) 4 neighboring nodes
         # direction is the direction of movement, meaning that at least a possible orientation of an agent
-        # in cell (row,col) allows a movement in direction `direction'
+        # in cell (row,col) allows a movement in direction `direction`
         nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
 
-        # BFS from target `position' to all the reachable nodes in the grid
+        # BFS from target `position` to all the reachable nodes in the grid
         # Stop the search if the target position is re-visited, in any direction
         visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
                    (position[0], position[1], 3)}
@@ -125,12 +125,12 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         possible_directions = [0, 1, 2, 3]
         if enforce_target_direction >= 0:
-            # The agent must land into the current cell with orientation `enforce_target_direction'.
+            # The agent must land into the current cell with orientation `enforce_target_direction`.
             # This is only possible if the agent has arrived from the cell in the opposite direction!
             possible_directions = [(enforce_target_direction + 2) % 4]
 
         for neigh_direction in possible_directions:
-            new_cell = self._new_position(position, neigh_direction)
+            new_cell = get_new_position(position, neigh_direction)
 
             if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
 
@@ -138,7 +138,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
                 # Check all possible transitions in new_cell
                 for agent_orientation in range(4):
-                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
+                    # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible?
                     is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
                                                             desired_movement_from_new_cell)
 
@@ -156,23 +156,10 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         return neighbors
 
-    def _new_position(self, position, movement):
-        """
-        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
-        """
-        if movement == Grid4TransitionsEnum.NORTH:
-            return (position[0] - 1, position[1])
-        elif movement == Grid4TransitionsEnum.EAST:
-            return (position[0], position[1] + 1)
-        elif movement == Grid4TransitionsEnum.SOUTH:
-            return (position[0] + 1, position[1])
-        elif movement == Grid4TransitionsEnum.WEST:
-            return (position[0], position[1] - 1)
-
     def get_many(self, handles=None):
         """
-        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
-        in the `handles' list.
+        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
+        in the `handles` list.
         """
 
         if handles is None:
@@ -200,7 +187,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
     def get(self, handle):
         """
-        Computes the current observation for agent `handle' in env
+        Computes the current observation for agent `handle` in env
 
         The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
         movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
@@ -280,7 +267,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
-                new_cell = self._new_position(agent.position, branch_direction)
+                new_cell = get_new_position(agent.position, branch_direction)
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, 1, 1)
                 observation = observation + branch_observation
@@ -428,11 +415,11 @@ class TreeObsForRailEnv(ObservationBuilder):
                     last_is_dead_end = True
 
                 if not last_is_dead_end:
-                    # Keep walking through the tree along `direction'
+                    # Keep walking through the tree along `direction`
                     exploring = True
                     # convert one-hot encoding to 0,1,2,3
                     direction = np.argmax(cell_transitions)
-                    position = self._new_position(position, direction)
+                    position = get_new_position(position, direction)
                     num_steps += 1
                     tot_dist += 1
             elif num_transitions > 0:
@@ -447,7 +434,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 last_is_terminal = True
                 break
 
-        # `position' is either a terminal node or a switch
+        # `position` is either a terminal node or a switch
 
         # #############################
         # #############################
@@ -499,7 +486,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                  (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
-                new_cell = self._new_position(position, (branch_direction + 2) % 4)
+                new_cell = get_new_position(position, (branch_direction + 2) % 4)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           (branch_direction + 2) % 4,
@@ -509,7 +496,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if len(branch_visited) != 0:
                     visited = visited.union(branch_visited)
             elif last_is_switch and possible_transitions[branch_direction]:
-                new_cell = self._new_position(position, branch_direction)
+                new_cell = get_new_position(position, branch_direction)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           branch_direction,
diff --git a/torch_training/predictors/predictions.py b/torch_training/predictors/predictions.py
index 10abcef..4b816ab 100644
--- a/torch_training/predictors/predictions.py
+++ b/torch_training/predictors/predictions.py
@@ -8,6 +8,7 @@ from flatland.core.env_prediction_builder import PredictionBuilder
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.rail_env import RailEnvActions
 
+
 class ShortestPathPredictorForRailEnv(PredictionBuilder):
     """
     ShortestPathPredictorForRailEnv object.
@@ -25,10 +26,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
         Requires distance_map to extract the shortest path.
 
         Parameters
-        -------
+        ----------
         custom_args: dict
             - distance_map : dict
-        handle : int (optional)
+        handle : int, optional
             Handle of the agent for which to compute the observation vector.
 
         Returns
-- 
GitLab