Skip to content
Snippets Groups Projects
Commit 9134ddfb authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Update generators.py

parent fc34b470
No related branches found
No related tags found
No related merge requests found
......@@ -538,3 +538,126 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0):
"""
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_start_goal:
num_agents = nr_start_goal
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
# generate rail array
# step 1:
# - generate a start and goal position
# - validate min/max distance allowed
# - validate that start/goals are not placed too close to other start/goals
# - draw a rail from [start,goal]
# - if rail crosses existing rail then validate new connection
# - possibility that this fails to create a path to goal
# - on failure generate new start/goal
#
# step 2:
# - add more rails to map randomly between cells that have rails
# - validate all new rails, on failure don't add new rails
#
# step 3:
# - return transition map + list of [start_pos, start_dir, goal_pos] points
#
start_goal = []
start_dir = []
nr_created = 0
created_sanity = 0
sanity_max = 9000
while nr_created < nr_start_goal and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, height), np.random.randint(0, width))
goal = (np.random.randint(0, height), np.random.randint(0, width))
# check to make sure start,goal pos is empty?
if rail_array[goal] != 0 or rail_array[start] != 0:
continue
# check min/max distance
dist_sg = distance_on_rail(start, goal)
if dist_sg < min_dist:
continue
if dist_sg > max_dist:
continue
# check distance to existing points
sg_new = [start, goal]
def check_all_dist(sg_new):
for sg in start_goal:
for i in range(2):
for j in range(2):
dist = distance_on_rail(sg_new[i], sg[j])
if dist < 2:
return False
return True
if check_all_dist(sg_new):
all_ok = True
break
if not all_ok:
# we might as well give up at this point
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
start_goal.append([start, goal])
start_dir.append(mirror(get_direction(new_path[0], new_path[1])))
else:
# after too many failures we will give up
created_sanity += 1
# add extra connections between existing rail
created_sanity = 0
nr_created = 0
while nr_created < nr_extra and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, height), np.random.randint(0, width))
goal = (np.random.randint(0, height), np.random.randint(0, width))
# check to make sure start,goal pos are not empty
if rail_array[goal] == 0 or rail_array[start] == 0:
continue
else:
all_ok = True
break
if not all_ok:
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
agents_position = [sg[0] for sg in start_goal[:num_agents]]
agents_target = [sg[1] for sg in start_goal[:num_agents]]
agents_direction = start_dir[:num_agents]
return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment