diff --git a/utils/shortest_path_walker_heuristic_agent.py b/utils/shortest_path_walker_heuristic_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaa71e91a416b0c899519e690c4e29ad8147a48d
--- /dev/null
+++ b/utils/shortest_path_walker_heuristic_agent.py
@@ -0,0 +1,57 @@
+import numpy as np
+from flatland.envs.rail_env import RailEnvActions
+
+from reinforcement_learning.policy import Policy
+
+
+class ShortestPathWalkerHeuristicPolicy(Policy):
+    def step(self, state, action, reward, next_state, done):
+        pass
+
+    def act(self, node, eps=0.):
+
+        left_node = node.childs.get('L')
+        forward_node = node.childs.get('F')
+        right_node = node.childs.get('R')
+
+        dist_map = np.zeros(5)
+        dist_map[RailEnvActions.DO_NOTHING] = np.inf
+        dist_map[RailEnvActions.STOP_MOVING] = 100000
+        # left
+        if left_node == -np.inf:
+            dist_map[RailEnvActions.MOVE_LEFT] = np.inf
+        else:
+            if left_node.num_agents_opposite_direction == 0:
+                dist_map[RailEnvActions.MOVE_LEFT] = left_node.dist_min_to_target
+            else:
+                dist_map[RailEnvActions.MOVE_LEFT] = np.inf
+        # forward
+        if forward_node == -np.inf:
+            dist_map[RailEnvActions.MOVE_FORWARD] = np.inf
+        else:
+            if forward_node.num_agents_opposite_direction == 0:
+                dist_map[RailEnvActions.MOVE_FORWARD] = forward_node.dist_min_to_target
+            else:
+                dist_map[RailEnvActions.MOVE_FORWARD] = np.inf
+        # right
+        if right_node == -np.inf:
+            dist_map[RailEnvActions.MOVE_RIGHT] = np.inf
+        else:
+            if right_node.num_agents_opposite_direction == 0:
+                dist_map[RailEnvActions.MOVE_RIGHT] = right_node.dist_min_to_target
+            else:
+                dist_map[RailEnvActions.MOVE_RIGHT] = np.inf
+        return np.argmin(dist_map)
+
+    def save(self, filename):
+        pass
+
+    def load(self, filename):
+        pass
+
+
+policy = ShortestPathWalkerHeuristicPolicy()
+
+
+def normalize_observation(observation, tree_depth: int, observation_radius=0):
+    return observation