Skip to content
Snippets Groups Projects
test_flatland_core_transition_map.py 6.16 KiB
Newer Older
from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum
from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected
def test_grid4_get_transitions():
    grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
u214892's avatar
u214892 committed
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    assert grid4_map.get_full_transitions(0, 0) == 0

    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1)
u214892's avatar
u214892 committed
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    assert grid4_map.get_full_transitions(0, 0) == pow(2, 15)  # the most significant bit is on

    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.WEST, 1)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 1)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    # the most significant and the fourth most significant bits are on
    assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12)

    grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0)
u214892's avatar
u214892 committed
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 1)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
    assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
    # the fourth most significant bits are on
    assert grid4_map.get_full_transitions(0, 0) == pow(2, 12)


def test_grid8_set_transitions():
    grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
u214892's avatar
u214892 committed
    assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
    grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 1)
u214892's avatar
u214892 committed
    assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (1, 0, 0, 0, 0, 0, 0, 0)
    grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
u214892's avatar
u214892 committed
    assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)

def check_path(env, rail, position, direction, target, expected, rendering=False):
    agent = env.agents_static[0]
    agent.position = position  # south dead-end
    agent.direction = direction  # north
    agent.target = target  # east dead-end
    agent.moving = True
    # reset to set agents from agents_static
    # env.reset(False, False)
    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")
    assert rail._path_exists(agent.position, agent.direction, agent.target) == expected


def test_path_exists(rendering=False):
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )

    # reset to initialize agents_static
    env.reset()

    check_path(
        env,
        rail,
        (5, 6),  # north of south dead-end
        0,  # north
        (3, 9),  # east dead-end
        True
    )

    check_path(
        env,
        rail,
        (6, 6),  # south dead-end
        2,  # south
        (3, 9),  # east dead-end
        True
    )

    check_path(
        env,
        rail,
        (3, 0),  # east dead-end
        3,  # west
        (0, 3),  # north dead-end
        True
    )
    check_path(
        env,
        rail,
        (5, 6),  # east dead-end
        0,  # west
        (1, 3),  # north dead-end
        True)

    check_path(
        env,
        rail,
        (1,3),  # east dead-end
        2,  # south
        (3,3),  # north dead-end
        True
    )

    check_path(
        env,
        rail,
        (1,3),  # east dead-end
        0,  # north
        (3,3),  # north dead-end
        True
    )


def test_path_not_exists(rendering=False):
    rail, rail_map = make_simple_rail_unconnected()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )

    # reset to initialize agents_static
    env.reset()

    check_path(
        env,
        rail,
        (5, 6),  # south dead-end
        0,  # north
        (0, 3),  # north dead-end
        False
    )

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")