diff --git a/flatland/envs/timetable_generators.py b/flatland/envs/timetable_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..b7876d742f61db830883f828faaf99a39a48bc65 --- /dev/null +++ b/flatland/envs/timetable_generators.py @@ -0,0 +1,96 @@ +import os +import json +import itertools +import warnings +from typing import Tuple, List, Callable, Mapping, Optional, Any +from flatland.envs.timetable_utils import Timetable + +import numpy as np +from numpy.random.mtrand import RandomState + +from flatland.envs.agent_utils import EnvAgent +from flatland.envs.distance_map import DistanceMap +from flatland.envs.rail_env_shortest_paths import get_shortest_paths + +def len_handle_none(v): + if v is not None: + return len(v) + else: + return 0 + +def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap, + agents_hints: dict, np_random: RandomState = None) -> Timetable: + """ + Calculates earliest departure and latest arrival times for the agents + This is the new addition in Flatland 3 + Also calculates the max episodes steps based on the density of the timetable + + inputs: + agents - List of all the agents rail_env.agents + distance_map - Distance map of positions to tagets of each agent in each direction + agent_hints - Uses the number of cities + np_random - RNG state for seeding + returns: + Timetable with the latest_arrivals, earliest_departures and max_episdode_steps + """ + # max_episode_steps calculation + if agents_hints: + city_positions = agents_hints['city_positions'] + num_cities = len(city_positions) + else: + num_cities = 2 + + timedelay_factor = 4 + alpha = 2 + max_episode_steps = int(timedelay_factor * alpha * \ + (distance_map.rail.width + distance_map.rail.height + (len(agents) / num_cities))) + + # Multipliers + old_max_episode_steps_multiplier = 3.0 + new_max_episode_steps_multiplier = 1.5 + travel_buffer_multiplier = 1.3 # must be strictly lesser than new_max_episode_steps_multiplier + assert new_max_episode_steps_multiplier > travel_buffer_multiplier + end_buffer_multiplier = 0.05 + mean_shortest_path_multiplier = 0.2 + + shortest_paths = get_shortest_paths(distance_map) + shortest_paths_lengths = [len_handle_none(v) for k,v in shortest_paths.items()] + + # Find mean_shortest_path_time + agent_speeds = [agent.speed_data['speed'] for agent in agents] + agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds) + mean_shortest_path_time = np.mean(agent_shortest_path_times) + + # Deciding on a suitable max_episode_steps + longest_speed_normalized_time = np.max(agent_shortest_path_times) + mean_path_delay = mean_shortest_path_time * mean_shortest_path_multiplier + max_episode_steps_new = int(np.ceil(longest_speed_normalized_time * new_max_episode_steps_multiplier) + mean_path_delay) + + max_episode_steps_old = int(max_episode_steps * old_max_episode_steps_multiplier) + + max_episode_steps = min(max_episode_steps_new, max_episode_steps_old) + + end_buffer = int(max_episode_steps * end_buffer_multiplier) + latest_arrival_max = max_episode_steps-end_buffer + + # Useless unless needed by returning + earliest_departures = [] + latest_arrivals = [] + + for agent in agents: + agent_shortest_path_time = agent_shortest_path_times[agent.handle] + agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) + mean_path_delay)) + + departure_window_max = max(latest_arrival_max - agent_travel_time_max, 1) + + earliest_departure = np_random.randint(0, departure_window_max) + latest_arrival = earliest_departure + agent_travel_time_max + + earliest_departures.append(earliest_departure) + latest_arrivals.append(latest_arrival) + + agent.earliest_departure = earliest_departure + agent.latest_arrival = latest_arrival + + return Timetable(earliest_departures=earliest_departures, latest_arrivals=latest_arrivals, + max_episode_steps=max_episode_steps) diff --git a/flatland/envs/timetable_utils.py b/flatland/envs/timetable_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..548624f2c08879ce0e507224e61b6fe43ffb955b --- /dev/null +++ b/flatland/envs/timetable_utils.py @@ -0,0 +1,14 @@ +from typing import List, NamedTuple + +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid_utils import IntVector2DArray + +Line = NamedTuple('Line', [('agent_positions', IntVector2DArray), + ('agent_directions', List[Grid4TransitionsEnum]), + ('agent_targets', IntVector2DArray), + ('agent_speeds', List[float]), + ('agent_malfunction_rates', List[int])]) + +Timetable = NamedTuple('Timetable', [('earliest_departures', List[int]), + ('latest_arrivals', List[int]), + ('max_episode_steps', int)])