From 275e66912c38863432e26abcaffda990d264a515 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Thu, 19 Sep 2019 07:48:35 +0200 Subject: [PATCH] test infrastructure generator --- ...est_flatland_core_grid4_generators_util.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/test_flatland_core_grid4_generators_util.py diff --git a/tests/test_flatland_core_grid4_generators_util.py b/tests/test_flatland_core_grid4_generators_util.py new file mode 100644 index 00000000..9e85d522 --- /dev/null +++ b/tests/test_flatland_core_grid4_generators_util.py @@ -0,0 +1,62 @@ +import numpy as np + +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d +from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes, connect_to_nodes + + +def test_build_railway_infrastructure(): + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=20, height=20, transitions=rail_trans) + grid_map.grid.fill(0) + np.random.seed(0) + + start_point = (2, 2) + end_point = (8, 8) + connection_001 = connect_rail(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_001_expected = [(2, 2), (3, 2), (3, 3), (4, 3), (4, 4), (5, 4), (5, 5), (5, 6), (6, 6), (6, 7), (7, 7), + (8, 7), (8, 8)] + + start_point = (1, 3) + end_point = (1, 7) + connection_002 = connect_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_002_expected = [(1, 3), (1, 4), (1, 5), (1, 6), (1, 7)] + + start_point = (6, 2) + end_point = (6, 5) + connection_003 = connect_from_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_003_expected = [(6, 2), (6, 3), (6, 4), (6, 5)] + + start_point = (7, 5) + end_point = (8, 9) + connection_004 = connect_to_nodes(rail_trans, grid_map, start_point, end_point, Vec2d.get_manhattan_distance) + connection_004_expected = [(7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (8, 9)] + + assert connection_001 == connection_001_expected + assert connection_002 == connection_002_expected + assert connection_003 == connection_003_expected + assert connection_004 == connection_004_expected + + grid_map_grid_expected = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1025, 1025, 1025, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 8192, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 72, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 72, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 72, 1025, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1025, 1025, 256, 72, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 4, 1025, 33825, 1025, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 72, 256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + + assert np.all(grid_map.grid == grid_map_grid_expected) -- GitLab