Commit 5d8b904c authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

a_star mode can now avoid city centers when building paths

parent 214b08a1
...@@ -37,7 +37,7 @@ env = RailEnv(width=50, ...@@ -37,7 +37,7 @@ env = RailEnv(width=50,
min_node_dist=8, # Minimal distance of nodes min_node_dist=8, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center node_radius=3, # Proximity of stations to city center
seed=15, # Random seed seed=15, # Random seed
grid_mode=True, grid_mode=False,
max_connection_points_per_side=2, max_connection_points_per_side=2,
max_nr_connection_directions=4 max_nr_connection_directions=4
), ),
......
...@@ -91,17 +91,17 @@ def a_star(grid_map: GridTransitionMap, ...@@ -91,17 +91,17 @@ def a_star(grid_map: GridTransitionMap,
if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0: if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
continue continue
# Skip paths through forbidden regions.
if forbidden_cells is not None:
if node_pos in forbidden_cells and node_pos != start_node and node_pos != end_node:
continue
# validate positions # validate positions
# #
if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos) and nice: if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos) and nice:
continue continue
# create new node # create new node
new_node = AStarNode(node_pos, current_node) new_node = AStarNode(node_pos, current_node)
# Skip paths through forbidden regions.
if forbidden_cells is not None:
if node_pos in forbidden_cells and new_node != start_node and new_node != end_node:
continue
children.append(new_node) children.append(new_node)
# loop through children # loop through children
......
...@@ -575,7 +575,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -575,7 +575,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
if grid_mode: if grid_mode:
node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, height, width) node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, height, width)
else: else:
node_positions = _generate_node_positions_not_grid_mode(nb_nodes, height, width) node_positions, city_cells = _generate_node_positions_not_grid_mode(nb_nodes, height, width)
print(city_cells)
# reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
nb_nodes = len(node_positions) nb_nodes = len(node_positions)
...@@ -586,7 +587,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -586,7 +587,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
max_nr_connection_directions) max_nr_connection_directions)
# Connect the cities through the connection points # Connect the cities through the connection points
_connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map) _connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map)
# Build inner cities # Build inner cities
train_stations, built_num_trainstation = _build_cities(node_positions, connection_points, rail_trans, grid_map) train_stations, built_num_trainstation = _build_cities(node_positions, connection_points, rail_trans, grid_map)
...@@ -611,6 +612,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -611,6 +612,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
def _generate_node_positions_not_grid_mode(nb_nodes, height, width): def _generate_node_positions_not_grid_mode(nb_nodes, height, width):
node_positions = [] node_positions = []
city_cells = []
for node_idx in range(nb_nodes): for node_idx in range(nb_nodes):
to_close = True to_close = True
tries = 0 tries = 0
...@@ -627,6 +630,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -627,6 +630,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
if not to_close: if not to_close:
node_positions.append((x_tmp, y_tmp)) node_positions.append((x_tmp, y_tmp))
city_cells.extend(_city_cells(node_positions[-1], node_radius))
tries += 1 tries += 1
if tries > 100: if tries > 100:
...@@ -636,7 +640,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -636,7 +640,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
tries, nb_nodes)) tries, nb_nodes))
break break
return node_positions return node_positions, city_cells
def _generate_node_positions_grid_mode(nb_nodes, height, width): def _generate_node_positions_grid_mode(nb_nodes, height, width):
nodes_ratio = height / width nodes_ratio = height / width
...@@ -645,12 +649,13 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -645,12 +649,13 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int) x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int) y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int)
node_positions = [] node_positions = []
forbidden_cells = [] city_cells = []
for node_idx in range(nb_nodes): for node_idx in range(nb_nodes):
x_tmp = x_positions[node_idx % nodes_per_row] x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row]
node_positions.append((x_tmp, y_tmp)) node_positions.append((x_tmp, y_tmp))
return node_positions, forbidden_cells city_cells.extend(_city_cells(node_positions[-1], node_radius))
return node_positions, city_cells
def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2, def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2,
max_nr_connection_directions=2): max_nr_connection_directions=2):
...@@ -673,7 +678,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -673,7 +678,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
idx = 1 idx = 1
while len(connection_sides_idx) < max_nr_connection_directions and idx < len(neighb_dist): while len(connection_sides_idx) < max_nr_connection_directions and idx < len(neighb_dist):
current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]]) current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
print(node_position)
if current_closest_direction not in connection_sides_idx: if current_closest_direction not in connection_sides_idx:
connection_sides_idx.append(current_closest_direction) connection_sides_idx.append(current_closest_direction)
idx += 1 idx += 1
...@@ -708,7 +712,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -708,7 +712,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
connection_info.append(connections_per_direction) connection_info.append(connections_per_direction)
return connection_points, connection_info return connection_points, connection_info
def _connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map): def _connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map):
""" """
Function to connect the different cities through their connection points Function to connect the different cities through their connection points
:param node_positions: Positions of city centers :param node_positions: Positions of city centers
...@@ -724,8 +728,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -724,8 +728,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
for nbr_connection_points in connection_info[current_node]: for nbr_connection_points in connection_info[current_node]:
if nbr_connection_points > 0: if nbr_connection_points > 0:
neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions) neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions)
print(current_node, node_positions[current_node], direction, neighb_idx,
connection_info[current_node], connection_points[current_node])
else: else:
direction += 1 direction += 1
continue continue
...@@ -745,7 +747,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -745,7 +747,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
if tmp_dist < min_connection_dist: if tmp_dist < min_connection_dist:
min_connection_dist = tmp_dist min_connection_dist = tmp_dist
neighb_connection_point = tmp_in_connection_point neighb_connection_point = tmp_in_connection_point
connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, None) connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point,
city_cells)
boarder_connections.add((tmp_out_connection_point, current_node)) boarder_connections.add((tmp_out_connection_point, current_node))
boarder_connections.add((neighb_connection_point, neighb_idx)) boarder_connections.add((neighb_connection_point, neighb_idx))
direction += 1 direction += 1
...@@ -924,7 +927,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n ...@@ -924,7 +927,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
city_cells = [] city_cells = []
for x in range(-radius, radius): for x in range(-radius, radius):
for y in range(-radius, radius): for y in range(-radius, radius):
city_cells.append(center[0] + x, center[1] + y) city_cells.append((center[0] + x, center[1] + y))
return city_cells return city_cells
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment