diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py
index 98652459d7a7ac7f1694ac53fe1d0a12880ab8b2..75cef7b4d3aea783140a5c08c3498a0bc321fb62 100644
--- a/flatland/core/grid/grid4_utils.py
+++ b/flatland/core/grid/grid4_utils.py
@@ -1,8 +1,8 @@
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.core.grid.grid_utils import IntVector2DArray
+from flatland.core.grid.grid_utils import IntVector2D
 
 
-def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4TransitionsEnum:
+def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
     """
     Assumes pos1 and pos2 are adjacent location on grid.
     Returns direction (int) that can be used with transitions.
@@ -10,13 +10,13 @@ def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4Transi
     diff_0 = pos2[0] - pos1[0]
     diff_1 = pos2[1] - pos1[1]
     if diff_0 < 0:
-        return 0
+        return Grid4TransitionsEnum.NORTH
     if diff_0 > 0:
-        return 2
+        return Grid4TransitionsEnum.SOUTH
     if diff_1 > 0:
-        return 1
+        return Grid4TransitionsEnum.EAST
     if diff_1 < 0:
-        return 3
+        return Grid4TransitionsEnum.WEST
     raise Exception("Could not determine direction {}->{}".format(pos1, pos2))
 
 
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index d6f47abfd85cfa1cc7e72e27aeb4f7ededa975dd..1dcd2a31f0f24686621c9d114cc32547326e72df 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -7,22 +7,24 @@ a GridTransitionMap object.
 
 from flatland.core.grid.grid4_astar import a_star
 from flatland.core.grid.grid4_utils import get_direction, mirror
-from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
+from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
 
 
-def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
-                            start: IntVector2D,
-                            end: IntVector2D,
-                            flip_start_node_trans=False,
-                            flip_end_node_trans=False,
-                            a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+def connect_basic_operation(
+    rail_trans: RailEnvTransitions,
+    grid_map: GridTransitionMap,
+    start: IntVector2D,
+    end: IntVector2D,
+    flip_start_node_trans=False,
+    flip_end_node_trans=False,
+    a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     """
-    Creates a new path [start,end] in grid_map, based on rail_trans.
+    Creates a new path [start,end] in `grid_map.grid`, based on rail_trans.
     """
     # in the worst case we will need to do a A* search, so we might as well set that up
-    path = a_star(grid_map, start, end, a_star_distance_function)
+    path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function)
     if len(path) < 2:
         return []
     current_dir = get_direction(path[0], path[1])
@@ -71,23 +73,24 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
 
 def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                  start: IntVector2D, end: IntVector2D,
-                 a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+                 a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function)
 
 
 def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                   start: IntVector2D, end: IntVector2D,
-                  a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+                  a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function)
 
 
 def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                        start: IntVector2D, end: IntVector2D,
-                       a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+                       a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance
+                       ) -> IntVector2DArray:
     return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function)
 
 
 def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                      start: IntVector2D, end: IntVector2D,
-                     a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance):
+                     a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function)
diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py
index b95922cf67febdaa0aad396459bc446bc31adfea..cf5c8592708eef237bcf29308032df49753860bd 100644
--- a/tests/test_flatland_envs_env_utils.py
+++ b/tests/test_flatland_envs_env_utils.py
@@ -2,8 +2,8 @@ import numpy as np
 import pytest
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position
 from flatland.core.grid.grid4_utils import get_direction
+from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position
 
 depth_to_test = 5
 positions_to_test = [0, 5, 1, 6, 20, 30]
@@ -31,4 +31,4 @@ def test_get_direction():
     assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
     assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
     with pytest.raises(Exception, match="Could not determine direction"):
-        get_direction((0, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
+        get_direction((0, 0), (0, 0))