Commit 214b08a1 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

code cleanup and added city cells in order to avoid drawing paths through cities

parent 3504c01f
Pipeline #2207 failed with stages
in 60 minutes
......@@ -33,7 +33,7 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are)
num_trainstations=0, # Number of possible start/targets on map
num_trainstations=50, # Number of possible start/targets on map
min_node_dist=8, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
seed=15, # Random seed
......
......@@ -37,7 +37,8 @@ class AStarNode:
def a_star(grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, nice=True) -> IntVector2DArray:
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, nice=True,
forbidden_cells=None) -> IntVector2DArray:
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
......@@ -90,11 +91,15 @@ def a_star(grid_map: GridTransitionMap,
if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
continue
# Skip paths through forbidden regions.
if forbidden_cells is not None:
if node_pos in forbidden_cells and node_pos != start_node and node_pos != end_node:
continue
# validate positions
#
if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos) and nice:
continue
# create new node
new_node = AStarNode(node_pos, current_node)
children.append(new_node)
......
......@@ -20,13 +20,15 @@ def connect_basic_operation(
flip_start_node_trans=False,
flip_end_node_trans=False,
nice=True,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
forbidden_cells=None
) -> IntVector2DArray:
"""
Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
returns the path created as a list of positions.
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice)
path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice, forbidden_cells)
if len(path) < 2:
print("No path found", path)
return []
......@@ -87,6 +89,12 @@ def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function)
def connect_cities(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D, forbidden_cells,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function,
forbidden_cells)
def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance
......
......@@ -9,7 +9,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes
from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_cities
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
......@@ -573,25 +573,9 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
node_positions: List[Any] = None
nb_nodes = num_cities
if grid_mode:
nodes_ratio = height / width
nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int)
city_idx = np.random.choice(np.arange(nb_nodes), num_cities, False)
node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions,
nb_nodes,
nodes_per_row, x_positions,
y_positions)
node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, height, width)
else:
node_positions = _generate_node_positions_not_grid_mode(city_positions, height,
intersection_positions,
nb_nodes, width)
node_positions = _generate_node_positions_not_grid_mode(nb_nodes, height, width)
# reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
nb_nodes = len(node_positions)
......@@ -624,8 +608,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
'train_stations': train_stations
}}
def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes,
width):
def _generate_node_positions_not_grid_mode(nb_nodes, height, width):
node_positions = []
for node_idx in range(nb_nodes):
......@@ -637,22 +620,14 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1)
to_close = False
# Check distance to cities
for node_pos in city_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
to_close = True
# Check distance to intersections
for node_pos in intersection_positions:
# Check distance to nodes
for node_pos in node_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
to_close = True
if not to_close:
node_positions.append((x_tmp, y_tmp))
if node_idx < num_cities:
city_positions.append((x_tmp, y_tmp))
else:
intersection_positions.append((x_tmp, y_tmp))
tries += 1
if tries > 100:
warnings.warn(
......@@ -661,23 +636,21 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
tries, nb_nodes))
break
node_positions = city_positions + intersection_positions
return node_positions
def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes,
nodes_per_row, x_positions, y_positions):
def _generate_node_positions_grid_mode(nb_nodes, height, width):
nodes_ratio = height / width
nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int)
node_positions = []
forbidden_cells = []
for node_idx in range(nb_nodes):
x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row]
if node_idx in city_idx:
city_positions.append((x_tmp, y_tmp))
else:
intersection_positions.append((x_tmp, y_tmp))
node_positions = city_positions + intersection_positions
return node_positions
node_positions.append((x_tmp, y_tmp))
return node_positions, forbidden_cells
def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2,
max_nr_connection_directions=2):
......@@ -698,8 +671,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
# Store the directions to these neighbours
connection_sides_idx = []
idx = 1
# TODO: Change the way this code works! Check that we get sufficient direction.
# TODO: Check if this works as expected
while len(connection_sides_idx) < max_nr_connection_directions and idx < len(neighb_dist):
current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
print(node_position)
......@@ -707,12 +678,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
connection_sides_idx.append(current_closest_direction)
idx += 1
# set the number of connection points for each direction
connections_per_direction = np.zeros(4, dtype=int)
for idx in connection_sides_idx:
nr_of_connection_points = max_nr_connection_points # np.random.randint(1, max_nr_connection_points + 1)
nr_of_connection_points = np.random.randint(1, max_nr_connection_points + 1)
connections_per_direction[idx] = nr_of_connection_points
connection_points_coordinates = []
......@@ -775,7 +745,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
if tmp_dist < min_connection_dist:
min_connection_dist = tmp_dist
neighb_connection_point = tmp_in_connection_point
connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point)
connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, None)
boarder_connections.add((tmp_out_connection_point, current_node))
boarder_connections.add((neighb_connection_point, neighb_idx))
direction += 1
......@@ -944,4 +914,18 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
# http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
return sorted(range(len(seq)), key=seq.__getitem__)
def _city_cells(center, radius):
"""
Function to return all cells within a city
:param center: center coordinates of city
:param radius: radius of city (it is a square)
:return: returns flat list of all cell coordinates in the city
"""
city_cells = []
for x in range(-radius, radius):
for y in range(-radius, radius):
city_cells.append(center[0] + x, center[1] + y)
return city_cells
return generator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment