diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 042e3c5f0dfc5d5f55824ee2f30f81d6d7f774bf..77b54e54a4bf9c5dda9d94a6e68deb83c8a66bc1 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -49,7 +49,7 @@ def connect_basic_operation( # need to flip direction because of how end points are defined new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) else: - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) # 0 + new_trans = 0 else: # into existing rail new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) @@ -69,7 +69,7 @@ def connect_basic_operation( if flip_end_node_trans: new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) else: - new_trans_e = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) #0 + new_trans_e = 0 else: # into existing rail new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) @@ -93,7 +93,7 @@ def connect_line(rail_trans, grid_map, start, end, openend=False): # Assert that a straight line is possible if not (start[0] == end[0] or start[1] == end[1]): - print("No line possible") + print("No straight line possible!") return [] current_cell = start path = [current_cell] @@ -129,34 +129,44 @@ def connect_line(rail_trans, grid_map, start, end, openend=False): def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: - return connect_basic_operation(rail_trans, grid_map, start, end, True, True, a_star_distance_function) - + """ -def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: - return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function) + :param rail_trans: + :param grid_map: + :param start: + :param end: + :param a_star_distance_function: + :return: + """ + return connect_basic_operation(rail_trans, grid_map, start, end, True, True, True, a_star_distance_function) def connect_cities(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, forbidden_cells=None, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: + start: IntVector2D, end: IntVector2D, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, + forbidden_cells=None) -> IntVector2DArray: + """ + + :param rail_trans: + :param grid_map: + :param start: + :param end: + :param forbidden_cells: + :param a_star_distance_function: + :return: + """ return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function, forbidden_cells) -def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance - ) -> IntVector2DArray: - return connect_basic_operation(rail_trans, grid_map, start, end, False, True, a_star_distance_function) - - -def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, - start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: - return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function) - - def connect_straigt_line(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, openend=False) -> IntVector2DArray: + """ + + :param rail_trans: + :param grid_map: + :param start: + :param end: + :param openend: + :return: + """ return connect_line(rail_trans, grid_map, start, end, openend) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 7bd22cbfa8b43f4b42efb35c0da4ce5f7dc2e7aa..7d7174d3d95334002c890fb7ee0176745266010a 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -737,7 +737,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ neighbour_direction = dir new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, - city_cells) + forbidden_cells=city_cells) all_paths.extend(new_line) return all_paths diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 3a29b64859b6de6d26b9a9deae8e17095d9d10c1..377e1ccb65956a8f70e7223ad37779f43e2ccee4 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -8,9 +8,6 @@ from PIL import Image, ImageDraw, ImageTk, ImageFont from numpy import array from pkg_resources import resource_string as resource_bytes -from flatland.core.grid.grid_utils import Vec2dOperations -from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_nodes from flatland.utils.graphics_layer import GraphicsLayer diff --git a/tests/test_flatland_core_grid4_generators_util.py b/tests/test_flatland_core_grid4_generators_util.py index 72deddc66eacc71a5aa840b49225e5ba056a8b84..64c03f3cf5516c17de8f6051fe731ce27b5c5b65 100644 --- a/tests/test_flatland_core_grid4_generators_util.py +++ b/tests/test_flatland_core_grid4_generators_util.py @@ -3,7 +3,7 @@ import numpy as np from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes, connect_to_nodes +from flatland.envs.grid4_generators_utils import connect_rail, connect_cities def test_build_railway_infrastructure(): @@ -20,17 +20,17 @@ def test_build_railway_infrastructure(): start_point = (1, 3) end_point = (1, 7) - connection_002 = connect_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_002 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)] start_point = (6, 2) end_point = (6, 5) - connection_003 = connect_from_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_003 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)] start_point = (7, 5) end_point = (8, 9) - connection_004 = connect_to_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_004 = connect_cities(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)] assert connection_001 == connection_001_expected, \