diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py index 31339183f851155d446238bb50a051cba2e9ec36..3952b7944449cb37a5135a54794f81d9a2810104 100644 --- a/flatland/envs/persistence.py +++ b/flatland/envs/persistence.py @@ -38,9 +38,11 @@ class RailEnvPersister(object): env_dict = cls.get_full_state(env) - #print(f"env save - agents: {env_dict['agents'][0]}") - #a0 = env_dict["agents"][0] - #print("agent type:", type(a0)) + # We have an unresolved problem with msgpack loading the list of agents + # see also 20 lines below. + # print(f"env save - agents: {env_dict['agents'][0]}") + # a0 = env_dict["agents"][0] + # print("agent type:", type(a0)) @@ -66,23 +68,21 @@ class RailEnvPersister(object): file_out.write(data) - #with open(filename, "rb") as file_in: - if filename.endswith("mpk"): - #bytes_in = file_in.read() - dIn = msgpack.unpackb(data, encoding="utf-8") - #print(f"msgpack check - {dIn.keys()}") - #print(f"msgpack check - {dIn['agents'][0]}") + # We have an unresovled problem with msgpack loading the list of Agents + # with open(filename, "rb") as file_in: + # if filename.endswith("mpk"): + # bytes_in = file_in.read() + # dIn = msgpack.unpackb(data, encoding="utf-8") + # print(f"msgpack check - {dIn.keys()}") + # print(f"msgpack check - {dIn['agents'][0]}") @classmethod def save_episode(cls, env, filename): dict_env = cls.get_full_state(env) - - lAgents = dict_env["agents"] - #print("Saving agents:", len(lAgents)) - #print("Agent 0:", type(lAgents[0]), lAgents[0]) + # Add additional info to dict_env before saving dict_env["episode"] = env.cur_episode dict_env["actions"] = env.list_actions dict_env["shape"] = (env.width, env.height) @@ -94,7 +94,6 @@ class RailEnvPersister(object): elif filename.endswith(".pkl"): pickle.dump(dict_env, file_out) - @classmethod def load(cls, env, filename, load_from_package=None): """ @@ -127,7 +126,6 @@ class RailEnvPersister(object): env.rail = GridTransitionMap(1,1) # dummy cls.set_full_state(env, env_dict) - return env, env_dict @classmethod @@ -148,9 +146,13 @@ class RailEnvPersister(object): print(f"filename {filename} must end with either pkl or mpk") env_dict = {} - if "agents" in env_dict: + # Replace the agents tuple with EnvAgent objects + if "agents_static" in env_dict: + env_dict["agents"] = EnvAgent.load_legacy_static_agent(env_dict["agents_static"]) + # remove the legacy key + del env_dict["agents_static"] + elif "agents" in env_dict: env_dict["agents"] = [EnvAgent(*d[0:12]) for d in env_dict["agents"]] - #print(f"env_dict agents: {env_dict['agents']}") return env_dict @@ -181,22 +183,13 @@ class RailEnvPersister(object): env_dict: dict """ env.rail.grid = np.array(env_dict["grid"]) - - # agents are always reset as not moving - if "agents_static" in env_dict: - # no idea if this still works - env.agents = EnvAgent.load_legacy_static_agent(env_dict["agents_static"]) - else: - agents_data = env_dict["agents"] - if len(agents_data)>0: - if type(agents_data[0]) is EnvAgent: - env.agents = agents_data - else: - env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]] - #print(f"env agents: {env.agents}") + # Initialise the env with the frozen agents in the file + env.agents = env_dict.get("agents", []) + + # For consistency, set number_of_agents, which is the number which will be generated on reset + env.number_of_agents = env.get_num_agents() - # setup with loaded data env.height, env.width = env.rail.grid.shape env.rail.height = env.height env.rail.width = env.width