Skip to content
Snippets Groups Projects
Commit af5ed987 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

refactoring of rail connection functions

parent 4a59c1af
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,7 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d ...@@ -14,7 +14,7 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions
def connect_basic_operation( def connect_rail(
rail_trans: RailEnvTransitions, rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap, grid_map: GridTransitionMap,
start: IntVector2D, start: IntVector2D,
...@@ -26,9 +26,20 @@ def connect_basic_operation( ...@@ -26,9 +26,20 @@ def connect_basic_operation(
forbidden_cells=None forbidden_cells=None
) -> IntVector2DArray: ) -> IntVector2DArray:
""" """
Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
returns the path created as a list of positions. returns the path created as a list of positions.
:param rail_trans:
:param grid_map:
:param start:
:param end:
:param flip_start_node_trans:
:param flip_end_node_trans:
:param nice:
:param a_star_distance_function:
:param forbidden_cells:
:return:
""" """
# in the worst case we will need to do a A* search, so we might as well set that up # in the worst case we will need to do a A* search, so we might as well set that up
path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice, forbidden_cells) path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice, forbidden_cells)
if len(path) < 2: if len(path) < 2:
...@@ -79,7 +90,7 @@ def connect_basic_operation( ...@@ -79,7 +90,7 @@ def connect_basic_operation(
return path return path
def connect_line(rail_trans, grid_map, start, end, openend=False): def connect_straigt_line(rail_trans, grid_map, start, end, openend=False):
""" """
Generates a straight rail line from start cell to end cell. Generates a straight rail line from start cell to end cell.
Diagonal lines are not allowed Diagonal lines are not allowed
...@@ -126,47 +137,3 @@ def connect_line(rail_trans, grid_map, start, end, openend=False): ...@@ -126,47 +137,3 @@ def connect_line(rail_trans, grid_map, start, end, openend=False):
path.append(current_cell) path.append(current_cell)
return path return path
def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
"""
:param rail_trans:
:param grid_map:
:param start:
:param end:
:param a_star_distance_function:
:return:
"""
return connect_basic_operation(rail_trans, grid_map, start, end, True, True, True, a_star_distance_function)
def connect_cities(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
forbidden_cells=None) -> IntVector2DArray:
"""
:param rail_trans:
:param grid_map:
:param start:
:param end:
:param forbidden_cells:
:param a_star_distance_function:
:return:
"""
return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function,
forbidden_cells)
def connect_straigt_line(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D,
end: IntVector2D, openend=False) -> IntVector2DArray:
"""
:param rail_trans:
:param grid_map:
:param start:
:param end:
:param openend:
:return:
"""
return connect_line(rail_trans, grid_map, start, end, openend)
...@@ -11,7 +11,7 @@ from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, I ...@@ -11,7 +11,7 @@ from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, I
Vec2dOperations Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail, connect_cities, connect_straigt_line from flatland.envs.grid4_generators_utils import connect_rail, connect_straigt_line
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
...@@ -125,7 +125,9 @@ def complex_rail_generator(nr_start_goal=1, ...@@ -125,7 +125,9 @@ def complex_rail_generator(nr_start_goal=1,
# we might as well give up at this point # we might as well give up at this point
break break
new_path = connect_rail(rail_trans, grid_map, start, goal) new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True,
flip_end_node_trans=True, nice=True,
forbidden_cells=None)
if len(new_path) >= 2: if len(new_path) >= 2:
nr_created += 1 nr_created += 1
start_goal.append([start, goal]) start_goal.append([start, goal])
...@@ -150,7 +152,10 @@ def complex_rail_generator(nr_start_goal=1, ...@@ -150,7 +152,10 @@ def complex_rail_generator(nr_start_goal=1,
break break
if not all_ok: if not all_ok:
break break
new_path = connect_rail(rail_trans, grid_map, start, goal) new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True,
flip_end_node_trans=True, nice=True,
forbidden_cells=None)
if len(new_path) >= 2: if len(new_path) >= 2:
nr_created += 1 nr_created += 1
...@@ -735,9 +740,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ ...@@ -735,9 +740,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
min_connection_dist = tmp_dist min_connection_dist = tmp_dist
neighb_connection_point = tmp_in_connection_point neighb_connection_point = tmp_in_connection_point
neighbour_direction = dir neighbour_direction = dir
new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, new_line = connect_rail(rail_trans, grid_map, tmp_out_connection_point,
neighb_connection_point, neighb_connection_point, flip_start_node_trans=False,
forbidden_cells=city_cells) flip_end_node_trans=False, nice=True,
forbidden_cells=city_cells)
all_paths.extend(new_line) all_paths.extend(new_line)
return all_paths return all_paths
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail, connect_cities from flatland.envs.grid4_generators_utils import connect_rail
def test_build_railway_infrastructure(): def test_build_railway_infrastructure():
...@@ -20,17 +20,17 @@ def test_build_railway_infrastructure(): ...@@ -20,17 +20,17 @@ def test_build_railway_infrastructure():
start_point = (1, 3) start_point = (1, 3)
end_point = (1, 7) end_point = (1, 7)
connection_002 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) connection_002 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)] connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)]
start_point = (6, 2) start_point = (6, 2)
end_point = (6, 5) end_point = (6, 5)
connection_003 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) connection_003 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)] connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)]
start_point = (7, 5) start_point = (7, 5)
end_point = (8, 9) end_point = (8, 9)
connection_004 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) connection_004 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance)
connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)] connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)]
assert connection_001 == connection_001_expected, \ assert connection_001 == connection_001_expected, \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment