From 9635e71cce55085d823cfd2a5809b9dbddb94f7c Mon Sep 17 00:00:00 2001 From: Mattias Ljungstrom <mattias.ljungstrom@gmail.com> Date: Sun, 28 Apr 2019 17:57:59 +0200 Subject: [PATCH] level gen: enforce min/max distances --- flatland/envs/rail_env.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 86582f5..8132d93 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -235,7 +235,11 @@ def connect_rail(rail_trans, rail_array, start, end): current_dir = new_dir -def complex_rail_generator(nr_start_goal=1, min_dist=0, max_dist=99999, seed=0): +def distance_on_rail(pos1, pos2): + return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) + + +def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): """ Parameters ------- @@ -295,19 +299,41 @@ def complex_rail_generator(nr_start_goal=1, min_dist=0, max_dist=99999, seed=0): # - return transition map + list of [start, goal] points # + # step 1: start_goal = [] for _ in range(nr_start_goal): - start = (np.random.randint(0, width), np.random.randint(0, height)) - goal = (np.random.randint(0, height), np.random.randint(0, height)) - # TODO: validate closeness with existing points - # TODO: make sure min/max distance condition is met + sanity_max = 9000 + for _ in range(sanity_max): + start = (np.random.randint(0, width), np.random.randint(0, height)) + goal = (np.random.randint(0, height), np.random.randint(0, height)) + # check min/max distance + dist_sg = distance_on_rail(start, goal) + if dist_sg < min_dist: + continue + if dist_sg > max_dist: + continue + # check distance to existing points + sg_new = [start, goal] + def check_all_dist(sg_new): + for sg in start_goal: + for i in range(2): + for j in range(2): + dist = distance_on_rail(sg_new[i], sg[j]) + if dist < 2: + # print("too close:", dist, sg_new[i], sg[j]) + return False + return True + if check_all_dist(sg_new): + break start_goal.append([start, goal]) + # step 3: for sg in start_goal: connect_rail(rail_trans, rail_array, sg[0], sg[1]) return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) return_rail.grid = rail_array + # TODO: return start_goal return return_rail return generator -- GitLab