From af5ed9877178e83285001fed8baa44d90f8d601e Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 3 Oct 2019 15:35:15 -0400 Subject: [PATCH] refactoring of rail connection functions --- flatland/envs/grid4_generators_utils.py | 61 +++++-------------- flatland/envs/rail_generators.py | 18 ++++-- ...est_flatland_core_grid4_generators_util.py | 8 +-- 3 files changed, 30 insertions(+), 57 deletions(-) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 77b54e54..dfd9218f 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -14,7 +14,7 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions -def connect_basic_operation( +def connect_rail( rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, @@ -26,9 +26,20 @@ def connect_basic_operation( forbidden_cells=None ) -> IntVector2DArray: """ - Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and + Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and returns the path created as a list of positions. + :param rail_trans: + :param grid_map: + :param start: + :param end: + :param flip_start_node_trans: + :param flip_end_node_trans: + :param nice: + :param a_star_distance_function: + :param forbidden_cells: + :return: """ + # in the worst case we will need to do a A* search, so we might as well set that up path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice, forbidden_cells) if len(path) < 2: @@ -79,7 +90,7 @@ def connect_basic_operation( return path -def connect_line(rail_trans, grid_map, start, end, openend=False): +def connect_straigt_line(rail_trans, grid_map, start, end, openend=False): """ Generates a straight rail line from start cell to end cell. Diagonal lines are not allowed @@ -126,47 +137,3 @@ def connect_line(rail_trans, grid_map, start, end, openend=False): path.append(current_cell) return path -def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: - """ - - :param rail_trans: - :param grid_map: - :param start: - :param end: - :param a_star_distance_function: - :return: - """ - return connect_basic_operation(rail_trans, grid_map, start, end, True, True, True, a_star_distance_function) - - -def connect_cities(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, - forbidden_cells=None) -> IntVector2DArray: - """ - - :param rail_trans: - :param grid_map: - :param start: - :param end: - :param forbidden_cells: - :param a_star_distance_function: - :return: - """ - return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function, - forbidden_cells) - -def connect_straigt_line(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, - end: IntVector2D, openend=False) -> IntVector2DArray: - """ - - :param rail_trans: - :param grid_map: - :param start: - :param end: - :param openend: - :return: - """ - return connect_line(rail_trans, grid_map, start, end, openend) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 7d7174d3..ee31666c 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -11,7 +11,7 @@ from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, I Vec2dOperations from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail, connect_cities, connect_straigt_line +from flatland.envs.grid4_generators_utils import connect_rail, connect_straigt_line RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] @@ -125,7 +125,9 @@ def complex_rail_generator(nr_start_goal=1, # we might as well give up at this point break - new_path = connect_rail(rail_trans, grid_map, start, goal) + new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True, + flip_end_node_trans=True, nice=True, + forbidden_cells=None) if len(new_path) >= 2: nr_created += 1 start_goal.append([start, goal]) @@ -150,7 +152,10 @@ def complex_rail_generator(nr_start_goal=1, break if not all_ok: break - new_path = connect_rail(rail_trans, grid_map, start, goal) + new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True, + flip_end_node_trans=True, nice=True, + forbidden_cells=None) + if len(new_path) >= 2: nr_created += 1 @@ -735,9 +740,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ min_connection_dist = tmp_dist neighb_connection_point = tmp_in_connection_point neighbour_direction = dir - new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, - neighb_connection_point, - forbidden_cells=city_cells) + new_line = connect_rail(rail_trans, grid_map, tmp_out_connection_point, + neighb_connection_point, flip_start_node_trans=False, + flip_end_node_trans=False, nice=True, + forbidden_cells=city_cells) all_paths.extend(new_line) return all_paths diff --git a/tests/test_flatland_core_grid4_generators_util.py b/tests/test_flatland_core_grid4_generators_util.py index 64c03f3c..ba71daac 100644 --- a/tests/test_flatland_core_grid4_generators_util.py +++ b/tests/test_flatland_core_grid4_generators_util.py @@ -3,7 +3,7 @@ import numpy as np from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail, connect_cities +from flatland.envs.grid4_generators_utils import connect_rail def test_build_railway_infrastructure(): @@ -20,17 +20,17 @@ def test_build_railway_infrastructure(): start_point = (1, 3) end_point = (1, 7) - connection_002 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_002 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)] start_point = (6, 2) end_point = (6, 5) - connection_003 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_003 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)] start_point = (7, 5) end_point = (8, 9) - connection_004 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_004 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)] assert connection_001 == connection_001_expected, \ -- GitLab