Skip to content
Snippets Groups Projects
Commit a0791a3f authored by Erik Nygren's avatar Erik Nygren
Browse files

initial commit

parent 2c17b423
No related branches found
No related tags found
No related merge requests found
...@@ -418,6 +418,34 @@ class RailEnv(Environment): ...@@ -418,6 +418,34 @@ class RailEnv(Environment):
self.rail.width = self.width self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def set_full_state_dist_msg(self, msg_data):
data = msgpack.unpackb(msg_data, use_list=False)
self.rail.grid = np.array(data[b"grid"])
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def get_full_state_dist_msg(self):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data)
msgpack.packb(agent_data)
msgpack.packb(agent_static_data)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
def save(self, filename): def save(self, filename):
with open(filename, "wb") as file_out: with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_msg()) file_out.write(self.get_full_state_msg())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment