Skip to content
Snippets Groups Projects
Commit 0304a317 authored by u214892's avatar u214892
Browse files

use set() instead of list to speed up is_node_in_list (O(n) -> O(1))

parent 504c44d4
No related branches found
No related tags found
No related merge requests found
""" """
Definition of the RailEnv environment and related level-generation functions. Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object. a GridTransitionMap object.
""" """
import numpy as np import numpy as np
# from flatland.core.env import Environment # from flatland.core.env import Environment
# from flatland.envs.observations import TreeObsForRailEnv # from flatland.envs.observations import TreeObsForRailEnv
...@@ -90,6 +91,9 @@ class AStarNode(): ...@@ -90,6 +91,9 @@ class AStarNode():
def __eq__(self, other): def __eq__(self, other):
return self.pos == other.pos return self.pos == other.pos
def __hash__(self):
return hash(self.pos)
def update_if_better(self, other): def update_if_better(self, other):
if other.g < self.g: if other.g < self.g:
self.parent = other.parent self.parent = other.parent
...@@ -106,30 +110,29 @@ def a_star(rail_trans, rail_array, start, end): ...@@ -106,30 +110,29 @@ def a_star(rail_trans, rail_array, start, end):
rail_shape = rail_array.shape rail_shape = rail_array.shape
start_node = AStarNode(None, start) start_node = AStarNode(None, start)
end_node = AStarNode(None, end) end_node = AStarNode(None, end)
open_list = [] open_list = set()
closed_list = [] closed_list = set()
open_list.add(start_node)
open_list.append(start_node)
# this could be optimized # this could be optimized
def is_node_in_list(node, the_list): def is_node_in_list(node, the_list):
for o_node in the_list: if node in the_list:
if node == o_node: return node
return o_node
return None return None
while len(open_list) > 0: while len(open_list) > 0:
# get node with current shortest est. path (lowest f) # get node with current shortest est. path (lowest f)
current_node = open_list[0] current_node = None
current_index = 0
for index, item in enumerate(open_list): for index, item in enumerate(open_list):
if current_node is None:
current_node = item
continue
if item.f < current_node.f: if item.f < current_node.f:
current_node = item current_node = item
current_index = index
# pop current off open list, add to closed list # pop current off open list, add to closed list
open_list.pop(current_index) open_list.remove(current_node)
closed_list.append(current_node) closed_list.add(current_node)
# found the goal # found the goal
if current_node == end_node: if current_node == end_node:
...@@ -150,9 +153,9 @@ def a_star(rail_trans, rail_array, start, end): ...@@ -150,9 +153,9 @@ def a_star(rail_trans, rail_array, start, end):
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
if node_pos[0] >= rail_shape[0] or \ if node_pos[0] >= rail_shape[0] or \
node_pos[0] < 0 or \ node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \ node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0: node_pos[1] < 0:
continue continue
# validate positions # validate positions
...@@ -186,7 +189,7 @@ def a_star(rail_trans, rail_array, start, end): ...@@ -186,7 +189,7 @@ def a_star(rail_trans, rail_array, start, end):
continue continue
# add the child to the open list # add the child to the open list
open_list.append(child) open_list.add(child)
# no full path found # no full path found
if len(open_list) == 0: if len(open_list) == 0:
...@@ -324,7 +327,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): ...@@ -324,7 +327,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
for m in valid_movements: for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1]) new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and \ if m[0] not in valid_starting_directions and \
_path_exists(rail, new_position, m[0], agents_target[i]): _path_exists(rail, new_position, m[0], agents_target[i]):
valid_starting_directions.append(m[0]) valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0: if len(valid_starting_directions) == 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment