From 0304a317155d9ca2ffbeb45c331cbc5f2b3361a9 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Mon, 20 May 2019 14:40:41 +0200
Subject: [PATCH] use set() instead of list to speed up is_node_in_list (O(n)
 -> O(1))

---
 flatland/envs/env_utils.py | 39 ++++++++++++++++++++------------------
 1 file changed, 21 insertions(+), 18 deletions(-)

diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index b58604c6..3bbf4e1a 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -1,12 +1,13 @@
-
 """
 Definition of the RailEnv environment and related level-generation functions.
 
 Generator functions are functions that take width, height and num_resets as arguments and return
 a GridTransitionMap object.
 """
+
 import numpy as np
 
+
 # from flatland.core.env import Environment
 # from flatland.envs.observations import TreeObsForRailEnv
 
@@ -90,6 +91,9 @@ class AStarNode():
     def __eq__(self, other):
         return self.pos == other.pos
 
+    def __hash__(self):
+        return hash(self.pos)
+
     def update_if_better(self, other):
         if other.g < self.g:
             self.parent = other.parent
@@ -106,30 +110,29 @@ def a_star(rail_trans, rail_array, start, end):
     rail_shape = rail_array.shape
     start_node = AStarNode(None, start)
     end_node = AStarNode(None, end)
-    open_list = []
-    closed_list = []
-
-    open_list.append(start_node)
+    open_list = set()
+    closed_list = set()
+    open_list.add(start_node)
 
     # this could be optimized
     def is_node_in_list(node, the_list):
-        for o_node in the_list:
-            if node == o_node:
-                return o_node
+        if node in the_list:
+            return node
         return None
 
     while len(open_list) > 0:
         # get node with current shortest est. path (lowest f)
-        current_node = open_list[0]
-        current_index = 0
+        current_node = None
         for index, item in enumerate(open_list):
+            if current_node is None:
+                current_node = item
+                continue
             if item.f < current_node.f:
                 current_node = item
-                current_index = index
 
         # pop current off open list, add to closed list
-        open_list.pop(current_index)
-        closed_list.append(current_node)
+        open_list.remove(current_node)
+        closed_list.add(current_node)
 
         # found the goal
         if current_node == end_node:
@@ -150,9 +153,9 @@ def a_star(rail_trans, rail_array, start, end):
         for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
             node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
             if node_pos[0] >= rail_shape[0] or \
-                    node_pos[0] < 0 or \
-                    node_pos[1] >= rail_shape[1] or \
-                    node_pos[1] < 0:
+                node_pos[0] < 0 or \
+                node_pos[1] >= rail_shape[1] or \
+                node_pos[1] < 0:
                 continue
 
             # validate positions
@@ -186,7 +189,7 @@ def a_star(rail_trans, rail_array, start, end):
                 continue
 
             # add the child to the open list
-            open_list.append(child)
+            open_list.add(child)
 
         # no full path found
         if len(open_list) == 0:
@@ -324,7 +327,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
             for m in valid_movements:
                 new_position = get_new_position(agents_position[i], m[1])
                 if m[0] not in valid_starting_directions and \
-                   _path_exists(rail, new_position, m[0], agents_target[i]):
+                    _path_exists(rail, new_position, m[0], agents_target[i]):
                     valid_starting_directions.append(m[0])
 
             if len(valid_starting_directions) == 0:
-- 
GitLab