diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/simple_example_city_railway_generator.py
similarity index 90%
rename from examples/Simple_Realistic_Railway_Generator.py
rename to examples/simple_example_city_railway_generator.py
index 082e9b679a221425516c547f02af0cf3ec38c4dc..7aa9f142631e1a07b77b416cdadd13a42ac83dc7 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/simple_example_city_railway_generator.py
@@ -1,10 +1,12 @@
 import copy
 import os
 import warnings
+from typing import Sequence, Optional
 
 import numpy as np
 
-from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d, IntVector2DArrayType
+from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d, IntVector2DArray, IntVector2DDistance, \
+    IntVector2DArrayArray
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail
@@ -17,19 +19,19 @@ from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 FloatArrayType = []
 
 
-def realistic_rail_generator(num_cities=5,
-                             city_size=10,
-                             allowed_rotation_angles=None,
-                             max_number_of_station_tracks=4,
-                             nbr_of_switches_per_station_track=2,
-                             connect_max_nbr_of_shortes_city=4,
-                             do_random_connect_stations=False,
-                             seed=0,
-                             print_out_info=True) -> RailGenerator:
+def realistic_rail_generator(num_cities: int = 5,
+                             city_size: int = 10,
+                             allowed_rotation_angles: Optional[Sequence[float]] = None,
+                             max_number_of_station_tracks: int = 4,
+                             nbr_of_switches_per_station_track: int = 2,
+                             connect_max_nbr_of_shortes_city: int = 4,
+                             do_random_connect_stations: bool = False,
+                             a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
+                             seed: int = 0,
+                             print_out_info: bool = True) -> RailGenerator:
     """
     This is a level generator which generates a realistic rail configurations
 
-    :param print_out_info:
     :param num_cities: Number of city node
     :param city_size: Length of city measure in cells
     :param allowed_rotation_angles: Rotate the city (around center)
@@ -37,8 +39,9 @@ def realistic_rail_generator(num_cities=5,
     :param nbr_of_switches_per_station_track: number of switches per track (max)
     :param connect_max_nbr_of_shortes_city: max number of connecting track between stations
     :param do_random_connect_stations : if false connect the stations along the grid (top,left -> down,right), else rand
+    :param a_star_distance_function: Heuristic how the distance between two nodes get estimated in the "a-star" path
     :param seed: Random Seed
-    :print_out_info : print debug info
+    :param print_out_info: print debug info if True
     :return:
         -------
     numpy.ndarray of type numpy.uint16
@@ -48,7 +51,7 @@ def realistic_rail_generator(num_cities=5,
     def do_generate_city_locations(width: int,
                                    height: int,
                                    intern_city_size: int,
-                                   intern_max_number_of_station_tracks: int) -> (IntVector2DArrayType, int):
+                                   intern_max_number_of_station_tracks: int) -> (IntVector2DArray, int):
 
         X = int(np.floor(max(1, height - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size))
         Y = int(np.floor(max(1, width - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size))
@@ -68,7 +71,7 @@ def realistic_rail_generator(num_cities=5,
         generate_city_locations = [[(int(xs[i]), int(ys[i])), (int(xs[i]), int(ys[i]))] for i in range(len(xs))]
         return generate_city_locations, max_num_cities
 
-    def do_orient_cities(generate_city_locations: IntVector2DArrayType, intern_city_size: int,
+    def do_orient_cities(generate_city_locations: IntVector2DArrayArray, intern_city_size: int,
                          rotation_angles_set: FloatArrayType):
         for i in range(len(generate_city_locations)):
             # station main orientation  (horizontal or vertical
@@ -83,12 +86,12 @@ def realistic_rail_generator(num_cities=5,
 
     def create_stations_from_city_locations(rail_trans: RailEnvTransitions,
                                             grid_map: GridTransitionMap,
-                                            generate_city_locations: IntVector2DArrayType,
-                                            intern_max_number_of_station_tracks: int) -> (IntVector2DArrayType,
-                                                                                          IntVector2DArrayType,
-                                                                                          IntVector2DArrayType,
-                                                                                          IntVector2DArrayType,
-                                                                                          IntVector2DArrayType):
+                                            generate_city_locations: IntVector2DArray,
+                                            intern_max_number_of_station_tracks: int) -> (IntVector2DArray,
+                                                                                          IntVector2DArray,
+                                                                                          IntVector2DArray,
+                                                                                          IntVector2DArray,
+                                                                                          IntVector2DArray):
 
         nodes_added = []
         start_nodes_added = [[] for _ in range(len(generate_city_locations))]
@@ -115,7 +118,7 @@ def realistic_rail_generator(num_cities=5,
                 end_node = Vec2d.ceil(
                     Vec2d.add(org_end_node, Vec2d.scale(ortho_trans, s)))
 
-                connection = connect_from_nodes(rail_trans, grid_map, start_node, end_node)
+                connection = connect_from_nodes(rail_trans, grid_map, start_node, end_node, a_star_distance_function)
                 if len(connection) > 0:
                     nodes_added.append(start_node)
                     nodes_added.append(end_node)
@@ -142,9 +145,9 @@ def realistic_rail_generator(num_cities=5,
 
     def create_switches_at_stations(rail_trans: RailEnvTransitions,
                                     grid_map: GridTransitionMap,
-                                    station_tracks: IntVector2DArrayType,
-                                    nodes_added: IntVector2DArrayType,
-                                    intern_nbr_of_switches_per_station_track: int) -> IntVector2DArrayType:
+                                    station_tracks: IntVector2DArray,
+                                    nodes_added: IntVector2DArray,
+                                    intern_nbr_of_switches_per_station_track: int) -> IntVector2DArray:
 
         for k_loop in range(intern_nbr_of_switches_per_station_track):
             for city_loop in range(len(station_tracks)):
@@ -170,13 +173,14 @@ def realistic_rail_generator(num_cities=5,
                                     if x < 2:
                                         x = len(track) - 1
                                 end_node = track[x]
-                                connection = connect_rail(rail_trans, grid_map, start_node, end_node)
+                                connection = connect_rail(rail_trans, grid_map, start_node, end_node,
+                                                          a_star_distance_function)
                                 if len(connection) == 0:
                                     if print_out_info:
                                         print("create_switches_at_stations : connect_rail -> no path found")
                                     start_node = datas[i][0]
                                     end_node = datas[i - 1][0]
-                                    connect_rail(rail_trans, grid_map, start_node, end_node)
+                                    connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function)
 
                                 nodes_added.append(start_node)
                                 nodes_added.append(end_node)
@@ -226,10 +230,10 @@ def realistic_rail_generator(num_cities=5,
         return graph, np.unique(graph_ids).astype(int)
 
     def connect_sub_graphs(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
-                           org_s_nodes: IntVector2DArrayType,
-                           org_e_nodes: IntVector2DArrayType,
-                           city_edges: IntVector2DArrayType,
-                           nodes_added: IntVector2DArrayType):
+                           org_s_nodes: IntVector2DArray,
+                           org_e_nodes: IntVector2DArray,
+                           city_edges: IntVector2DArray,
+                           nodes_added: IntVector2DArray):
         _, graphids = calc_nbr_of_graphs(city_edges)
         if len(graphids) > 0:
             for i in range(len(graphids) - 1):
@@ -247,7 +251,7 @@ def realistic_rail_generator(num_cities=5,
                     # TODO : will be generated.
                     grid_map.grid[start_node] = 0
                     grid_map.grid[end_node] = 0
-                    connection = connect_rail(rail_trans, grid_map, start_node, end_node)
+                    connection = connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function)
                     if len(connection) > 0:
                         nodes_added.append(start_node)
                         nodes_added.append(end_node)
@@ -259,9 +263,9 @@ def realistic_rail_generator(num_cities=5,
 
     def connect_stations(rail_trans: RailEnvTransitions,
                          grid_map: GridTransitionMap,
-                         org_s_nodes: IntVector2DArrayType,
-                         org_e_nodes: IntVector2DArrayType,
-                         nodes_added: IntVector2DArrayType,
+                         org_s_nodes: IntVector2DArray,
+                         org_e_nodes: IntVector2DArray,
+                         nodes_added: IntVector2DArray,
                          intern_connect_max_nbr_of_shortes_city: int):
         city_edges = []
 
@@ -291,7 +295,7 @@ def realistic_rail_generator(num_cities=5,
                         tmp_trans_en = grid_map.grid[end_node]
                         grid_map.grid[start_node] = 0
                         grid_map.grid[end_node] = 0
-                        connection = connect_rail(rail_trans, grid_map, start_node, end_node)
+                        connection = connect_rail(rail_trans, grid_map, start_node, end_node, a_star_distance_function)
                         if len(connection) > 0:
                             s_nodes[city_loop].remove(start_node)
                             e_nodes[cl].remove(end_node)
@@ -313,9 +317,9 @@ def realistic_rail_generator(num_cities=5,
         connect_sub_graphs(rail_trans, grid_map, org_s_nodes, org_e_nodes, city_edges, nodes_added)
 
     def connect_random_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
-                                start_nodes_added: IntVector2DArrayType,
-                                end_nodes_added: IntVector2DArrayType,
-                                nodes_added: IntVector2DArrayType,
+                                start_nodes_added: IntVector2DArray,
+                                end_nodes_added: IntVector2DArray,
+                                nodes_added: IntVector2DArray,
                                 intern_connect_max_nbr_of_shortes_city: int):
         if len(start_nodes_added) < 1:
             return
@@ -355,7 +359,7 @@ def realistic_rail_generator(num_cities=5,
                 end_node = e_nodes[idx_e_nodes[i]]
                 grid_map.grid[start_node] = 0
                 grid_map.grid[end_node] = 0
-                connection = connect_nodes(rail_trans, grid_map, start_node, end_node)
+                connection = connect_nodes(rail_trans, grid_map, start_node, end_node, a_star_distance_function)
                 if len(connection) > 0:
                     nodes_added.append(start_node)
                     nodes_added.append(end_node)
@@ -364,7 +368,7 @@ def realistic_rail_generator(num_cities=5,
                         print("connect_random_stations : connect_nodes -> no path found")
 
     def remove_switch_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
-                               train_stations: IntVector2DArrayType):
+                               train_stations: IntVector2DArray):
         tmp_train_stations = copy.deepcopy(train_stations)
         for city_loop in range(len(train_stations)):
             for n in tmp_train_stations[city_loop]:
@@ -481,7 +485,7 @@ def realistic_rail_generator(num_cities=5,
                 if (tries + 1) % 10 == 0:
                     start_node = np.random.choice(avail_start_nodes)
                 if tries > 100:
-                    warnings.warn("Could not set trainstations, removing agent!")
+                    warnings.warn("Could not set train_stations, removing agent!")
                     found_agent_pair = False
                     break
             if found_agent_pair:
@@ -508,13 +512,13 @@ if os.path.exists("./../render_output/"):
                       height=40 + np.random.choice(100),
                       rail_generator=realistic_rail_generator(num_cities=5 + np.random.choice(10),
                                                               city_size=10 + np.random.choice(5),
-                                                              allowed_rotation_angles=np.arange(0, 360, 90),
-                                                              max_number_of_station_tracks=1 + np.random.choice(4),
+                                                              allowed_rotation_angles=np.arange(0, 360, 6),
+                                                              max_number_of_station_tracks=4 + np.random.choice(4),
                                                               nbr_of_switches_per_station_track=2 + np.random.choice(2),
                                                               connect_max_nbr_of_shortes_city=2 + np.random.choice(4),
                                                               do_random_connect_stations=itrials % 2 == 0,
-                                                              # Number of cities in map
-                                                              seed=itrials,  # Random seed
+                                                              a_star_distance_function=Vec2d.get_euclidean_distance,
+                                                              seed=itrials,
                                                               print_out_info=False
                                                               ),
                       schedule_generator=sparse_schedule_generator(),
diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py
index 91fec9a95cfc2b01d119f6ea857778bba52f2295..9022329565a093a4c60336680a98ba4d8d7a0a83 100644
--- a/flatland/core/grid/grid4_astar.py
+++ b/flatland/core/grid/grid4_astar.py
@@ -1,24 +1,28 @@
 import numpy as np
-from matplotlib import pyplot as plt
 
-from flatland.core.grid.grid_utils import IntVector2D
-from flatland.core.grid.grid_utils import IntVector2DArrayType
+from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
+from flatland.core.grid.grid_utils import IntVector2DArray
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
-from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 
 
 class AStarNode:
     """A node class for A* Pathfinding"""
 
-    def __init__(self, parent: IntVector2D = None, pos: IntVector2D = None):
-        self.parent: IntVector2D = parent
+    def __init__(self, pos: IntVector2D, parent=None):
+        self.parent = parent
         self.pos: IntVector2D = pos
         self.g = 0.0
         self.h = 0.0
         self.f = 0.0
 
-    def __eq__(self, other: IntVector2D):
+    def __eq__(self, other):
+        """
+
+        Parameters
+        ----------
+        other : AStarNode
+        """
         return self.pos == other.pos
 
     def __hash__(self):
@@ -32,10 +36,9 @@ class AStarNode:
             self.f = other.f
 
 
-def a_star(rail_trans: RailEnvTransitions,
-           grid_map: GridTransitionMap,
+def a_star(grid_map: GridTransitionMap,
            start: IntVector2D, end: IntVector2D,
-           a_star_distance_function=Vec2d.get_manhattan_distance) -> IntVector2DArrayType:
+           a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     """
     Returns a list of tuples as a path from the given start to end.
     If no path is found, returns path to closest point to end.
@@ -44,8 +47,8 @@ def a_star(rail_trans: RailEnvTransitions,
 
     tmp = np.zeros(rail_shape) - 10
 
-    start_node = AStarNode(None, start)
-    end_node = AStarNode(None, end)
+    start_node = AStarNode(start, None)
+    end_node = AStarNode(end, None)
     open_nodes = set()
     closed_nodes = set()
     open_nodes.add(start_node)
@@ -72,13 +75,6 @@ def a_star(rail_trans: RailEnvTransitions,
                 path.append(current.pos)
                 current = current.parent
 
-            if False:
-                plt.ion()
-                plt.clf()
-                plt.imshow(tmp, interpolation='nearest')
-                plt.draw()
-                plt.pause(1e-17)
-
             # return reversed path
             return path[::-1]
 
@@ -91,7 +87,7 @@ def a_star(rail_trans: RailEnvTransitions,
 
         for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
             # update the "current" pos
-            node_pos = Vec2d.add(current_node.pos, new_pos)
+            node_pos: IntVector2D = Vec2d.add(current_node.pos, new_pos)
 
             # is node_pos inside the grid?
             if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
@@ -102,7 +98,7 @@ def a_star(rail_trans: RailEnvTransitions,
                 continue
 
             # create new node
-            new_node = AStarNode(current_node, node_pos)
+            new_node = AStarNode(node_pos, current_node)
             children.append(new_node)
 
         # loop through children
diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py
index 0a0ba6b8774b09f52562187a8d475f629e295c5a..98652459d7a7ac7f1694ac53fe1d0a12880ab8b2 100644
--- a/flatland/core/grid/grid4_utils.py
+++ b/flatland/core/grid/grid4_utils.py
@@ -1,8 +1,8 @@
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.core.grid.grid_utils import IntVector2DArrayType
+from flatland.core.grid.grid_utils import IntVector2DArray
 
 
-def get_direction(pos1: IntVector2DArrayType, pos2: IntVector2DArrayType) -> Grid4TransitionsEnum:
+def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4TransitionsEnum:
     """
     Assumes pos1 and pos2 are adjacent location on grid.
     Returns direction (int) that can be used with transitions.
diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py
index f9657bf340145e3a18eaa4203c2c9a163fde8620..3fda39b1d0e027c281c7b2b41545632ef0fcf3f9 100644
--- a/flatland/core/grid/grid_utils.py
+++ b/flatland/core/grid/grid_utils.py
@@ -1,11 +1,17 @@
-from typing import Tuple
+from typing import Tuple, Callable, List
 
 import numpy as np
 
 Vector2D = Tuple[float, float]
 IntVector2D = Tuple[int, int]
 
-IntVector2DArrayType = []
+IntVector2DArray = List[IntVector2D]
+IntVector2DArrayArray = List[List[IntVector2D]]
+
+Vector2DArray = List[Vector2D]
+Vector2DArrayArray = List[List[Vector2D]]
+
+IntVector2DDistance = Callable[[IntVector2D, IntVector2D], float]
 
 
 class Vec2dOperations:
@@ -73,42 +79,30 @@ class Vec2dOperations:
         """
         return np.sqrt(node[0] * node[0] + node[1] * node[1])
 
-
     @staticmethod
-    def get_manhattan_norm(node: Vector2D) -> float:
+    def get_euclidean_distance(node_a: Vector2D, node_b: Vector2D) -> float:
         """
         calculates the euclidean norm of the 2d vector
 
         :param node: tuple with coordinate (x,y) or 2d vector
         :return:
             -------
-        returns the manhatten norm
+        returns the euclidean distance
         """
-        return abs(node[0] * node[0]) + abs(node[1] * node[1])
-
-    @staticmethod
-    def get_euclidean_distance(node_a: Vector2D,node_b: Vector2D) -> float:
-        """
-        calculates the euclidean norm of the 2d vector
-
-        :param node: tuple with coordinate (x,y) or 2d vector
-        :return:
-            -------
-        returnss the manhatten distance
-        """
-        return Vec2dOperations.get_norm(Vec2dOperations.subtract(node_b,node_a))
+        return Vec2dOperations.get_norm(Vec2dOperations.subtract(node_b, node_a))
 
     @staticmethod
     def get_manhattan_distance(node_a: Vector2D, node_b: Vector2D) -> float:
         """
-        calculates the euclidean norm of the 2d vector
+        calculates the manhattan distance of the 2d vector
 
         :param node: tuple with coordinate (x,y) or 2d vector
         :return:
             -------
-        returnss the manhatten distance
+        returns the manhattan distance
         """
-        return Vec2dOperations.get_manhattan_norm(Vec2dOperations.subtract(node_b, node_a))
+        delta = (Vec2dOperations.subtract(node_b, node_a))
+        return np.abs(delta[0]) + np.abs(delta[1])
 
     @staticmethod
     def normalize(node: Vector2D) -> Tuple[float, float]:
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index d8da9c561368922db0ebecf5098ef6ad04aff383..f860b77c6ecf5088ad8c47fd2b0665a88aeeef0c 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -8,7 +8,7 @@ from numpy import array
 
 from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position, get_direction
-from flatland.core.grid.grid_utils import IntVector2DArrayType
+from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transitions import Transitions
@@ -302,7 +302,7 @@ class GridTransitionMap(TransitionMap):
         self.height = new_height
         self.grid = new_grid
 
-    def is_dead_end(self, rcPos: IntVector2DArrayType):
+    def is_dead_end(self, rcPos: IntVector2DArray):
         """
         Check if the cell is a dead-end.
 
@@ -322,7 +322,7 @@ class GridTransitionMap(TransitionMap):
             tmp = tmp >> 1
         return nbits == 1
 
-    def is_simple_turn(self, rcPos: IntVector2DArrayType):
+    def is_simple_turn(self, rcPos: IntVector2DArray):
         """
         Check if the cell is a left/right simple turn
 
@@ -349,7 +349,7 @@ class GridTransitionMap(TransitionMap):
 
         return is_simple_turn(tmp)
 
-    def check_path_exists(self, start: IntVector2DArrayType, direction: int, end: IntVector2DArrayType):
+    def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
         # print("_path_exists({},{},{}".format(start, direction, end))
         # BFS - Check if a path exists between the 2 nodes
 
@@ -373,7 +373,7 @@ class GridTransitionMap(TransitionMap):
 
         return False
 
-    def cell_neighbours_valid(self, rcPos: IntVector2DArrayType, check_this_cell=False):
+    def cell_neighbours_valid(self, rcPos: IntVector2DArray, check_this_cell=False):
         """
         Check validity of cell at rcPos = tuple(row, column)
         Checks that:
@@ -425,7 +425,7 @@ class GridTransitionMap(TransitionMap):
 
         return True
 
-    def fix_neighbours(self, rcPos: IntVector2DArrayType, check_this_cell=False):
+    def fix_neighbours(self, rcPos: IntVector2DArray, check_this_cell=False):
         """
         Check validity of cell at rcPos = tuple(row, column)
         Checks that:
@@ -478,7 +478,7 @@ class GridTransitionMap(TransitionMap):
 
         return True
 
-    def fix_transitions(self, rcPos: IntVector2DArrayType):
+    def fix_transitions(self, rcPos: IntVector2DArray):
         """
         Fixes broken transitions
         """
@@ -543,8 +543,8 @@ class GridTransitionMap(TransitionMap):
             self.set_transitions((rcPos[0], rcPos[1]), transition)
         return True
 
-    def validate_new_transition(self, prev_pos: IntVector2DArrayType, current_pos: IntVector2DArrayType,
-                                new_pos: IntVector2DArrayType, end_pos: IntVector2DArrayType):
+    def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
+                                new_pos: IntVector2D, end_pos: IntVector2D):
 
         # start by getting direction used to get to current node
         # and direction from current node to possible child node
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index 0ead334f4ac8a87f53c0f10bb74855050dc89e62..d6f47abfd85cfa1cc7e72e27aeb4f7ededa975dd 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -7,7 +7,8 @@ a GridTransitionMap object.
 
 from flatland.core.grid.grid4_astar import a_star
 from flatland.core.grid.grid4_utils import get_direction, mirror
-from flatland.core.grid.grid_utils import IntVector2D
+from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
+from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
 
 
@@ -15,12 +16,13 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
                             start: IntVector2D,
                             end: IntVector2D,
                             flip_start_node_trans=False,
-                            flip_end_node_trans=False):
+                            flip_end_node_trans=False,
+                            a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
     """
     Creates a new path [start,end] in grid_map, based on rail_trans.
     """
     # in the worst case we will need to do a A* search, so we might as well set that up
-    path = a_star(rail_trans, grid_map, start, end)
+    path = a_star(grid_map, start, end, a_star_distance_function)
     if len(path) < 2:
         return []
     current_dir = get_direction(path[0], path[1])
@@ -67,18 +69,25 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
     return path
 
 
-def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D):
-    return connect_basic_operation(rail_trans, grid_map, start, end, True, True)
+def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
+                 start: IntVector2D, end: IntVector2D,
+                 a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+    return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function)
 
 
-def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D):
-    return connect_basic_operation(rail_trans, grid_map, start, end, False, False)
+def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
+                  start: IntVector2D, end: IntVector2D,
+                  a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+    return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function)
 
 
-def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D,
-                       end: IntVector2D):
-    return connect_basic_operation(rail_trans, grid_map, start, end, False, True)
+def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
+                       start: IntVector2D, end: IntVector2D,
+                       a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+    return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function)
 
 
-def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D):
-    return connect_basic_operation(rail_trans, grid_map, start, end, True, False)
+def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
+                     start: IntVector2D, end: IntVector2D,
+                     a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+    return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function)
diff --git a/flatland/envs/rail_generators_city_generator.py b/flatland/envs/rail_generators_city_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391