Forked from
Flatland / Flatland
1169 commits behind the upstream repository.
test_speed_classes.py 1.45 KiB
"""Test speed initialization by a map of speeds and their corresponding ratios."""
import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator
def test_speed_initialization_helper():
np.random.seed(1)
speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3}
actual_speeds = speed_initialization_helper(10, speed_ratio_map)
# seed makes speed_initialization_helper deterministic -> check generated speeds.
assert actual_speeds == [2, 3, 1, 2, 1, 1, 1, 2, 2, 2]
def test_rail_env_speed_intializer():
speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}
env = RailEnv(width=50,
height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=10)
env.reset()
actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))
expected_speed_set = set(speed_ratio_map.keys())
# check that the number of speeds generated is correct
assert len(actual_speeds) == env.get_num_agents()
# check that only the speeds defined are generated
assert all({(actual_speed in expected_speed_set) for actual_speed in actual_speeds})