diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py index 3febb521d336f0668d4bc3843dcea93837b9699a..5c09f0ac8ba86ed7987aefcf92a541f2ea5d1de4 100644 --- a/flatland/core/grid/grid4.py +++ b/flatland/core/grid/grid4.py @@ -49,6 +49,9 @@ class Grid4Transitions(Transitions): # row,col delta for each direction self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) + # These bits represent all the possible dead ends + self.maskDeadEnds = 0b0010000110000100 + def get_type(self): return np.uint16 @@ -210,3 +213,19 @@ class Grid4Transitions(Transitions): def get_direction_enum(self) -> IntEnum: return Grid4TransitionsEnum + + def has_deadend(self, cell_transition): + """ + Checks if one entry can only by exited by a turn-around. + """ + if cell_transition & self.maskDeadEnds > 0: + return True + else: + return False + + def remove_deadends(self, cell_transition): + """ + Remove all turn-arounds (e.g. N-S, S-N, E-W,...). + """ + cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff + return cell_transition diff --git a/flatland/core/grid/rail_env_grid.py b/flatland/core/grid/rail_env_grid.py index efb5ea15744928aa3f3d5543524ccc3c07a18626..c043b42f1ca84ba9d0f7a68f5e18a192ff374d7a 100644 --- a/flatland/core/grid/rail_env_grid.py +++ b/flatland/core/grid/rail_env_grid.py @@ -43,9 +43,6 @@ class RailEnvTransitions(Grid4Transitions): transitions=self.transition_list ) - # These bits represent all the possible dead ends - self.maskDeadEnds = 0b0010000110000100 - # create this to make validation faster self.transitions_all = set() for index, trans in enumerate(self.transitions): @@ -112,13 +109,3 @@ class RailEnvTransitions(Grid4Transitions): True or False """ return cell_transition in self.transitions_all - - def has_deadend(self, cell_transition): - if cell_transition & self.maskDeadEnds > 0: - return True - else: - return False - - def remove_deadends(self, cell_transition): - cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff - return cell_transition diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py index 4e69acfd5a10fb37bd1d4fad86662c3bbb1f0d57..048520c17eeddf4d1f4a4c6beeb49427887b77f4 100644 --- a/tests/test_flatland_core_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -4,18 +4,20 @@ """Tests for `flatland` package.""" import numpy as np +from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid8 import Grid8Transitions from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.envs.env_utils import validate_new_transition +# remove whitespace in string; keep whitespace below for easier reading +def rw(s): + return s.replace(" ", "") + + def test_rotate_railenv_transition(): rail_env_transitions = RailEnvTransitions() - # remove whitespace in string; keep whitespace below for easier reading - def rw(s): - return s.replace(" ", "") - # TODO test all cases transition_cycles = [ # empty cell - Case 0 @@ -68,9 +70,17 @@ def test_rotate_railenv_transition(): # int('1100110000110011', 2), \ # noqa: E800 # Case 6 - symmetrical # int('0101001000000010', 2), \ # noqa: E800 - # Case 7 - dead end - # int('0010000000000000', 2), \ # noqa: E800 + # Case 7 - dead end + # + # + # | + [ + int(rw('0010 0000 0000 0000'), 2), + int(rw('0000 0001 0000 0000'), 2), + int(rw('0000 0000 1000 0000'), 2), + int(rw('0000 0000 0000 0100'), 2), + ], ] for index, cycle in enumerate(transition_cycles): @@ -206,3 +216,33 @@ def test_diagonal_transitions(): assert (diagonal_trans_env.rotate_transition( south_northeast_transition, 180) == north_southwest_transition) + + +def test_rail_env_has_deadend(): + deadends = set([int(rw('0010 0000 0000 0000'), 2), + int(rw('0000 0001 0000 0000'), 2), + int(rw('0000 0000 1000 0000'), 2), + int(rw('0000 0000 0000 0100'), 2)]) + ret = RailEnvTransitions() + transitions_all = ret.transitions_all + for t in transitions_all: + expected_has_deadend = t in deadends + actual_had_deadend = ret.has_deadend(t) + assert actual_had_deadend == expected_has_deadend, \ + "{} should be deadend = {}, actual = {}".format(t, ) + + +def test_rail_env_remove_deadend(): + ret = Grid4Transitions([]) + rail_env_deadends = set([int(rw('0010 0000 0000 0000'), 2), + int(rw('0000 0001 0000 0000'), 2), + int(rw('0000 0000 1000 0000'), 2), + int(rw('0000 0000 0000 0100'), 2)]) + for t in rail_env_deadends: + expected_has_deadend = 0 + actual_had_deadend = ret.remove_deadends(t) + assert actual_had_deadend == expected_has_deadend, \ + "{} should be deadend = {}, actual = {}".format(t, ) + + assert ret.remove_deadends(int(rw('0010 0001 1000 0100'), 2)) == 0 + assert ret.remove_deadends(int(rw('0010 0001 1000 0110'), 2)) == int(rw('0000 0000 0000 0010'), 2)