diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index ed507dca9de9b3e90d412e77a7204037a5a20975..8f2c5231232c3cc17f50de4b89b88f1c9fdc5d60 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -787,6 +787,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 num_agents -= 1 return grid_map, {'agents_hints': { + 'num_agents': num_agents, 'agent_start_targets_nodes': agent_start_targets_nodes, 'train_stations': train_stations }} diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 0ebc6c71c17db308789a4baf0ec99729ec9991e8..2ef6dab85fc6145bbfb57b25994903bbd861f65f 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -60,6 +60,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): train_stations = hints['train_stations'] agent_start_targets_nodes = hints['agent_start_targets_nodes'] + num_agents = hints['num_agents'] # Place agents and targets within available train stations agents_position = [] agents_target = [] @@ -207,7 +208,7 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> return generator -def agents_from_file(filename) -> ScheduleGenerator: +def schedule_from_file(filename) -> ScheduleGenerator: """ Utility to load pickle file diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 610022cafe12fccb2cbbd5da57006e61c89faf28..8b0480c887a53ade155c28aa6199db3d32f19603 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -9,7 +9,7 @@ from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ random_rail_generator, empty_rail_generator from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \ - agents_from_file + schedule_from_file from flatland.utils.simple_rail import make_simple_rail @@ -137,7 +137,7 @@ def tests_rail_from_file(): env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - schedule_generator=agents_from_file(file_name), + schedule_generator=schedule_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -173,7 +173,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), - schedule_generator=agents_from_file(file_name_2), + schedule_generator=schedule_from_file(file_name_2), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -190,7 +190,7 @@ def tests_rail_from_file(): env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - schedule_generator=agents_from_file(file_name), + schedule_generator=schedule_from_file(file_name), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -208,7 +208,7 @@ def tests_rail_from_file(): env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), - schedule_generator=agents_from_file(file_name_2), + schedule_generator=schedule_from_file(file_name_2), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), )