Skip to content
Snippets Groups Projects
test_flatland_core_transitions.py 7.22 KiB
Newer Older
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Tests for `flatland` package."""
from flatland.core.transitions import RailEnvTransitions, Grid8Transitions
from flatland.envs.env_utils import validate_new_transition
def test_rotate_railenv_transition():
    rail_env_transitions = RailEnvTransitions()

    # remove whitespace in string; keep whitespace below for easier reading
    def rw(s):
        return s.replace(" ", "")

    # TODO test all cases
    transition_cycles = [
        # empty cell - Case 0
        [int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2),
         int('0000000000000000', 2)],
        # Case 1 - straight
        #     |
        #     |
        #     |
        [int(rw('1000 0000 0010 0000'), 2), int(rw('0000 0100 0000 0001'), 2)],
        # Case 1b (8)  - simple turn right
        #      _
        #     |
        #     |
        [
            int(rw('0100 0000 0000 0010'), 2),
            int(rw('0001 0010 0000 0000'), 2),
            int(rw('0000 1000 0001 0000'), 2),
            int(rw('0000 0000 0100 1000'), 2),
        ],
        # Case 1c (9)  - simple turn left
        #    _
        #     |
        #     |

        #                int('0001001000000000', 2),

        # Case 2 - simple left switch
        #  _ _|
        #     |
        #     |
        [
            int(rw('1001 0010 0010 0000'), 2),
            int(rw('0000 1100 0001 0001'), 2),
            int(rw('1000 0000 0110 1000'), 2),
            int(rw('0100 0100 0000 0011'), 2),
        ],
        # Case 2b (10) - simple right switch
        #     |
        #     |
        #     |
        #                int('1100000000100010', 2)]
        #                int('1000010000100001', 2),  # Case 3 - diamond drossing
        #                int('1001011000100001', 2),  # Case 4 - single slip
        #                int('1100110000110011', 2),  # Case 5 - double slip
        #                int('0101001000000010', 2),  # Case 6 - symmetrical
        #                int('0010000000000000', 2),  # Case 7 - dead end

    for index, cycle in enumerate(transition_cycles):
        for i in range(4):
            actual_transition = rail_env_transitions.rotate_transition(cycle[0], i * 90)
            expected_transition = cycle[i % len(cycle)]
            try:
                assert actual_transition == expected_transition, \
                    "Case {}: rotate_transition({}, {}) should equal {} but was {}." \
                        .format(i, cycle[0], i, expected_transition, actual_transition)
            except Exception as e:
                print("expected:")
                rail_env_transitions.print(expected_transition)
                print("actual:")
                rail_env_transitions.print(actual_transition)

                raise e
def test_is_valid_railenv_transitions():
    rail_env_trans = RailEnvTransitions()
    transition_list = rail_env_trans.transitions

    for t in transition_list:
        assert (rail_env_trans.is_valid(t) is True)
        for i in range(3):
            rot_trans = rail_env_trans.rotate_transition(t, 90 * i)
            assert (rail_env_trans.is_valid(rot_trans) is True)
    assert (rail_env_trans.is_valid(int('1111111111110010', 2)) is False)
    assert (rail_env_trans.is_valid(int('1001111111110010', 2)) is False)
    assert (rail_env_trans.is_valid(int('1001111001110110', 2)) is False)


def test_adding_new_valid_transition():
    rail_trans = RailEnvTransitions()
    rail_array = np.zeros(shape=(15, 15), dtype=np.uint16)

    # adding straight
    assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True)

    # adding valid right turn
    assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
    # adding valid left turn
    assert (validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
    rail_array[(5, 5)] = rail_trans.transitions[2]
    assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)

    # should create #4 -> valid
    rail_array[(5, 5)] = rail_trans.transitions[3]
    assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
    rail_array[(5, 5)] = rail_trans.transitions[7]
    assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)

    # test path start condition
    rail_array[(5, 5)] = rail_trans.transitions[0]
    assert (validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True)

    # test path end condition
    rail_array[(5, 5)] = rail_trans.transitions[0]
    assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True)


def test_valid_railenv_transitions():
    rail_env_trans = RailEnvTransitions()

    # directions:
    #            'N': 0
    #            'E': 1
    #            'S': 2
    #            'W': 3

    for i in range(2):
        assert (rail_env_trans.get_transitions(
            int('1100110000110011', 2), i) == (1, 1, 0, 0))
        assert (rail_env_trans.get_transitions(
            int('1100110000110011', 2), 2 + i) == (0, 0, 1, 1))

    no_transition_cell = int('0000000000000000', 2)

    for i in range(4):
        assert (rail_env_trans.get_transitions(
            no_transition_cell, i) == (0, 0, 0, 0))

    # Facing south, going south
spiglerg's avatar
spiglerg committed
    north_south_transition = rail_env_trans.set_transitions(no_transition_cell, 2, (0, 0, 1, 0))
    assert (rail_env_trans.set_transition(
        north_south_transition, 2, 2, 0) == no_transition_cell)
    assert (rail_env_trans.get_transition(
        north_south_transition, 2, 2))

    # Facing north, going east
    south_east_transition = \
spiglerg's avatar
spiglerg committed
        rail_env_trans.set_transition(no_transition_cell, 0, 1, 1)
    assert (rail_env_trans.get_transition(
        south_east_transition, 0, 1))

    # The opposite transitions are not feasible
    assert (not rail_env_trans.get_transition(
        north_south_transition, 2, 0))
    assert (not rail_env_trans.get_transition(
        south_east_transition, 2, 1))
spiglerg's avatar
spiglerg committed
    east_west_transition = rail_env_trans.rotate_transition(north_south_transition, 90)
    north_west_transition = rail_env_trans.rotate_transition(south_east_transition, 180)

    # Facing west, going west
    assert (rail_env_trans.get_transition(
        east_west_transition, 3, 3))
    # Facing south, going west
    assert (rail_env_trans.get_transition(
        north_west_transition, 2, 3))
    assert (south_east_transition == rail_env_trans.rotate_transition(
        south_east_transition, 360))


def test_diagonal_transitions():
    diagonal_trans_env = Grid8Transitions([])

    # Facing north, going north-east
spiglerg's avatar
spiglerg committed
    south_northeast_transition = int('01000000' + '0' * 8 * 7, 2)
    assert (diagonal_trans_env.get_transitions(
        south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0))
    # Allowing transition from north to southwest: Facing south, going SW
    north_southwest_transition = \
spiglerg's avatar
spiglerg committed
        diagonal_trans_env.set_transitions(int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0))
    assert (diagonal_trans_env.rotate_transition(
        south_northeast_transition, 180) == north_southwest_transition)