From 87a34e74f1e95c491156fcf19297553dc75d6997 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Wed, 14 Aug 2019 17:18:49 +0200 Subject: [PATCH] first draft --- flatland/envs/generators.py | 193 ++++++++++++++++++++---------------- 1 file changed, 107 insertions(+), 86 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 2de26ddc..4cb8b8cc 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,3 +1,5 @@ +from enum import IntEnum + import msgpack import numpy as np @@ -555,6 +557,82 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis The matrix with the correct 16-bit bitmaps for each cell. """ + def add_rail(grid_map, pt_from, pt_via, pt_to, bAddRemove=True): + gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) # NESW in RC + + lrcStroke = [[pt_from[0], pt_from[1]], [pt_via[0], pt_via[1]], [pt_to[0], pt_to[1]]] + + rc3Cells = np.array(lrcStroke[:3]) # the 3 cells + rcMiddle = rc3Cells[1] # the middle cell which we will update + bDeadend = np.all(lrcStroke[0] == lrcStroke[2]) # deadend means cell 0 == cell 2 + + # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East + rc2Trans = np.diff(rc3Cells, axis=0) + + # get the direction index for the 2 transitions + liTrans = [] + for rcTrans in rc2Trans: + # gRCTrans - rcTrans gives an array of vector differences between our rcTrans + # and the 4 directions stored in gRCTrans. + # Where the vector difference is zero, we have a match... + # np.all detects where the whole row,col vector is zero. + # argwhere gives the index of the zero vector, ie the direction index + iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1)) + if len(iTrans) > 0: + iTrans = iTrans[0][0] + liTrans.append(iTrans) + + # check that we have two transitions + if len(liTrans) == 2: + # Set the transition + # If this transition spans 3 cells, it is not a deadend, so remove any deadends. + # The user will need to resolve any conflicts. + grid_map.set_transition((*rcMiddle, liTrans[0]), + liTrans[1], + bAddRemove, + remove_deadends=not bDeadend) + + # Also set the reverse transition + # use the reversed outbound transition for inbound + # and the reversed inbound transition for outbound + grid_map.set_transition((*rcMiddle, mirror(liTrans[1])), + mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend) + + def make_switch_w_e(grid_map, center): + # e -> w + start = (center[0]+1, center[1]-1) + via = (center[0], center[1] - 1) + goal = (center[0], center[1]) + add_rail(grid_map, start, via, goal) + start = (center[0], center[1]-1) + via = (center[0]+1, center[1]-1) + goal = (center[0]+1, center[1]-2) + add_rail(grid_map, start, via, goal) + + def make_switch_e_w(grid_map, center): + # e -> w + start = (center[0] + 1, center[1]) + via = (center[0] + 1, center[1] - 1) + goal = (center[0], center[1] - 1) + add_rail(grid_map, start, via, goal) + start = (center[0] + 1, center[1] - 1) + via = (center[0], center[1] - 1) + goal = (center[0], center[1] - 2) + add_rail(grid_map, start, via, goal) + + class Grid4TransitionsEnum(IntEnum): + NORTH = 0 + EAST = 1 + SOUTH = 2 + WEST = 3 + + @staticmethod + def to_char(int: int): + return {0: 'N', + 1: 'E', + 2: 'S', + 3: 'W'}[int] + def generator(width, height, num_agents, num_resets=0): if num_agents > nr_start_goal: @@ -567,97 +645,40 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis np.random.seed(seed + num_resets) - # generate rail array - # step 1: - # - generate a start and goal position - # - validate min/max distance allowed - # - validate that start/goals are not placed too close to other start/goals - # - draw a rail from [start,goal] - # - if rail crosses existing rail then validate new connection - # - possibility that this fails to create a path to goal - # - on failure generate new start/goal - # - # step 2: - # - add more rails to map randomly between cells that have rails - # - validate all new rails, on failure don't add new rails - # - # step 3: - # - return transition map + list of [start_pos, start_dir, goal_pos] points - # - start_goal = [] - start_dir = [] - nr_created = 0 - created_sanity = 0 - sanity_max = 9000 - while nr_created < nr_start_goal and created_sanity < sanity_max: - all_ok = False - for _ in range(sanity_max): - start = (np.random.randint(0, height), np.random.randint(0, width)) - goal = (np.random.randint(0, height), np.random.randint(0, width)) - - # check to make sure start,goal pos is empty? - if rail_array[goal] != 0 or rail_array[start] != 0: - continue - # check min/max distance - dist_sg = distance_on_rail(start, goal) - if dist_sg < min_dist: - continue - if dist_sg > max_dist: - continue - # check distance to existing points - sg_new = [start, goal] - - def check_all_dist(sg_new): - for sg in start_goal: - for i in range(2): - for j in range(2): - dist = distance_on_rail(sg_new[i], sg[j]) - if dist < 2: - return False - return True - - if check_all_dist(sg_new): - all_ok = True - break - - if not all_ok: - # we might as well give up at this point - break + max_n_track_seg=4 + x_offsets = np.arange(0,height,max_n_track_seg).astype(int) + for off_set in x_offsets: + # second track + data = np.arange(int((width - max_n_track_seg) / max_n_track_seg)) * max_n_track_seg + 2 + n_track_seg = np.random.choice(max_n_track_seg) + 1 + # track one (full track : left right) + start = (off_set, 0) + goal = (off_set, width - 1) new_path = connect_rail(rail_trans, rail_array, start, goal) - if len(new_path) >= 2: - nr_created += 1 - start_goal.append([start, goal]) - start_dir.append(mirror(get_direction(new_path[0], new_path[1]))) - else: - # after too many failures we will give up - created_sanity += 1 - # add extra connections between existing rail - created_sanity = 0 - nr_created = 0 - while nr_created < nr_extra and created_sanity < sanity_max: - all_ok = False - for _ in range(sanity_max): - start = (np.random.randint(0, height), np.random.randint(0, width)) - goal = (np.random.randint(0, height), np.random.randint(0, width)) - # check to make sure start,goal pos are not empty - if rail_array[goal] == 0 or rail_array[start] == 0: - continue - else: - all_ok = True - break - if not all_ok: - break - new_path = connect_rail(rail_trans, rail_array, start, goal) - if len(new_path) >= 2: - nr_created += 1 + agents_position = [new_path[0]] + agents_target = [new_path[1]] # len(new_path) - 1]] + agents_direction = [3] - agents_position = [sg[0] for sg in start_goal[:num_agents]] - agents_target = [sg[1] for sg in start_goal[:num_agents]] - agents_direction = start_dir[:num_agents] + for nbr_track_loop in range(height-1): + if len(data) < 2*n_track_seg+1: + break + x = np.sort(np.random.choice(data, 2 * n_track_seg, False)).astype(int) + data = [] + for x_loop in range(int(len(x) / 2)): + start = (off_set+nbr_track_loop+1, x[2 * x_loop]) + goal = (off_set+nbr_track_loop+1, x[2 * x_loop + 1]) + d = np.arange(x[2 * x_loop]+1,x[2 * x_loop+1]-1,2) + data.extend(d) + new_path = connect_rail(rail_trans, rail_array, start, goal) + if len(new_path) >0: + c = (off_set+nbr_track_loop, x[2 * x_loop] + 1) + make_switch_e_w(grid_map, c) + c = (off_set+nbr_track_loop, x[2 * x_loop + 1]+1) + make_switch_w_e(grid_map, c) return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) - return generator \ No newline at end of file + return generator -- GitLab