Skip to content
Snippets Groups Projects
test_flatland_envs_predictions.py 11.2 KiB
Newer Older
u214892's avatar
u214892 committed
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pprint
u214892's avatar
u214892 committed

import numpy as np

from flatland.core.grid.grid4 import Grid4TransitionsEnum
u214892's avatar
u214892 committed
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
u214892's avatar
u214892 committed
from flatland.envs.rail_env_utils import get_shortest_paths
u214892's avatar
u214892 committed
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
u214892's avatar
u214892 committed
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
u214892's avatar
u214892 committed

u214892's avatar
u214892 committed
"""Test predictions for `flatland` package."""
u214892's avatar
u214892 committed


def test_dummy_predictor(rendering=False):
u214892's avatar
u214892 committed
    rail, rail_map = make_simple_rail2()
u214892's avatar
u214892 committed
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
u214892's avatar
u214892 committed
                  number_of_agents=1,
u214892's avatar
u214892 committed
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
u214892's avatar
u214892 committed
                  )
    # reset to initialize agents_static
u214892's avatar
u214892 committed
    env.reset()

u214892's avatar
u214892 committed
    # set initial position and direction for testing...
    env.agents_static[0].position = (5, 6)
    env.agents_static[0].direction = 0
    env.agents_static[0].target = (3, 0)

    # reset to set agents from agents_static
    env.reset(False, False)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")
u214892's avatar
u214892 committed

    # test assertions
    predictions = env.obs_builder.predictor.get(None)
    positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
u214892's avatar
u214892 committed
    directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
    time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
    actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))

    # compare against expected values
    expected_positions = np.array([[5., 6.],
                                   [4., 6.],
                                   [3., 6.],
                                   [3., 5.],
                                   [3., 4.],
                                   [3., 3.],
                                   [3., 2.],
                                   [3., 1.],
                                   # at target (3,0): stay in this position from here on
u214892's avatar
u214892 committed
                                   [3., 0.],
                                   [3., 0.],
                                   [3., 0.],
                                   ])
u214892's avatar
u214892 committed
    expected_directions = np.array([[0.],
                                    [0.],
                                    [0.],
                                    [3.],
                                    [3.],
                                    [3.],
                                    [3.],
                                    [3.],
                                    # at target (3,0): stay in this position from here on
u214892's avatar
u214892 committed
                                    [3.],
                                    [3.],
u214892's avatar
u214892 committed
    expected_time_offsets = np.array([[0.],
                                      [1.],
                                      [2.],
                                      [3.],
                                      [4.],
                                      [5.],
                                      [6.],
                                      [7.],
                                      [8.],
                                      [9.],
                                      [10.],
u214892's avatar
u214892 committed
    expected_actions = np.array([[0.],
                                 [2.],
                                 [2.],
u214892's avatar
u214892 committed
                                 [2.],
u214892's avatar
u214892 committed
                                 [2.],
                                 [2.],
                                 [2.],
                                 [2.],
                                 # reaching target by straight
u214892's avatar
u214892 committed
                                 [2.],
                                 # at target: stopped moving
                                 [4.],
                                 [4.],
                                 ])
u214892's avatar
u214892 committed
    assert np.array_equal(positions, expected_positions)
    assert np.array_equal(directions, expected_directions)
    assert np.array_equal(time_offsets, expected_time_offsets)
    assert np.array_equal(actions, expected_actions)
u214892's avatar
u214892 committed


def test_shortest_path_predictor(rendering=False):
    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )

    # reset to initialize agents_static
    # set the initial position
    agent = env.agents_static[0]
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
    agent.target = (3, 9)  # east dead-end
    agent.moving = True

    # reset to set agents from agents_static
    env.reset(False, False)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
    # compute the observations and predictions
    distance_map = env.distance_map.get()
    assert distance_map[0, agent.position[0], agent.position[
        1], agent.direction] == 5.0, "found {} instead of {}".format(
        distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)

u214892's avatar
u214892 committed
    paths = get_shortest_paths(env.distance_map)[0]
    assert paths == [
        RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(4, 6), next_direction=0),
        RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 6), next_direction=0),
        RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 7), next_direction=1),
        RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 8), next_direction=1),
        RailEnvNextAction(action=RailEnvActions.MOVE_FORWARD, next_position=(3, 9), next_direction=1)]

    # extract the data
    predictions = env.obs_builder.predictions
    positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
    directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
    time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))

    # test if data meets expectations
    expected_positions = [
        [5, 6],
        [4, 6],
        [3, 6],
        [3, 7],
        [3, 8],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
        [3, 9],
    ]
    expected_directions = [
        [Grid4TransitionsEnum.NORTH],  # next is [5,6] heading north
        [Grid4TransitionsEnum.NORTH],  # next is [4,6] heading north
        [Grid4TransitionsEnum.NORTH],  # next is [3,6] heading north
        [Grid4TransitionsEnum.EAST],  # next is [3,7] heading east
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
        [Grid4TransitionsEnum.EAST],
    ]

    expected_time_offsets = np.array([
        [0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.],
        [9.],
        [10.],
        [11.],
        [12.],
        [13.],
        [14.],
        [15.],
        [16.],
        [17.],
        [18.],
        [19.],
        [20.],
    ])
u214892's avatar
u214892 committed

    assert np.array_equal(positions, expected_positions), \
        "positions {}, expected {}".format(positions, expected_positions)
    assert np.array_equal(directions, expected_directions), \
        "directions {}, expected {}".format(directions, expected_directions)
    assert np.array_equal(time_offsets, expected_time_offsets), \
        "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)


def test_shortest_path_predictor_conflicts(rendering=False):
u214892's avatar
u214892 committed
    rail, rail_map = make_invalid_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  schedule_generator=random_schedule_generator(),
                  number_of_agents=2,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )
    # initialize agents_static
    env.reset()

    # set the initial position
    agent = env.agents_static[0]
    agent.position = (5, 6)  # south dead-end
    agent.direction = 0  # north
    agent.target = (3, 9)  # east dead-end
    agent.moving = True

    agent = env.agents_static[1]
    agent.position = (3, 8)  # east dead-end
    agent.direction = 3  # west
    agent.target = (6, 6)  # south dead-end
    agent.moving = True

    # reset to set agents from agents_static
    observations = env.reset(False, False)

    if rendering:
        renderer = RenderTool(env, gl="PILSVG")
        renderer.render_env(show=True, show_observations=False)
        input("Continue?")

    # get the trees to test
    obs_builder: TreeObsForRailEnv = env.obs_builder
    pp = pprint.PrettyPrinter(indent=4)
    tree_0 = obs_builder.unfold_observation_tree(observations[0])
    tree_1 = obs_builder.unfold_observation_tree(observations[1])
    pp.pprint(tree_0)

    # check the expectations
u214892's avatar
u214892 committed
    expected_conflicts_0 = [('F', 'R')]
    expected_conflicts_1 = [('F', 'L')]
    _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ")
    _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")


def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''):
u214892's avatar
u214892 committed
    assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
    for a_1 in obs_builder.tree_explorted_actions_char:
u214892's avatar
u214892 committed
        conflict = tree_0[a_1][''][8]
        assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
        for a_2 in obs_builder.tree_explorted_actions_char:
u214892's avatar
u214892 committed
            conflict = tree_0[a_1][a_2][''][8]
            assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)