persistence.py 11.4 KB
Newer Older
1
2
3
4


import pickle
import msgpack
5
import msgpack_numpy
6
7
8
9
10
11
12
13
14
15
import numpy as np

from flatland.envs import rail_env 

#from flatland.core.env import Environment
from flatland.core.env_observation_builder import DummyObservationBuilder
#from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
#from flatland.core.grid.grid4_utils import get_new_position
#from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
16
from flatland.envs.agent_utils import Agent, EnvAgent
17
18
19
20
21
22
23
from flatland.envs.distance_map import DistanceMap

#from flatland.envs.observations import GlobalObsForRailEnv

# cannot import objects / classes directly because of circular import
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
Dipam Chakraborty's avatar
Dipam Chakraborty committed
24
from flatland.envs import line_generators as line_gen
25

26
msgpack_numpy.patch()
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

class RailEnvPersister(object):

    @classmethod
    def save(cls, env, filename, save_distance_maps=False):
        """
        Saves environment and distance map information in a file

        Parameters:
        ---------
        filename: string
        save_distance_maps: bool
        """

        env_dict = cls.get_full_state(env)

43
44
45
46
47
        # We have an unresolved problem with msgpack loading the list of agents
        # see also 20 lines below.
        # print(f"env save - agents: {env_dict['agents'][0]}")
        # a0 = env_dict["agents"][0]
        # print("agent type:", type(a0))
48
49
50



51
52
53
54
55
56
57
58
59
60
61
        if save_distance_maps is True:
            oDistMap = env.distance_map.get() 
            if oDistMap is not None:
                if len(oDistMap) > 0:
                    env_dict["distance_map"] = oDistMap
                else:
                    print("[WARNING] Unable to save the distance map for this environment, as none was found !")
            else:
                print("[WARNING] Unable to save the distance map for this environment, as none was found !")

        with open(filename, "wb") as file_out:
62

63
            if filename.endswith("mpk"):
64
65
66
                data = msgpack.packb(env_dict)
                
                
67
            elif filename.endswith("pkl"):
68
69
70
71
72
                data = pickle.dumps(env_dict)
                #pickle.dump(env_dict, file_out)

            file_out.write(data)

73
74
75
76
77
78
79
        # We have an unresovled problem with msgpack loading the list of Agents 
        # with open(filename, "rb") as file_in:
        # if filename.endswith("mpk"):
            # bytes_in = file_in.read()
            # dIn = msgpack.unpackb(data, encoding="utf-8")
            # print(f"msgpack check - {dIn.keys()}")
            # print(f"msgpack check - {dIn['agents'][0]}")
80
81

                
82
83
84
85
86

    @classmethod
    def save_episode(cls, env, filename):
        dict_env = cls.get_full_state(env)

87
        # Add additional info to dict_env before saving
88
        dict_env["episode"] = env.cur_episode
89
        dict_env["actions"] = env.list_actions
90
        dict_env["shape"] = (env.width, env.height)
91
        dict_env["max_episode_steps"] = env._max_episode_steps
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

        with open(filename, "wb") as file_out:
            if filename.endswith(".mpk"):
                file_out.write(msgpack.packb(dict_env))
            elif filename.endswith(".pkl"):
                pickle.dump(dict_env, file_out)

    @classmethod
    def load(cls, env, filename, load_from_package=None):
        """
        Load environment with distance map from a file

        Parameters:
        -------
        filename: string
        """
        env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
        cls.set_full_state(env, env_dict)

    @classmethod
    def load_new(cls, filename, load_from_package=None):

        env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)

116
117
118
        llGrid = env_dict["grid"]
        height = len(llGrid)
        width = len(llGrid[0])
119
120

        # TODO: inefficient - each one of these generators loads the complete env file.
121
122
        env = rail_env.RailEnv(#width=1, height=1,
                width=width, height=height,
123
124
                rail_generator=rail_gen.rail_from_file(filename, 
                    load_from_package=load_from_package),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
125
                    line_generator=line_gen.line_from_file(filename,
126
                    load_from_package=load_from_package),
127
128
129
                #malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename,
                #    load_from_package=load_from_package),
                malfunction_generator=mal_gen.FileMalfunctionGen(env_dict),
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
                obs_builder_object=DummyObservationBuilder(),
                record_steps=True)

        env.rail = GridTransitionMap(1,1) # dummy        

        cls.set_full_state(env, env_dict)
        return env, env_dict

    @classmethod
    def load_env_dict(cls, filename, load_from_package=None):

        if load_from_package is not None:
            from importlib_resources import read_binary
            load_data = read_binary(load_from_package, filename)
        else:
            with open(filename, "rb") as file_in:
                load_data = file_in.read()

        if filename.endswith("mpk"):
            env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8")
        elif filename.endswith("pkl"):
151
152
153
154
155
            try:
                env_dict = pickle.loads(load_data)
            except ValueError:
                print("pickle failed to load file:", filename, " trying msgpack (deprecated)...")
                env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8")
156
157
158
        else:
            print(f"filename {filename} must end with either pkl or mpk")
            env_dict = {}
159
        
160
161
162
163
164
165
        # Replace the agents tuple with EnvAgent objects
        if "agents_static" in env_dict:
            env_dict["agents"] = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
            # remove the legacy key
            del env_dict["agents_static"]
        elif "agents" in env_dict:
166
            env_dict["agents"] = [EnvAgent(*d[0:len(d)]) for d in env_dict["agents"]]
167
168
169
170
171
172
173
174

        return env_dict

    @classmethod
    def load_resource(cls, package, resource):
        """
        Load environment (with distance map?) from a binary
        """
175
176
        #from importlib_resources import read_binary
        #load_data = read_binary(package, resource)
177

178
179
180
181
        #if resource.endswith("pkl"):
        #    env_dict = pickle.loads(load_data)
        #elif resource.endswith("mpk"):
        #    env_dict = msgpack.unpackb(load_data, encoding="utf-8")
182
        
183
184
185
        #cls.set_full_state(env, env_dict)

        return cls.load_new(resource, load_from_package=package)
186
187
188
189
190
191
192
193
194
195
196

    @classmethod
    def set_full_state(cls, env, env_dict):
        """
        Sets environment state from env_dict 

        Parameters
        -------
        env_dict: dict
        """
        env.rail.grid = np.array(env_dict["grid"])
197

198
199
200
201
202
        # Initialise the env with the frozen agents in the file
        env.agents = env_dict.get("agents", [])

        # For consistency, set number_of_agents, which is the number which will be generated on reset
        env.number_of_agents = env.get_num_agents()
203

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        env.height, env.width = env.rail.grid.shape
        env.rail.height = env.height
        env.rail.width = env.width
        env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False)

    @classmethod
    def get_full_state(cls, env):
        """
        Returns state of environment in dict object, ready for serialization

        """
        grid_data = env.rail.grid.tolist()

        # msgpack cannot persist EnvAgent so use the Agent namedtuple.
        agent_data = [agent.to_agent() for agent in env.agents]
219
        #print("get_full_state - agent_data:", agent_data)
220
        malfunction_data: mal_gen.MalfunctionProcessData = env.malfunction_process_data
221
222
223
224

        msg_data_dict = {
            "grid": grid_data,
            "agents": agent_data,
225
226
227
            "malfunction": malfunction_data,
            "max_episode_steps": env._max_episode_steps,
            }
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        return msg_data_dict


################################################################################################
# deprecated methods moved from RailEnv.  Most likely broken.

    def deprecated_get_full_state_msg(self) -> msgpack.Packer:
        """
        Returns state of environment in msgpack object
        """
        msg_data_dict = self.get_full_state_dict()
        return msgpack.packb(msg_data_dict, use_bin_type=True)

    def deprecated_get_agent_state_msg(self) -> msgpack.Packer:
        """
        Returns agents information in msgpack object
        """
        agent_data = [agent.to_agent() for agent in self.agents]
        msg_data = {
            "agents": agent_data}
        return msgpack.packb(msg_data, use_bin_type=True)

    def deprecated_get_full_state_dist_msg(self) -> msgpack.Packer:
        """
        Returns environment information with distance map information as msgpack object
        """
        grid_data = self.rail.grid.tolist()
        agent_data = [agent.to_agent() for agent in self.agents]

        # I think these calls do nothing - they create packed data and it is discarded
        #msgpack.packb(grid_data, use_bin_type=True)
        #msgpack.packb(agent_data, use_bin_type=True)

        distance_map_data = self.distance_map.get()
        malfunction_data: MalfunctionProcessData = self.malfunction_process_data
        #msgpack.packb(distance_map_data, use_bin_type=True)  # does nothing
        msg_data = {
            "grid": grid_data,
            "agents": agent_data,
            "distance_map": distance_map_data,
            "malfunction": malfunction_data}
        return msgpack.packb(msg_data, use_bin_type=True)

    def deprecated_set_full_state_msg(self, msg_data):
        """
        Sets environment state with msgdata object passed as argument

        Parameters
        -------
        msg_data: msgpack object
        """
        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
        self.rail.grid = np.array(data["grid"])
        # agents are always reset as not moving
        if "agents_static" in data:
            self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
        else:
            self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
        self.rail.height = self.height
        self.rail.width = self.width
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

    def deprecated_set_full_state_dist_msg(self, msg_data):
        """
        Sets environment grid state and distance map with msgdata object passed as argument

        Parameters
        -------
        msg_data: msgpack object
        """
        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
        self.rail.grid = np.array(data["grid"])
        # agents are always reset as not moving
        if "agents_static" in data:
            self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
        else:
            self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
        if "distance_map" in data.keys():
            self.distance_map.set(data["distance_map"])
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
        self.rail.height = self.height
        self.rail.width = self.width
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)