From 7f351228c6e6d9a3c93fd71cce98542fef558cf6 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Mon, 26 Aug 2019 12:12:32 +0200 Subject: [PATCH] #141 speed generator first implementation --- flatland/envs/generators.py | 38 +++++++++++++++++++++++++++++++++++-- tests/test_speed_classes.py | 37 ++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 tests/test_speed_classes.py diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 355f5502..79e0ac7d 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,3 +1,5 @@ +from typing import Mapping, Tuple, List, Callable + import msgpack import numpy as np @@ -27,7 +29,12 @@ def empty_rail_generator(): return generator -def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0): +def complex_rail_generator(nr_start_goal=1, + nr_extra=100, + min_dist=20, + max_dist=99999, + seed=0, + speed_initializer: Callable[[int], List[float]] = None): """ Parameters ------- @@ -35,6 +42,8 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= The width (number of cells) of the grid to generate. height : int The height (number of cells) of the grid to generate. + speed_initializer : Callable[[int], List[float]] + Function that returns a list of speeds for the numer of agents given as argument. Returns ------- @@ -145,7 +154,11 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= agents_target = [sg[1] for sg in start_goal[:num_agents]] agents_direction = start_dir[:num_agents] - return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + if speed_initializer: + speeds = speed_initializer(num_agents) + else: + speeds = [1.0] * len(agents_position) + return grid_map, agents_position, agents_direction, agents_target, speeds return generator @@ -538,3 +551,24 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator + + +def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float]) -> List[float]: + """ + Parameters + ------- + nb_agents : int + The number of agents to generate a speed for + speed_ratio_map : Mapping[float,float] + A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1. + + Returns + ------- + List[float] + A list of size nb_agents of speeds with the corresponding probabilistic ratios. + """ + nb_classes = len(speed_ratio_map.keys()) + speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items()) + speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list)) + speeds = list(map(lambda t: t[0], speed_ratio_map_as_list)) + return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios))) diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py new file mode 100644 index 00000000..6ef600d9 --- /dev/null +++ b/tests/test_speed_classes.py @@ -0,0 +1,37 @@ +"""Test speed initialization by a map of speeds and their corresponding ratios.""" +import numpy as np + +from flatland.envs.generators import speed_initialization_helper, complex_rail_generator +from flatland.envs.rail_env import RailEnv + + +def test_speed_initialization_helper(): + np.random.seed(1) + speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3} + actual_speeds = speed_initialization_helper(10, speed_ratio_map) + + # seed makes speed_initialization_helper deterministic -> check generated speeds. + assert actual_speeds == [2, 3, 1, 2, 1, 1, 1, 2, 2, 2] + + +def test_rail_env_speed_intializer(): + speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2} + + def my_speed_initializer(nb_agents): + return speed_initialization_helper(nb_agents, speed_ratio_map) + + env = RailEnv(width=50, + height=50, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, + seed=0, speed_initializer=my_speed_initializer), + number_of_agents=10) + env.reset() + actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents)) + + expected_speed_set = set(speed_ratio_map.keys()) + + # check that the number of speeds generated is correct + assert len(actual_speeds) == env.get_num_agents() + + # check that only the speeds defined are generated + assert all({(actual_speed in expected_speed_set) for actual_speed in actual_speeds}) -- GitLab