diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index f396164615baad167adac5860201af3359d8efbc..9f2dfee3e89b88009d8489faaa6fb0870e01204b 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -40,6 +40,9 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= """ def generator(width, height, num_agents, num_resets=0): + if num_agents > nr_start_goal: + num_agents = nr_start_goal + print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid @@ -49,41 +52,20 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= # generate rail array # step 1: - # - generate a list of start and goal positions - # - use a min/max distance allowed as input for this - # - validate that start/goals are not placed too close to other start/goals - # - # step 2: (optional) - # - place random elements on rails array - # - for instance "train station", etc. - # - # step 3: - # - iterate over all [start, goal] pairs: - # - [first X pairs] - # - draw a rail from [start,goal] - # - draw either vertical or horizontal part first (randomly) + # - generate a start and goal position + # - validate min/max distance allowed + # - validate that start/goals are not placed too close to other start/goals + # - draw a rail from [start,goal] # - if rail crosses existing rail then validate new connection - # - if new connection is invalid turn 90 degrees to left/right - # - possibility that this fails to create a path to goal - # - on failure goto step1 and retry with seed+1 - # - [avoid crossing other start,goal positions] (optional) + # - possibility that this fails to create a path to goal + # - on failure generate new start/goal # - # - [after X pairs] - # - find closest rail from start (Pa) - # - iterating outwards in a "circle" from start until an existing rail cell is hit - # - connect [start, Pa] - # - validate crossing rails - # - Do A* from Pa to find closest point on rail (Pb) to goal point - # - Basically normal A* but find point on rail which is closest to goal - # - since full path to goal is unlikely - # - connect [Pb, goal] - # - validate crossing rails + # step 2: + # - add more rails to map randomly between cells that have rails + # - validate all new rails, on failure don't add new rails # - # step 4: (optional) - # - add more rails to map randomly - # - # step 5: - # - return transition map + list of [start, goal] points + # step 3: + # - return transition map + list of [start_pos, start_dir, goal_pos] points # start_goal = [] @@ -161,9 +143,9 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= # print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections") # print(start_goal) - agents_position = [sg[0] for sg in start_goal] - agents_target = [sg[1] for sg in start_goal] - agents_direction = start_dir + agents_position = [sg[0] for sg in start_goal[:num_agents]] + agents_target = [sg[1] for sg in start_goal[:num_agents]] + agents_direction = start_dir[:num_agents] return grid_map, agents_position, agents_direction, agents_target