Skip to content
Snippets Groups Projects
Commit 407a6a39 authored by hagrid67's avatar hagrid67
Browse files

moved the agents_static legacy logic, and the tuple to EnvAgent conversion, into the load_env_dict

parent f151b4c7
No related branches found
No related tags found
No related merge requests found
......@@ -38,9 +38,11 @@ class RailEnvPersister(object):
env_dict = cls.get_full_state(env)
#print(f"env save - agents: {env_dict['agents'][0]}")
#a0 = env_dict["agents"][0]
#print("agent type:", type(a0))
# We have an unresolved problem with msgpack loading the list of agents
# see also 20 lines below.
# print(f"env save - agents: {env_dict['agents'][0]}")
# a0 = env_dict["agents"][0]
# print("agent type:", type(a0))
......@@ -66,23 +68,21 @@ class RailEnvPersister(object):
file_out.write(data)
#with open(filename, "rb") as file_in:
if filename.endswith("mpk"):
#bytes_in = file_in.read()
dIn = msgpack.unpackb(data, encoding="utf-8")
#print(f"msgpack check - {dIn.keys()}")
#print(f"msgpack check - {dIn['agents'][0]}")
# We have an unresovled problem with msgpack loading the list of Agents
# with open(filename, "rb") as file_in:
# if filename.endswith("mpk"):
# bytes_in = file_in.read()
# dIn = msgpack.unpackb(data, encoding="utf-8")
# print(f"msgpack check - {dIn.keys()}")
# print(f"msgpack check - {dIn['agents'][0]}")
@classmethod
def save_episode(cls, env, filename):
dict_env = cls.get_full_state(env)
lAgents = dict_env["agents"]
#print("Saving agents:", len(lAgents))
#print("Agent 0:", type(lAgents[0]), lAgents[0])
# Add additional info to dict_env before saving
dict_env["episode"] = env.cur_episode
dict_env["actions"] = env.list_actions
dict_env["shape"] = (env.width, env.height)
......@@ -94,7 +94,6 @@ class RailEnvPersister(object):
elif filename.endswith(".pkl"):
pickle.dump(dict_env, file_out)
@classmethod
def load(cls, env, filename, load_from_package=None):
"""
......@@ -127,7 +126,6 @@ class RailEnvPersister(object):
env.rail = GridTransitionMap(1,1) # dummy
cls.set_full_state(env, env_dict)
return env, env_dict
@classmethod
......@@ -148,9 +146,13 @@ class RailEnvPersister(object):
print(f"filename {filename} must end with either pkl or mpk")
env_dict = {}
if "agents" in env_dict:
# Replace the agents tuple with EnvAgent objects
if "agents_static" in env_dict:
env_dict["agents"] = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
# remove the legacy key
del env_dict["agents_static"]
elif "agents" in env_dict:
env_dict["agents"] = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
#print(f"env_dict agents: {env_dict['agents']}")
return env_dict
......@@ -181,22 +183,13 @@ class RailEnvPersister(object):
env_dict: dict
"""
env.rail.grid = np.array(env_dict["grid"])
# agents are always reset as not moving
if "agents_static" in env_dict:
# no idea if this still works
env.agents = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
else:
agents_data = env_dict["agents"]
if len(agents_data)>0:
if type(agents_data[0]) is EnvAgent:
env.agents = agents_data
else:
env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]]
#print(f"env agents: {env.agents}")
# Initialise the env with the frozen agents in the file
env.agents = env_dict.get("agents", [])
# For consistency, set number_of_agents, which is the number which will be generated on reset
env.number_of_agents = env.get_num_agents()
# setup with loaded data
env.height, env.width = env.rail.grid.shape
env.rail.height = env.height
env.rail.width = env.width
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment