diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index dabb2609b173306342f86bc945de231d75c7d9de..cbcc4591464abf843ec600d97b15f47cb8a07d6b 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -239,32 +239,6 @@ class RailEnv(Environment): agent.reset() self.active_agents = [i for i in range(len(self.agents))] - @staticmethod - def compute_max_episode_steps(width: int, height: int, ratio_nr_agents_to_nr_cities: float = 20.0) -> int: - """ - compute_max_episode_steps(width, height, ratio_nr_agents_to_nr_cities, timedelay_factor, alpha) - - The method computes the max number of episode steps allowed - - Parameters - ---------- - width : int - width of environment - height : int - height of environment - ratio_nr_agents_to_nr_cities : float, optional - number_of_agents/number_of_cities - - Returns - ------- - max_episode_steps: int - maximum number of episode steps - - """ - timedelay_factor = 4 - alpha = 2 - return int(timedelay_factor * alpha * (width + height + ratio_nr_agents_to_nr_cities)) - def action_required(self, agent): """ Check if an agent needs to provide an action @@ -328,8 +302,6 @@ class RailEnv(Environment): if optionals and 'distance_map' in optionals: self.distance_map.set(optionals['distance_map']) - - if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0: agents_hints = None if optionals and 'agents_hints' in optionals: @@ -339,13 +311,8 @@ class RailEnv(Environment): self.np_random) self.agents = EnvAgent.from_schedule(schedule) - if agents_hints and 'city_orientations' in agents_hints: - ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations']) - self._max_episode_steps = self.compute_max_episode_steps( - width=self.width, height=self.height, - ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities) - else: - self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height) + # Get max number of allowed time steps from schedule generator + self._max_episode_steps = schedule.max_episode_steps self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1 @@ -836,7 +803,8 @@ class RailEnv(Environment): msg_data = { "grid": grid_data, "agents": agent_data, - "malfunction": malfunction_data} + "malfunction": malfunction_data, + "max_episode_steps": self._max_episode_steps} return msgpack.packb(msg_data, use_bin_type=True) def get_agent_state_msg(self) -> Packer: @@ -863,7 +831,8 @@ class RailEnv(Environment): "grid": grid_data, "agents": agent_data, "distance_map": distance_map_data, - "malfunction": malfunction_data} + "malfunction": malfunction_data, + "max_episode_steps": self._max_episode_steps} return msgpack.packb(msg_data, use_bin_type=True) def set_full_state_msg(self, msg_data): diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index fb355967b93f05042e6b31edfa7948e181da72c8..cb8569643698f68612720a175192af885a222d6d 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -167,6 +167,9 @@ def complex_rail_generator(nr_start_goal=1, if len(new_path) >= 2: nr_created += 1 + else: + # after too many failures we will give up + created_sanity += 1 return grid_map, {'agents_hints': { 'start_goal': start_goal, diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 7736d53c667904e23718a78a2c4601fa2574c08b..54e2c52afc646540f54800e0674f773021302e6e 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -76,9 +76,13 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se speeds = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed, np_random=np_random) else: speeds = [1.0] * len(agents_position) + # Compute max number of steps with given schedule + nice_factor = 1.5 # Factor to allow for more then minimal time + max_episode_steps = nice_factor * rail.height * rail.width return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None) + agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None, + max_episode_steps=max_episode_steps) return generator @@ -162,9 +166,14 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see speeds = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed, np_random=np_random) else: speeds = [1.0] * len(agents_position) + timedelay_factor = 4 + alpha = 2 + max_episode_steps = int( + timedelay_factor * alpha * (rail.width + rail.height + num_agents / len(city_positions))) return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None) + agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None, + max_episode_steps=max_episode_steps) return generator @@ -196,12 +205,12 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = valid_positions.append((r, c)) if len(valid_positions) == 0: return Schedule(agent_positions=[], agent_directions=[], - agent_targets=[], agent_speeds=[], agent_malfunction_rates=None) + agent_targets=[], agent_speeds=[], agent_malfunction_rates=None, max_episode_steps=0) if len(valid_positions) < num_agents: warnings.warn("schedule_generators: len(valid_positions) < num_agents") return Schedule(agent_positions=[], agent_directions=[], - agent_targets=[], agent_speeds=[], agent_malfunction_rates=None) + agent_targets=[], agent_speeds=[], agent_malfunction_rates=None, max_episode_steps=0) agents_position_idx = [i for i in np_random.choice(len(valid_positions), num_agents, replace=False)] agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)] @@ -259,8 +268,14 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = np_random.choice(len(valid_starting_directions), 1)[0]] agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed, np_random=np_random) + + # Compute max number of steps with given schedule + nice_factor = 1.5 # Factor to allow for more then minimal time + max_episode_steps = nice_factor * rail.height * rail.width + return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None) + agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None, + max_episode_steps=max_episode_steps) return generator @@ -292,7 +307,11 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) else: agents = [EnvAgent(*d[0:12]) for d in data["agents"]] - + if "max_episode_steps" in data: + max_episode_steps = data["max_episode_steps"] + else: + # If no max time was found return 0. + max_episode_steps = 0 # setup with loaded data agents_position = [a.initial_position for a in agents] agents_direction = [a.direction for a in agents] @@ -301,6 +320,7 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents] return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None) + agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None, + max_episode_steps=max_episode_steps) return generator diff --git a/flatland/envs/schedule_utils.py b/flatland/envs/schedule_utils.py index e89f170dbb87388bcecbc6b2e176ba277162a4db..a811ea4af7d7f7faccfe16e94adf117cba05d6b8 100644 --- a/flatland/envs/schedule_utils.py +++ b/flatland/envs/schedule_utils.py @@ -7,4 +7,5 @@ Schedule = NamedTuple('Schedule', [('agent_positions', IntVector2DArray), ('agent_directions', List[Grid4TransitionsEnum]), ('agent_targets', IntVector2DArray), ('agent_speeds', List[float]), - ('agent_malfunction_rates', List[int])]) + ('agent_malfunction_rates', List[int]), + ('max_episode_steps', int)]) diff --git a/tests/test_flatland_schedule_generators.py b/tests/test_flatland_schedule_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..255b7f34b0c62260da7708b78523d3570e4b3252 --- /dev/null +++ b/tests/test_flatland_schedule_generators.py @@ -0,0 +1,74 @@ +from test_utils import create_and_save_env + +from flatland.envs.rail_generators import sparse_rail_generator, random_rail_generator, complex_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator, random_schedule_generator, \ + complex_schedule_generator + + +def test_schedule_from_file(): + """ + Test to see that all parameters are loaded as expected + Returns + ------- + + """ + # Generate Sparse test env + rail_generator = sparse_rail_generator(max_num_cities=5, + seed=1, + grid_mode=False, + max_rails_between_cities=3, + max_rails_in_city=6, + ) + + # Different agent types (trains) with different speeds. + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + + schedule_generator = sparse_schedule_generator( + speed_ration_map) + + create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator, + schedule_generator=schedule_generator) + + # Generate random test env + rail_generator = random_rail_generator() + + # Different agent types (trains) with different speeds. + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + + schedule_generator = random_schedule_generator( + speed_ration_map) + + create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator, + schedule_generator=schedule_generator) + + # Generate complex test env + rail_generator = complex_rail_generator(nr_start_goal=10, + nr_extra=1, + min_dist=8, + max_dist=99999) + + # Different agent types (trains) with different speeds. + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + + schedule_generator = complex_schedule_generator( + speed_ration_map) + + create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator, + schedule_generator=schedule_generator) + +# def test_sparse_schedule_generator(): + + +# def test_random_schedule_generator(): + + +# def test_complex_schedule_generator(): diff --git a/tests/test_utils.py b/tests/test_utils.py index e4fba2aebd795462971e8d1e8f16992c2affbac8..bb344962f84ab6b651e9cd688f7784804e8062c2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,11 @@ from typing import List, Tuple, Optional import numpy as np from attr import attrs, attrib +from flatland.envs.malfunction_generators import MalfunctionParameters, malfunction_from_params + +from flatland.envs.rail_generators import RailGenerator + +from flatland.envs.schedule_generators import ScheduleGenerator from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.agent_utils import EnvAgent, RailAgentStatus @@ -131,3 +136,22 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: replay = test_config.replay[step] _assert(a, rewards_dict[a], replay.reward, 'reward') + + +def create_and_save_env(file_name: str, schedule_generator: ScheduleGenerator, rail_generator: RailGenerator): + + + stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence + min_duration=15, # Minimal duration of malfunction + max_duration=50 # Max duration of malfunction + ) + + env = RailEnv(width=30, + height=30, + rail_generator=rail_generator, + schedule_generator=schedule_generator, + number_of_agents=10, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + remove_agents_at_target=True) + env.reset(True, True) + env.save(file_name)