From 2b292dcce88c71466f08938053acd8d18255a822 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 20 Jun 2019 15:14:28 +0200 Subject: [PATCH] #62 increase unit test coverage --- flatland/core/grid/grid4.py | 19 +++++++++ flatland/core/grid/rail_env_grid.py | 13 ------- tests/test_flatland_core_transitions.py | 52 ++++++++++++++++++++++--- 3 files changed, 65 insertions(+), 19 deletions(-) diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py index 3febb52..5c09f0a 100644 --- a/flatland/core/grid/grid4.py +++ b/flatland/core/grid/grid4.py @@ -49,6 +49,9 @@ class Grid4Transitions(Transitions): # row,col delta for each direction self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) + # These bits represent all the possible dead ends + self.maskDeadEnds = 0b0010000110000100 + def get_type(self): return np.uint16 @@ -210,3 +213,19 @@ class Grid4Transitions(Transitions): def get_direction_enum(self) -> IntEnum: return Grid4TransitionsEnum + + def has_deadend(self, cell_transition): + """ + Checks if one entry can only by exited by a turn-around. + """ + if cell_transition & self.maskDeadEnds > 0: + return True + else: + return False + + def remove_deadends(self, cell_transition): + """ + Remove all turn-arounds (e.g. N-S, S-N, E-W,...). + """ + cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff + return cell_transition diff --git a/flatland/core/grid/rail_env_grid.py b/flatland/core/grid/rail_env_grid.py index efb5ea1..c043b42 100644 --- a/flatland/core/grid/rail_env_grid.py +++ b/flatland/core/grid/rail_env_grid.py @@ -43,9 +43,6 @@ class RailEnvTransitions(Grid4Transitions): transitions=self.transition_list ) - # These bits represent all the possible dead ends - self.maskDeadEnds = 0b0010000110000100 - # create this to make validation faster self.transitions_all = set() for index, trans in enumerate(self.transitions): @@ -112,13 +109,3 @@ class RailEnvTransitions(Grid4Transitions): True or False """ return cell_transition in self.transitions_all - - def has_deadend(self, cell_transition): - if cell_transition & self.maskDeadEnds > 0: - return True - else: - return False - - def remove_deadends(self, cell_transition): - cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff - return cell_transition diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py index 4e69acf..048520c 100644 --- a/tests/test_flatland_core_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -4,18 +4,20 @@ """Tests for `flatland` package.""" import numpy as np +from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid8 import Grid8Transitions from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.envs.env_utils import validate_new_transition +# remove whitespace in string; keep whitespace below for easier reading +def rw(s): + return s.replace(" ", "") + + def test_rotate_railenv_transition(): rail_env_transitions = RailEnvTransitions() - # remove whitespace in string; keep whitespace below for easier reading - def rw(s): - return s.replace(" ", "") - # TODO test all cases transition_cycles = [ # empty cell - Case 0 @@ -68,9 +70,17 @@ def test_rotate_railenv_transition(): # int('1100110000110011', 2), \ # noqa: E800 # Case 6 - symmetrical # int('0101001000000010', 2), \ # noqa: E800 - # Case 7 - dead end - # int('0010000000000000', 2), \ # noqa: E800 + # Case 7 - dead end + # + # + # | + [ + int(rw('0010 0000 0000 0000'), 2), + int(rw('0000 0001 0000 0000'), 2), + int(rw('0000 0000 1000 0000'), 2), + int(rw('0000 0000 0000 0100'), 2), + ], ] for index, cycle in enumerate(transition_cycles): @@ -206,3 +216,33 @@ def test_diagonal_transitions(): assert (diagonal_trans_env.rotate_transition( south_northeast_transition, 180) == north_southwest_transition) + + +def test_rail_env_has_deadend(): + deadends = set([int(rw('0010 0000 0000 0000'), 2), + int(rw('0000 0001 0000 0000'), 2), + int(rw('0000 0000 1000 0000'), 2), + int(rw('0000 0000 0000 0100'), 2)]) + ret = RailEnvTransitions() + transitions_all = ret.transitions_all + for t in transitions_all: + expected_has_deadend = t in deadends + actual_had_deadend = ret.has_deadend(t) + assert actual_had_deadend == expected_has_deadend, \ + "{} should be deadend = {}, actual = {}".format(t, ) + + +def test_rail_env_remove_deadend(): + ret = Grid4Transitions([]) + rail_env_deadends = set([int(rw('0010 0000 0000 0000'), 2), + int(rw('0000 0001 0000 0000'), 2), + int(rw('0000 0000 1000 0000'), 2), + int(rw('0000 0000 0000 0100'), 2)]) + for t in rail_env_deadends: + expected_has_deadend = 0 + actual_had_deadend = ret.remove_deadends(t) + assert actual_had_deadend == expected_has_deadend, \ + "{} should be deadend = {}, actual = {}".format(t, ) + + assert ret.remove_deadends(int(rw('0010 0001 1000 0100'), 2)) == 0 + assert ret.remove_deadends(int(rw('0010 0001 1000 0110'), 2)) == int(rw('0000 0000 0000 0010'), 2) -- GitLab