From 2556ec1575caf2b4a51629a19602317104f118d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mattias=20Ljungstr=C3=B6m?= <ml@mljx.io> Date: Wed, 8 May 2019 17:18:35 +0200 Subject: [PATCH] bug fixes to complex rail gen --- examples/play_model.py | 2 +- flatland/envs/env_utils.py | 53 ++++++++++++++++--------------------- flatland/envs/generators.py | 9 ++++++- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/examples/play_model.py b/examples/play_model.py index cba087bc..e69b312b 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -97,7 +97,7 @@ def main(render=True, delay=0.0): # Example generate a random rail env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=15, min_dist=5), + rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=12), number_of_agents=5) if render: diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index 1ad6c6de..a1e46db6 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -14,27 +14,6 @@ import numpy as np # from flatland.core.transition_map import GridTransitionMap -class AStarNode(): - """A node class for A* Pathfinding""" - - def __init__(self, parent=None, pos=None): - self.parent = parent - self.pos = pos - self.g = 0 - self.h = 0 - self.f = 0 - - def __eq__(self, other): - return self.pos == other.pos - - def update_if_better(self, other): - if other.g < self.g: - self.parent = other.parent - self.g = other.g - self.h = other.h - self.f = other.f - - def get_direction(pos1, pos2): """ Assumes pos1 and pos2 are adjacent location on grid. @@ -98,6 +77,27 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p return rail_trans.is_valid(new_trans) +class AStarNode(): + """A node class for A* Pathfinding""" + + def __init__(self, parent=None, pos=None): + self.parent = parent + self.pos = pos + self.g = 0 + self.h = 0 + self.f = 0 + + def __eq__(self, other): + return self.pos == other.pos + + def update_if_better(self, other): + if other.g < self.g: + self.parent = other.parent + self.g = other.g + self.h = other.h + self.f = other.f + + def a_star(rail_trans, rail_array, start, end): """ Returns a list of tuples as a path from the given start to end. @@ -188,16 +188,9 @@ def a_star(rail_trans, rail_array, start, end): # add the child to the open list open_list.append(child) - # no full path found, return partial path + # no full path found if len(open_list) == 0: - path = [] - current = current_node - while current is not None: - path.append(current.pos) - current = current.parent - # return reversed path - # print("partial:", start, end, path[::-1]) - return path[::-1] + return [] def connect_rail(rail_trans, rail_array, start, end): diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 29f8c6f4..4f356e1c 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -77,6 +77,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): created_sanity = 0 sanity_max = 9000 while nr_created < nr_start_goal and created_sanity < sanity_max: + all_ok = False for _ in range(sanity_max): start = (np.random.randint(0, width), np.random.randint(0, height)) goal = (np.random.randint(0, height), np.random.randint(0, height)) @@ -103,8 +104,14 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): return True if check_all_dist(sg_new): + all_ok = True break + if not all_ok: + # we can might as well give up at this point + # print("\n> Complex Rail Gen: Sanity counter reached, giving up!") + break + new_path = connect_rail(rail_trans, rail_array, start, goal) if len(new_path) >= 2: nr_created += 1 @@ -116,7 +123,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): # print("failed...") created_sanity += 1 - # print("Created #", len(start_goal), "pairs") + print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs") # print(start_goal) agents_position = [sg[0] for sg in start_goal] -- GitLab