From 9e99adb7ad9a728a58ca86dd3f38c6366d6e95c7 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 4 Nov 2019 14:31:33 -0500
Subject: [PATCH] fixed loading and saving of new level gernator objects

---
 flatland/envs/malfunction_generators.py |  9 +++++----
 flatland/envs/rail_env.py               | 10 ++++-----
 tests/test_malfunction_generators.py    | 27 +++++++++++++++++--------
 3 files changed, 29 insertions(+), 17 deletions(-)

diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index 1810b12c..89af545b 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -41,14 +41,15 @@ def malfunction_from_file(filename) -> Tuple[MalfunctionGenerator, MalfunctionPr
     with open(filename, "rb") as file_in:
         load_data = file_in.read()
     data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
-
+    # TODO: make this better by using namedtuple in the pickle file
+    data['malfunction'] =MalfunctionProcessData._make(data['malfunction'])
     if "malfunction" in data:
         # Mean malfunction in number of time steps
-        mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
+        mean_malfunction_rate = data["malfunction"].malfunction_rate
 
         # Uniform distribution parameters for malfunction duration
-        min_number_of_steps_broken = data["malfunction"]["min_duration"]
-        max_number_of_steps_broken = data["malfunction"]["max_duration"]
+        min_number_of_steps_broken = data["malfunction"].min_duration
+        max_number_of_steps_broken = data["malfunction"].max_duration
     else:
         # Mean malfunction in number of time steps
         mean_malfunction_rate = 0.
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cba49f73..9025d415 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -19,7 +19,8 @@ from flatland.core.grid.grid_utils import IntVector2D
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
-from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction
+from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionGenerator, \
+    MalfunctionProcessData
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_generators import random_rail_generator, RailGenerator
 from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
@@ -159,7 +160,7 @@ class RailEnv(Environment):
         """
         super().__init__()
 
-        self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
+        self.malfunction_generator, self.malfunction_process_data  = malfunction_generator_and_process_data
         self.rail_generator: RailGenerator = rail_generator
         self.schedule_generator: ScheduleGenerator = schedule_generator
         self.rail: Optional[GridTransitionMap] = None
@@ -802,8 +803,7 @@ class RailEnv(Environment):
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
         agent_data = [agent.to_list() for agent in self.agents]
-        malfunction_data = {"malfunction_process_data": self.malfunction_process_data}
-
+        malfunction_data: MalfunctionProcessData = self.malfunction_process_data
         msgpack.packb(grid_data, use_bin_type=True)
         msgpack.packb(agent_data, use_bin_type=True)
         msgpack.packb(agent_static_data, use_bin_type=True)
@@ -825,7 +825,7 @@ class RailEnv(Environment):
         msgpack.packb(agent_data, use_bin_type=True)
         msgpack.packb(agent_static_data, use_bin_type=True)
         distance_map_data = self.distance_map.get()
-        malfunction_data = {"malfunction_process_data": self.malfunction_process_data}
+        malfunction_data: MalfunctionProcessData = self.malfunction_process_data
         msgpack.packb(distance_map_data, use_bin_type=True)
         msg_data = {
             "grid": grid_data,
diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py
index 779af28c..e01c7483 100644
--- a/tests/test_malfunction_generators.py
+++ b/tests/test_malfunction_generators.py
@@ -45,18 +45,29 @@ def test_malfanction_to_and_from_file():
 
     rail, rail_map = make_simple_rail2()
 
-    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(seed=10), number_of_agents=1)
-
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=10,
+                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
+                  )
     env.reset()
     env.save("./malfunction_saving_loading_tests.pkl")
 
     malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl")
-    env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
-                   schedule_generator=random_schedule_generator(seed=10), number_of_agents=1)
+    env2 = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=10,
+                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
+                  )
 
     env2.reset()
 
-    assert env2.mean_malfunction_rate == 1000
-    assert env2.min_number_of_steps_broken == 2
-    assert env2.max_number_of_steps_broken == 5
+    assert env2.malfunction_process_data ==  env.malfunction_process_data
+    assert env2.malfunction_process_data.malfunction_rate == 1000
+    assert env2.malfunction_process_data.min_duration == 2
+    assert env2.malfunction_process_data.max_duration == 5
+
-- 
GitLab