From d8253ebb1129ff4dc491ad5d3631c448bdfd8346 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Wed, 10 Jul 2019 17:38:12 -0400
Subject: [PATCH] added pkl_load functionality to the rail_env. It now sets all
 the info from pickle file before running all the prediction and observations.

---
 flatland/envs/observations.py | 21 ++++++++++++---------
 flatland/envs/rail_env.py     | 25 +++++++++++++++++++------
 2 files changed, 31 insertions(+), 15 deletions(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index b1f46ec..8fc70ce 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -172,18 +172,21 @@ class TreeObsForRailEnv(ObservationBuilder):
         if handles is None:
             handles = []
         if self.predictor:
+            self.max_prediction_depth = 0
             self.predicted_pos = {}
             self.predicted_dir = {}
             self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
-            for t in range(len(self.predictions[0])):
-                pos_list = []
-                dir_list = []
-                for a in handles:
-                    pos_list.append(self.predictions[a][t][1:3])
-                    dir_list.append(self.predictions[a][t][3])
-                self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
-                self.predicted_dir.update({t: dir_list})
-            self.max_prediction_depth = len(self.predicted_pos)
+            if self.predictions:
+
+                for t in range(len(self.predictions[0])):
+                    pos_list = []
+                    dir_list = []
+                    for a in handles:
+                        pos_list.append(self.predictions[a][t][1:3])
+                        dir_list.append(self.predictions[a][t][3])
+                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
+                    self.predicted_dir.update({t: dir_list})
+                self.max_prediction_depth = len(self.predicted_pos)
         observations = {}
         for h in handles:
             observations[h] = self.get(h)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 149e32b..029851a 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -80,6 +80,7 @@ class RailEnv(Environment):
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
+                 pkl_load=None
                  ):
         """
         Environment init.
@@ -110,6 +111,7 @@ class RailEnv(Environment):
         obs_builder_object: ObservationBuilder object
             ObservationBuilder-derived object that takes builds observation
             vectors for each agent.
+        pkl_load: you can load a pickle file.
         """
 
         self.rail_generator = rail_generator
@@ -117,12 +119,6 @@ class RailEnv(Environment):
         self.width = width
         self.height = height
 
-        self.obs_builder = obs_builder_object
-        self.obs_builder._set_env(self)
-
-        self.action_space = [1]
-        self.observation_space = self.obs_builder.observation_space  # updated on resets?
-
         self.rewards = [0] * number_of_agents
         self.done = False
 
@@ -135,6 +131,17 @@ class RailEnv(Environment):
         self.agents = [None] * number_of_agents  # live agents
         self.agents_static = [None] * number_of_agents  # static agent information
         self.num_resets = 0
+        if pkl_load:
+            self.loaded_data = pkl_load
+        else:
+            self.loaded_data = None
+
+        self.obs_builder = obs_builder_object
+        self.obs_builder._set_env(self)
+
+        self.action_space = [1]
+        self.observation_space = self.obs_builder.observation_space  # updated on resets?
+
         self.reset()
         self.num_resets = 0  # yes, set it to zero again!
 
@@ -175,6 +182,9 @@ class RailEnv(Environment):
         if replace_agents:
             self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
 
+        if self.loaded_data:
+            self.load_pkl(self.loaded_data)
+
         self.restart_agents()
 
         for i_agent in range(self.get_num_agents()):
@@ -425,6 +435,9 @@ class RailEnv(Environment):
             load_data = file_in.read()
             self.set_full_state_msg(load_data)
 
+    def load_pkl(self, pkl_data):
+        self.set_full_state_msg(pkl_data)
+
     def load_resource(self, package, resource):
         from importlib_resources import read_binary
         load_data = read_binary(package, resource)
-- 
GitLab