From d41ef256998a48a585faa035daa6ec777e6204cd Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 11 Jul 2019 13:55:51 -0400 Subject: [PATCH] initial commit of new generator function --- flatland/envs/generators.py | 39 ++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index c3f569a..ee75df2 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,10 +1,10 @@ import numpy as np -from flatland.core.transition_map import GridTransitionMap +from flatland.core.grid.grid4_utils import get_direction, mirror +from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap from flatland.envs.grid4_generators_utils import connect_rail -from flatland.core.grid.grid_utils import distance_on_rail -from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail @@ -195,6 +195,39 @@ def rail_from_manual_specifications_generator(rail_spec): return generator +def rail_from_data(input_data): + """ + Utility to load pickle file + + Parameters + ------- + input_file : Pickle file generated by env.save() or editor + + Returns + ------- + function + Generator function that always returns a GridTransitionMap object with + the matrix of correct 16-bit bitmaps for each rail_spec_of_cell. + """ + + def generator(): + data = msgpack.unpackb(msg_data, use_list=False) + self.rail.grid = np.array(input_data[b"grid"]) + rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions) + # agents are always reset as not moving + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + # setup with loaded data + self.height, self.width = self.rail.grid.shape + self.rail.height = self.height + self.rail.width = self.width + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + + return generator + + def rail_from_GridTransitionMap_generator(rail_map): """ Utility to convert a rail given by a GridTransitionMap map with the correct -- GitLab