Skip to content
Snippets Groups Projects
Commit 7f351228 authored by u214892's avatar u214892
Browse files

#141 speed generator first implementation

parent ebbbe6fb
No related branches found
No related tags found
No related merge requests found
from typing import Mapping, Tuple, List, Callable
import msgpack import msgpack
import numpy as np import numpy as np
...@@ -27,7 +29,12 @@ def empty_rail_generator(): ...@@ -27,7 +29,12 @@ def empty_rail_generator():
return generator return generator
def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0): def complex_rail_generator(nr_start_goal=1,
nr_extra=100,
min_dist=20,
max_dist=99999,
seed=0,
speed_initializer: Callable[[int], List[float]] = None):
""" """
Parameters Parameters
------- -------
...@@ -35,6 +42,8 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= ...@@ -35,6 +42,8 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
The width (number of cells) of the grid to generate. The width (number of cells) of the grid to generate.
height : int height : int
The height (number of cells) of the grid to generate. The height (number of cells) of the grid to generate.
speed_initializer : Callable[[int], List[float]]
Function that returns a list of speeds for the numer of agents given as argument.
Returns Returns
------- -------
...@@ -145,7 +154,11 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= ...@@ -145,7 +154,11 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
agents_target = [sg[1] for sg in start_goal[:num_agents]] agents_target = [sg[1] for sg in start_goal[:num_agents]]
agents_direction = start_dir[:num_agents] agents_direction = start_dir[:num_agents]
return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) if speed_initializer:
speeds = speed_initializer(num_agents)
else:
speeds = [1.0] * len(agents_position)
return grid_map, agents_position, agents_direction, agents_target, speeds
return generator return generator
...@@ -538,3 +551,24 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): ...@@ -538,3 +551,24 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator return generator
def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float]) -> List[float]:
"""
Parameters
-------
nb_agents : int
The number of agents to generate a speed for
speed_ratio_map : Mapping[float,float]
A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
Returns
-------
List[float]
A list of size nb_agents of speeds with the corresponding probabilistic ratios.
"""
nb_classes = len(speed_ratio_map.keys())
speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
"""Test speed initialization by a map of speeds and their corresponding ratios."""
import numpy as np
from flatland.envs.generators import speed_initialization_helper, complex_rail_generator
from flatland.envs.rail_env import RailEnv
def test_speed_initialization_helper():
np.random.seed(1)
speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3}
actual_speeds = speed_initialization_helper(10, speed_ratio_map)
# seed makes speed_initialization_helper deterministic -> check generated speeds.
assert actual_speeds == [2, 3, 1, 2, 1, 1, 1, 2, 2, 2]
def test_rail_env_speed_intializer():
speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}
def my_speed_initializer(nb_agents):
return speed_initialization_helper(nb_agents, speed_ratio_map)
env = RailEnv(width=50,
height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
seed=0, speed_initializer=my_speed_initializer),
number_of_agents=10)
env.reset()
actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))
expected_speed_set = set(speed_ratio_map.keys())
# check that the number of speeds generated is correct
assert len(actual_speeds) == env.get_num_agents()
# check that only the speeds defined are generated
assert all({(actual_speed in expected_speed_set) for actual_speed in actual_speeds})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment