Skip to content
Snippets Groups Projects
Commit 69063b3b authored by maljx's avatar maljx
Browse files

new level gen, work in progress

parent 5ab5adfd
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_env import RailEnv, random_rail_generator from flatland.envs.rail_env import RailEnv, random_rail_generator, complex_rail_generator
# from flatland.core.env_observation_builder import TreeObsForRailEnv # from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from flatland.baselines.dueling_double_dqn import Agent from flatland.baselines.dueling_double_dqn import Agent
...@@ -17,20 +17,19 @@ def main(render=True, delay=0.0): ...@@ -17,20 +17,19 @@ def main(render=True, delay=0.0):
# Example generate a rail given a manual specification, # Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation) # a map of tuples (cell_type, rotation)
transition_probability = [0.5, # empty cell - Case 0 #transition_probability = [0.5, # empty cell - Case 0
1.0, # Case 1 - straight # 1.0, # Case 1 - straight
1.0, # Case 2 - simple switch # 1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing # 0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip # 0.5, # Case 4 - single slip
0.5, # Case 5 - double slip # 0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical # 0.2, # Case 6 - symmetrical
0.0] # Case 7 - dead end # 0.0] # Case 7 - dead end
# Example generate a random rail # Example generate a random rail
env = RailEnv(width=15, env = RailEnv(width=15, height=15,
height=15, rail_generator=complex_rail_generator(),
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1)
number_of_agents=5)
if render: if render:
env_renderer = RenderTool(env, gl="QT") env_renderer = RenderTool(env, gl="QT")
......
...@@ -537,3 +537,27 @@ class RailEnvTransitions(Grid4Transitions): ...@@ -537,3 +537,27 @@ class RailEnvTransitions(Grid4Transitions):
super(RailEnvTransitions, self).__init__( super(RailEnvTransitions, self).__init__(
transitions=self.transition_list transitions=self.transition_list
) )
def is_valid(self, cell_transition):
"""
Checks if a cell transition is a valid cell setup.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
Returns
-------
Boolean
True or False
"""
for trans in self.transitions:
if cell_transition == trans:
return True
for _ in range(3):
trans = self.rotate_transition(trans, rotation=90)
if cell_transition == trans:
return True
return False
...@@ -13,6 +13,254 @@ from flatland.core.transitions import Grid8Transitions, RailEnvTransitions ...@@ -13,6 +13,254 @@ from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
class AStarNode():
"""A node class for A* Pathfinding"""
def __init__(self, parent=None, pos=None):
self.parent = parent
self.pos = pos
self.g = 0
self.h = 0
self.f = 0
def __eq__(self, other):
return self.pos == other.pos
def update_if_better(self, other):
if other.g < self.g:
self.parent = other.parent
self.g = other.g
self.h = other.h
self.f = other.f
def a_star(rail_array, start, end):
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
"""
rail_shape = rail_array.shape
start_node = AStarNode(None, start)
end_node = AStarNode(None, end)
open_list = []
closed_list = []
open_list.append(start_node)
# this could be optimized
def is_node_in_list(node, the_list):
for o_node in the_list:
if node == o_node:
return o_node
return None
while len(open_list) > 0:
# get node with current shortest est. path (lowest f)
current_node = open_list[0]
current_index = 0
for index, item in enumerate(open_list):
if item.f < current_node.f:
current_node = item
current_index = index
# pop current off open list, add to closed list
open_list.pop(current_index)
closed_list.append(current_node)
# print("a*:", current_node.pos)
# for cn in closed_list:
# print("closed:", cn.pos)
# found the goal
if current_node == end_node:
path = []
current = current_node
while current is not None:
path.append(current.pos)
current = current.parent
# return reversed path
return path[::-1]
# generate children
children = []
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
if node_pos[0] >= rail_shape[0] or \
node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0:
continue
# validate positions
# debug: avoid all current rails
# if rail_array.item(node_pos) != 0:
# continue
# create new node
new_node = AStarNode(current_node, node_pos)
children.append(new_node)
# loop through children
for child in children:
# already in closed list?
closed_node = is_node_in_list(child, closed_list)
if closed_node is not None:
continue
# create the f, g, and h values
child.g = current_node.g + 1
# this heuristic favors diagonal paths
# child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \
# ((child.pos[1] - end_node.pos[1]) ** 2)
# this heuristic avoids diagonal paths
child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1])
child.f = child.g + child.h
# already in the open list?
open_node = is_node_in_list(child, open_list)
if open_node is not None:
open_node.update_if_better(child)
continue
# add the child to the open list
open_list.append(child)
# no full path found, return partial path
if len(open_list) == 0:
path = []
current = current_node
while current is not None:
path.append(current.pos)
current = current.parent
# return reversed path
return path[::-1]
def complex_rail_generator(nr_start_goal=10, min_dist=0, max_dist=99999, seed=0):
"""
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_resets=0):
rail_trans = RailEnvTransitions()
rail_array = np.zeros(shape=(width, height), dtype=np.uint16)
np.random.seed(seed)
# generate rail array
# step 1:
# - generate a list of start and goal positions
# - use a min/max distance allowed as input for this
# - validate that start/goals are not placed too close to other start/goals
#
# step 2: (optional)
# - place random elements on rails array
# - for instance "train station", etc.
#
# step 3:
# - iterate over all [start, goal] pairs:
# - [first X pairs]
# - draw a rail from [start,goal]
# - draw either vertical or horizontal part first (randomly)
# - if rail crosses existing rail then validate new connection
# - if new connection is invalid turn 90 degrees to left/right
# - possibility that this fails to create a path to goal
# - on failure goto step1 and retry with seed+1
# - [avoid crossing other start,goal positions] (optional)
#
# - [after X pairs]
# - find closest rail from start (Pa)
# - iterating outwards in a "circle" from start until an existing rail cell is hit
# - connect [start, Pa]
# - validate crossing rails
# - Do A* from Pa to find closest point on rail (Pb) to goal point
# - Basically normal A* but find point on rail which is closest to goal
# - since full path to goal is unlikely
# - connect [Pb, goal]
# - validate crossing rails
#
# step 4: (optional)
# - add more rails to map randomly
#
# step 5:
# - return transition map + list of [start, goal] points
#
start_goal = []
for _ in range(nr_start_goal):
start = (np.random.randint(0, width), np.random.randint(0, height))
goal = (np.random.randint(0, height), np.random.randint(0, height))
# TODO: validate closeness with existing points
# TODO: make sure min/max distance condition is met
start_goal.append([start, goal])
def get_direction(pos1, pos2):
diff_0 = pos2[0] - pos1[0]
diff_1 = pos2[1] - pos1[1]
if diff_0 < 0:
return 0
if diff_0 > 0:
return 2
if diff_1 > 0:
return 1
if diff_1 < 0:
return 3
return 0
def connect_two_cells(pos1, pos2):
# connect two adjacent cells
direction = get_direction(pos1, pos2)
rail_array[pos1] = rail_trans.set_transition(rail_array[pos1], direction, direction, 1)
o_dir = (direction + 2) % 4
rail_array[pos2] = rail_trans.set_transition(rail_array[pos2], o_dir, o_dir, 1)
def connect_rail(start, end):
# in the worst case we will need to do a A* search, so we might as well set that up
# TODO: need to check transitions in A* to see if new path is valid
path = a_star(rail_array, start, end)
print("connecting path", path)
if len(path) < 2:
return
if len(path) == 2:
connect_two_cells(path[0], path[1])
return
current_dir = get_direction(path[0], path[1])
for index in range(len(path)):
pos1 = path[index]
if index+1 < len(path):
new_dir = get_direction(pos1, path[index+1])
else:
new_dir = current_dir
cell_trans = rail_array[pos1]
if index != len(path)-1:
# set the forward path
cell_trans = rail_trans.set_transition(cell_trans, current_dir, new_dir, 1)
if index != 0:
# set the backwards path
cell_trans = rail_trans.set_transition(cell_trans, (new_dir+2) % 4, (current_dir+2) % 4, 1)
rail_array[pos1] = cell_trans
current_dir = new_dir
for sg in start_goal:
connect_rail(sg[0], sg[1])
return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
return_rail.grid = rail_array
return return_rail
return generator
def rail_from_manual_specifications_generator(rail_spec): def rail_from_manual_specifications_generator(rail_spec):
""" """
Utility to convert a rail given by manual specification as a map of tuples Utility to convert a rail given by manual specification as a map of tuples
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment