Skip to content
Snippets Groups Projects
Forked from Flatland / Flatland
1169 commits behind the upstream repository.
test_speed_classes.py 1.45 KiB
"""Test speed initialization by a map of speeds and their corresponding ratios."""
import numpy as np

from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator


def test_speed_initialization_helper():
    np.random.seed(1)
    speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3}
    actual_speeds = speed_initialization_helper(10, speed_ratio_map)

    # seed makes speed_initialization_helper deterministic -> check generated speeds.
    assert actual_speeds == [2, 3, 1, 2, 1, 1, 1, 2, 2, 2]


def test_rail_env_speed_intializer():
    speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}

    env = RailEnv(width=50,
                  height=50,
                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
                                                        seed=0),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=10)
    env.reset()
    actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))

    expected_speed_set = set(speed_ratio_map.keys())

    # check that the number of speeds generated is correct
    assert len(actual_speeds) == env.get_num_agents()

    # check that only the speeds defined are generated
    assert all({(actual_speed in expected_speed_set) for actual_speed in actual_speeds})