From ef0b7175d48f8f727679ff467238990cbd1a8f4c Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 17 Sep 2019 13:19:20 +0200
Subject: [PATCH] refactoring and clean up

---
 .../Simple_Realistic_Railway_Generator.py     |  4 +-
 flatland/core/grid/grid4_astar.py             | 10 ++---
 flatland/core/grid/grid4_utils.py             | 42 +-----------------
 flatland/core/transition_map.py               | 43 ++++++++++++++++++-
 tests/test_flatland_core_transitions.py       | 32 +++++++-------
 5 files changed, 66 insertions(+), 65 deletions(-)

diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
index 6e5e6a56..f935d3e3 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -451,7 +451,7 @@ def realistic_rail_generator(num_cities=5,
 if os.path.exists("./../render_output/"):
     for itrials in range(1000):
         print(itrials, "generate new city")
-        np.random.seed(0 * int(time.time()))
+        np.random.seed(itrials)
         env = RailEnv(width=40 + np.random.choice(100),
                       height=40 + np.random.choice(100),
                       rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10),
@@ -462,7 +462,7 @@ if os.path.exists("./../render_output/"):
                                                               connect_max_nbr_of_shortes_city=2,
                                                               do_random_connect_stations=False,
                                                               # Number of cities in map
-                                                              seed=0*int(time.time()),  # Random seed
+                                                              seed=itrials,  # Random seed
                                                               print_out_info=True
                                                               ),
                       schedule_generator=sparse_schedule_generator(),
diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py
index c04d71dd..d0b2cc97 100644
--- a/flatland/core/grid/grid4_astar.py
+++ b/flatland/core/grid/grid4_astar.py
@@ -1,4 +1,4 @@
-from flatland.core.grid.grid4_utils import validate_new_transition
+
 from flatland.core.grid.grid_utils import IntVector2D
 from flatland.core.grid.grid_utils import IntVector2DArrayType
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
@@ -12,9 +12,9 @@ class AStarNode:
     def __init__(self, parent: IntVector2D = None, pos: IntVector2D = None):
         self.parent: IntVector2D = parent
         self.pos: IntVector2D = pos
-        self.g: float = 0.0
-        self.h: float = 0.0
-        self.f: float = 0.0
+        self.g = 0.0
+        self.h = 0.0
+        self.f = 0.0
 
     def __eq__(self, other: IntVector2D):
         return self.pos == other.pos
@@ -80,7 +80,7 @@ def a_star(rail_trans: RailEnvTransitions,
                 continue
 
             # validate positions
-            if not validate_new_transition(rail_trans, grid_map.grid, prev_pos, current_node.pos, node_pos,
+            if not grid_map.validate_new_transition(rail_trans, prev_pos, current_node.pos, node_pos,
                                            end_node.pos):
                 continue
 
diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py
index d64a160b..0a0ba6b8 100644
--- a/flatland/core/grid/grid4_utils.py
+++ b/flatland/core/grid/grid4_utils.py
@@ -1,7 +1,8 @@
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid_utils import IntVector2DArrayType
 
 
-def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
+def get_direction(pos1: IntVector2DArrayType, pos2: IntVector2DArrayType) -> Grid4TransitionsEnum:
     """
     Assumes pos1 and pos2 are adjacent location on grid.
     Returns direction (int) that can be used with transitions.
@@ -23,45 +24,6 @@ def mirror(dir):
     return (dir + 2) % 4
 
 
-def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
-    # start by getting direction used to get to current node
-    # and direction from current node to possible child node
-    new_dir = get_direction(current_pos, new_pos)
-    if prev_pos is not None:
-        current_dir = get_direction(prev_pos, current_pos)
-    else:
-        current_dir = new_dir
-    # create new transition that would go to child
-    new_trans = rail_array[current_pos]
-    if prev_pos is None:
-        if new_trans == 0:
-            # need to flip direction because of how end points are defined
-            new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
-        else:
-            # check if matches existing layout
-            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
-    else:
-        # set the forward path
-        new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
-        # set the backwards path
-        new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
-    if new_pos == end_pos:
-        # need to validate end pos setup as well
-        new_trans_e = rail_array[end_pos]
-        if new_trans_e == 0:
-            # need to flip direction because of how end points are defined
-            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
-        else:
-            # check if matches existing layout
-            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
-
-        if not rail_trans.is_valid(new_trans_e):
-            return False
-
-    # is transition is valid?
-    return rail_trans.is_valid(new_trans)
-
-
 def get_new_position(position, movement):
     """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
     if movement == Grid4TransitionsEnum.NORTH:
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 232d6fda..2b06714b 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -7,7 +7,8 @@ from importlib_resources import path
 from numpy import array
 
 from flatland.core.grid.grid4 import Grid4Transitions
-from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.grid.grid4_utils import get_new_position, get_direction
+from flatland.core.grid.grid_utils import IntVector2DArrayType
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transitions import Transitions
 
@@ -540,6 +541,46 @@ class GridTransitionMap(TransitionMap):
             self.set_transitions((rcPos[0], rcPos[1]), transition)
         return True
 
+    def validate_new_transition(self, rail_trans: RailEnvTransitions,
+                                prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType,
+                                new_pos: IntVector2DArrayType, end_pos: IntVector2DArrayType):
+        # start by getting direction used to get to current node
+        # and direction from current node to possible child node
+        new_dir = get_direction(current_pos, new_pos)
+        if prev_pos is not None:
+            current_dir = get_direction(prev_pos, current_pos)
+        else:
+            current_dir = new_dir
+        # create new transition that would go to child
+        new_trans = self.grid[current_pos]
+        if prev_pos is None:
+            if new_trans == 0:
+                # need to flip direction because of how end points are defined
+                new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
+            else:
+                # check if matches existing layout
+                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        if new_pos == end_pos:
+            # need to validate end pos setup as well
+            new_trans_e = self.grid[end_pos]
+            if new_trans_e == 0:
+                # need to flip direction because of how end points are defined
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+            else:
+                # check if matches existing layout
+                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
+
+            if not rail_trans.is_valid(new_trans_e):
+                return False
+
+        # is transition is valid?
+        return rail_trans.is_valid(new_trans)
+
 
 def mirror(dir):
     return (dir + 2) % 4
diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py
index 9d4a72c0..9c01be4b 100644
--- a/tests/test_flatland_core_transitions.py
+++ b/tests/test_flatland_core_transitions.py
@@ -2,12 +2,10 @@
 # -*- coding: utf-8 -*-
 
 """Tests for `flatland` package."""
-import numpy as np
-
 from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.grid.grid8 import Grid8Transitions
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
-from flatland.core.grid.grid4_utils import validate_new_transition
+from flatland.core.transition_map import GridTransitionMap
 
 
 # remove whitespace in string; keep whitespace below for easier reading
@@ -117,35 +115,35 @@ def test_is_valid_railenv_transitions():
 
 def test_adding_new_valid_transition():
     rail_trans = RailEnvTransitions()
-    rail_array = np.zeros(shape=(15, 15), dtype=np.uint16)
+    grid_map = GridTransitionMap(width=15, height=15, transitions=rail_trans)
 
     # adding straight
-    assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True)
+    assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (6, 5), (10, 10)) is True)
 
     # adding valid right turn
-    assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
+    assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
     # adding valid left turn
-    assert (validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
+    assert (grid_map.validate_new_transition(rail_trans, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
 
     # adding invalid turn
-    rail_array[(5, 5)] = rail_trans.transitions[2]
-    assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
+    grid_map.grid[(5, 5)] = rail_trans.transitions[2]
+    assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
 
     # should create #4 -> valid
-    rail_array[(5, 5)] = rail_trans.transitions[3]
-    assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
+    grid_map.grid[(5, 5)] = rail_trans.transitions[3]
+    assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
 
     # adding invalid turn
-    rail_array[(5, 5)] = rail_trans.transitions[7]
-    assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
+    grid_map.grid[(5, 5)] = rail_trans.transitions[7]
+    assert (grid_map.validate_new_transition(rail_trans, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
 
     # test path start condition
-    rail_array[(5, 5)] = rail_trans.transitions[0]
-    assert (validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True)
+    grid_map.grid[(5, 5)] = rail_trans.transitions[0]
+    assert (grid_map.validate_new_transition(rail_trans, None, (5, 5), (5, 6), (10, 10)) is True)
 
     # test path end condition
-    rail_array[(5, 5)] = rail_trans.transitions[0]
-    assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True)
+    grid_map.grid[(5, 5)] = rail_trans.transitions[0]
+    assert (grid_map.validate_new_transition(rail_trans, (5, 4), (5, 5), (6, 5), (6, 5)) is True)
 
 
 def test_valid_railenv_transitions():
-- 
GitLab