From 43ffa6304159cde4b817ea78cf2e01b83473ae22 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 3 Oct 2019 15:57:32 -0400 Subject: [PATCH] fixed failing connection tests --- flatland/envs/grid4_generators_utils.py | 20 ++++++----------- flatland/envs/rail_generators.py | 17 ++++++-------- ...est_flatland_core_grid4_generators_util.py | 22 ++++++++++++------- 3 files changed, 28 insertions(+), 31 deletions(-) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index a332cd70..bc43f90e 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -14,17 +14,10 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.transition_map import GridTransitionMap, RailEnvTransitions -def connect_rail( - rail_trans: RailEnvTransitions, - grid_map: GridTransitionMap, - start: IntVector2D, - end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, - flip_start_node_trans=False, - flip_end_node_trans=False, - nice=True, - forbidden_cells=None -) -> IntVector2DArray: +def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, + flip_start_node_trans=False, flip_end_node_trans=False, respect_transition_validity=True, + forbidden_cells=None) -> IntVector2DArray: """ Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and returns the path created as a list of positions. @@ -34,14 +27,15 @@ def connect_rail( :param end: :param flip_start_node_trans: :param flip_end_node_trans: - :param nice: + :param respect_transition_validity: :param a_star_distance_function: :param forbidden_cells: :return: """ # in the worst case we will need to do a A* search, so we might as well set that up - path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice, forbidden_cells) + path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, respect_transition_validity, + forbidden_cells) if len(path) < 2: print("No path found", path) return [] diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 557f25e4..e6b9cf69 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -127,9 +127,8 @@ def complex_rail_generator(nr_start_goal=1, break new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance, - flip_start_node_trans=True, - flip_end_node_trans=True, nice=True, - forbidden_cells=None) + flip_start_node_trans=True, flip_end_node_trans=True, + respect_transition_validity=True, forbidden_cells=None) if len(new_path) >= 2: nr_created += 1 start_goal.append([start, goal]) @@ -155,9 +154,8 @@ def complex_rail_generator(nr_start_goal=1, if not all_ok: break new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance, - flip_start_node_trans=True, - flip_end_node_trans=True, nice=True, - forbidden_cells=None) + flip_start_node_trans=True, flip_end_node_trans=True, + respect_transition_validity=True, forbidden_cells=None) if len(new_path) >= 2: nr_created += 1 @@ -743,10 +741,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ min_connection_dist = tmp_dist neighb_connection_point = tmp_in_connection_point neighbour_direction = dir - new_line = connect_rail(rail_trans, grid_map, tmp_out_connection_point, - neighb_connection_point, flip_start_node_trans=False, - flip_end_node_trans=False, nice=True, - forbidden_cells=city_cells) + new_line = connect_rail(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, + flip_start_node_trans=False, flip_end_node_trans=False, + respect_transition_validity=False, forbidden_cells=city_cells) all_paths.extend(new_line) return all_paths diff --git a/tests/test_flatland_core_grid4_generators_util.py b/tests/test_flatland_core_grid4_generators_util.py index ba71daac..9722063e 100644 --- a/tests/test_flatland_core_grid4_generators_util.py +++ b/tests/test_flatland_core_grid4_generators_util.py @@ -1,6 +1,5 @@ import numpy as np -from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.grid4_generators_utils import connect_rail @@ -14,23 +13,31 @@ def test_build_railway_infrastructure(): start_point = (2, 2) end_point = (8, 8) - connection_001 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_001 = connect_rail(rail_trans, grid_map, start_point, end_point, + flip_start_node_trans=True, flip_end_node_trans=True, + respect_transition_validity=True, forbidden_cells=None) connection_001_expected = [(2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (3, 8), (4, 8), (5, 8), (6, 8), (7, 8), (8, 8)] start_point = (1, 3) end_point = (1, 7) - connection_002 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_002 = connect_rail(rail_trans, grid_map, start_point, end_point, + flip_start_node_trans=False, flip_end_node_trans=False, + respect_transition_validity=True, forbidden_cells=None) connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)] start_point = (6, 2) end_point = (6, 5) - connection_003 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_003 = connect_rail(rail_trans, grid_map, start_point, end_point, + flip_start_node_trans=False, flip_end_node_trans=True, + respect_transition_validity=True, forbidden_cells=None) connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)] start_point = (7, 5) end_point = (8, 9) - connection_004 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_004 = connect_rail(rail_trans, grid_map, start_point, end_point, + flip_start_node_trans=True, flip_end_node_trans=False, + respect_transition_validity=True, forbidden_cells=None) connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)] assert connection_001 == connection_001_expected, \ @@ -64,6 +71,5 @@ def test_build_railway_infrastructure(): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ] - - assert np.all(grid_map.grid == grid_map_grid_expected), \ - "actual={}, expected={}".format(grid_map.grid, grid_map_grid_expected) + for i in range(len(grid_map_grid_expected)): + assert np.all(grid_map.grid[i] == grid_map_grid_expected[i]) -- GitLab