diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index c3f569ae0d927a6e71803a24d921337c65d39c29..ee75df2834898042f12df6065c220ea51358c0bb 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,10 +1,10 @@ import numpy as np -from flatland.core.transition_map import GridTransitionMap +from flatland.core.grid.grid4_utils import get_direction, mirror +from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap from flatland.envs.grid4_generators_utils import connect_rail -from flatland.core.grid.grid_utils import distance_on_rail -from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail @@ -195,6 +195,39 @@ def rail_from_manual_specifications_generator(rail_spec): return generator +def rail_from_data(input_data): + """ + Utility to load pickle file + + Parameters + ------- + input_file : Pickle file generated by env.save() or editor + + Returns + ------- + function + Generator function that always returns a GridTransitionMap object with + the matrix of correct 16-bit bitmaps for each rail_spec_of_cell. + """ + + def generator(): + data = msgpack.unpackb(msg_data, use_list=False) + self.rail.grid = np.array(input_data[b"grid"]) + rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions) + # agents are always reset as not moving + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + # setup with loaded data + self.height, self.width = self.rail.grid.shape + self.rail.height = self.height + self.rail.width = self.width + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + + return generator + + def rail_from_GridTransitionMap_generator(rail_map): """ Utility to convert a rail given by a GridTransitionMap map with the correct