Skip to content
Snippets Groups Projects
Commit 28c23e66 authored by u229589's avatar u229589
Browse files

fix return types and set parameter types of methods

parent 8e7d6bc9
No related branches found
No related tags found
No related merge requests found
......@@ -38,7 +38,7 @@ class AStarNode:
def a_star(grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, nice=True,
forbidden_cells=None) -> IntVector2DArray:
forbidden_cells: IntVector2DArray = None) -> IntVector2DArray:
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
......
......@@ -60,7 +60,7 @@ def direction_to_point(pos1: IntVector2D, pos2: IntVector2D) -> Grid4Transitions
return Grid4TransitionsEnum.EAST
def directions_of_vector(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
def directions_of_vector(pos1: IntVector2D, pos2: IntVector2D) -> (Grid4TransitionsEnum, Grid4TransitionsEnum):
"""
Returns the closest direction orientation of position 2 relative to position 1
:param pos1: position we are interested in
......
......@@ -17,8 +17,8 @@ from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
flip_start_node_trans=False, flip_end_node_trans=False, respect_transition_validity=True,
forbidden_cells=None) -> IntVector2DArray:
flip_start_node_trans: bool = False, flip_end_node_trans: bool = False,
respect_transition_validity: bool = True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray:
"""
Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
returns the path created as a list of positions.
......@@ -37,7 +37,7 @@ def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, st
# in the worst case we will need to do a A* search, so we might as well set that up
path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, respect_transition_validity,
forbidden_cells)
# path: IntVector2DArray = quick_path(grid_map, start, end, forbidden_cells=forbidden_cells, openend=False)
# path: IntVector2DArray = quick_path(grid_map, start, end, forbidden_cells=forbidden_cells)
if len(path) < 2:
print("No path found", path)
return []
......@@ -86,7 +86,8 @@ def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, st
return path
def connect_straigt_line(rail_trans, grid_map, start, end, openend=False):
def connect_straigt_line(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D,
end: IntVector2D, openend: bool = False) -> IntVector2DArray:
"""
Generates a straight rail line from start cell to end cell.
Diagonal lines are not allowed
......@@ -102,8 +103,8 @@ def connect_straigt_line(rail_trans, grid_map, start, end, openend=False):
if not (start[0] == end[0] or start[1] == end[1]):
print("No straight line possible!")
return []
current_cell = start
path = [current_cell]
current_cell: IntVector2D = start
path: IntVector2DArray = [current_cell]
new_trans = grid_map.grid[current_cell]
direction = (np.clip(end[0] - start[0], -1, 1), np.clip(end[1] - start[1], -1, 1))
if direction[0] == 0:
......@@ -134,7 +135,8 @@ def connect_straigt_line(rail_trans, grid_map, start, end, openend=False):
return path
def quick_path(grid_map, start, end, forbidden_cells=[], openend=False):
def quick_path(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
forbidden_cells: IntVector2DArray = None) -> IntVector2DArray:
"""
Quick path connecting algorithm with simple heuristic to allways follow largest value of vector towards target.
When obstacle is encountereed second direction of vector is chosen.
......@@ -160,9 +162,8 @@ def quick_path(grid_map, start, end, forbidden_cells=[], openend=False):
if next_position == target:
return next_position, closest_direction
while (not np.array_equal(next_position, np.clip(next_position, [0, 0],
[height - 1,
width - 1])) or next_position in forbidden_cells):
while (not np.array_equal(next_position, np.clip(next_position, [0, 0], [height - 1, width - 1])) or
(forbidden_cells is not None and next_position in forbidden_cells)):
if direction_tries > 1:
closest_direction = (closest_direction + 1) % 4
......
......@@ -587,13 +587,13 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
rail_trans,
grid_map)
# Populate cities
train_stations, built_num_trainstation = _set_trainstation_positions(city_positions, city_radius, free_rails)
train_stations = _set_trainstation_positions(city_positions, city_radius, free_rails)
# Fix all transition elements
_fix_transitions(city_cells, inter_city_lines, grid_map)
# Generate start target pairs
agent_start_targets_cities, num_agents = _generate_start_target_pairs(num_agents, num_cities, train_stations,
agent_start_targets_cities = _generate_start_target_pairs(num_agents, num_cities, train_stations,
city_orientations)
return grid_map, {'agents_hints': {
......@@ -708,8 +708,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
connection_info.append(connections_per_direction)
return inner_connection_points, outer_connection_points, connection_info, city_orientations
def _connect_cities(city_positions: IntVector2DArray, connection_points, city_cells: IntVector2DArray,
rail_trans, grid_map: GridTransitionMap):
def _connect_cities(city_positions: IntVector2DArray, connection_points: List[List[List[IntVector2D]]], city_cells: IntVector2DArray,
rail_trans: RailEnvTransitions, grid_map: GridTransitionMap) -> List[IntVector2DArray]:
"""
Function to connect the different cities through their connection points
:param city_positions: Positions of city centers
......@@ -718,7 +718,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
:param grid_map: Grid map
:return:
"""
all_paths = []
all_paths: List[IntVector2DArray] = []
for current_city_idx in np.arange(len(city_positions)):
neighbours = _closest_neighbour_in_direction(current_city_idx, city_positions)
......@@ -748,8 +748,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
return all_paths
def _build_inner_cities(city_positions, inner_connection_points, outer_connection_points, rail_trans,
grid_map: GridTransitionMap):
def _build_inner_cities(city_positions: IntVector2DArray, inner_connection_points: List[List[List[IntVector2D]]],
outer_connection_points: List[List[List[IntVector2D]]], rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap) -> (List[IntVector2DArray], List[List[List[IntVector2D]]]):
"""
Builds inner city tracks. This current version connects all incoming connections to all outgoing connections
:param city_positions: Positions of the cities
......@@ -759,8 +760,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
:param grid_map:
:return: Returns the cells of the through path which cannot be occupied by trainstations
"""
through_path_cells = [[] for i in range(len(city_positions))]
free_tracks = [[] for i in range(len(city_positions))]
through_path_cells: List[IntVector2DArray] = [[] for i in range(len(city_positions))]
free_rails: List[List[List[IntVector2D]]] = [[] for i in range(len(city_positions))]
for current_city in range(len(city_positions)):
all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in
sublist]
......@@ -784,37 +785,26 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
source = inner_connection_points[current_city][boarder][track_id]
target = inner_connection_points[current_city][opposite_boarder][track_id]
current_track = connect_straigt_line(rail_trans, grid_map, source, target, False)
if target in all_outer_connection_points and source in \
all_outer_connection_points and len(through_path_cells[current_city]) < 1:
if target in all_outer_connection_points and source in all_outer_connection_points and len(through_path_cells[current_city]) < 1:
through_path_cells[current_city].extend(current_track)
else:
free_tracks[current_city].append(current_track)
return through_path_cells, free_tracks
free_rails[current_city].append(current_track)
return through_path_cells, free_rails
def _set_trainstation_positions(city_positions: IntVector2DArray, city_radius: int, free_rails):
"""
:param city_positions:
:param num_trainstations:
:return:
"""
def _set_trainstation_positions(city_positions: IntVector2DArray, city_radius: int,
free_rails: List[List[List[IntVector2D]]]) -> List[List[Tuple[IntVector2D, int]]]:
num_cities = len(city_positions)
train_stations = [[] for i in range(num_cities)]
built_num_trainstations = 0
for current_city in range(len(city_positions)):
for track_nbr in range(len(free_rails[current_city])):
possible_location = free_rails[current_city][track_nbr][city_radius]
train_stations[current_city].append((possible_location, track_nbr))
return train_stations, built_num_trainstations
return train_stations
def _generate_start_target_pairs(num_agents, num_cities, train_stations, city_orientation):
"""
Fill the trainstation positions with targets and goals
:param num_agents:
:param num_cities:
:param train_stations:
:return:
"""
def _generate_start_target_pairs(num_agents: int, num_cities: int,
train_stations: List[List[Tuple[IntVector2D, int]]],
city_orientation: List[Grid4TransitionsEnum]) -> List[Tuple[int, int,
Grid4TransitionsEnum]]:
# Generate start and target city directory for all agents.
# Assure that start and target are not in the same city
agent_start_targets_cities = []
......@@ -839,9 +829,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
start_city = start_target_tuple[0]
target_city = start_target_tuple[1]
agent_start_targets_cities.append((start_city, target_city, city_orientation[start_city]))
return agent_start_targets_cities, num_agents
return agent_start_targets_cities
def _fix_transitions(city_cells, inter_city_lines, grid_map: GridTransitionMap):
def _fix_transitions(city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray],
grid_map: GridTransitionMap):
"""
Function to fix all transition elements in environment
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment