From e1713c6f63ceaac2c31a0894bc8fd284b756a65f Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 11 Jul 2019 15:37:10 -0400
Subject: [PATCH] refactored how we import envs by moving it into a generator

---
 flatland/envs/generators.py   | 27 +++++++++++++++------------
 flatland/envs/observations.py |  2 +-
 flatland/envs/rail_env.py     |  9 +--------
 tests/test_file_load.py       |  5 ++---
 4 files changed, 19 insertions(+), 24 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index ee75df2..ff21046 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1,9 +1,11 @@
+import msgpack
 import numpy as np
 
 from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import distance_on_rail
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import EnvAgentStatic
 from flatland.envs.grid4_generators_utils import connect_rail
 from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
 
@@ -195,7 +197,7 @@ def rail_from_manual_specifications_generator(rail_spec):
     return generator
 
 
-def rail_from_data(input_data):
+def rail_from_data(filename):
     """
     Utility to load pickle file
 
@@ -210,19 +212,20 @@ def rail_from_data(input_data):
         the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
     """
 
-    def generator():
-        data = msgpack.unpackb(msg_data, use_list=False)
-        self.rail.grid = np.array(input_data[b"grid"])
-        rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions)
+    def generator(width, height, num_agents, num_resets):
+        rail_env_transitions = RailEnvTransitions()
+        with open(filename, "rb") as file_in:
+            load_data = file_in.read()
+        data = msgpack.unpackb(load_data, use_list=False)
+        grid = np.array(data[b"grid"])
+        rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
+        rail.grid = grid
         # agents are always reset as not moving
-        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
-        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
+        agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
         # setup with loaded data
-        self.height, self.width = self.rail.grid.shape
-        self.rail.height = self.height
-        self.rail.width = self.width
-        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
-
+        agents_position = [a.position for a in agents_static]
+        agents_direction = [a.direction for a in agents_static]
+        agents_target = [a.target for a in agents_static]
         return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
     return generator
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 8fc70ce..b7d1ff2 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -60,6 +60,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     def _compute_distance_map(self):
         agents = self.env.agents
         nb_agents = len(agents)
+        print(nb_agents)
         self.distance_map = np.inf * np.ones(shape=(nb_agents,
                                                     self.env.height,
                                                     self.env.width,
@@ -75,7 +76,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         orientation within it) to each agent's target cell.
         """
         # Returns max distance to target, from the farthest away node, while filling in distance_map
-
         self.distance_map[target_nr, position[0], position[1], :] = 0
 
         # Fill in the (up to) 4 neighboring nodes
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 03879eb..34762c0 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -80,7 +80,6 @@ class RailEnv(Environment):
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
-                 file_name=None
                  ):
         """
         Environment init.
@@ -133,10 +132,6 @@ class RailEnv(Environment):
         self.agents = [None] * number_of_agents  # live agents
         self.agents_static = [None] * number_of_agents  # static agent information
         self.num_resets = 0
-        if file_name:
-            self.loaded_file = file_name
-        else:
-            self.loaded_file = None
 
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
@@ -177,13 +172,11 @@ class RailEnv(Environment):
 
         if regen_rail or self.rail is None:
             self.rail = tRailAgents[0]
+            self.height, self.width = self.rail.grid.shape
 
         if replace_agents:
             self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
 
-        if self.loaded_file:
-            self.load(self.loaded_file)
-
         self.restart_agents()
 
         for i_agent in range(self.get_num_agents()):
diff --git a/tests/test_file_load.py b/tests/test_file_load.py
index af5644f..2b929b1 100644
--- a/tests/test_file_load.py
+++ b/tests/test_file_load.py
@@ -3,7 +3,7 @@
 
 import numpy as np
 
-from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator
+from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_data
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -26,10 +26,9 @@ def test_load_pkl():
 
     env = RailEnv(width=1,
                   height=1,
-                  rail_generator=empty_rail_generator(),
+                  rail_generator=rail_from_data(file_name),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  file_name=file_name
                   )
     rails_loaded = env.rail.grid
     agents_loaded = env.agents
-- 
GitLab