From 0304a317155d9ca2ffbeb45c331cbc5f2b3361a9 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Mon, 20 May 2019 14:40:41 +0200 Subject: [PATCH] use set() instead of list to speed up is_node_in_list (O(n) -> O(1)) --- flatland/envs/env_utils.py | 39 ++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index b58604c6..3bbf4e1a 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -1,12 +1,13 @@ - """ Definition of the RailEnv environment and related level-generation functions. Generator functions are functions that take width, height and num_resets as arguments and return a GridTransitionMap object. """ + import numpy as np + # from flatland.core.env import Environment # from flatland.envs.observations import TreeObsForRailEnv @@ -90,6 +91,9 @@ class AStarNode(): def __eq__(self, other): return self.pos == other.pos + def __hash__(self): + return hash(self.pos) + def update_if_better(self, other): if other.g < self.g: self.parent = other.parent @@ -106,30 +110,29 @@ def a_star(rail_trans, rail_array, start, end): rail_shape = rail_array.shape start_node = AStarNode(None, start) end_node = AStarNode(None, end) - open_list = [] - closed_list = [] - - open_list.append(start_node) + open_list = set() + closed_list = set() + open_list.add(start_node) # this could be optimized def is_node_in_list(node, the_list): - for o_node in the_list: - if node == o_node: - return o_node + if node in the_list: + return node return None while len(open_list) > 0: # get node with current shortest est. path (lowest f) - current_node = open_list[0] - current_index = 0 + current_node = None for index, item in enumerate(open_list): + if current_node is None: + current_node = item + continue if item.f < current_node.f: current_node = item - current_index = index # pop current off open list, add to closed list - open_list.pop(current_index) - closed_list.append(current_node) + open_list.remove(current_node) + closed_list.add(current_node) # found the goal if current_node == end_node: @@ -150,9 +153,9 @@ def a_star(rail_trans, rail_array, start, end): for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) if node_pos[0] >= rail_shape[0] or \ - node_pos[0] < 0 or \ - node_pos[1] >= rail_shape[1] or \ - node_pos[1] < 0: + node_pos[0] < 0 or \ + node_pos[1] >= rail_shape[1] or \ + node_pos[1] < 0: continue # validate positions @@ -186,7 +189,7 @@ def a_star(rail_trans, rail_array, start, end): continue # add the child to the open list - open_list.append(child) + open_list.add(child) # no full path found if len(open_list) == 0: @@ -324,7 +327,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): for m in valid_movements: new_position = get_new_position(agents_position[i], m[1]) if m[0] not in valid_starting_directions and \ - _path_exists(rail, new_position, m[0], agents_target[i]): + _path_exists(rail, new_position, m[0], agents_target[i]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: -- GitLab