diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index 028c17a7be65a21a759767a45ad0eabde4f938d4..f8d1bc66b4f9a66c9657902aaa67ae42c9fd8c71 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -2,11 +2,11 @@
 
 from typing import Callable, NamedTuple, Optional, Tuple
 
-import msgpack
 import numpy as np
 from numpy.random.mtrand import RandomState
 
 from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
+from flatland.envs import persistence 
 
 Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
 MalfunctionParameters = NamedTuple('MalfunctionParameters',
@@ -28,7 +28,7 @@ def _malfunction_prob(rate: float) -> float:
         return 1 - np.exp(- (1 / rate))
 
 
-def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
+def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
     """
     Utility to load pickle file
 
@@ -40,18 +40,26 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
     -------
     generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
     """
-    with open(filename, "rb") as file_in:
-        load_data = file_in.read()
-    data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
+    # with open(filename, "rb") as file_in:
+    #     load_data = file_in.read()
+    
+    # if filename.endswith("mpk"):
+    #     data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
+    # elif filename.endswith("pkl"):
+    #     data = pickle.loads(load_data)
+    env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
     # TODO: make this better by using namedtuple in the pickle file. See issue 282
-    data['malfunction'] = MalfunctionProcessData._make(data['malfunction'])
-    if "malfunction" in data:
+    if "malfunction" in env_dict:
+        env_dict['malfunction'] = oMPD = MalfunctionProcessData._make(env_dict['malfunction'])
+    else:
+        oMPD=None
+    if oMPD is not None:
         # Mean malfunction in number of time steps
-        mean_malfunction_rate = data["malfunction"].malfunction_rate
+        mean_malfunction_rate = oMPD.malfunction_rate
 
         # Uniform distribution parameters for malfunction duration
-        min_number_of_steps_broken = data["malfunction"].min_duration
-        max_number_of_steps_broken = data["malfunction"].max_duration
+        min_number_of_steps_broken = oMPD.min_duration
+        max_number_of_steps_broken = oMPD.max_duration
     else:
         # Mean malfunction in number of time steps
         mean_malfunction_rate = 0.
diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py
new file mode 100644
index 0000000000000000000000000000000000000000..3689d20f2da2d6a053ff876731f63d0f6bc0a48b
--- /dev/null
+++ b/flatland/envs/persistence.py
@@ -0,0 +1,266 @@
+
+
+import pickle
+import msgpack
+import numpy as np
+
+from flatland.envs import rail_env 
+
+#from flatland.core.env import Environment
+from flatland.core.env_observation_builder import DummyObservationBuilder
+#from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
+#from flatland.core.grid.grid4_utils import get_new_position
+#from flatland.core.grid.grid_utils import IntVector2D
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus
+from flatland.envs.distance_map import DistanceMap
+
+#from flatland.envs.observations import GlobalObsForRailEnv
+
+# cannot import objects / classes directly because of circular import
+from flatland.envs import malfunction_generators as mal_gen
+from flatland.envs import rail_generators as rail_gen
+from flatland.envs import schedule_generators as sched_gen
+
+
+class RailEnvPersister(object):
+
+    @classmethod
+    def save(cls, env, filename, save_distance_maps=False):
+        """
+        Saves environment and distance map information in a file
+
+        Parameters:
+        ---------
+        filename: string
+        save_distance_maps: bool
+        """
+
+        env_dict = cls.get_full_state(env)
+
+        if save_distance_maps is True:
+            oDistMap = env.distance_map.get() 
+            if oDistMap is not None:
+                if len(oDistMap) > 0:
+                    env_dict["distance_map"] = oDistMap
+                else:
+                    print("[WARNING] Unable to save the distance map for this environment, as none was found !")
+            else:
+                print("[WARNING] Unable to save the distance map for this environment, as none was found !")
+
+        with open(filename, "wb") as file_out:
+            if filename.endswith("mpk"):
+                file_out.write(msgpack.packb(env_dict))
+            elif filename.endswith("pkl"):
+                pickle.dump(env_dict, file_out)
+
+    @classmethod
+    def save_episode(cls, env, filename):
+        dict_env = cls.get_full_state(env)
+        
+        lAgents = dict_env["agents"]
+        print("Saving agents:", len(lAgents))
+        print("Agent 0:", type(lAgents[0]), lAgents[0])
+
+        dict_env["episode"] = env.cur_episode
+        dict_env["shape"] = (env.width, env.height)
+
+        with open(filename, "wb") as file_out:
+            if filename.endswith(".mpk"):
+                file_out.write(msgpack.packb(dict_env))
+            elif filename.endswith(".pkl"):
+                pickle.dump(dict_env, file_out)
+
+
+    @classmethod
+    def load(cls, env, filename, load_from_package=None):
+        """
+        Load environment with distance map from a file
+
+        Parameters:
+        -------
+        filename: string
+        """
+        env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
+        cls.set_full_state(env, env_dict)
+
+    @classmethod
+    def load_new(cls, filename, load_from_package=None):
+
+        env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
+
+
+        # TODO: inefficient - each one of these generators loads the complete env file.
+        env = rail_env.RailEnv(width=1, height=1,
+                rail_generator=rail_gen.rail_from_file(filename),
+                schedule_generator=sched_gen.schedule_from_file(filename),
+                malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename),
+                obs_builder_object=DummyObservationBuilder(),
+                record_steps=True)
+
+        env.rail = GridTransitionMap(1,1) # dummy        
+
+        cls.set_full_state(env, env_dict)
+
+        return env, env_dict
+
+    @classmethod
+    def load_env_dict(cls, filename, load_from_package=None):
+
+        if load_from_package is not None:
+            from importlib_resources import read_binary
+            load_data = read_binary(load_from_package, filename)
+        else:
+            with open(filename, "rb") as file_in:
+                load_data = file_in.read()
+
+        if filename.endswith("mpk"):
+            env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8")
+        elif filename.endswith("pkl"):
+            env_dict = pickle.loads(load_data)
+        else:
+            print(f"filename {filename} must end with either pkl or mpk")
+            env_dict = {}
+
+        return env_dict
+
+    @classmethod
+    def load_resource(cls, package, resource):
+        """
+        Load environment (with distance map?) from a binary
+        """
+        from importlib_resources import read_binary
+        load_data = read_binary(package, resource)
+
+        if resource.endswith("pkl"):
+            env_dict = pickle.loads(load_data)
+        elif resource.endswith("mpk"):
+            env_dict = msgpack.unpackb(load_data, encoding="utf-8")
+        
+        cls.set_full_state(env, env_dict)
+
+    @classmethod
+    def set_full_state(cls, env, env_dict):
+        """
+        Sets environment state from env_dict 
+
+        Parameters
+        -------
+        env_dict: dict
+        """
+        env.rail.grid = np.array(env_dict["grid"])
+        
+        # agents are always reset as not moving
+        if "agents_static" in env_dict:
+            # no idea if this still works
+            env.agents = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
+        else:
+            env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
+        # setup with loaded data
+        env.height, env.width = env.rail.grid.shape
+        env.rail.height = env.height
+        env.rail.width = env.width
+        env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False)
+
+    @classmethod
+    def get_full_state(cls, env):
+        """
+        Returns state of environment in dict object, ready for serialization
+
+        """
+        grid_data = env.rail.grid.tolist()
+
+        # msgpack cannot persist EnvAgent so use the Agent namedtuple.
+        agent_data = [agent.to_agent() for agent in env.agents]
+        malfunction_data: MalfunctionProcessData = env.malfunction_process_data
+
+        msg_data_dict = {
+            "grid": grid_data,
+            "agents": agent_data,
+            "malfunction": malfunction_data}
+        return msg_data_dict
+
+
+################################################################################################
+# deprecated methods moved from RailEnv.  Most likely broken.
+
+    def deprecated_get_full_state_msg(self) -> msgpack.Packer:
+        """
+        Returns state of environment in msgpack object
+        """
+        msg_data_dict = self.get_full_state_dict()
+        return msgpack.packb(msg_data_dict, use_bin_type=True)
+
+    def deprecated_get_agent_state_msg(self) -> msgpack.Packer:
+        """
+        Returns agents information in msgpack object
+        """
+        agent_data = [agent.to_agent() for agent in self.agents]
+        msg_data = {
+            "agents": agent_data}
+        return msgpack.packb(msg_data, use_bin_type=True)
+
+    def deprecated_get_full_state_dist_msg(self) -> msgpack.Packer:
+        """
+        Returns environment information with distance map information as msgpack object
+        """
+        grid_data = self.rail.grid.tolist()
+        agent_data = [agent.to_agent() for agent in self.agents]
+
+        # I think these calls do nothing - they create packed data and it is discarded
+        #msgpack.packb(grid_data, use_bin_type=True)
+        #msgpack.packb(agent_data, use_bin_type=True)
+
+        distance_map_data = self.distance_map.get()
+        malfunction_data: MalfunctionProcessData = self.malfunction_process_data
+        #msgpack.packb(distance_map_data, use_bin_type=True)  # does nothing
+        msg_data = {
+            "grid": grid_data,
+            "agents": agent_data,
+            "distance_map": distance_map_data,
+            "malfunction": malfunction_data}
+        return msgpack.packb(msg_data, use_bin_type=True)
+
+    def deprecated_set_full_state_msg(self, msg_data):
+        """
+        Sets environment state with msgdata object passed as argument
+
+        Parameters
+        -------
+        msg_data: msgpack object
+        """
+        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
+        self.rail.grid = np.array(data["grid"])
+        # agents are always reset as not moving
+        if "agents_static" in data:
+            self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
+        else:
+            self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
+        # setup with loaded data
+        self.height, self.width = self.rail.grid.shape
+        self.rail.height = self.height
+        self.rail.width = self.width
+        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
+
+    def deprecated_set_full_state_dist_msg(self, msg_data):
+        """
+        Sets environment grid state and distance map with msgdata object passed as argument
+
+        Parameters
+        -------
+        msg_data: msgpack object
+        """
+        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
+        self.rail.grid = np.array(data["grid"])
+        # agents are always reset as not moving
+        if "agents_static" in data:
+            self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
+        else:
+            self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
+        if "distance_map" in data.keys():
+            self.distance_map.set(data["distance_map"])
+        # setup with loaded data
+        self.height, self.width = self.rail.grid.shape
+        self.rail.height = self.height
+        self.rail.width = self.width
+        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
\ No newline at end of file
diff --git a/requirements_dev.txt b/requirements_dev.txt
index e971cdd466203b5253e9700001ba0733dbf419c0..2103afa450280f65e8c1cc153c4edd60a81245ee 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -9,7 +9,7 @@ recordtype>=1.3
 matplotlib>=3.0.2
 Pillow>=5.4.1
 CairoSVG>=2.3.1
-msgpack>=1.0.0
+msgpack>=0.6.1
 msgpack-numpy>=0.4.4.0
 svgutils>=0.3.1
 pyarrow>=0.13.0