From 43ffa6304159cde4b817ea78cf2e01b83473ae22 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Thu, 3 Oct 2019 15:57:32 -0400
Subject: [PATCH] fixed failing connection tests

---
 flatland/envs/grid4_generators_utils.py       | 20 ++++++-----------
 flatland/envs/rail_generators.py              | 17 ++++++--------
 ...est_flatland_core_grid4_generators_util.py | 22 ++++++++++++-------
 3 files changed, 28 insertions(+), 31 deletions(-)

diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index a332cd70..bc43f90e 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -14,17 +14,10 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
 
 
-def connect_rail(
-    rail_trans: RailEnvTransitions,
-    grid_map: GridTransitionMap,
-    start: IntVector2D,
-    end: IntVector2D,
-    a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
-    flip_start_node_trans=False,
-    flip_end_node_trans=False,
-    nice=True,
-    forbidden_cells=None
-) -> IntVector2DArray:
+def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
+                 a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
+                 flip_start_node_trans=False, flip_end_node_trans=False, respect_transition_validity=True,
+                 forbidden_cells=None) -> IntVector2DArray:
     """
         Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
     returns the path created as a list of positions.
@@ -34,14 +27,15 @@ def connect_rail(
     :param end:
     :param flip_start_node_trans:
     :param flip_end_node_trans:
-    :param nice:
+    :param respect_transition_validity:
     :param a_star_distance_function:
     :param forbidden_cells:
     :return:
     """
 
     # in the worst case we will need to do a A* search, so we might as well set that up
-    path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice, forbidden_cells)
+    path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, respect_transition_validity,
+                                    forbidden_cells)
     if len(path) < 2:
         print("No path found", path)
         return []
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 557f25e4..e6b9cf69 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -127,9 +127,8 @@ def complex_rail_generator(nr_start_goal=1,
                 break
 
             new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance,
-                                    flip_start_node_trans=True,
-                                    flip_end_node_trans=True, nice=True,
-                                    forbidden_cells=None)
+                                    flip_start_node_trans=True, flip_end_node_trans=True,
+                                    respect_transition_validity=True, forbidden_cells=None)
             if len(new_path) >= 2:
                 nr_created += 1
                 start_goal.append([start, goal])
@@ -155,9 +154,8 @@ def complex_rail_generator(nr_start_goal=1,
             if not all_ok:
                 break
             new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance,
-                                    flip_start_node_trans=True,
-                                    flip_end_node_trans=True, nice=True,
-                                    forbidden_cells=None)
+                                    flip_start_node_trans=True, flip_end_node_trans=True,
+                                    respect_transition_validity=True, forbidden_cells=None)
 
             if len(new_path) >= 2:
                 nr_created += 1
@@ -743,10 +741,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                                 min_connection_dist = tmp_dist
                                 neighb_connection_point = tmp_in_connection_point
                                 neighbour_direction = dir
-                    new_line = connect_rail(rail_trans, grid_map, tmp_out_connection_point,
-                                            neighb_connection_point, flip_start_node_trans=False,
-                                            flip_end_node_trans=False, nice=True,
-                                            forbidden_cells=city_cells)
+                    new_line = connect_rail(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point,
+                                            flip_start_node_trans=False, flip_end_node_trans=False,
+                                            respect_transition_validity=False, forbidden_cells=city_cells)
                     all_paths.extend(new_line)
 
         return all_paths
diff --git a/tests/test_flatland_core_grid4_generators_util.py b/tests/test_flatland_core_grid4_generators_util.py
index ba71daac..9722063e 100644
--- a/tests/test_flatland_core_grid4_generators_util.py
+++ b/tests/test_flatland_core_grid4_generators_util.py
@@ -1,6 +1,5 @@
 import numpy as np
 
-from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.grid4_generators_utils import connect_rail
@@ -14,23 +13,31 @@ def test_build_railway_infrastructure():
 
     start_point = (2, 2)
     end_point = (8, 8)
-    connection_001 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
+    connection_001 = connect_rail(rail_trans, grid_map, start_point, end_point,
+                                  flip_start_node_trans=True, flip_end_node_trans=True,
+                                  respect_transition_validity=True, forbidden_cells=None)
     connection_001_expected = [(2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (3, 8), (4, 8), (5, 8), (6, 8),
                                (7, 8), (8, 8)]
 
     start_point = (1, 3)
     end_point = (1, 7)
-    connection_002 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
+    connection_002 = connect_rail(rail_trans, grid_map, start_point, end_point,
+                                  flip_start_node_trans=False, flip_end_node_trans=False,
+                                  respect_transition_validity=True, forbidden_cells=None)
     connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)]
 
     start_point = (6, 2)
     end_point = (6, 5)
-    connection_003 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
+    connection_003 = connect_rail(rail_trans, grid_map, start_point, end_point,
+                                  flip_start_node_trans=False, flip_end_node_trans=True,
+                                  respect_transition_validity=True, forbidden_cells=None)
     connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)]
 
     start_point = (7, 5)
     end_point = (8, 9)
-    connection_004 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
+    connection_004 = connect_rail(rail_trans, grid_map, start_point, end_point,
+                                  flip_start_node_trans=True, flip_end_node_trans=False,
+                                  respect_transition_validity=True, forbidden_cells=None)
     connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)]
 
     assert connection_001 == connection_001_expected, \
@@ -64,6 +71,5 @@ def test_build_railway_infrastructure():
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
     ]
-
-    assert np.all(grid_map.grid == grid_map_grid_expected), \
-        "actual={}, expected={}".format(grid_map.grid, grid_map_grid_expected)
+    for i in range(len(grid_map_grid_expected)):
+        assert np.all(grid_map.grid[i] == grid_map_grid_expected[i])
-- 
GitLab