From 1664efd1a4d9d60341358741db949257d5c4cadd Mon Sep 17 00:00:00 2001 From: Mattias Ljungstrom <mattias.ljungstrom@gmail.com> Date: Sun, 28 Apr 2019 20:19:54 +0200 Subject: [PATCH] level gen: bug fixes --- examples/play_model.py | 2 +- flatland/core/transitions.py | 3 +- flatland/envs/rail_env.py | 59 ++++++++++++++++++++++++++++-------- 3 files changed, 49 insertions(+), 15 deletions(-) diff --git a/examples/play_model.py b/examples/play_model.py index 82458e3..1f654c1 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -103,7 +103,7 @@ def main(render=True, delay=0.0): # Example generate a random rail env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=5), + rail_generator=complex_rail_generator(nr_start_goal=15, min_dist=5), number_of_agents=1) if render: diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 9262089..0afec98 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -533,7 +533,8 @@ class RailEnvTransitions(Grid4Transitions): int('0101001000000010', 2), # Case 6 - symmetrical int('0010000000000000', 2), # Case 7 - dead end int('0100000000000010', 2), # Case 1b - simple turn right - int('0001001000000000', 2)] # Case 1c - simple turn left + int('0001001000000000', 2), # Case 1c - simple turn left + int('1100000000100010', 2)] # Case 2b - simple switch mirrored def __init__(self): super(RailEnvTransitions, self).__init__( diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a7af40e..edcd069 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -67,8 +67,14 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p # create new transition that would go to child new_trans = rail_array[current_pos] if prev_pos is None: - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + if new_trans == 0: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # check if matches existing layout + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + # rail_trans.print(new_trans) else: # set the forward path new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) @@ -77,11 +83,24 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p if new_pos == end_pos: # need to validate end pos setup as well new_trans_e = rail_array[end_pos] - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + if new_trans_e == 0: + # need to flip direction because of how end points are defined + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # check if matches existing layout + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1) + # print("end:", end_pos, current_pos) + # rail_trans.print(new_trans_e) + # print("========> end trans") # rail_trans.print(new_trans_e) if not rail_trans.is_valid(new_trans_e): + # print("end failed", end_pos, current_pos) return False + # else: + # print("end ok!", end_pos, current_pos) + # is transition is valid? # print("=======> trans") # rail_trans.print(new_trans) @@ -196,6 +215,7 @@ def a_star(rail_trans, rail_array, start, end): path.append(current.pos) current = current.parent # return reversed path + print("partial:", start, end, path[::-1]) return path[::-1] @@ -217,8 +237,14 @@ def connect_rail(rail_trans, rail_array, start, end): new_trans = rail_array[current_pos] if index == 0: - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) else: # set the forward path new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) @@ -227,9 +253,15 @@ def connect_rail(rail_trans, rail_array, start, end): rail_array[current_pos] = new_trans if new_pos == end_pos: - # need to validate end pos setup as well + # setup end pos setup new_trans_e = rail_array[end_pos] - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + if new_trans_e == 0: + # end-point + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1) rail_array[end_pos] = new_trans_e current_dir = new_dir @@ -299,13 +331,15 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): # - return transition map + list of [start, goal] points # - # step 1: start_goal = [] for _ in range(nr_start_goal): sanity_max = 9000 for _ in range(sanity_max): start = (np.random.randint(0, width), np.random.randint(0, height)) goal = (np.random.randint(0, height), np.random.randint(0, height)) + # check to make sure start,goal pos is empty? + # if rail_array[goal] != 0: # or rail_array[start] != 0: + # continue # check min/max distance dist_sg = distance_on_rail(start, goal) if dist_sg < min_dist: @@ -327,11 +361,10 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): if check_all_dist(sg_new): break start_goal.append([start, goal]) - print("Created #", len(start_goal), "pairs") + connect_rail(rail_trans, rail_array, start, goal) - # step 3: - for sg in start_goal: - connect_rail(rail_trans, rail_array, sg[0], sg[1]) + print("Created #", len(start_goal), "pairs") + # print(start_goal) return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) return_rail.grid = rail_array @@ -476,7 +509,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): transitions_templates_ = [] transition_probabilities = [] - for i in range(len(t_utils.transitions) - 3): # don't include dead-ends + for i in range(len(t_utils.transitions) - 4): # don't include dead-ends all_transitions = 0 for dir_ in range(4): trans = t_utils.get_transitions(t_utils.transitions[i], dir_) -- GitLab