From ff49ce2e7cb926c7dec649fb5f056ce76b218806 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Tue, 2 Jun 2020 20:32:30 +0100 Subject: [PATCH] removing some prints from persistence --- flatland/envs/persistence.py | 44 ++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py index 078b6052..e77e97aa 100644 --- a/flatland/envs/persistence.py +++ b/flatland/envs/persistence.py @@ -38,6 +38,12 @@ class RailEnvPersister(object): env_dict = cls.get_full_state(env) + #print(f"env save - agents: {env_dict['agents'][0]}") + #a0 = env_dict["agents"][0] + #print("agent type:", type(a0)) + + + if save_distance_maps is True: oDistMap = env.distance_map.get() if oDistMap is not None: @@ -49,18 +55,33 @@ class RailEnvPersister(object): print("[WARNING] Unable to save the distance map for this environment, as none was found !") with open(filename, "wb") as file_out: + if filename.endswith("mpk"): - file_out.write(msgpack.packb(env_dict)) + data = msgpack.packb(env_dict) + + elif filename.endswith("pkl"): - pickle.dump(env_dict, file_out) + data = pickle.dumps(env_dict) + #pickle.dump(env_dict, file_out) + + file_out.write(data) + + #with open(filename, "rb") as file_in: + if filename.endswith("mpk"): + #bytes_in = file_in.read() + dIn = msgpack.unpackb(data, encoding="utf-8") + #print(f"msgpack check - {dIn.keys()}") + #print(f"msgpack check - {dIn['agents'][0]}") + + @classmethod def save_episode(cls, env, filename): dict_env = cls.get_full_state(env) lAgents = dict_env["agents"] - print("Saving agents:", len(lAgents)) - print("Agent 0:", type(lAgents[0]), lAgents[0]) + #print("Saving agents:", len(lAgents)) + #print("Agent 0:", type(lAgents[0]), lAgents[0]) dict_env["episode"] = env.cur_episode dict_env["actions"] = env.list_actions @@ -123,6 +144,10 @@ class RailEnvPersister(object): else: print(f"filename {filename} must end with either pkl or mpk") env_dict = {} + + if "agents" in env_dict: + env_dict["agents"] = [EnvAgent(*d[0:12]) for d in env_dict["agents"]] + #print(f"env_dict agents: {env_dict['agents']}") return env_dict @@ -159,7 +184,15 @@ class RailEnvPersister(object): # no idea if this still works env.agents = EnvAgent.load_legacy_static_agent(env_dict["agents_static"]) else: - env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]] + agents_data = env_dict["agents"] + if len(agents_data)>0: + if type(agents_data[0]) is EnvAgent: + env.agents = agents_data + else: + env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]] + + #print(f"env agents: {env.agents}") + # setup with loaded data env.height, env.width = env.rail.grid.shape env.rail.height = env.height @@ -176,6 +209,7 @@ class RailEnvPersister(object): # msgpack cannot persist EnvAgent so use the Agent namedtuple. agent_data = [agent.to_agent() for agent in env.agents] + #print("get_full_state - agent_data:", agent_data) malfunction_data: MalfunctionProcessData = env.malfunction_process_data msg_data_dict = { -- GitLab