diff --git a/examples/training_example.py b/examples/training_example.py
index df93479f5a5ee05abfcb1a98b07ef052bffc2bd4..78c0299d4cee8ae588bcf8e9e7559ff1c8364c26 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -16,7 +16,7 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor
 LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
 env = RailEnv(width=20,
               height=20,
-              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=0),
               schedule_generator=complex_schedule_generator(),
               obs_builder_object=TreeObservation,
               number_of_agents=3)
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index dfd9218f0e575fc8909994d635315e6cab6cd5b3..a332cd70b3c8055856332ba1214be9e1c0df55f8 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -19,10 +19,10 @@ def connect_rail(
     grid_map: GridTransitionMap,
     start: IntVector2D,
     end: IntVector2D,
+    a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
     flip_start_node_trans=False,
     flip_end_node_trans=False,
     nice=True,
-    a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
     forbidden_cells=None
 ) -> IntVector2DArray:
     """
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index ee31666c99082e77855764c7bdee72031034ad22..557f25e4ca956e541ae80be051e7f9fa7e93fc93 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -7,6 +7,7 @@ import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_direction, mirror
+from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, IntVector2DArray, IntVector2D, \
     Vec2dOperations
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
@@ -125,7 +126,8 @@ def complex_rail_generator(nr_start_goal=1,
                 # we might as well give up at this point
                 break
 
-            new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True,
+            new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance,
+                                    flip_start_node_trans=True,
                                     flip_end_node_trans=True, nice=True,
                                     forbidden_cells=None)
             if len(new_path) >= 2:
@@ -152,7 +154,8 @@ def complex_rail_generator(nr_start_goal=1,
                     break
             if not all_ok:
                 break
-            new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True,
+            new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance,
+                                    flip_start_node_trans=True,
                                     flip_end_node_trans=True, nice=True,
                                     forbidden_cells=None)