diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index b90d38a427533846baadb704f5137c90c1044f73..fffe7ff786a32a6796af9667f1dfb9a3eb92ce9c 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -8,7 +8,7 @@ from typing import Tuple, Optional, NamedTuple, List from attr import attr, attrs, attrib, Factory from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.schedule_utils import Line +from flatland.envs.timetable_utils import Line class RailAgentStatus(IntEnum): WAITING = 0 diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py index e7d6ac170aa82f86e75fe291c454648defedabd9..74d01e6f23856e9f14d2fbe70eb2bdbfb85175be 100644 --- a/flatland/envs/line_generators.py +++ b/flatland/envs/line_generators.py @@ -8,7 +8,7 @@ from numpy.random.mtrand import RandomState from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent -from flatland.envs.schedule_utils import Line +from flatland.envs.timetable_utils import Line from flatland.envs import persistence AgentPosition = Tuple[int, int] diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 591ac48b27718911aa0788c275e29d3e67fe3c60..4ee6dbaed0af3ef6a29b287f3d343830432f73e4 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -23,7 +23,7 @@ from flatland.envs.rail_env_action import RailEnvActions from flatland.envs import malfunction_generators as mal_gen from flatland.envs import rail_generators as rail_gen from flatland.envs import line_generators as line_gen -from flatland.envs.schedule_generators import schedule_generator +from flatland.envs.timetable_generators import timetable_generator from flatland.envs import persistence from flatland.envs import agent_chains as ac @@ -369,14 +369,14 @@ class RailEnv(Environment): self.distance_map.reset(self.agents, self.rail) # NEW : Time Schedule Generation - schedule = schedule_generator(self.agents, self.distance_map, + timetable = timetable_generator(self.agents, self.distance_map, agents_hints, self.np_random) - self._max_episode_steps = schedule.max_episode_steps + self._max_episode_steps = timetable.max_episode_steps for agent_i, agent in enumerate(self.agents): - agent.earliest_departure = schedule.earliest_departures[agent_i] - agent.latest_arrival = schedule.latest_arrivals[agent_i] + agent.earliest_departure = timetable.earliest_departures[agent_i] + agent.latest_arrival = timetable.latest_arrivals[agent_i] # Agent Positions Map self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1 diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 1650ae579cd830deefbc4286da8ca8922ef0cd49..6abaddd0098a2bff27fff06de2cbcacda8a05ac6 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -1,84 +1 @@ -import os -import json -import itertools -import warnings -from typing import Tuple, List, Callable, Mapping, Optional, Any -from flatland.envs.schedule_utils import Schedule - -import numpy as np -from numpy.random.mtrand import RandomState - -from flatland.envs.agent_utils import EnvAgent -from flatland.envs.distance_map import DistanceMap -from flatland.envs.rail_env_shortest_paths import get_shortest_paths - -def len_handle_none(v): - if v is not None: - return len(v) - else: - return 0 - -def schedule_generator(agents: List[EnvAgent], distance_map: DistanceMap, - agents_hints: dict, np_random: RandomState = None) -> Schedule: - - # max_episode_steps calculation - if agents_hints: - city_positions = agents_hints['city_positions'] - num_cities = len(city_positions) - else: - num_cities = 2 - - timedelay_factor = 4 - alpha = 2 - max_episode_steps = int(timedelay_factor * alpha * \ - (distance_map.rail.width + distance_map.rail.height + (len(agents) / num_cities))) - - # Multipliers - old_max_episode_steps_multiplier = 3.0 - new_max_episode_steps_multiplier = 1.5 - travel_buffer_multiplier = 1.3 # must be strictly lesser than new_max_episode_steps_multiplier - assert new_max_episode_steps_multiplier > travel_buffer_multiplier - end_buffer_multiplier = 0.05 - mean_shortest_path_multiplier = 0.2 - - shortest_paths = get_shortest_paths(distance_map) - shortest_paths_lengths = [len_handle_none(v) for k,v in shortest_paths.items()] - - # Find mean_shortest_path_time - agent_speeds = [agent.speed_data['speed'] for agent in agents] - agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds) - mean_shortest_path_time = np.mean(agent_shortest_path_times) - - # Deciding on a suitable max_episode_steps - longest_speed_normalized_time = np.max(agent_shortest_path_times) - mean_path_delay = mean_shortest_path_time * mean_shortest_path_multiplier - max_episode_steps_new = int(np.ceil(longest_speed_normalized_time * new_max_episode_steps_multiplier) + mean_path_delay) - - max_episode_steps_old = int(max_episode_steps * old_max_episode_steps_multiplier) - - max_episode_steps = min(max_episode_steps_new, max_episode_steps_old) - - end_buffer = int(max_episode_steps * end_buffer_multiplier) - latest_arrival_max = max_episode_steps-end_buffer - - # Useless unless needed by returning - earliest_departures = [] - latest_arrivals = [] - - for agent in agents: - agent_shortest_path_time = agent_shortest_path_times[agent.handle] - agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) + mean_path_delay)) - - departure_window_max = max(latest_arrival_max - agent_travel_time_max, 1) - - earliest_departure = np_random.randint(0, departure_window_max) - latest_arrival = earliest_departure + agent_travel_time_max - - earliest_departures.append(earliest_departure) - latest_arrivals.append(latest_arrival) - - agent.earliest_departure = earliest_departure - agent.latest_arrival = latest_arrival - - return Schedule(earliest_departures=earliest_departures, latest_arrivals=latest_arrivals, - max_episode_steps=max_episode_steps) +raise ImportError(" Schedule Generators is now renamed to line_generators, any reference to schedule should be replaced with line") \ No newline at end of file diff --git a/flatland/envs/schedule_utils.py b/flatland/envs/schedule_utils.py deleted file mode 100644 index d8340b2708d98e4bddcc525e8169a8f659058aa9..0000000000000000000000000000000000000000 --- a/flatland/envs/schedule_utils.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import List, NamedTuple - -from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.core.grid.grid_utils import IntVector2DArray - -Line = NamedTuple('Line', [('agent_positions', IntVector2DArray), - ('agent_directions', List[Grid4TransitionsEnum]), - ('agent_targets', IntVector2DArray), - ('agent_speeds', List[float]), - ('agent_malfunction_rates', List[int])]) - -Schedule = NamedTuple('Schedule', [('earliest_departures', List[int]), - ('latest_arrivals', List[int]), - ('max_episode_steps', int)])