From 9babe92826827a697ef488776cc562b8d89a3a8d Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 17 Sep 2019 13:35:53 +0200
Subject: [PATCH] refactoring

---
 .../Simple_Realistic_Railway_Generator.py     |  2 -
 flatland/core/grid/grid4_astar.py             |  9 +++--
 flatland/core/grid/grid_utils.py              | 13 +++++++
 flatland/core/transition_map.py               | 38 ++++++++++---------
 tests/test_flatland_core_transitions.py       | 16 ++++----
 5 files changed, 46 insertions(+), 32 deletions(-)

diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
index f935d3e3..09f3b538 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -1,6 +1,5 @@
 import copy
 import os
-import time
 import warnings
 
 import numpy as np
@@ -278,7 +277,6 @@ def realistic_rail_generator(num_cities=5,
                             grid_map.grid[start_node] = tmp_trans_sn
                             grid_map.grid[end_node] = tmp_trans_en
 
-
         connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added)
 
     def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py
index d0b2cc97..779ee984 100644
--- a/flatland/core/grid/grid4_astar.py
+++ b/flatland/core/grid/grid4_astar.py
@@ -1,4 +1,3 @@
-
 from flatland.core.grid.grid_utils import IntVector2D
 from flatland.core.grid.grid_utils import IntVector2DArrayType
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
@@ -75,13 +74,15 @@ def a_star(rail_trans: RailEnvTransitions,
         else:
             prev_pos = None
         for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
-            node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
+            # update the "current" pos
+            node_pos = Vec2d.add(current_node.pos, new_pos)
+
+            # is node_pos inside the grid?
             if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
                 continue
 
             # validate positions
-            if not grid_map.validate_new_transition(rail_trans, prev_pos, current_node.pos, node_pos,
-                                           end_node.pos):
+            if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos):
                 continue
 
             # create new node
diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py
index a9d6ffaa..f9657bf3 100644
--- a/flatland/core/grid/grid_utils.py
+++ b/flatland/core/grid/grid_utils.py
@@ -10,6 +10,19 @@ IntVector2DArrayType = []
 
 class Vec2dOperations:
 
+    @staticmethod
+    def is_equal(node_a: Vector2D, node_b: Vector2D) -> bool:
+        """
+        vector operation : node_a + node_b
+
+        :param node_a: tuple with coordinate (x,y) or 2d vector
+        :param node_b: tuple with coordinate (x,y) or 2d vector
+        :return:
+            -------
+        check if node_a and nobe_b are equal
+        """
+        return node_a[0] == node_b[0] and node_a[1] == node_b[1]
+
     @staticmethod
     def subtract(node_a: Vector2D, node_b: Vector2D) -> Vector2D:
         """
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 2b06714b..d8da9c56 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -9,6 +9,7 @@ from numpy import array
 from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position, get_direction
 from flatland.core.grid.grid_utils import IntVector2DArrayType
+from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transitions import Transitions
 
@@ -301,7 +302,7 @@ class GridTransitionMap(TransitionMap):
         self.height = new_height
         self.grid = new_grid
 
-    def is_dead_end(self, rcPos):
+    def is_dead_end(self, rcPos: IntVector2DArrayType):
         """
         Check if the cell is a dead-end.
 
@@ -321,7 +322,7 @@ class GridTransitionMap(TransitionMap):
             tmp = tmp >> 1
         return nbits == 1
 
-    def is_simple_turn(self, rcPos):
+    def is_simple_turn(self, rcPos: IntVector2DArrayType):
         """
         Check if the cell is a left/right simple turn
 
@@ -348,7 +349,7 @@ class GridTransitionMap(TransitionMap):
 
         return is_simple_turn(tmp)
 
-    def check_path_exists(self, start, direction, end):
+    def check_path_exists(self, start: IntVector2DArrayType, direction: int, end: IntVector2DArrayType):
         # print("_path_exists({},{},{}".format(start, direction, end))
         # BFS - Check if a path exists between the 2 nodes
 
@@ -358,7 +359,8 @@ class GridTransitionMap(TransitionMap):
             node = stack.pop()
             node_position = node[0]
             node_direction = node[1]
-            if node_position[0] == end[0] and node_position[1] == end[1]:
+
+            if Vec2d.is_equal(node_position, end):
                 return True
             if node not in visited:
                 visited.add(node)
@@ -371,7 +373,7 @@ class GridTransitionMap(TransitionMap):
 
         return False
 
-    def cell_neighbours_valid(self, rcPos, check_this_cell=False):
+    def cell_neighbours_valid(self, rcPos: IntVector2DArrayType, check_this_cell=False):
         """
         Check validity of cell at rcPos = tuple(row, column)
         Checks that:
@@ -423,7 +425,7 @@ class GridTransitionMap(TransitionMap):
 
         return True
 
-    def fix_neighbours(self, rcPos, check_this_cell=False):
+    def fix_neighbours(self, rcPos: IntVector2DArrayType, check_this_cell=False):
         """
         Check validity of cell at rcPos = tuple(row, column)
         Checks that:
@@ -476,7 +478,7 @@ class GridTransitionMap(TransitionMap):
 
         return True
 
-    def fix_transitions(self, rcPos):
+    def fix_transitions(self, rcPos: IntVector2DArrayType):
         """
         Fixes broken transitions
         """
@@ -541,9 +543,9 @@ class GridTransitionMap(TransitionMap):
             self.set_transitions((rcPos[0], rcPos[1]), transition)
         return True
 
-    def validate_new_transition(self, rail_trans: RailEnvTransitions,
-                                prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType,
+    def validate_new_transition(self, prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType,
                                 new_pos: IntVector2DArrayType, end_pos: IntVector2DArrayType):
+
         # start by getting direction used to get to current node
         # and direction from current node to possible child node
         new_dir = get_direction(current_pos, new_pos)
@@ -556,30 +558,30 @@ class GridTransitionMap(TransitionMap):
         if prev_pos is None:
             if new_trans == 0:
                 # need to flip direction because of how end points are defined
-                new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
+                new_trans = self.transitions.set_transition(new_trans, mirror(current_dir), new_dir, 1)
             else:
                 # check if matches existing layout
-                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+                new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
         else:
             # set the forward path
-            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            new_trans = self.transitions.set_transition(new_trans, current_dir, new_dir, 1)
             # set the backwards path
-            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
-        if new_pos == end_pos:
+            new_trans = self.transitions.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        if Vec2d.is_equal(new_pos, end_pos):
             # need to validate end pos setup as well
             new_trans_e = self.grid[end_pos]
             if new_trans_e == 0:
                 # need to flip direction because of how end points are defined
-                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+                new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
             else:
                 # check if matches existing layout
-                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+                new_trans_e = self.transitions.set_transition(new_trans_e, new_dir, new_dir, 1)
 
-            if not rail_trans.is_valid(new_trans_e):
+            if not self.transitions.is_valid(new_trans_e):
                 return False
 
         # is transition is valid?
-        return rail_trans.is_valid(new_trans)
+        return self.transitions.is_valid(new_trans)
 
 
 def mirror(dir):
diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py
index 9c01be4b..b4c268a7 100644
--- a/tests/test_flatland_core_transitions.py
+++ b/tests/test_flatland_core_transitions.py
@@ -118,32 +118,32 @@ def test_adding_new_valid_transition():
     grid_map = GridTransitionMap(width=15, height=15, transitions=rail_trans)
 
     # adding straight
-    assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (6, 5), (10, 10)) is True)
+    assert (grid_map.validate_new_transition((4, 5), (5, 5), (6, 5), (10, 10)) is True)
 
     # adding valid right turn
-    assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
+    assert (grid_map.validate_new_transition((5, 4), (5, 5), (5, 6), (10, 10)) is True)
     # adding valid left turn
-    assert (grid_map.validate_new_transition(rail_trans, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
+    assert (grid_map.validate_new_transition((5, 6), (5, 5), (5, 6), (10, 10)) is True)
 
     # adding invalid turn
     grid_map.grid[(5, 5)] = rail_trans.transitions[2]
-    assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
+    assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False)
 
     # should create #4 -> valid
     grid_map.grid[(5, 5)] = rail_trans.transitions[3]
-    assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
+    assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is True)
 
     # adding invalid turn
     grid_map.grid[(5, 5)] = rail_trans.transitions[7]
-    assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
+    assert (grid_map.validate_new_transition((4, 5), (5, 5), (5, 6), (10, 10)) is False)
 
     # test path start condition
     grid_map.grid[(5, 5)] = rail_trans.transitions[0]
-    assert (grid_map.validate_new_transition(rail_trans, None, (5, 5), (5, 6), (10, 10)) is True)
+    assert (grid_map.validate_new_transition(None, (5, 5), (5, 6), (10, 10)) is True)
 
     # test path end condition
     grid_map.grid[(5, 5)] = rail_trans.transitions[0]
-    assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (6, 5), (6, 5)) is True)
+    assert (grid_map.validate_new_transition((5, 4), (5, 5), (6, 5), (6, 5)) is True)
 
 
 def test_valid_railenv_transitions():
-- 
GitLab