From 50a24136360c3dd8080f0c2e8fb66da90b41dcff Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 24 Sep 2019 09:25:49 +0200 Subject: [PATCH] utility for getting entry directions of a cell --- flatland/core/grid/grid4.py | 3 +++ tests/test_flatland_core_transition_map.py | 27 ++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py index da721dd9..2d7284dc 100644 --- a/flatland/core/grid/grid4.py +++ b/flatland/core/grid/grid4.py @@ -237,3 +237,6 @@ class Grid4Transitions(Transitions): """ cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff return cell_transition + + def get_entry_directions(self, cell_transition): + return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)] diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 1137b881..4da1da4d 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -1,5 +1,6 @@ from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -158,3 +159,29 @@ def test_path_not_exists(rendering=False): renderer = RenderTool(env, gl="PILSVG") renderer.render_env(show=True, show_observations=False) input("Continue?") + + +def test_get_entry_directions(): + transitions = RailEnvTransitions() + cells = transitions.transition_list + vertical_line = cells[1] + south_symmetrical_switch = cells[6] + north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) + + # Simple turn not in the base transitions ? + south_east_turn = int('0100000000000010', 2) + south_west_turn = transitions.rotate_transition(south_east_turn, 90) + north_east_turn = transitions.rotate_transition(south_east_turn, 270) + north_west_turn = transitions.rotate_transition(south_east_turn, 180) + + def _assert(transition, expected): + actual = transitions.get_entry_directions(transition) + assert actual == expected, "Found {}, expected {}.".format(actual, expected) + + _assert(south_east_turn, [True, False, False, True]) + _assert(south_west_turn, [True, True, False, False]) + _assert(north_east_turn, [False, False, True, True]) + _assert(north_west_turn, [False, True, True, False]) + _assert(vertical_line, [True, False, True, False]) + _assert(south_symmetrical_switch, [True, True, False, True]) + _assert(north_symmetrical_switch, [False, True, True, True]) -- GitLab