From 9a6efb0583d4952739805f708864735d84fdc17c Mon Sep 17 00:00:00 2001
From: Giacomo Spigler <spiglerg@gmail.com>
Date: Sun, 21 Apr 2019 17:23:16 +0200
Subject: [PATCH] added save/load gridmap for GridTransitionMap

---
 examples/temporary_example.py            | 32 ++++++------------
 flatland/core/env_observation_builder.py | 16 ++++-----
 flatland/core/transition_map.py          | 43 +++++++++++++++++++++++-
 3 files changed, 61 insertions(+), 30 deletions(-)

diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 1d13d2ce..662bfe94 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -9,24 +9,14 @@ from flatland.utils.rendertools import *
 random.seed(0)
 np.random.seed(0)
 
-"""
-transition_probability = [1.0,  # empty cell - Case 0
-                          3.0,  # Case 1 - straight
-                          1.0,  # Case 2 - simple switch
-                          3.0,  # Case 3 - diamond drossing
-                          2.0,  # Case 4 - single slip
-                          1.0,  # Case 5 - double slip
-                          1.0,  # Case 6 - symmetrical
-                          1.0]  # Case 7 - dead end
-"""
 transition_probability = [1.0,  # empty cell - Case 0
                           1.0,  # Case 1 - straight
-                          0.5,  # Case 2 - simple switch
-                          0.2,  # Case 3 - diamond drossing
+                          1.0,  # Case 2 - simple switch
+                          0.3,  # Case 3 - diamond drossing
                           0.5,  # Case 4 - single slip
-                          0.1,  # Case 5 - double slip
+                          0.5,  # Case 5 - double slip
                           0.2,  # Case 6 - symmetrical
-                          0.01]  # Case 7 - dead end
+                          0.0]  # Case 7 - dead end
 
 # Example generate a random rail
 env = RailEnv(width=20,
@@ -38,12 +28,12 @@ env.reset()
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
 
+"""
 # Example generate a rail given a manual specification,
 # a map of tuples (cell_type, rotation)
 specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
          [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
 
-"""
 env = RailEnv(width=6,
               height=2,
               rail_generator=rail_from_manual_specifications_generator(specs),
@@ -56,20 +46,20 @@ env.agents_target[0] = [1, 1]
 env.agents_direction[0] = 1
 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
 env.obs_builder.reset()
-#"""
-
+"""
 env = RailEnv(width=7,
               height=7,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
               number_of_agents=2)
 
-# TODO: delete next line
-#for i in range(4):
-#    print(env.obs_builder.distance_map[0, :, :, i])
+# Print the distance map of each cell to the target of the first agent
+# for i in range(4):
+#     print(env.obs_builder.distance_map[0, :, :, i])
 
+# Print the observation vector for agent 0
 obs, all_rewards, done, _ = env.step({0:0})
 for i in range(env.number_of_agents):
-    env.obs_builder.util_print_obs_subtree(tree=obs[i], num_elements_per_node=5)
+    env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index a0def07e..d7bee930 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -383,15 +383,15 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         return observation
 
-    def util_print_obs_subtree(self, tree, num_elements_per_node=5, prompt='', current_depth=0):
+    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
         """
         Utility function to pretty-print tree observations returned by this object.
         """
-        if len(tree) < num_elements_per_node:
+        if len(tree) < num_features_per_node:
             return
 
         depth = 0
-        tmp = len(tree)/num_elements_per_node-1
+        tmp = len(tree)/num_features_per_node-1
         pow4 = 4
         while tmp > 0:
             tmp -= pow4
@@ -400,12 +400,12 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         prompt_ = ['L:', 'F:', 'R:', 'B:']
 
-        print("  "*current_depth + prompt, tree[0:num_elements_per_node])
-        child_size = (len(tree)-num_elements_per_node)//4
+        print("  "*current_depth + prompt, tree[0:num_features_per_node])
+        child_size = (len(tree)-num_features_per_node)//4
         for children in range(4):
-            child_tree = tree[(num_elements_per_node+children*child_size):
-                              (num_elements_per_node+(children+1)*child_size)]
+            child_tree = tree[(num_features_per_node+children*child_size):
+                              (num_features_per_node+(children+1)*child_size)]
             self.util_print_obs_subtree(child_tree,
-                                        num_elements_per_node,
+                                        num_features_per_node,
                                         prompt=prompt_[children],
                                         current_depth=current_depth+1)
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index d3fcf5c8..73bb6eef 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -118,7 +118,7 @@ class GridTransitionMap(TransitionMap):
             Width of the grid.
         height : int
             Height of the grid.
-        transitions_class : Transitions object
+        transitions : Transitions object
             The Transitions object to use to encode/decode transitions over the
             grid.
 
@@ -243,6 +243,47 @@ class GridTransitionMap(TransitionMap):
             return
         self.transitions.set_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index, new_transition)
 
+    def save_transition_map(self, filename):
+        """
+        Save the transitions grid as `filename', in npy format.
+
+        Parameters
+        ----------
+        filename : string
+            Name of the file to which to save the transitions grid.
+
+        """
+        np.save(filename, self.grid)
+
+    def load_transition_map(self, filename, override_gridsize=True):
+        """
+        Load the transitions grid from `filename' (npy format).
+        The load function only updates the transitions grid, and possibly width and height, but the object has to be
+        initialized with the correct `transitions' object anyway.
+
+        Parameters
+        ----------
+        filename : string
+            Name of the file from which to load the transitions grid.
+        override_gridsize : bool
+            If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
+            of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if
+            the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than (height,width) )
+
+        """
+        new_grid = np.load(filename)
+
+        new_height = new_grid.shape[0]
+        new_width = new_grid.shape[1]
+
+        if override_gridsize:
+            self.width = new_width
+            self.height = new_height
+            self.grid = new_grid
+
+        else:
+            self.grid = self.grid * 0
+            self.grid[0:min(self.height, new_height), 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), 0:min(self.width, new_width)]
 
 # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
 # (most general implementation) or to make Grid-class specific methods for
-- 
GitLab