From 4d538a4155be094ec868282aa520dec1992f5fb3 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sat, 13 Jul 2019 12:08:18 -0400
Subject: [PATCH] updated load and save function. Now also distance maps are
 stored. Additional package msgpack-numpy needed for ndarray. This saves tons
 of time when loading precomputed files.

---
 flatland/envs/generators.py   |  7 ++++--
 flatland/envs/observations.py | 10 ++++++--
 flatland/envs/rail_env.py     | 46 +++++++++++++++++++++++++++--------
 requirements_dev.txt          |  3 ++-
 4 files changed, 51 insertions(+), 15 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index ec579c1d..9b33a55c 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -227,8 +227,11 @@ def rail_from_file(filename):
         agents_position = [a.position for a in agents_static]
         agents_direction = [a.direction for a in agents_static]
         agents_target = [a.target for a in agents_static]
-        return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
-
+        if len(data) > 3:
+            distance_maps = data[b"distance_maps"]
+            return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps
+        else:
+            return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
     return generator
 
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 3929b9e1..4385d4da 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -41,22 +41,28 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.agents_previous_reset = None
         self.tree_explored_actions = [1, 2, 3, 0]
         self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
+        self.distance_map = None
 
     def reset(self):
         agents = self.env.agents
         nb_agents = len(agents)
-
         compute_distance_map = True
         if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
             compute_distance_map = False
             for i in range(nb_agents):
                 if agents[i].target != self.agents_previous_reset[i].target:
                     compute_distance_map = True
-        self.agents_previous_reset = agents
+
+        # Don't compute the distance map if it was loaded
+        if self.agents_previous_reset is None and self.distance_map is not None:
+            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
+            compute_distance_map = False
 
         if compute_distance_map:
             self._compute_distance_map()
 
+        self.agents_previous_reset = agents
+
     def _compute_distance_map(self):
         agents = self.env.agents
         nb_agents = len(agents)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 43351099..f082f080 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -6,6 +6,7 @@ Definition of the RailEnv environment.
 from enum import IntEnum
 
 import msgpack
+import msgpack_numpy as m
 import numpy as np
 
 from flatland.core.env import Environment
@@ -14,6 +15,8 @@ from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
 from flatland.envs.generators import random_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 
+m.patch()
+
 
 class RailEnvActions(IntEnum):
     DO_NOTHING = 0  # implies change of direction in a dead-end!
@@ -170,6 +173,10 @@ class RailEnv(Environment):
         """
         tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
+        # Check if generator provided a distance map TODO: Make this check safer!
+        if len(tRailAgents) > 5:
+            self.obs_builder.distance_map = tRailAgents[-1]
+
         if regen_rail or self.rail is None:
             self.rail = tRailAgents[0]
             self.height, self.width = self.rail.grid.shape
@@ -424,6 +431,8 @@ class RailEnv(Environment):
         # agents are always reset as not moving
         self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
         self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
+        if hasattr(self.obs_builder, 'distance_map'):
+            self.obs_builder.distance_map = data[b"distance_maps"]
         # setup with loaded data
         self.height, self.width = self.rail.grid.shape
         self.rail.height = self.height
@@ -438,22 +447,39 @@ class RailEnv(Environment):
         msgpack.packb(grid_data)
         msgpack.packb(agent_data)
         msgpack.packb(agent_static_data)
+        if hasattr(self.obs_builder, 'distance_map'):
+            distance_map_data = self.obs_builder.distance_map
+            msgpack.packb(distance_map_data)
+            msg_data = {
+                "grid": grid_data,
+                "agents_static": agent_static_data,
+                "agents": agent_data,
+                "distance_maps": distance_map_data}
+        else:
+            msg_data = {
+                "grid": grid_data,
+                "agents_static": agent_static_data,
+                "agents": agent_data}
 
-        msg_data = {
-            "grid": grid_data,
-            "agents_static": agent_static_data,
-            "agents": agent_data}
         return msgpack.packb(msg_data, use_bin_type=True)
 
-
     def save(self, filename):
-        with open(filename, "wb") as file_out:
-            file_out.write(self.get_full_state_msg())
+        if hasattr(self.obs_builder, 'distance_map'):
+            with open(filename, "wb") as file_out:
+                file_out.write(self.get_full_state_dist_msg())
+        else:
+            with open(filename, "wb") as file_out:
+                file_out.write(self.get_full_state_msg())
 
     def load(self, filename):
-        with open(filename, "rb") as file_in:
-            load_data = file_in.read()
-            self.set_full_state_msg(load_data)
+        if hasattr(self.obs_builder, 'distance_map'):
+            with open(filename, "rb") as file_in:
+                load_data = file_in.read()
+                self.set_full_state_dist_msg(load_data)
+        else:
+            with open(filename, "rb") as file_in:
+                load_data = file_in.read()
+                self.set_full_state_msg(load_data)
 
     def load_pkl(self, pkl_data):
         self.set_full_state_msg(pkl_data)
diff --git a/requirements_dev.txt b/requirements_dev.txt
index ea46eb24..edd6ee28 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -3,13 +3,14 @@ tox>=3.5.2
 twine>=1.12.1
 pytest>=3.8.2
 pytest-runner>=4.2
-numpy>=1.16.4
+numpy>=1.16.2
 recordtype>=1.3
 xarray>=0.11.3
 matplotlib>=3.0.2
 Pillow>=5.4.1
 CairoSVG>=2.3.1
 msgpack>=0.6.1
+msgpack-numpy>=0.4.4.0
 svgutils>=0.3.1
 screeninfo>=0.3.1
 pyarrow>=0.13.0
-- 
GitLab