From 57ce38217fe097253a02c480a9905b567d78e5e7 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Mon, 28 Sep 2020 13:31:24 +0100
Subject: [PATCH] added FileMalfunctionGen to replace
 file_malfunction_generator

---
 flatland/envs/malfunction_generators.py | 284 +++++++++++++-----------
 flatland/envs/persistence.py            |   5 +-
 2 files changed, 154 insertions(+), 135 deletions(-)

diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index 62a92080..0d27913d 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -8,13 +8,17 @@ from numpy.random.mtrand import RandomState
 from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
 from flatland.envs import persistence
 
-Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
+
+# why do we have both MalfunctionParameters and MalfunctionProcessData - they are both the same!
 MalfunctionParameters = NamedTuple('MalfunctionParameters',
                                    [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
-MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]]
 MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
                                     [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
 
+Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
+
+# Why is the return value Optional?  We always return a Malfunction.
+MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]]
 
 def _malfunction_prob(rate: float) -> float:
     """
@@ -27,6 +31,146 @@ def _malfunction_prob(rate: float) -> float:
     else:
         return 1 - np.exp(-rate)
 
+class ParamMalfunctionGen(object):
+    """ Preserving old behaviour of using MalfunctionParameters for constructor,
+        but returning MalfunctionProcessData in get_process_data.  
+        Data structure and content is the same.
+    """
+    def __init__(self, parameters: MalfunctionParameters):
+        #self.mean_malfunction_rate = parameters.malfunction_rate
+        #self.min_number_of_steps_broken = parameters.min_duration
+        #self.max_number_of_steps_broken = parameters.max_duration
+        self.MFP = parameters
+
+    def generate(self,
+        agent: EnvAgent = None,
+        np_random: RandomState = None,
+        reset=False) -> Optional[Malfunction]:
+      
+        # Dummy reset function as we don't implement specific seeding here
+        if reset:
+            return Malfunction(0)
+
+        if agent.malfunction_data['malfunction'] < 1:
+            if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
+                num_broken_steps = np_random.randint(self.MFP.min_duration,
+                                                     self.MFP.max_duration + 1) + 1
+                return Malfunction(num_broken_steps)
+        return Malfunction(0)
+
+    def get_process_data(self):
+        return MalfunctionProcessData(*self.MFP)
+
+
+class NoMalfunctionGen(ParamMalfunctionGen):
+    def __init__(self):
+        super().__init__(MalfunctionParameters(0,0,0))
+
+class FileMalfunctionGen(ParamMalfunctionGen):
+    def __init__(self, env_dict=None, filename=None, load_from_package=None):
+        """ uses env_dict if populated, otherwise tries to load from file / package.
+        """
+        if env_dict is None:
+             env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
+
+        if "malfunction" in env_dict:
+            oMFP = MalfunctionParameters(*env_dict["malfunction"])
+        else:
+            oMFP = MalfunctionParameters(0,0,0)  # no malfunctions
+        super().__init__(oMFP)
+
+
+################################################################################################
+# OLD / DEPRECATED generator functions below. To be removed.
+
+def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
+    """
+    Malfunction generator which generates no malfunctions
+
+    Parameters
+    ----------
+    Nothing
+
+    Returns
+    -------
+    generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+    """
+    print("DEPRECATED - use NoMalfunctionGen instead of no_malfunction_generator")
+    # Mean malfunction in number of time steps
+    mean_malfunction_rate = 0.
+
+    # Uniform distribution parameters for malfunction duration
+    min_number_of_steps_broken = 0
+    max_number_of_steps_broken = 0
+
+    def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
+        return Malfunction(0)
+
+    return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
+                                             max_number_of_steps_broken)
+
+
+
+def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[
+    MalfunctionGenerator, MalfunctionProcessData]:
+    """
+    Malfunction generator which guarantees exactly one malfunction during an episode of an ACTIVE agent.
+
+    Parameters
+    ----------
+    earlierst_malfunction: Earliest possible malfunction onset
+    malfunction_duration: The duration of the single malfunction
+
+    Returns
+    -------
+    generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+    """
+    # Mean malfunction in number of time steps
+    mean_malfunction_rate = 0.
+
+    # Uniform distribution parameters for malfunction duration
+    min_number_of_steps_broken = 0
+    max_number_of_steps_broken = 0
+
+    # Keep track of the total number of malfunctions in the env
+    global_nr_malfunctions = 0
+
+    # Malfunction calls per agent
+    malfunction_calls = dict()
+
+    def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
+        # We use the global variable to assure only a single malfunction in the env
+        nonlocal global_nr_malfunctions
+        nonlocal malfunction_calls
+
+        # Reset malfunciton generator
+        if reset:
+            nonlocal global_nr_malfunctions
+            nonlocal malfunction_calls
+            global_nr_malfunctions = 0
+            malfunction_calls = dict()
+            return Malfunction(0)
+
+        # No more malfunctions if we already had one, ignore all updates
+        if global_nr_malfunctions > 0:
+            return Malfunction(0)
+
+        # Update number of calls per agent
+        if agent.handle in malfunction_calls:
+            malfunction_calls[agent.handle] += 1
+        else:
+            malfunction_calls[agent.handle] = 1
+
+        # Break an agent that is active at the time of the malfunction
+        if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction:
+            global_nr_malfunctions += 1
+            return Malfunction(malfunction_duration)
+        else:
+            return Malfunction(0)
+
+    return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
+                                             max_number_of_steps_broken)
+
 
 def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
     """
@@ -40,13 +184,9 @@ def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[Malfun
     -------
     generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
     """
-    # with open(filename, "rb") as file_in:
-    #     load_data = file_in.read()
 
-    # if filename.endswith("mpk"):
-    #     data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
-    # elif filename.endswith("pkl"):
-    #     data = pickle.loads(load_data)
+    print("DEPRECATED - use FileMalfunctionGen instead of malfunction_from_file")
+
     env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
     # TODO: make this better by using namedtuple in the pickle file. See issue 282
     if "malfunction" in env_dict:
@@ -111,6 +251,9 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
     -------
     generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
     """
+    
+    print("DEPRECATED - use ParamMalfunctionGen instead of malfunction_from_params")
+
     mean_malfunction_rate = parameters.malfunction_rate
     min_number_of_steps_broken = parameters.min_duration
     max_number_of_steps_broken = parameters.max_duration
@@ -142,128 +285,3 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
     return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
                                              max_number_of_steps_broken)
 
-
-class ParamMalfunctionGen(object):
-    def __init__(self, parameters: MalfunctionParameters):
-        self.mean_malfunction_rate = parameters.malfunction_rate
-        self.min_number_of_steps_broken = parameters.min_duration
-        self.max_number_of_steps_broken = parameters.max_duration
-
-    def generate(self, agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
-      
-        # Dummy reset function as we don't implement specific seeding here
-        if reset:
-            return Malfunction(0)
-
-        if agent.malfunction_data['malfunction'] < 1:
-            if np_random.rand() < _malfunction_prob(self.mean_malfunction_rate):
-                num_broken_steps = np_random.randint(self.min_number_of_steps_broken,
-                                                     self.max_number_of_steps_broken + 1) + 1
-                return Malfunction(num_broken_steps)
-        return Malfunction(0)
-
-    def get_process_data(self):
-        return MalfunctionProcessData(
-            self.mean_malfunction_rate, 
-            self.min_number_of_steps_broken,
-            self.max_number_of_steps_broken)
-
-
-class NoMalfunctionGen(ParamMalfunctionGen):
-    def __init__(self):
-        self.mean_malfunction_rate = 0.
-        self.min_number_of_steps_broken = 0
-        self.max_number_of_steps_broken = 0
-
-    def generate(self, agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
-        return Malfunction(0)
-
-    
-
-
-def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
-    """
-    Malfunction generator which generates no malfunctions
-
-    Parameters
-    ----------
-    Nothing
-
-    Returns
-    -------
-    generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
-    """
-    # Mean malfunction in number of time steps
-    mean_malfunction_rate = 0.
-
-    # Uniform distribution parameters for malfunction duration
-    min_number_of_steps_broken = 0
-    max_number_of_steps_broken = 0
-
-    def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
-        return Malfunction(0)
-
-    return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
-                                             max_number_of_steps_broken)
-
-
-
-def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[
-    MalfunctionGenerator, MalfunctionProcessData]:
-    """
-    Malfunction generator which guarantees exactly one malfunction during an episode of an ACTIVE agent.
-
-    Parameters
-    ----------
-    earlierst_malfunction: Earliest possible malfunction onset
-    malfunction_duration: The duration of the single malfunction
-
-    Returns
-    -------
-    generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
-    """
-    # Mean malfunction in number of time steps
-    mean_malfunction_rate = 0.
-
-    # Uniform distribution parameters for malfunction duration
-    min_number_of_steps_broken = 0
-    max_number_of_steps_broken = 0
-
-    # Keep track of the total number of malfunctions in the env
-    global_nr_malfunctions = 0
-
-    # Malfunction calls per agent
-    malfunction_calls = dict()
-
-    def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
-        # We use the global variable to assure only a single malfunction in the env
-        nonlocal global_nr_malfunctions
-        nonlocal malfunction_calls
-
-        # Reset malfunciton generator
-        if reset:
-            nonlocal global_nr_malfunctions
-            nonlocal malfunction_calls
-            global_nr_malfunctions = 0
-            malfunction_calls = dict()
-            return Malfunction(0)
-
-        # No more malfunctions if we already had one, ignore all updates
-        if global_nr_malfunctions > 0:
-            return Malfunction(0)
-
-        # Update number of calls per agent
-        if agent.handle in malfunction_calls:
-            malfunction_calls[agent.handle] += 1
-        else:
-            malfunction_calls[agent.handle] = 1
-
-        # Break an agent that is active at the time of the malfunction
-        if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction:
-            global_nr_malfunctions += 1
-            return Malfunction(malfunction_duration)
-        else:
-            return Malfunction(0)
-
-    return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
-                                             max_number_of_steps_broken)
diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py
index 1b0f05f1..bc4b169b 100644
--- a/flatland/envs/persistence.py
+++ b/flatland/envs/persistence.py
@@ -124,8 +124,9 @@ class RailEnvPersister(object):
                     load_from_package=load_from_package),
                 schedule_generator=sched_gen.schedule_from_file(filename,
                     load_from_package=load_from_package),
-                malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename,
-                    load_from_package=load_from_package),
+                #malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename,
+                #    load_from_package=load_from_package),
+                malfunction_generator=mal_gen.FileMalfunctionGen(env_dict),
                 obs_builder_object=DummyObservationBuilder(),
                 record_steps=True)
 
-- 
GitLab