Skip to content
Snippets Groups Projects
Commit fce9451b authored by hagrid67's avatar hagrid67
Browse files

adding missing files, and fixed malfunction_generators

parent 48b99c0c
No related branches found
No related tags found
No related merge requests found
......@@ -2,11 +2,11 @@
from typing import Callable, NamedTuple, Optional, Tuple
import msgpack
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs import persistence
Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
MalfunctionParameters = NamedTuple('MalfunctionParameters',
......@@ -28,7 +28,7 @@ def _malfunction_prob(rate: float) -> float:
return 1 - np.exp(- (1 / rate))
def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load pickle file
......@@ -40,18 +40,26 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
# with open(filename, "rb") as file_in:
# load_data = file_in.read()
# if filename.endswith("mpk"):
# data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
# elif filename.endswith("pkl"):
# data = pickle.loads(load_data)
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
# TODO: make this better by using namedtuple in the pickle file. See issue 282
data['malfunction'] = MalfunctionProcessData._make(data['malfunction'])
if "malfunction" in data:
if "malfunction" in env_dict:
env_dict['malfunction'] = oMPD = MalfunctionProcessData._make(env_dict['malfunction'])
else:
oMPD=None
if oMPD is not None:
# Mean malfunction in number of time steps
mean_malfunction_rate = data["malfunction"].malfunction_rate
mean_malfunction_rate = oMPD.malfunction_rate
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = data["malfunction"].min_duration
max_number_of_steps_broken = data["malfunction"].max_duration
min_number_of_steps_broken = oMPD.min_duration
max_number_of_steps_broken = oMPD.max_duration
else:
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
......
import pickle
import msgpack
import numpy as np
from flatland.envs import rail_env
#from flatland.core.env import Environment
from flatland.core.env_observation_builder import DummyObservationBuilder
#from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
#from flatland.core.grid.grid4_utils import get_new_position
#from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus
from flatland.envs.distance_map import DistanceMap
#from flatland.envs.observations import GlobalObsForRailEnv
# cannot import objects / classes directly because of circular import
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import schedule_generators as sched_gen
class RailEnvPersister(object):
@classmethod
def save(cls, env, filename, save_distance_maps=False):
"""
Saves environment and distance map information in a file
Parameters:
---------
filename: string
save_distance_maps: bool
"""
env_dict = cls.get_full_state(env)
if save_distance_maps is True:
oDistMap = env.distance_map.get()
if oDistMap is not None:
if len(oDistMap) > 0:
env_dict["distance_map"] = oDistMap
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
with open(filename, "wb") as file_out:
if filename.endswith("mpk"):
file_out.write(msgpack.packb(env_dict))
elif filename.endswith("pkl"):
pickle.dump(env_dict, file_out)
@classmethod
def save_episode(cls, env, filename):
dict_env = cls.get_full_state(env)
lAgents = dict_env["agents"]
print("Saving agents:", len(lAgents))
print("Agent 0:", type(lAgents[0]), lAgents[0])
dict_env["episode"] = env.cur_episode
dict_env["shape"] = (env.width, env.height)
with open(filename, "wb") as file_out:
if filename.endswith(".mpk"):
file_out.write(msgpack.packb(dict_env))
elif filename.endswith(".pkl"):
pickle.dump(dict_env, file_out)
@classmethod
def load(cls, env, filename, load_from_package=None):
"""
Load environment with distance map from a file
Parameters:
-------
filename: string
"""
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
cls.set_full_state(env, env_dict)
@classmethod
def load_new(cls, filename, load_from_package=None):
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
# TODO: inefficient - each one of these generators loads the complete env file.
env = rail_env.RailEnv(width=1, height=1,
rail_generator=rail_gen.rail_from_file(filename),
schedule_generator=sched_gen.schedule_from_file(filename),
malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename),
obs_builder_object=DummyObservationBuilder(),
record_steps=True)
env.rail = GridTransitionMap(1,1) # dummy
cls.set_full_state(env, env_dict)
return env, env_dict
@classmethod
def load_env_dict(cls, filename, load_from_package=None):
if load_from_package is not None:
from importlib_resources import read_binary
load_data = read_binary(load_from_package, filename)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
if filename.endswith("mpk"):
env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8")
elif filename.endswith("pkl"):
env_dict = pickle.loads(load_data)
else:
print(f"filename {filename} must end with either pkl or mpk")
env_dict = {}
return env_dict
@classmethod
def load_resource(cls, package, resource):
"""
Load environment (with distance map?) from a binary
"""
from importlib_resources import read_binary
load_data = read_binary(package, resource)
if resource.endswith("pkl"):
env_dict = pickle.loads(load_data)
elif resource.endswith("mpk"):
env_dict = msgpack.unpackb(load_data, encoding="utf-8")
cls.set_full_state(env, env_dict)
@classmethod
def set_full_state(cls, env, env_dict):
"""
Sets environment state from env_dict
Parameters
-------
env_dict: dict
"""
env.rail.grid = np.array(env_dict["grid"])
# agents are always reset as not moving
if "agents_static" in env_dict:
# no idea if this still works
env.agents = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
else:
env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
# setup with loaded data
env.height, env.width = env.rail.grid.shape
env.rail.height = env.height
env.rail.width = env.width
env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False)
@classmethod
def get_full_state(cls, env):
"""
Returns state of environment in dict object, ready for serialization
"""
grid_data = env.rail.grid.tolist()
# msgpack cannot persist EnvAgent so use the Agent namedtuple.
agent_data = [agent.to_agent() for agent in env.agents]
malfunction_data: MalfunctionProcessData = env.malfunction_process_data
msg_data_dict = {
"grid": grid_data,
"agents": agent_data,
"malfunction": malfunction_data}
return msg_data_dict
################################################################################################
# deprecated methods moved from RailEnv. Most likely broken.
def deprecated_get_full_state_msg(self) -> msgpack.Packer:
"""
Returns state of environment in msgpack object
"""
msg_data_dict = self.get_full_state_dict()
return msgpack.packb(msg_data_dict, use_bin_type=True)
def deprecated_get_agent_state_msg(self) -> msgpack.Packer:
"""
Returns agents information in msgpack object
"""
agent_data = [agent.to_agent() for agent in self.agents]
msg_data = {
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
def deprecated_get_full_state_dist_msg(self) -> msgpack.Packer:
"""
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_data = [agent.to_agent() for agent in self.agents]
# I think these calls do nothing - they create packed data and it is discarded
#msgpack.packb(grid_data, use_bin_type=True)
#msgpack.packb(agent_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
malfunction_data: MalfunctionProcessData = self.malfunction_process_data
#msgpack.packb(distance_map_data, use_bin_type=True) # does nothing
msg_data = {
"grid": grid_data,
"agents": agent_data,
"distance_map": distance_map_data,
"malfunction": malfunction_data}
return msgpack.packb(msg_data, use_bin_type=True)
def deprecated_set_full_state_msg(self, msg_data):
"""
Sets environment state with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def deprecated_set_full_state_dist_msg(self, msg_data):
"""
Sets environment grid state and distance map with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
if "distance_map" in data.keys():
self.distance_map.set(data["distance_map"])
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
\ No newline at end of file
......@@ -9,7 +9,7 @@ recordtype>=1.3
matplotlib>=3.0.2
Pillow>=5.4.1
CairoSVG>=2.3.1
msgpack>=1.0.0
msgpack>=0.6.1
msgpack-numpy>=0.4.4.0
svgutils>=0.3.1
pyarrow>=0.13.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment