diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 49209460dd191f1628aabac49d3ce5b72df7e817..0d42ab776e84e59a41f8b3ddb82dc236e1ae7525 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -23,14 +23,21 @@ class EnvAgentStatic(object): position = attrib() direction = attrib() target = attrib() - old_direction = attrib(default=None) + + def __init__(self, position, direction, target): + self.position = position + self.direction = direction + self.target = target @classmethod def from_lists(cls, positions, directions, targets): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ return list(starmap(EnvAgentStatic, zip(positions, directions, targets))) - + + def to_list(self): + return [self.position, self.direction, self.target] + @attrs class EnvAgent(EnvAgentStatic): @@ -41,6 +48,15 @@ class EnvAgent(EnvAgentStatic): forcing the env to refer to it in the EnvAgentStatic """ handle = attrib(default=None) + old_direction = attrib(default=None) + + def __init__(self, position, direction, target, handle, old_direction): + super(EnvAgent, self).__init__(position, direction, target) + self.handle = handle + self.old_direction = old_direction + + def to_list(self): + return [self.position, self.direction, self.target, self.handle, self.old_direction] @classmethod def from_static(cls, oStatic): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index adb40f03b2ee100d8e12cfead660f9600dd9968b..228cc32537b3a464afcaf9f0555c444dd38339bc 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -5,7 +5,7 @@ Generator functions are functions that take width, height and num_resets as argu a GridTransitionMap object. """ import numpy as np -import pickle +import msgpack from flatland.core.env import Environment from flatland.core.env_observation_builder import TreeObsForRailEnv @@ -324,20 +324,39 @@ class RailEnv(Environment): # TODO: pass - def save(self, sFilename): - dSave = { - "grid": self.rail.grid, - "agents_static": self.agents_static + def get_full_state_msg(self): + grid_data = self.rail.grid.tolist() + agent_static_data = [agent.to_list() for agent in self.agents_static] + agent_data = [agent.to_list() for agent in self.agents] + msg_data = { + "grid": grid_data, + "agents_static": agent_static_data, + "agents": agent_data } - with open(sFilename, "wb") as fOut: - pickle.dump(dSave, fOut) - - def load(self, sFilename): - with open(sFilename, "rb") as fIn: - dLoad = pickle.load(fIn) - self.rail.grid = dLoad["grid"] - self.height, self.width = self.rail.grid.shape - self.agents_static = dLoad["agents_static"] - self.agents = [None] * self.get_num_agents() - self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) - + return msgpack.packb(msg_data, use_bin_type=True) + + def get_agent_state_msg(self): + agent_data = [agent.to_list() for agent in self.agents] + msg_data = { + "agents": agent_data + } + return msgpack.packb(msg_data, use_bin_type=True) + + def set_full_state_msg(self, msg_data): + data = msgpack.unpackb(msg_data, use_list=False) + self.rail.grid = np.array(data[b"grid"]) + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + # setup with loaded data + self.height, self.width = self.rail.grid.shape + # self.agents = [None] * self.get_num_agents() + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + + def save(self, filename): + with open(filename, "wb") as file_out: + file_out.write(self.get_full_state_msg()) + + def load(self, filename): + with open(filename, "rb") as file_in: + load_data = file_in.read() + self.set_full_state_msg(load_data) diff --git a/requirements_dev.txt b/requirements_dev.txt index 40a6b7f683a9b268f210b48057d0186b4e89b92b..bcf770c9476f6333db5f1bcf72814671dcb4ad30 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -18,3 +18,4 @@ matplotlib==3.0.2 PyQt5==5.12 Pillow==5.4.1 +msgpack==0.6.1 diff --git a/tests/test_environments.py b/tests/test_environments.py index f12dfa3d6b57f76ce490c2c748fefc008ba371a5..dae7c130720a8ef8e6e3bf7681d741b2f26b2b02 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -4,6 +4,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.generators import complex_rail_generator from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap from flatland.core.env_observation_builder import GlobalObsForRailEnv @@ -12,6 +13,30 @@ from flatland.envs.agent_utils import EnvAgent """Tests for `flatland` package.""" +def test_save_load(): + env = RailEnv(width=10, height=10, + rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0), + number_of_agents=2) + env.reset() + agent_1_pos = env.agents_static[0].position + agent_1_dir = env.agents_static[0].direction + agent_1_tar = env.agents_static[0].target + agent_2_pos = env.agents_static[1].position + agent_2_dir = env.agents_static[1].direction + agent_2_tar = env.agents_static[1].target + env.save("test_save.dat") + env.load("test_save.dat") + assert(env.width == 10) + assert(env.height == 10) + assert(len(env.agents) == 2) + assert(agent_1_pos == env.agents_static[0].position) + assert(agent_1_dir == env.agents_static[0].direction) + assert(agent_1_tar == env.agents_static[0].target) + assert(agent_2_pos == env.agents_static[1].position) + assert(agent_2_dir == env.agents_static[1].direction) + assert(agent_2_tar == env.agents_static[1].target) + + def test_rail_environment_single_agent(): cells = [int('0000000000000000', 2), # empty cell - Case 0