Skip to content
Snippets Groups Projects
Commit ca044f83 authored by maljx's avatar maljx
Browse files

using msgpack to save/load state

parent 216d1ae6
No related branches found
No related tags found
No related merge requests found
......@@ -23,14 +23,21 @@ class EnvAgentStatic(object):
position = attrib()
direction = attrib()
target = attrib()
old_direction = attrib(default=None)
def __init__(self, position, direction, target):
self.position = position
self.direction = direction
self.target = target
@classmethod
def from_lists(cls, positions, directions, targets):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
return list(starmap(EnvAgentStatic, zip(positions, directions, targets)))
def to_list(self):
return [self.position, self.direction, self.target]
@attrs
class EnvAgent(EnvAgentStatic):
......@@ -41,6 +48,15 @@ class EnvAgent(EnvAgentStatic):
forcing the env to refer to it in the EnvAgentStatic
"""
handle = attrib(default=None)
old_direction = attrib(default=None)
def __init__(self, position, direction, target, handle, old_direction):
super(EnvAgent, self).__init__(position, direction, target)
self.handle = handle
self.old_direction = old_direction
def to_list(self):
return [self.position, self.direction, self.target, self.handle, self.old_direction]
@classmethod
def from_static(cls, oStatic):
......
......@@ -5,7 +5,7 @@ Generator functions are functions that take width, height and num_resets as argu
a GridTransitionMap object.
"""
import numpy as np
import pickle
import msgpack
from flatland.core.env import Environment
from flatland.core.env_observation_builder import TreeObsForRailEnv
......@@ -324,20 +324,39 @@ class RailEnv(Environment):
# TODO:
pass
def save(self, sFilename):
dSave = {
"grid": self.rail.grid,
"agents_static": self.agents_static
def get_full_state_msg(self):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data
}
with open(sFilename, "wb") as fOut:
pickle.dump(dSave, fOut)
def load(self, sFilename):
with open(sFilename, "rb") as fIn:
dLoad = pickle.load(fIn)
self.rail.grid = dLoad["grid"]
self.height, self.width = self.rail.grid.shape
self.agents_static = dLoad["agents_static"]
self.agents = [None] * self.get_num_agents()
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
return msgpack.packb(msg_data, use_bin_type=True)
def get_agent_state_msg(self):
agent_data = [agent.to_list() for agent in self.agents]
msg_data = {
"agents": agent_data
}
return msgpack.packb(msg_data, use_bin_type=True)
def set_full_state_msg(self, msg_data):
data = msgpack.unpackb(msg_data, use_list=False)
self.rail.grid = np.array(data[b"grid"])
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
# self.agents = [None] * self.get_num_agents()
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def save(self, filename):
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_msg())
def load(self, filename):
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_msg(load_data)
......@@ -18,3 +18,4 @@ matplotlib==3.0.2
PyQt5==5.12
Pillow==5.4.1
msgpack==0.6.1
......@@ -4,6 +4,7 @@ import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.generators import complex_rail_generator
from flatland.core.transitions import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.core.env_observation_builder import GlobalObsForRailEnv
......@@ -12,6 +13,30 @@ from flatland.envs.agent_utils import EnvAgent
"""Tests for `flatland` package."""
def test_save_load():
env = RailEnv(width=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0),
number_of_agents=2)
env.reset()
agent_1_pos = env.agents_static[0].position
agent_1_dir = env.agents_static[0].direction
agent_1_tar = env.agents_static[0].target
agent_2_pos = env.agents_static[1].position
agent_2_dir = env.agents_static[1].direction
agent_2_tar = env.agents_static[1].target
env.save("test_save.dat")
env.load("test_save.dat")
assert(env.width == 10)
assert(env.height == 10)
assert(len(env.agents) == 2)
assert(agent_1_pos == env.agents_static[0].position)
assert(agent_1_dir == env.agents_static[0].direction)
assert(agent_1_tar == env.agents_static[0].target)
assert(agent_2_pos == env.agents_static[1].position)
assert(agent_2_dir == env.agents_static[1].direction)
assert(agent_2_tar == env.agents_static[1].target)
def test_rail_environment_single_agent():
cells = [int('0000000000000000', 2), # empty cell - Case 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment