From 71576225a116aa07857308bdf919a039888badeb Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Sun, 21 Apr 2019 17:36:19 +0200 Subject: [PATCH] added rail_from_list_of_saved_GridTransitionMap_generator rail generator function --- examples/sample_10_10_rail.npy | Bin 0 -> 328 bytes examples/temporary_example.py | 6 ++++++ flatland/core/transition_map.py | 6 +++++- flatland/envs/rail_env.py | 29 ++++++++++++++++++++++++++++- 4 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 examples/sample_10_10_rail.npy diff --git a/examples/sample_10_10_rail.npy b/examples/sample_10_10_rail.npy new file mode 100644 index 0000000000000000000000000000000000000000..a8dc0d41ecfff0c5c3a8b7446b1dd6246573608e GIT binary patch literal 328 zcmbVEu?oUK49)2cMY{M0DGsO5!RjUs4&opvA_#SoqLnL%N~<m+Wa#Krp<m?KkMhzw zID6qG?_S=^YrdK-)&g#DSGzFT#%h3Sc<m_{)&3M;;^4SB#k-L8NpMI)qbE8zhvwh* zx-NR%j*Acc^EChf%QD2w5)z!;N&5u|08K)APY@FNg>N-&Lj)3_0h`OtOp}vFIVV^% zCMhIlnlR(-*N2>|Gzq>{9}DT^m?05NN1Uyknxd3*tnsDRn-Qf=ySA+E0g&<1PyYY{ C{5nbi literal 0 HcmV?d00001 diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 662bfe94..02c282cb 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -23,6 +23,12 @@ env = RailEnv(width=20, height=20, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=10) + +# env = RailEnv(width=20, +# height=20, +# rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']), +# number_of_agents=10) + env.reset() env_renderer = RenderTool(env) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 73bb6eef..6e37a1c9 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -282,7 +282,11 @@ class GridTransitionMap(TransitionMap): self.grid = new_grid else: - self.grid = self.grid * 0 + if new_grid.dtype == np.uint16: + self.grid = np.zeros((self.height, self.width), dtype=np.uint16) + elif new_grid.dtype == np.uint64: + self.grid = np.zeros((self.height, self.width), dtype=np.uint64) + self.grid[0:min(self.height, new_height), 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), 0:min(self.width, new_width)] # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 8083f1fa..06b544e3 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -10,7 +10,7 @@ import numpy as np from flatland.core.env import Environment from flatland.core.env_observation_builder import TreeObsForRailEnv -from flatland.core.transitions import RailEnvTransitions +from flatland.core.transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap @@ -75,6 +75,33 @@ def rail_from_GridTransitionMap_generator(rail_map): return generator +def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): + """ + Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset. + + Parameters + ------- + list_of_filenames : list + List of filenames with the saved grids to load. + + Returns + ------- + function + Generator function that always returns the given `rail_map' object. + """ + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() + rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) + rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False) + + if rail_map.grid.dtype == np.uint64: + rail_map.transitions = Grid8Transitions() + + return rail_map + + return generator + + """ def generate_rail_from_list_of_manual_specifications(list_of_specifications) def generator(width, height, num_resets=0): -- GitLab