diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py index da721dd99fbc0ea46124dcd036425df3206aa339..2d7284dc74cf2411ee2f0466d0d653bffb2132dc 100644 --- a/flatland/core/grid/grid4.py +++ b/flatland/core/grid/grid4.py @@ -237,3 +237,6 @@ class Grid4Transitions(Transitions): """ cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff return cell_transition + + def get_entry_directions(self, cell_transition): + return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)] diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 1137b8816973c12601029543c221810c9acd157c..4da1da4d23cc98ec7032530ee51820f29b637c17 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -1,5 +1,6 @@ from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -158,3 +159,29 @@ def test_path_not_exists(rendering=False): renderer = RenderTool(env, gl="PILSVG") renderer.render_env(show=True, show_observations=False) input("Continue?") + + +def test_get_entry_directions(): + transitions = RailEnvTransitions() + cells = transitions.transition_list + vertical_line = cells[1] + south_symmetrical_switch = cells[6] + north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) + + # Simple turn not in the base transitions ? + south_east_turn = int('0100000000000010', 2) + south_west_turn = transitions.rotate_transition(south_east_turn, 90) + north_east_turn = transitions.rotate_transition(south_east_turn, 270) + north_west_turn = transitions.rotate_transition(south_east_turn, 180) + + def _assert(transition, expected): + actual = transitions.get_entry_directions(transition) + assert actual == expected, "Found {}, expected {}.".format(actual, expected) + + _assert(south_east_turn, [True, False, False, True]) + _assert(south_west_turn, [True, True, False, False]) + _assert(north_east_turn, [False, False, True, True]) + _assert(north_west_turn, [False, True, True, False]) + _assert(vertical_line, [True, False, True, False]) + _assert(south_symmetrical_switch, [True, True, False, True]) + _assert(north_symmetrical_switch, [False, True, True, True])