diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py index 98652459d7a7ac7f1694ac53fe1d0a12880ab8b2..75cef7b4d3aea783140a5c08c3498a0bc321fb62 100644 --- a/flatland/core/grid/grid4_utils.py +++ b/flatland/core/grid/grid4_utils.py @@ -1,8 +1,8 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.core.grid.grid_utils import IntVector2DArray +from flatland.core.grid.grid_utils import IntVector2D -def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4TransitionsEnum: +def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum: """ Assumes pos1 and pos2 are adjacent location on grid. Returns direction (int) that can be used with transitions. @@ -10,13 +10,13 @@ def get_direction(pos1: IntVector2DArray, pos2: IntVector2DArray) -> Grid4Transi diff_0 = pos2[0] - pos1[0] diff_1 = pos2[1] - pos1[1] if diff_0 < 0: - return 0 + return Grid4TransitionsEnum.NORTH if diff_0 > 0: - return 2 + return Grid4TransitionsEnum.SOUTH if diff_1 > 0: - return 1 + return Grid4TransitionsEnum.EAST if diff_1 < 0: - return 3 + return Grid4TransitionsEnum.WEST raise Exception("Could not determine direction {}->{}".format(pos1, pos2)) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index d6f47abfd85cfa1cc7e72e27aeb4f7ededa975dd..1dcd2a31f0f24686621c9d114cc32547326e72df 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -7,22 +7,24 @@ a GridTransitionMap object. from flatland.core.grid.grid4_astar import a_star from flatland.core.grid.grid4_utils import get_direction, mirror -from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance +from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions -def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, - end: IntVector2D, - flip_start_node_trans=False, - flip_end_node_trans=False, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): +def connect_basic_operation( + rail_trans: RailEnvTransitions, + grid_map: GridTransitionMap, + start: IntVector2D, + end: IntVector2D, + flip_start_node_trans=False, + flip_end_node_trans=False, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: """ - Creates a new path [start,end] in grid_map, based on rail_trans. + Creates a new path [start,end] in `grid_map.grid`, based on rail_trans. """ # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(grid_map, start, end, a_star_distance_function) + path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function) if len(path) < 2: return [] current_dir = get_direction(path[0], path[1]) @@ -71,23 +73,24 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function) def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function) def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance + ) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function) def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance): + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function) diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py index b95922cf67febdaa0aad396459bc446bc31adfea..cf5c8592708eef237bcf29308032df49753860bd 100644 --- a/tests/test_flatland_envs_env_utils.py +++ b/tests/test_flatland_envs_env_utils.py @@ -2,8 +2,8 @@ import numpy as np import pytest from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position from flatland.core.grid.grid4_utils import get_direction +from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position depth_to_test = 5 positions_to_test = [0, 5, 1, 6, 20, 30] @@ -31,4 +31,4 @@ def test_get_direction(): assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH with pytest.raises(Exception, match="Could not determine direction"): - get_direction((0, 0), (0, 0)) == Grid4TransitionsEnum.NORTH + get_direction((0, 0), (0, 0))