From 71576225a116aa07857308bdf919a039888badeb Mon Sep 17 00:00:00 2001
From: Giacomo Spigler <spiglerg@gmail.com>
Date: Sun, 21 Apr 2019 17:36:19 +0200
Subject: [PATCH] added rail_from_list_of_saved_GridTransitionMap_generator
 rail generator function

---
 examples/sample_10_10_rail.npy  | Bin 0 -> 328 bytes
 examples/temporary_example.py   |   6 ++++++
 flatland/core/transition_map.py |   6 +++++-
 flatland/envs/rail_env.py       |  29 ++++++++++++++++++++++++++++-
 4 files changed, 39 insertions(+), 2 deletions(-)
 create mode 100644 examples/sample_10_10_rail.npy

diff --git a/examples/sample_10_10_rail.npy b/examples/sample_10_10_rail.npy
new file mode 100644
index 0000000000000000000000000000000000000000..a8dc0d41ecfff0c5c3a8b7446b1dd6246573608e
GIT binary patch
literal 328
zcmbVEu?oUK49)2cMY{M0DGsO5!RjUs4&opvA_#SoqLnL%N~<m+Wa#Krp<m?KkMhzw
zID6qG?_S=^YrdK-)&g#DSGzFT#%h3Sc<m_{)&3M;;^4SB#k-L8NpMI)qbE8zhvwh*
zx-NR%j*Acc^EChf%QD2w5)z!;N&5u|08K)APY@FNg>N-&Lj)3_0h`OtOp}vFIVV^%
zCMhIlnlR(-*N2>|Gzq>{9}DT^m?05NN1Uyknxd3*tnsDRn-Qf=ySA+E0g&<1PyYY{
C{5nbi

literal 0
HcmV?d00001

diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 662bfe94..02c282cb 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -23,6 +23,12 @@ env = RailEnv(width=20,
               height=20,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
               number_of_agents=10)
+
+# env = RailEnv(width=20,
+#               height=20,
+#               rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']),
+#               number_of_agents=10)
+
 env.reset()
 
 env_renderer = RenderTool(env)
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 73bb6eef..6e37a1c9 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -282,7 +282,11 @@ class GridTransitionMap(TransitionMap):
             self.grid = new_grid
 
         else:
-            self.grid = self.grid * 0
+            if new_grid.dtype == np.uint16:
+                self.grid = np.zeros((self.height, self.width), dtype=np.uint16)
+            elif new_grid.dtype == np.uint64:
+                self.grid = np.zeros((self.height, self.width), dtype=np.uint64)
+
             self.grid[0:min(self.height, new_height), 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), 0:min(self.width, new_width)]
 
 # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 8083f1fa..06b544e3 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -10,7 +10,7 @@ import numpy as np
 from flatland.core.env import Environment
 from flatland.core.env_observation_builder import TreeObsForRailEnv
 
-from flatland.core.transitions import RailEnvTransitions
+from flatland.core.transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 
 
@@ -75,6 +75,33 @@ def rail_from_GridTransitionMap_generator(rail_map):
     return generator
 
 
+def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
+    """
+    Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
+
+    Parameters
+    -------
+    list_of_filenames : list
+        List of filenames with the saved grids to load.
+
+    Returns
+    -------
+    function
+        Generator function that always returns the given `rail_map' object.
+    """
+    def generator(width, height, num_resets=0):
+        t_utils = RailEnvTransitions()
+        rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
+        rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
+
+        if rail_map.grid.dtype == np.uint64:
+            rail_map.transitions = Grid8Transitions()
+
+        return rail_map
+
+    return generator
+
+
 """
 def generate_rail_from_list_of_manual_specifications(list_of_specifications)
     def generator(width, height, num_resets=0):
-- 
GitLab