diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 54e2c52afc646540f54800e0674f773021302e6e..3899798ad0900470e58c4fe1732275a25d841a5c 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -78,7 +78,7 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se speeds = [1.0] * len(agents_position) # Compute max number of steps with given schedule nice_factor = 1.5 # Factor to allow for more then minimal time - max_episode_steps = nice_factor * rail.height * rail.width + max_episode_steps = int(nice_factor * rail.height * rail.width) return Schedule(agent_positions=agents_position, agent_directions=agents_direction, agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None, @@ -271,7 +271,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = # Compute max number of steps with given schedule nice_factor = 1.5 # Factor to allow for more then minimal time - max_episode_steps = nice_factor * rail.height * rail.width + max_episode_steps = int(nice_factor * rail.height * rail.width) return Schedule(agent_positions=agents_position, agent_directions=agents_direction, agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None, diff --git a/tests/test_flatland_schedule_generators.py b/tests/test_flatland_schedule_generators.py index 77d6b6cdc06ff06ef7f01ac9c2c57cf8e0040c27..487d619a347f7d4276f3e7e88a2986ac804669b0 100644 --- a/tests/test_flatland_schedule_generators.py +++ b/tests/test_flatland_schedule_generators.py @@ -73,27 +73,35 @@ def test_schedule_from_file(): rail_generator = rail_from_file("./sparse_env_test.pkl") schedule_generator = schedule_from_file("./sparse_env_test.pkl") sparse_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, schedule_generator=schedule_generator) - sparse_env_from_file.reset(True,True) + sparse_env_from_file.reset(True, True) + + # Assert loaded agent number is correct assert sparse_env_from_file.get_num_agents() == 10 + # Assert max steps is correct + assert sparse_env_from_file._max_episode_steps == 500 + # Random generator rail_generator = rail_from_file("./random_env_test.pkl") schedule_generator = schedule_from_file("./random_env_test.pkl") random_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, schedule_generator=schedule_generator) - random_env_from_file.reset(True,True) + random_env_from_file.reset(True, True) + + # Assert loaded agent number is correct assert random_env_from_file.get_num_agents() == 10 + # Assert max steps is correct + assert random_env_from_file._max_episode_steps == 1350 + # Complex generator rail_generator = rail_from_file("./complex_env_test.pkl") schedule_generator = schedule_from_file("./complex_env_test.pkl") complex_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, schedule_generator=schedule_generator) - complex_env_from_file.reset(True,True) - assert complex_env_from_file.get_num_agents() == 10 - -# def test_sparse_schedule_generator(): + complex_env_from_file.reset(True, True) + # Assert loaded agent number is correct + assert complex_env_from_file.get_num_agents() == 10 -# def test_random_schedule_generator(): - + # Assert max steps is correct + assert complex_env_from_file._max_episode_steps == 1350 -# def test_complex_schedule_generator():