Skip to content
Snippets Groups Projects
Commit 0304a317 authored by u214892's avatar u214892
Browse files

use set() instead of list to speed up is_node_in_list (O(n) -> O(1))

parent 504c44d4
No related branches found
No related tags found
No related merge requests found
"""
Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
import numpy as np
# from flatland.core.env import Environment
# from flatland.envs.observations import TreeObsForRailEnv
......@@ -90,6 +91,9 @@ class AStarNode():
def __eq__(self, other):
return self.pos == other.pos
def __hash__(self):
return hash(self.pos)
def update_if_better(self, other):
if other.g < self.g:
self.parent = other.parent
......@@ -106,30 +110,29 @@ def a_star(rail_trans, rail_array, start, end):
rail_shape = rail_array.shape
start_node = AStarNode(None, start)
end_node = AStarNode(None, end)
open_list = []
closed_list = []
open_list.append(start_node)
open_list = set()
closed_list = set()
open_list.add(start_node)
# this could be optimized
def is_node_in_list(node, the_list):
for o_node in the_list:
if node == o_node:
return o_node
if node in the_list:
return node
return None
while len(open_list) > 0:
# get node with current shortest est. path (lowest f)
current_node = open_list[0]
current_index = 0
current_node = None
for index, item in enumerate(open_list):
if current_node is None:
current_node = item
continue
if item.f < current_node.f:
current_node = item
current_index = index
# pop current off open list, add to closed list
open_list.pop(current_index)
closed_list.append(current_node)
open_list.remove(current_node)
closed_list.add(current_node)
# found the goal
if current_node == end_node:
......@@ -150,9 +153,9 @@ def a_star(rail_trans, rail_array, start, end):
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
if node_pos[0] >= rail_shape[0] or \
node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0:
node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0:
continue
# validate positions
......@@ -186,7 +189,7 @@ def a_star(rail_trans, rail_array, start, end):
continue
# add the child to the open list
open_list.append(child)
open_list.add(child)
# no full path found
if len(open_list) == 0:
......@@ -324,7 +327,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and \
_path_exists(rail, new_position, m[0], agents_target[i]):
_path_exists(rail, new_position, m[0], agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment