diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py index c32b03857911295fa558957e2a660db5704ac783..47def83c8923ea44b884f1341954944901dd5db0 100644 --- a/tests/test_flatland_core_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -10,27 +10,76 @@ from flatland.envs.env_utils import validate_new_transition def test_rotate_railenv_transition(): rail_env_transitions = RailEnvTransitions() + + # remove whitespace in string; keep whitespace below for easier reading + def rw(s): + return s.replace(" ", "") + + # TODO test all cases transition_cycles = [ # empty cell - Case 0 - [int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2)], - # Case 1 - straight - [int('1000000000100000', 2), int('0000000100000100', 2)], + [int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2), + int('0000000000000000', 2)], + # Case 1 - straight + # | + # | + # | + [int(rw('1000 0000 0010 0000'), 2), int(rw('0000 0100 0000 0001'), 2)], + # Case 1b (8) - simple turn right + # _ + # | + # | + [ + int(rw('0100 0000 0000 0010'), 2), + int(rw('0001 0010 0000 0000'), 2), + int(rw('0000 1000 0001 0000'), 2), + int(rw('0000 0000 0100 1000'), 2), + ], + # Case 1c (9) - simple turn left + # _ + # | + # | + + # int('0001001000000000', 2), + + # Case 2 - simple left switch + # _ _| + # | + # | + [ + int(rw('1001 0010 0010 0000'), 2), + int(rw('0000 1100 0001 0001'), 2), + int(rw('1000 0000 0110 1000'), 2), + int(rw('0100 0100 0000 0011'), 2), + ], + # Case 2b (10) - simple right switch + # | + # | + # | + # int('1100000000100010', 2)] + # int('1000010000100001', 2), # Case 3 - diamond drossing + # int('1001011000100001', 2), # Case 4 - single slip + # int('1100110000110011', 2), # Case 5 - double slip + # int('0101001000000010', 2), # Case 6 - symmetrical + # int('0010000000000000', 2), # Case 7 - dead end + ] - for cycle in transition_cycles: + for index, cycle in enumerate(transition_cycles): for i in range(4): - assert rail_env_transitions.rotate_transition(cycle[0], i) == cycle[i % len(cycle)] - - # - # int('1001001000100000', 2), # Case 2 - simple switch - # int('1000010000100001', 2), # Case 3 - diamond drossing - # int('1001011000100001', 2), # Case 4 - single slip - # int('1100110000110011', 2), # Case 5 - double slip - # int('0101001000000010', 2), # Case 6 - symmetrical - # int('0010000000000000', 2), # Case 7 - dead end - # int('0100000000000010', 2), # Case 1b (8) - simple turn right - # int('0001001000000000', 2), # Case 1c (9) - simple turn left - # int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored + actual_transition = rail_env_transitions.rotate_transition(cycle[0], i * 90) + expected_transition = cycle[i % len(cycle)] + try: + assert actual_transition == expected_transition, \ + "Case {}: rotate_transition({}, {}) should equal {} but was {}." \ + .format(i, cycle[0], i, expected_transition, actual_transition) + except Exception as e: + print("expected:") + rail_env_transitions.print(expected_transition) + print("actual:") + rail_env_transitions.print(actual_transition) + + raise e def test_is_valid_railenv_transitions():