diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 996301a84f06f185ba9ed605ea1145f404c8b16e..4335109931188d40868b227ae1bac6b0a66aa36e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -418,6 +418,34 @@ class RailEnv(Environment): self.rail.width = self.width self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + def set_full_state_dist_msg(self, msg_data): + data = msgpack.unpackb(msg_data, use_list=False) + self.rail.grid = np.array(data[b"grid"]) + # agents are always reset as not moving + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + # setup with loaded data + self.height, self.width = self.rail.grid.shape + self.rail.height = self.height + self.rail.width = self.width + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + + def get_full_state_dist_msg(self): + grid_data = self.rail.grid.tolist() + agent_static_data = [agent.to_list() for agent in self.agents_static] + agent_data = [agent.to_list() for agent in self.agents] + + msgpack.packb(grid_data) + msgpack.packb(agent_data) + msgpack.packb(agent_static_data) + + msg_data = { + "grid": grid_data, + "agents_static": agent_static_data, + "agents": agent_data} + return msgpack.packb(msg_data, use_bin_type=True) + + def save(self, filename): with open(filename, "wb") as file_out: file_out.write(self.get_full_state_msg())