diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 028c17a7be65a21a759767a45ad0eabde4f938d4..f8d1bc66b4f9a66c9657902aaa67ae42c9fd8c71 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -2,11 +2,11 @@ from typing import Callable, NamedTuple, Optional, Tuple -import msgpack import numpy as np from numpy.random.mtrand import RandomState from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs import persistence Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)]) MalfunctionParameters = NamedTuple('MalfunctionParameters', @@ -28,7 +28,7 @@ def _malfunction_prob(rate: float) -> float: return 1 - np.exp(- (1 / rate)) -def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: +def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: """ Utility to load pickle file @@ -40,18 +40,26 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct ------- generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken """ - with open(filename, "rb") as file_in: - load_data = file_in.read() - data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') + # with open(filename, "rb") as file_in: + # load_data = file_in.read() + + # if filename.endswith("mpk"): + # data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') + # elif filename.endswith("pkl"): + # data = pickle.loads(load_data) + env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package) # TODO: make this better by using namedtuple in the pickle file. See issue 282 - data['malfunction'] = MalfunctionProcessData._make(data['malfunction']) - if "malfunction" in data: + if "malfunction" in env_dict: + env_dict['malfunction'] = oMPD = MalfunctionProcessData._make(env_dict['malfunction']) + else: + oMPD=None + if oMPD is not None: # Mean malfunction in number of time steps - mean_malfunction_rate = data["malfunction"].malfunction_rate + mean_malfunction_rate = oMPD.malfunction_rate # Uniform distribution parameters for malfunction duration - min_number_of_steps_broken = data["malfunction"].min_duration - max_number_of_steps_broken = data["malfunction"].max_duration + min_number_of_steps_broken = oMPD.min_duration + max_number_of_steps_broken = oMPD.max_duration else: # Mean malfunction in number of time steps mean_malfunction_rate = 0. diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py new file mode 100644 index 0000000000000000000000000000000000000000..3689d20f2da2d6a053ff876731f63d0f6bc0a48b --- /dev/null +++ b/flatland/envs/persistence.py @@ -0,0 +1,266 @@ + + +import pickle +import msgpack +import numpy as np + +from flatland.envs import rail_env + +#from flatland.core.env import Environment +from flatland.core.env_observation_builder import DummyObservationBuilder +#from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions +#from flatland.core.grid.grid4_utils import get_new_position +#from flatland.core.grid.grid_utils import IntVector2D +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus +from flatland.envs.distance_map import DistanceMap + +#from flatland.envs.observations import GlobalObsForRailEnv + +# cannot import objects / classes directly because of circular import +from flatland.envs import malfunction_generators as mal_gen +from flatland.envs import rail_generators as rail_gen +from flatland.envs import schedule_generators as sched_gen + + +class RailEnvPersister(object): + + @classmethod + def save(cls, env, filename, save_distance_maps=False): + """ + Saves environment and distance map information in a file + + Parameters: + --------- + filename: string + save_distance_maps: bool + """ + + env_dict = cls.get_full_state(env) + + if save_distance_maps is True: + oDistMap = env.distance_map.get() + if oDistMap is not None: + if len(oDistMap) > 0: + env_dict["distance_map"] = oDistMap + else: + print("[WARNING] Unable to save the distance map for this environment, as none was found !") + else: + print("[WARNING] Unable to save the distance map for this environment, as none was found !") + + with open(filename, "wb") as file_out: + if filename.endswith("mpk"): + file_out.write(msgpack.packb(env_dict)) + elif filename.endswith("pkl"): + pickle.dump(env_dict, file_out) + + @classmethod + def save_episode(cls, env, filename): + dict_env = cls.get_full_state(env) + + lAgents = dict_env["agents"] + print("Saving agents:", len(lAgents)) + print("Agent 0:", type(lAgents[0]), lAgents[0]) + + dict_env["episode"] = env.cur_episode + dict_env["shape"] = (env.width, env.height) + + with open(filename, "wb") as file_out: + if filename.endswith(".mpk"): + file_out.write(msgpack.packb(dict_env)) + elif filename.endswith(".pkl"): + pickle.dump(dict_env, file_out) + + + @classmethod + def load(cls, env, filename, load_from_package=None): + """ + Load environment with distance map from a file + + Parameters: + ------- + filename: string + """ + env_dict = cls.load_env_dict(filename, load_from_package=load_from_package) + cls.set_full_state(env, env_dict) + + @classmethod + def load_new(cls, filename, load_from_package=None): + + env_dict = cls.load_env_dict(filename, load_from_package=load_from_package) + + + # TODO: inefficient - each one of these generators loads the complete env file. + env = rail_env.RailEnv(width=1, height=1, + rail_generator=rail_gen.rail_from_file(filename), + schedule_generator=sched_gen.schedule_from_file(filename), + malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename), + obs_builder_object=DummyObservationBuilder(), + record_steps=True) + + env.rail = GridTransitionMap(1,1) # dummy + + cls.set_full_state(env, env_dict) + + return env, env_dict + + @classmethod + def load_env_dict(cls, filename, load_from_package=None): + + if load_from_package is not None: + from importlib_resources import read_binary + load_data = read_binary(load_from_package, filename) + else: + with open(filename, "rb") as file_in: + load_data = file_in.read() + + if filename.endswith("mpk"): + env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8") + elif filename.endswith("pkl"): + env_dict = pickle.loads(load_data) + else: + print(f"filename {filename} must end with either pkl or mpk") + env_dict = {} + + return env_dict + + @classmethod + def load_resource(cls, package, resource): + """ + Load environment (with distance map?) from a binary + """ + from importlib_resources import read_binary + load_data = read_binary(package, resource) + + if resource.endswith("pkl"): + env_dict = pickle.loads(load_data) + elif resource.endswith("mpk"): + env_dict = msgpack.unpackb(load_data, encoding="utf-8") + + cls.set_full_state(env, env_dict) + + @classmethod + def set_full_state(cls, env, env_dict): + """ + Sets environment state from env_dict + + Parameters + ------- + env_dict: dict + """ + env.rail.grid = np.array(env_dict["grid"]) + + # agents are always reset as not moving + if "agents_static" in env_dict: + # no idea if this still works + env.agents = EnvAgent.load_legacy_static_agent(env_dict["agents_static"]) + else: + env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]] + # setup with loaded data + env.height, env.width = env.rail.grid.shape + env.rail.height = env.height + env.rail.width = env.width + env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False) + + @classmethod + def get_full_state(cls, env): + """ + Returns state of environment in dict object, ready for serialization + + """ + grid_data = env.rail.grid.tolist() + + # msgpack cannot persist EnvAgent so use the Agent namedtuple. + agent_data = [agent.to_agent() for agent in env.agents] + malfunction_data: MalfunctionProcessData = env.malfunction_process_data + + msg_data_dict = { + "grid": grid_data, + "agents": agent_data, + "malfunction": malfunction_data} + return msg_data_dict + + +################################################################################################ +# deprecated methods moved from RailEnv. Most likely broken. + + def deprecated_get_full_state_msg(self) -> msgpack.Packer: + """ + Returns state of environment in msgpack object + """ + msg_data_dict = self.get_full_state_dict() + return msgpack.packb(msg_data_dict, use_bin_type=True) + + def deprecated_get_agent_state_msg(self) -> msgpack.Packer: + """ + Returns agents information in msgpack object + """ + agent_data = [agent.to_agent() for agent in self.agents] + msg_data = { + "agents": agent_data} + return msgpack.packb(msg_data, use_bin_type=True) + + def deprecated_get_full_state_dist_msg(self) -> msgpack.Packer: + """ + Returns environment information with distance map information as msgpack object + """ + grid_data = self.rail.grid.tolist() + agent_data = [agent.to_agent() for agent in self.agents] + + # I think these calls do nothing - they create packed data and it is discarded + #msgpack.packb(grid_data, use_bin_type=True) + #msgpack.packb(agent_data, use_bin_type=True) + + distance_map_data = self.distance_map.get() + malfunction_data: MalfunctionProcessData = self.malfunction_process_data + #msgpack.packb(distance_map_data, use_bin_type=True) # does nothing + msg_data = { + "grid": grid_data, + "agents": agent_data, + "distance_map": distance_map_data, + "malfunction": malfunction_data} + return msgpack.packb(msg_data, use_bin_type=True) + + def deprecated_set_full_state_msg(self, msg_data): + """ + Sets environment state with msgdata object passed as argument + + Parameters + ------- + msg_data: msgpack object + """ + data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') + self.rail.grid = np.array(data["grid"]) + # agents are always reset as not moving + if "agents_static" in data: + self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) + else: + self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] + # setup with loaded data + self.height, self.width = self.rail.grid.shape + self.rail.height = self.height + self.rail.width = self.width + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + + def deprecated_set_full_state_dist_msg(self, msg_data): + """ + Sets environment grid state and distance map with msgdata object passed as argument + + Parameters + ------- + msg_data: msgpack object + """ + data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') + self.rail.grid = np.array(data["grid"]) + # agents are always reset as not moving + if "agents_static" in data: + self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) + else: + self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] + if "distance_map" in data.keys(): + self.distance_map.set(data["distance_map"]) + # setup with loaded data + self.height, self.width = self.rail.grid.shape + self.rail.height = self.height + self.rail.width = self.width + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt index e971cdd466203b5253e9700001ba0733dbf419c0..2103afa450280f65e8c1cc153c4edd60a81245ee 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -9,7 +9,7 @@ recordtype>=1.3 matplotlib>=3.0.2 Pillow>=5.4.1 CairoSVG>=2.3.1 -msgpack>=1.0.0 +msgpack>=0.6.1 msgpack-numpy>=0.4.4.0 svgutils>=0.3.1 pyarrow>=0.13.0