diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
index 11abf7e2da57db7b87dd20ddfa7f8d32d0806c3c..7f155cf2c72f14231993c644df29ae215bd29903 100644
--- a/examples/Simple_Realistic_Railway_Generator.py
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -5,7 +5,7 @@ import warnings
 
 import numpy as np
 
-from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
+from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d, IntVector2DArrayType
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail
@@ -15,7 +15,6 @@ from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
 from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import AgentRenderVariant, RenderTool
 
-IntVector2DArrayType = []
 FloatArrayType = []
 
 
diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py
index feb72313f21b9ecc989688d63ba02ccf3a458107..3b4e69380a0cb3f3e69d2b6a08ab145f7a20ce04 100644
--- a/flatland/core/grid/grid4_astar.py
+++ b/flatland/core/grid/grid4_astar.py
@@ -1,7 +1,11 @@
 from flatland.core.grid.grid4_utils import validate_new_transition
+from flatland.core.grid.grid_utils import IntVector2D
+from flatland.core.grid.grid_utils import IntVector2DArrayType
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
+from flatland.core.transition_map import GridTransitionMap
 
 
-class AStarNode():
+class AStarNode:
     """A node class for A* Pathfinding"""
 
     def __init__(self, parent=None, pos=None):
@@ -25,12 +29,13 @@ class AStarNode():
             self.f = other.f
 
 
-def a_star(rail_trans, rail_array, start, end):
+def a_star(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D) -> \
+    (IntVector2DArrayType):
     """
     Returns a list of tuples as a path from the given start to end.
     If no path is found, returns path to closest point to end.
     """
-    rail_shape = rail_array.shape
+    rail_shape = grid_map.grid.shape
     start_node = AStarNode(None, start)
     end_node = AStarNode(None, end)
     open_nodes = set()
@@ -73,7 +78,8 @@ def a_star(rail_trans, rail_array, start, end):
                 continue
 
             # validate positions
-            if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos):
+            if not validate_new_transition(rail_trans, grid_map.grid, prev_pos, current_node.pos, node_pos,
+                                           end_node.pos):
                 continue
 
             # create new node
diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py
index b7bf988fb754b30d4784e0feef36176fa5aa28df..ffbf79a7b99e1628d52814638c753811ad9f7282 100644
--- a/flatland/core/grid/grid_utils.py
+++ b/flatland/core/grid/grid_utils.py
@@ -5,6 +5,8 @@ import numpy as np
 Vector2D = Tuple[float, float]
 IntVector2D = Tuple[int, int]
 
+IntVector2DArrayType = []
+
 
 class Vec2dOperations:
 
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index c1bdedebd3edef2c22ef971a538fa55c5c1482a6..0ead334f4ac8a87f53c0f10bb74855050dc89e62 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -20,7 +20,7 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
     Creates a new path [start,end] in grid_map, based on rail_trans.
     """
     # in the worst case we will need to do a A* search, so we might as well set that up
-    path = a_star(rail_trans, grid_map.grid, start, end)
+    path = a_star(rail_trans, grid_map, start, end)
     if len(path) < 2:
         return []
     current_dir = get_direction(path[0], path[1])