diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 94f1c2d4b5defbabb08ffaae8757ae98b2332fd9..22f73f98dddf01a9e34b55b62ea1d81c69d69713 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -33,7 +33,7 @@ def load_flatland_environment_from_file(file_name: str, max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)) environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package), - schedule_generator=line_from_file(file_name, load_from_package), + line_generator=line_from_file(file_name, load_from_package), number_of_agents=1, obs_builder_object=obs_builder_object, record_steps=record_steps, diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index ac25118d8a3647924c3fb5fc3880ec83a4f9c03c..a8bd42b990794a0e61477f3f7ada2bf0cdfe176f 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -32,6 +32,7 @@ def schedule_generator(agents: List[EnvAgent], config_speeds: List[float], dist old_max_episode_steps_multiplier = 3.0 new_max_episode_steps_multiplier = 1.5 travel_buffer_multiplier = 1.3 # must be strictly lesser than new_max_episode_steps_multiplier + assert new_max_episode_steps_multiplier > travel_buffer_multiplier end_buffer_multiplier = 0.05 mean_shortest_path_multiplier = 0.2 @@ -39,20 +40,14 @@ def schedule_generator(agents: List[EnvAgent], config_speeds: List[float], dist shortest_paths_lengths = [len(v) for k,v in shortest_paths.items()] # Find mean_shortest_path_time - agent_shortest_path_times = [] - for agent in agents: - speed = agent.speed_data['speed'] - distance = shortest_paths_lengths[agent.handle] - agent_shortest_path_times.append(int(np.ceil(distance / speed))) - + agent_speeds = [agent.speed_data['speed'] for agent in agents] + agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds) mean_shortest_path_time = np.mean(agent_shortest_path_times) # Deciding on a suitable max_episode_steps - max_sp_len = max(shortest_paths_lengths) # longest path - min_speed = min(config_speeds) # slowest possible speed in config - - longest_sp_time = max_sp_len / min_speed - max_episode_steps_new = int(np.ceil(longest_sp_time * new_max_episode_steps_multiplier)) + longest_speed_normalized_time = np.max(agent_shortest_path_times) + mean_path_delay = mean_shortest_path_time * mean_shortest_path_multiplier + max_episode_steps_new = int(np.ceil(longest_speed_normalized_time * new_max_episode_steps_multiplier) + mean_path_delay) max_episode_steps_old = int(max_episode_steps * old_max_episode_steps_multiplier) @@ -67,8 +62,7 @@ def schedule_generator(agents: List[EnvAgent], config_speeds: List[float], dist for agent in agents: agent_shortest_path_time = agent_shortest_path_times[agent.handle] - agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) \ - + (mean_shortest_path_time * mean_shortest_path_multiplier))) + agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) + mean_path_delay)) departure_window_max = max(latest_arrival_max - agent_travel_time_max, 1)