Skip to content
Snippets Groups Projects
Commit ff49ce2e authored by hagrid67's avatar hagrid67
Browse files

removing some prints from persistence

parent efd8dbba
No related branches found
No related tags found
1 merge request!288merging manually from branch 223_UpdateEditor_55_Notebooks
Pipeline #4383 failed
......@@ -38,6 +38,12 @@ class RailEnvPersister(object):
env_dict = cls.get_full_state(env)
#print(f"env save - agents: {env_dict['agents'][0]}")
#a0 = env_dict["agents"][0]
#print("agent type:", type(a0))
if save_distance_maps is True:
oDistMap = env.distance_map.get()
if oDistMap is not None:
......@@ -49,18 +55,33 @@ class RailEnvPersister(object):
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
with open(filename, "wb") as file_out:
if filename.endswith("mpk"):
file_out.write(msgpack.packb(env_dict))
data = msgpack.packb(env_dict)
elif filename.endswith("pkl"):
pickle.dump(env_dict, file_out)
data = pickle.dumps(env_dict)
#pickle.dump(env_dict, file_out)
file_out.write(data)
#with open(filename, "rb") as file_in:
if filename.endswith("mpk"):
#bytes_in = file_in.read()
dIn = msgpack.unpackb(data, encoding="utf-8")
#print(f"msgpack check - {dIn.keys()}")
#print(f"msgpack check - {dIn['agents'][0]}")
@classmethod
def save_episode(cls, env, filename):
dict_env = cls.get_full_state(env)
lAgents = dict_env["agents"]
print("Saving agents:", len(lAgents))
print("Agent 0:", type(lAgents[0]), lAgents[0])
#print("Saving agents:", len(lAgents))
#print("Agent 0:", type(lAgents[0]), lAgents[0])
dict_env["episode"] = env.cur_episode
dict_env["actions"] = env.list_actions
......@@ -123,6 +144,10 @@ class RailEnvPersister(object):
else:
print(f"filename {filename} must end with either pkl or mpk")
env_dict = {}
if "agents" in env_dict:
env_dict["agents"] = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
#print(f"env_dict agents: {env_dict['agents']}")
return env_dict
......@@ -159,7 +184,15 @@ class RailEnvPersister(object):
# no idea if this still works
env.agents = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
else:
env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
agents_data = env_dict["agents"]
if len(agents_data)>0:
if type(agents_data[0]) is EnvAgent:
env.agents = agents_data
else:
env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
#print(f"env agents: {env.agents}")
# setup with loaded data
env.height, env.width = env.rail.grid.shape
env.rail.height = env.height
......@@ -176,6 +209,7 @@ class RailEnvPersister(object):
# msgpack cannot persist EnvAgent so use the Agent namedtuple.
agent_data = [agent.to_agent() for agent in env.agents]
#print("get_full_state - agent_data:", agent_data)
malfunction_data: MalfunctionProcessData = env.malfunction_process_data
msg_data_dict = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment