diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 6c9ec67d616978f1cffe5267f4e2274809a51cf3..39bc1088f9b354d78692d7ecf8ade37c3896112d 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -542,6 +542,15 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R return generator +def compute_max_episode_steps(width: int, + height: int, + num_agents: int, + num_cities: int = 1, + timedelay_factor: int = 4, + alpha: int = 2) -> int: + return int(timedelay_factor * alpha * (width + height + (num_agents/num_cities))) + + def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4, max_rails_in_city: int = 4, seed: int = 1) -> RailGenerator: """ @@ -611,11 +620,14 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ # Generate start target pairs agent_start_targets_cities = _generate_start_target_pairs(num_agents, num_cities, train_stations, city_orientations) + max_episode_steps = compute_max_episode_steps(width=width, height=height, num_agents=num_agents, num_cities=num_cities) + return grid_map, {'agents_hints': { 'num_agents': num_agents, 'agent_start_targets_cities': agent_start_targets_cities, 'train_stations': train_stations, - 'city_orientations': city_orientations + 'city_orientations': city_orientations, + 'max_episode_steps': max_episode_steps }} def _generate_random_city_positions(num_cities: int, city_radius: int, width: int,