From af5ed9877178e83285001fed8baa44d90f8d601e Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Thu, 3 Oct 2019 15:35:15 -0400
Subject: [PATCH] refactoring of rail connection functions

---
 flatland/envs/grid4_generators_utils.py       | 61 +++++--------------
 flatland/envs/rail_generators.py              | 18 ++++--
 ...est_flatland_core_grid4_generators_util.py |  8 +--
 3 files changed, 30 insertions(+), 57 deletions(-)

diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index 77b54e54..dfd9218f 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -14,7 +14,7 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
 
 
-def connect_basic_operation(
+def connect_rail(
     rail_trans: RailEnvTransitions,
     grid_map: GridTransitionMap,
     start: IntVector2D,
@@ -26,9 +26,20 @@ def connect_basic_operation(
     forbidden_cells=None
 ) -> IntVector2DArray:
     """
-    Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
+        Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
     returns the path created as a list of positions.
+    :param rail_trans:
+    :param grid_map:
+    :param start:
+    :param end:
+    :param flip_start_node_trans:
+    :param flip_end_node_trans:
+    :param nice:
+    :param a_star_distance_function:
+    :param forbidden_cells:
+    :return:
     """
+
     # in the worst case we will need to do a A* search, so we might as well set that up
     path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice, forbidden_cells)
     if len(path) < 2:
@@ -79,7 +90,7 @@ def connect_basic_operation(
     return path
 
 
-def connect_line(rail_trans, grid_map, start, end, openend=False):
+def connect_straigt_line(rail_trans, grid_map, start, end, openend=False):
     """
     Generates a straight rail line from start cell to end cell.
     Diagonal lines are not allowed
@@ -126,47 +137,3 @@ def connect_line(rail_trans, grid_map, start, end, openend=False):
         path.append(current_cell)
     return path
 
-def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
-                 start: IntVector2D, end: IntVector2D,
-                 a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
-    """
-
-    :param rail_trans:
-    :param grid_map:
-    :param start:
-    :param end:
-    :param a_star_distance_function:
-    :return:
-    """
-    return connect_basic_operation(rail_trans, grid_map, start, end, True, True, True, a_star_distance_function)
-
-
-def connect_cities(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
-                   start: IntVector2D, end: IntVector2D,
-                   a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
-                   forbidden_cells=None) -> IntVector2DArray:
-    """
-
-    :param rail_trans:
-    :param grid_map:
-    :param start:
-    :param end:
-    :param forbidden_cells:
-    :param a_star_distance_function:
-    :return:
-    """
-    return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function,
-                                   forbidden_cells)
-
-def connect_straigt_line(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D,
-                         end: IntVector2D, openend=False) -> IntVector2DArray:
-    """
-
-    :param rail_trans:
-    :param grid_map:
-    :param start:
-    :param end:
-    :param openend:
-    :return:
-    """
-    return connect_line(rail_trans, grid_map, start, end, openend)
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 7d7174d3..ee31666c 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -11,7 +11,7 @@ from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, I
     Vec2dOperations
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.grid4_generators_utils import connect_rail, connect_cities, connect_straigt_line
+from flatland.envs.grid4_generators_utils import connect_rail, connect_straigt_line
 
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
@@ -125,7 +125,9 @@ def complex_rail_generator(nr_start_goal=1,
                 # we might as well give up at this point
                 break
 
-            new_path = connect_rail(rail_trans, grid_map, start, goal)
+            new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True,
+                                    flip_end_node_trans=True, nice=True,
+                                    forbidden_cells=None)
             if len(new_path) >= 2:
                 nr_created += 1
                 start_goal.append([start, goal])
@@ -150,7 +152,10 @@ def complex_rail_generator(nr_start_goal=1,
                     break
             if not all_ok:
                 break
-            new_path = connect_rail(rail_trans, grid_map, start, goal)
+            new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True,
+                                    flip_end_node_trans=True, nice=True,
+                                    forbidden_cells=None)
+
             if len(new_path) >= 2:
                 nr_created += 1
 
@@ -735,9 +740,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                                 min_connection_dist = tmp_dist
                                 neighb_connection_point = tmp_in_connection_point
                                 neighbour_direction = dir
-                    new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point,
-                                              neighb_connection_point,
-                                              forbidden_cells=city_cells)
+                    new_line = connect_rail(rail_trans, grid_map, tmp_out_connection_point,
+                                            neighb_connection_point, flip_start_node_trans=False,
+                                            flip_end_node_trans=False, nice=True,
+                                            forbidden_cells=city_cells)
                     all_paths.extend(new_line)
 
         return all_paths
diff --git a/tests/test_flatland_core_grid4_generators_util.py b/tests/test_flatland_core_grid4_generators_util.py
index 64c03f3c..ba71daac 100644
--- a/tests/test_flatland_core_grid4_generators_util.py
+++ b/tests/test_flatland_core_grid4_generators_util.py
@@ -3,7 +3,7 @@ import numpy as np
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.grid4_generators_utils import connect_rail, connect_cities
+from flatland.envs.grid4_generators_utils import connect_rail
 
 
 def test_build_railway_infrastructure():
@@ -20,17 +20,17 @@ def test_build_railway_infrastructure():
 
     start_point = (1, 3)
     end_point = (1, 7)
-    connection_002 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
+    connection_002 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
     connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)]
 
     start_point = (6, 2)
     end_point = (6, 5)
-    connection_003 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
+    connection_003 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
     connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)]
 
     start_point = (7, 5)
     end_point = (8, 9)
-    connection_004 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
+    connection_004 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
     connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)]
 
     assert connection_001 == connection_001_expected, \
-- 
GitLab