From 96444273bc2f3338c47b180883d47b94e6db8362 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Tue, 24 Sep 2019 11:51:29 +0200
Subject: [PATCH] bugfix load flatland

---
 env_data/tests/test_001.pkl                | Bin 0 -> 461 bytes
 flatland/envs/distance_map.py              |   2 ++
 flatland/envs/observations.py              |   3 ++-
 flatland/envs/rail_env_utils.py            |  17 +++++++++++++++++
 flatland/envs/rail_generators.py           |  10 +++++++---
 flatland/envs/schedule_generators.py       |  10 +++++++---
 tests/test_flatland_core_transition_map.py |   1 -
 tests/test_flatland_envs_env_utils.py      |   5 +++++
 tests/test_flatland_envs_rail_env.py       |   1 -
 9 files changed, 40 insertions(+), 9 deletions(-)
 create mode 100644 env_data/tests/test_001.pkl
 create mode 100644 flatland/envs/rail_env_utils.py

diff --git a/env_data/tests/test_001.pkl b/env_data/tests/test_001.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..823e594e41f26f6febac6a241f3befcec9f81cf7
GIT binary patch
literal 461
zcmZo(l3tXVGIbUM9L!=lD}>Bp@jT1Ih>pY<&oVN=HF+R&7#P+jrl;nW6vr2rB$i|*
zPntG~k%4g%Gbcmy#)ACf%#zIfy!f=D#AFcTG|Z-@#RaLUDW~l}Km_)d6eZ?C)k0Lp
z19c^*=H#SSWR@_ruFg%&Nh{3*o58RFi4$LxSdz-HAuqL}BpykaVSQc^GQXH%XKsF3
hW?p)HQfgX$Q7W?Bp=pztIGHCgF$@_8FM|a51OPozqFw+1

literal 0
HcmV?d00001

diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py
index 94019319..c9c6b003 100644
--- a/flatland/envs/distance_map.py
+++ b/flatland/envs/distance_map.py
@@ -58,6 +58,8 @@ class DistanceMap:
         self.reset_was_called = True
         self.agents = agents
         self.rail = rail
+        self.env_height = rail.height
+        self.env_width = rail.width
 
     def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
         self.agents_previous_computation = self.agents
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 712d2425..6312cca3 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -167,7 +167,8 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Root node - current position
         # Here information about the agent itself is stored
-        observation = [0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent.position, agent.direction)], 0, 0,
+        distance_map = self.env.distance_map.get()
+        observation = [0, 0, 0, 0, 0, 0, distance_map[(handle, *agent.position, agent.direction)], 0, 0,
                        agent.malfunction_data['malfunction'], agent.speed_data['speed']]
 
         visited = OrderedSet()
diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py
new file mode 100644
index 00000000..722f54bc
--- /dev/null
+++ b/flatland/envs/rail_env_utils.py
@@ -0,0 +1,17 @@
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_file
+from flatland.envs.schedule_generators import schedule_from_file
+
+
+def load_flatland_environment_from_file(file_name, load_from_package=None):
+    environment = RailEnv(width=1,
+                          height=1,
+                          rail_generator=rail_from_file(file_name, load_from_package),
+                          number_of_agents=1,
+                          schedule_generator=schedule_from_file(file_name,load_from_package),
+                          obs_builder_object=TreeObsForRailEnv(
+                              max_depth=2,
+                              predictor=ShortestPathPredictorForRailEnv(max_depth=10)))
+    return environment
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 60c606f7..0bce90aa 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -204,7 +204,7 @@ def rail_from_manual_specifications_generator(rail_spec):
     return generator
 
 
-def rail_from_file(filename) -> RailGenerator:
+def rail_from_file(filename, load_from_package=None) -> RailGenerator:
     """
     Utility to load pickle file
 
@@ -221,8 +221,12 @@ def rail_from_file(filename) -> RailGenerator:
 
     def generator(width, height, num_agents, num_resets):
         rail_env_transitions = RailEnvTransitions()
-        with open(filename, "rb") as file_in:
-            load_data = file_in.read()
+        if load_from_package is not None:
+            from importlib_resources import read_binary
+            load_data = read_binary(load_from_package, filename)
+        else:
+            with open(filename, "rb") as file_in:
+                load_data = file_in.read()
         data = msgpack.unpackb(load_data, use_list=False)
 
         grid = np.array(data[b"grid"])
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index b3576a2b..7f42feea 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -205,7 +205,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
     return generator
 
 
-def schedule_from_file(filename) -> ScheduleGenerator:
+def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
     """
     Utility to load pickle file
 
@@ -220,8 +220,12 @@ def schedule_from_file(filename) -> ScheduleGenerator:
     """
 
     def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
-        with open(filename, "rb") as file_in:
-            load_data = file_in.read()
+        if load_from_package is not None:
+            from importlib_resources import read_binary
+            load_data = read_binary(load_from_package, filename)
+        else:
+            with open(filename, "rb") as file_in:
+                load_data = file_in.read()
         data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
 
         # agents are always reset as not moving
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index 4da1da4d..930cc24f 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -168,7 +168,6 @@ def test_get_entry_directions():
     south_symmetrical_switch = cells[6]
     north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
 
-    # Simple turn not in the base transitions ?
     south_east_turn = int('0100000000000010', 2)
     south_west_turn = transitions.rotate_transition(south_east_turn, 90)
     north_east_turn = transitions.rotate_transition(south_east_turn, 270)
diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py
index cf5c8592..d0320ab7 100644
--- a/tests/test_flatland_envs_env_utils.py
+++ b/tests/test_flatland_envs_env_utils.py
@@ -4,6 +4,7 @@ import pytest
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_direction
 from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position
+from flatland.envs.rail_env_utils import load_flatland_environment_from_file
 
 depth_to_test = 5
 positions_to_test = [0, 5, 1, 6, 20, 30]
@@ -32,3 +33,7 @@ def test_get_direction():
     assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
     with pytest.raises(Exception, match="Could not determine direction"):
         get_direction((0, 0), (0, 0))
+
+
+def test_load():
+    load_flatland_environment_from_file('test_001.pkl', 'env_data.tests')
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index d5dc3ac7..0114730a 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -61,7 +61,6 @@ def test_rail_environment_single_agent():
     vertical_line = cells[1]
     south_symmetrical_switch = cells[6]
     north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
-    # Simple turn not in the base transitions ?
     south_east_turn = int('0100000000000010', 2)
     south_west_turn = transitions.rotate_transition(south_east_turn, 90)
     north_east_turn = transitions.rotate_transition(south_east_turn, 270)
-- 
GitLab