From 57ce38217fe097253a02c480a9905b567d78e5e7 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Mon, 28 Sep 2020 13:31:24 +0100 Subject: [PATCH] added FileMalfunctionGen to replace file_malfunction_generator --- flatland/envs/malfunction_generators.py | 284 +++++++++++++----------- flatland/envs/persistence.py | 5 +- 2 files changed, 154 insertions(+), 135 deletions(-) diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 62a92080..0d27913d 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -8,13 +8,17 @@ from numpy.random.mtrand import RandomState from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs import persistence -Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)]) + +# why do we have both MalfunctionParameters and MalfunctionProcessData - they are both the same! MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)]) -MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]] MalfunctionProcessData = NamedTuple('MalfunctionProcessData', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)]) +Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)]) + +# Why is the return value Optional? We always return a Malfunction. +MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]] def _malfunction_prob(rate: float) -> float: """ @@ -27,6 +31,146 @@ def _malfunction_prob(rate: float) -> float: else: return 1 - np.exp(-rate) +class ParamMalfunctionGen(object): + """ Preserving old behaviour of using MalfunctionParameters for constructor, + but returning MalfunctionProcessData in get_process_data. + Data structure and content is the same. + """ + def __init__(self, parameters: MalfunctionParameters): + #self.mean_malfunction_rate = parameters.malfunction_rate + #self.min_number_of_steps_broken = parameters.min_duration + #self.max_number_of_steps_broken = parameters.max_duration + self.MFP = parameters + + def generate(self, + agent: EnvAgent = None, + np_random: RandomState = None, + reset=False) -> Optional[Malfunction]: + + # Dummy reset function as we don't implement specific seeding here + if reset: + return Malfunction(0) + + if agent.malfunction_data['malfunction'] < 1: + if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate): + num_broken_steps = np_random.randint(self.MFP.min_duration, + self.MFP.max_duration + 1) + 1 + return Malfunction(num_broken_steps) + return Malfunction(0) + + def get_process_data(self): + return MalfunctionProcessData(*self.MFP) + + +class NoMalfunctionGen(ParamMalfunctionGen): + def __init__(self): + super().__init__(MalfunctionParameters(0,0,0)) + +class FileMalfunctionGen(ParamMalfunctionGen): + def __init__(self, env_dict=None, filename=None, load_from_package=None): + """ uses env_dict if populated, otherwise tries to load from file / package. + """ + if env_dict is None: + env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package) + + if "malfunction" in env_dict: + oMFP = MalfunctionParameters(*env_dict["malfunction"]) + else: + oMFP = MalfunctionParameters(0,0,0) # no malfunctions + super().__init__(oMFP) + + +################################################################################################ +# OLD / DEPRECATED generator functions below. To be removed. + +def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: + """ + Malfunction generator which generates no malfunctions + + Parameters + ---------- + Nothing + + Returns + ------- + generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + """ + print("DEPRECATED - use NoMalfunctionGen instead of no_malfunction_generator") + # Mean malfunction in number of time steps + mean_malfunction_rate = 0. + + # Uniform distribution parameters for malfunction duration + min_number_of_steps_broken = 0 + max_number_of_steps_broken = 0 + + def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: + return Malfunction(0) + + return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, + max_number_of_steps_broken) + + + +def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[ + MalfunctionGenerator, MalfunctionProcessData]: + """ + Malfunction generator which guarantees exactly one malfunction during an episode of an ACTIVE agent. + + Parameters + ---------- + earlierst_malfunction: Earliest possible malfunction onset + malfunction_duration: The duration of the single malfunction + + Returns + ------- + generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + """ + # Mean malfunction in number of time steps + mean_malfunction_rate = 0. + + # Uniform distribution parameters for malfunction duration + min_number_of_steps_broken = 0 + max_number_of_steps_broken = 0 + + # Keep track of the total number of malfunctions in the env + global_nr_malfunctions = 0 + + # Malfunction calls per agent + malfunction_calls = dict() + + def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: + # We use the global variable to assure only a single malfunction in the env + nonlocal global_nr_malfunctions + nonlocal malfunction_calls + + # Reset malfunciton generator + if reset: + nonlocal global_nr_malfunctions + nonlocal malfunction_calls + global_nr_malfunctions = 0 + malfunction_calls = dict() + return Malfunction(0) + + # No more malfunctions if we already had one, ignore all updates + if global_nr_malfunctions > 0: + return Malfunction(0) + + # Update number of calls per agent + if agent.handle in malfunction_calls: + malfunction_calls[agent.handle] += 1 + else: + malfunction_calls[agent.handle] = 1 + + # Break an agent that is active at the time of the malfunction + if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: + global_nr_malfunctions += 1 + return Malfunction(malfunction_duration) + else: + return Malfunction(0) + + return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, + max_number_of_steps_broken) + def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: """ @@ -40,13 +184,9 @@ def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[Malfun ------- generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken """ - # with open(filename, "rb") as file_in: - # load_data = file_in.read() - # if filename.endswith("mpk"): - # data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') - # elif filename.endswith("pkl"): - # data = pickle.loads(load_data) + print("DEPRECATED - use FileMalfunctionGen instead of malfunction_from_file") + env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package) # TODO: make this better by using namedtuple in the pickle file. See issue 282 if "malfunction" in env_dict: @@ -111,6 +251,9 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct ------- generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken """ + + print("DEPRECATED - use ParamMalfunctionGen instead of malfunction_from_params") + mean_malfunction_rate = parameters.malfunction_rate min_number_of_steps_broken = parameters.min_duration max_number_of_steps_broken = parameters.max_duration @@ -142,128 +285,3 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken) - -class ParamMalfunctionGen(object): - def __init__(self, parameters: MalfunctionParameters): - self.mean_malfunction_rate = parameters.malfunction_rate - self.min_number_of_steps_broken = parameters.min_duration - self.max_number_of_steps_broken = parameters.max_duration - - def generate(self, agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: - - # Dummy reset function as we don't implement specific seeding here - if reset: - return Malfunction(0) - - if agent.malfunction_data['malfunction'] < 1: - if np_random.rand() < _malfunction_prob(self.mean_malfunction_rate): - num_broken_steps = np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + 1 - return Malfunction(num_broken_steps) - return Malfunction(0) - - def get_process_data(self): - return MalfunctionProcessData( - self.mean_malfunction_rate, - self.min_number_of_steps_broken, - self.max_number_of_steps_broken) - - -class NoMalfunctionGen(ParamMalfunctionGen): - def __init__(self): - self.mean_malfunction_rate = 0. - self.min_number_of_steps_broken = 0 - self.max_number_of_steps_broken = 0 - - def generate(self, agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: - return Malfunction(0) - - - - -def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: - """ - Malfunction generator which generates no malfunctions - - Parameters - ---------- - Nothing - - Returns - ------- - generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken - """ - # Mean malfunction in number of time steps - mean_malfunction_rate = 0. - - # Uniform distribution parameters for malfunction duration - min_number_of_steps_broken = 0 - max_number_of_steps_broken = 0 - - def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: - return Malfunction(0) - - return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, - max_number_of_steps_broken) - - - -def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[ - MalfunctionGenerator, MalfunctionProcessData]: - """ - Malfunction generator which guarantees exactly one malfunction during an episode of an ACTIVE agent. - - Parameters - ---------- - earlierst_malfunction: Earliest possible malfunction onset - malfunction_duration: The duration of the single malfunction - - Returns - ------- - generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken - """ - # Mean malfunction in number of time steps - mean_malfunction_rate = 0. - - # Uniform distribution parameters for malfunction duration - min_number_of_steps_broken = 0 - max_number_of_steps_broken = 0 - - # Keep track of the total number of malfunctions in the env - global_nr_malfunctions = 0 - - # Malfunction calls per agent - malfunction_calls = dict() - - def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: - # We use the global variable to assure only a single malfunction in the env - nonlocal global_nr_malfunctions - nonlocal malfunction_calls - - # Reset malfunciton generator - if reset: - nonlocal global_nr_malfunctions - nonlocal malfunction_calls - global_nr_malfunctions = 0 - malfunction_calls = dict() - return Malfunction(0) - - # No more malfunctions if we already had one, ignore all updates - if global_nr_malfunctions > 0: - return Malfunction(0) - - # Update number of calls per agent - if agent.handle in malfunction_calls: - malfunction_calls[agent.handle] += 1 - else: - malfunction_calls[agent.handle] = 1 - - # Break an agent that is active at the time of the malfunction - if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: - global_nr_malfunctions += 1 - return Malfunction(malfunction_duration) - else: - return Malfunction(0) - - return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, - max_number_of_steps_broken) diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py index 1b0f05f1..bc4b169b 100644 --- a/flatland/envs/persistence.py +++ b/flatland/envs/persistence.py @@ -124,8 +124,9 @@ class RailEnvPersister(object): load_from_package=load_from_package), schedule_generator=sched_gen.schedule_from_file(filename, load_from_package=load_from_package), - malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename, - load_from_package=load_from_package), + #malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename, + # load_from_package=load_from_package), + malfunction_generator=mal_gen.FileMalfunctionGen(env_dict), obs_builder_object=DummyObservationBuilder(), record_steps=True) -- GitLab