Skip to content
Snippets Groups Projects
Commit ff49ce2e authored by hagrid67's avatar hagrid67
Browse files

removing some prints from persistence

parent efd8dbba
No related branches found
No related tags found
No related merge requests found
......@@ -38,6 +38,12 @@ class RailEnvPersister(object):
env_dict = cls.get_full_state(env)
#print(f"env save - agents: {env_dict['agents'][0]}")
#a0 = env_dict["agents"][0]
#print("agent type:", type(a0))
if save_distance_maps is True:
oDistMap = env.distance_map.get()
if oDistMap is not None:
......@@ -49,18 +55,33 @@ class RailEnvPersister(object):
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
with open(filename, "wb") as file_out:
if filename.endswith("mpk"):
file_out.write(msgpack.packb(env_dict))
data = msgpack.packb(env_dict)
elif filename.endswith("pkl"):
pickle.dump(env_dict, file_out)
data = pickle.dumps(env_dict)
#pickle.dump(env_dict, file_out)
file_out.write(data)
#with open(filename, "rb") as file_in:
if filename.endswith("mpk"):
#bytes_in = file_in.read()
dIn = msgpack.unpackb(data, encoding="utf-8")
#print(f"msgpack check - {dIn.keys()}")
#print(f"msgpack check - {dIn['agents'][0]}")
@classmethod
def save_episode(cls, env, filename):
dict_env = cls.get_full_state(env)
lAgents = dict_env["agents"]
print("Saving agents:", len(lAgents))
print("Agent 0:", type(lAgents[0]), lAgents[0])
#print("Saving agents:", len(lAgents))
#print("Agent 0:", type(lAgents[0]), lAgents[0])
dict_env["episode"] = env.cur_episode
dict_env["actions"] = env.list_actions
......@@ -123,6 +144,10 @@ class RailEnvPersister(object):
else:
print(f"filename {filename} must end with either pkl or mpk")
env_dict = {}
if "agents" in env_dict:
env_dict["agents"] = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
#print(f"env_dict agents: {env_dict['agents']}")
return env_dict
......@@ -159,7 +184,15 @@ class RailEnvPersister(object):
# no idea if this still works
env.agents = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
else:
env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
agents_data = env_dict["agents"]
if len(agents_data)>0:
if type(agents_data[0]) is EnvAgent:
env.agents = agents_data
else:
env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
#print(f"env agents: {env.agents}")
# setup with loaded data
env.height, env.width = env.rail.grid.shape
env.rail.height = env.height
......@@ -176,6 +209,7 @@ class RailEnvPersister(object):
# msgpack cannot persist EnvAgent so use the Agent namedtuple.
agent_data = [agent.to_agent() for agent in env.agents]
#print("get_full_state - agent_data:", agent_data)
malfunction_data: MalfunctionProcessData = env.malfunction_process_data
msg_data_dict = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment