diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 355f5502992a34f8d58d4dbd80028eb4dd71cc48..2de26ddc4513db7ce30dbdd31c8f748d8d54a105 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -538,3 +538,126 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator + + +def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0): + """ + Parameters + ------- + width : int + The width (number of cells) of the grid to generate. + height : int + The height (number of cells) of the grid to generate. + + Returns + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def generator(width, height, num_agents, num_resets=0): + + if num_agents > nr_start_goal: + num_agents = nr_start_goal + print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + rail_array = grid_map.grid + rail_array.fill(0) + + np.random.seed(seed + num_resets) + + # generate rail array + # step 1: + # - generate a start and goal position + # - validate min/max distance allowed + # - validate that start/goals are not placed too close to other start/goals + # - draw a rail from [start,goal] + # - if rail crosses existing rail then validate new connection + # - possibility that this fails to create a path to goal + # - on failure generate new start/goal + # + # step 2: + # - add more rails to map randomly between cells that have rails + # - validate all new rails, on failure don't add new rails + # + # step 3: + # - return transition map + list of [start_pos, start_dir, goal_pos] points + # + + start_goal = [] + start_dir = [] + nr_created = 0 + created_sanity = 0 + sanity_max = 9000 + while nr_created < nr_start_goal and created_sanity < sanity_max: + all_ok = False + for _ in range(sanity_max): + start = (np.random.randint(0, height), np.random.randint(0, width)) + goal = (np.random.randint(0, height), np.random.randint(0, width)) + + # check to make sure start,goal pos is empty? + if rail_array[goal] != 0 or rail_array[start] != 0: + continue + # check min/max distance + dist_sg = distance_on_rail(start, goal) + if dist_sg < min_dist: + continue + if dist_sg > max_dist: + continue + # check distance to existing points + sg_new = [start, goal] + + def check_all_dist(sg_new): + for sg in start_goal: + for i in range(2): + for j in range(2): + dist = distance_on_rail(sg_new[i], sg[j]) + if dist < 2: + return False + return True + + if check_all_dist(sg_new): + all_ok = True + break + + if not all_ok: + # we might as well give up at this point + break + + new_path = connect_rail(rail_trans, rail_array, start, goal) + if len(new_path) >= 2: + nr_created += 1 + start_goal.append([start, goal]) + start_dir.append(mirror(get_direction(new_path[0], new_path[1]))) + else: + # after too many failures we will give up + created_sanity += 1 + + # add extra connections between existing rail + created_sanity = 0 + nr_created = 0 + while nr_created < nr_extra and created_sanity < sanity_max: + all_ok = False + for _ in range(sanity_max): + start = (np.random.randint(0, height), np.random.randint(0, width)) + goal = (np.random.randint(0, height), np.random.randint(0, width)) + # check to make sure start,goal pos are not empty + if rail_array[goal] == 0 or rail_array[start] == 0: + continue + else: + all_ok = True + break + if not all_ok: + break + new_path = connect_rail(rail_trans, rail_array, start, goal) + if len(new_path) >= 2: + nr_created += 1 + + agents_position = [sg[0] for sg in start_goal[:num_agents]] + agents_target = [sg[1] for sg in start_goal[:num_agents]] + agents_direction = start_dir[:num_agents] + + return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + + return generator \ No newline at end of file