From ff49ce2e7cb926c7dec649fb5f056ce76b218806 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Tue, 2 Jun 2020 20:32:30 +0100
Subject: [PATCH] removing some prints from persistence

---
 flatland/envs/persistence.py | 44 ++++++++++++++++++++++++++++++++----
 1 file changed, 39 insertions(+), 5 deletions(-)

diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py
index 078b6052..e77e97aa 100644
--- a/flatland/envs/persistence.py
+++ b/flatland/envs/persistence.py
@@ -38,6 +38,12 @@ class RailEnvPersister(object):
 
         env_dict = cls.get_full_state(env)
 
+        #print(f"env save - agents: {env_dict['agents'][0]}")
+        #a0 = env_dict["agents"][0]
+        #print("agent type:", type(a0))
+
+
+
         if save_distance_maps is True:
             oDistMap = env.distance_map.get() 
             if oDistMap is not None:
@@ -49,18 +55,33 @@ class RailEnvPersister(object):
                 print("[WARNING] Unable to save the distance map for this environment, as none was found !")
 
         with open(filename, "wb") as file_out:
+
             if filename.endswith("mpk"):
-                file_out.write(msgpack.packb(env_dict))
+                data = msgpack.packb(env_dict)
+                
+                
             elif filename.endswith("pkl"):
-                pickle.dump(env_dict, file_out)
+                data = pickle.dumps(env_dict)
+                #pickle.dump(env_dict, file_out)
+
+            file_out.write(data)
+
+        #with open(filename, "rb") as file_in:
+        if filename.endswith("mpk"):
+            #bytes_in = file_in.read()
+            dIn = msgpack.unpackb(data, encoding="utf-8")
+            #print(f"msgpack check - {dIn.keys()}")
+            #print(f"msgpack check - {dIn['agents'][0]}")
+
+                
 
     @classmethod
     def save_episode(cls, env, filename):
         dict_env = cls.get_full_state(env)
         
         lAgents = dict_env["agents"]
-        print("Saving agents:", len(lAgents))
-        print("Agent 0:", type(lAgents[0]), lAgents[0])
+        #print("Saving agents:", len(lAgents))
+        #print("Agent 0:", type(lAgents[0]), lAgents[0])
 
         dict_env["episode"] = env.cur_episode
         dict_env["actions"] = env.list_actions
@@ -123,6 +144,10 @@ class RailEnvPersister(object):
         else:
             print(f"filename {filename} must end with either pkl or mpk")
             env_dict = {}
+        
+        if "agents" in env_dict:
+            env_dict["agents"] = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
+            #print(f"env_dict agents: {env_dict['agents']}")
 
         return env_dict
 
@@ -159,7 +184,15 @@ class RailEnvPersister(object):
             # no idea if this still works
             env.agents = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
         else:
-            env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
+            agents_data = env_dict["agents"]
+            if len(agents_data)>0:
+                if type(agents_data[0]) is EnvAgent:
+                    env.agents = agents_data
+                else:
+                    env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
+
+                #print(f"env agents: {env.agents}")
+
         # setup with loaded data
         env.height, env.width = env.rail.grid.shape
         env.rail.height = env.height
@@ -176,6 +209,7 @@ class RailEnvPersister(object):
 
         # msgpack cannot persist EnvAgent so use the Agent namedtuple.
         agent_data = [agent.to_agent() for agent in env.agents]
+        #print("get_full_state - agent_data:", agent_data)
         malfunction_data: MalfunctionProcessData = env.malfunction_process_data
 
         msg_data_dict = {
-- 
GitLab