diff --git a/examples/sample_10_10_rail.npy b/examples/sample_10_10_rail.npy new file mode 100644 index 0000000000000000000000000000000000000000..a8dc0d41ecfff0c5c3a8b7446b1dd6246573608e Binary files /dev/null and b/examples/sample_10_10_rail.npy differ diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 662bfe94d0cd72ff685d5546efbc70a3d641c057..02c282cb374914651d063a2b118fb688257e7631 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -23,6 +23,12 @@ env = RailEnv(width=20, height=20, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=10) + +# env = RailEnv(width=20, +# height=20, +# rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']), +# number_of_agents=10) + env.reset() env_renderer = RenderTool(env) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 73bb6eeffd9bf61d418adf67e9d49bec5a12234b..6e37a1c9cc92f31ec561ba129639d4e6e9425f6e 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -282,7 +282,11 @@ class GridTransitionMap(TransitionMap): self.grid = new_grid else: - self.grid = self.grid * 0 + if new_grid.dtype == np.uint16: + self.grid = np.zeros((self.height, self.width), dtype=np.uint16) + elif new_grid.dtype == np.uint64: + self.grid = np.zeros((self.height, self.width), dtype=np.uint64) + self.grid[0:min(self.height, new_height), 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), 0:min(self.width, new_width)] # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 8083f1fa1459001cf577a3bd1fb5b0aec6a57e2d..06b544e307db017b874cfd61a20fe72d15ca8b80 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -10,7 +10,7 @@ import numpy as np from flatland.core.env import Environment from flatland.core.env_observation_builder import TreeObsForRailEnv -from flatland.core.transitions import RailEnvTransitions +from flatland.core.transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap @@ -75,6 +75,33 @@ def rail_from_GridTransitionMap_generator(rail_map): return generator +def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): + """ + Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset. + + Parameters + ------- + list_of_filenames : list + List of filenames with the saved grids to load. + + Returns + ------- + function + Generator function that always returns the given `rail_map' object. + """ + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() + rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) + rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False) + + if rail_map.grid.dtype == np.uint64: + rail_map.transitions = Grid8Transitions() + + return rail_map + + return generator + + """ def generate_rail_from_list_of_manual_specifications(list_of_specifications) def generator(width, height, num_resets=0):