From 8b5566f8bc7bbaa0b13656e6dee26c074e5b29a0 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Thu, 31 Oct 2019 11:14:13 -0400
Subject: [PATCH] Introducing malfunction_generators This resolves issue #273

added test for saving and loading malfunction parameters
---
 tests/tests_malfunction_generators.py | 78 +++++++++++++++++++++++++++
 1 file changed, 78 insertions(+)
 create mode 100644 tests/tests_malfunction_generators.py

diff --git a/tests/tests_malfunction_generators.py b/tests/tests_malfunction_generators.py
new file mode 100644
index 00000000..fa455b75
--- /dev/null
+++ b/tests/tests_malfunction_generators.py
@@ -0,0 +1,78 @@
+import random
+from typing import Dict, List
+
+import numpy as np
+from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
+
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.simple_rail import make_simple_rail2
+
+
+def test_malfanction_from_params():
+    """
+    Test loading malfunction from
+    Returns
+    -------
+
+    """
+    stochastic_data = {'malfunction_rate': 1000,  # Rate of malfunction occurence
+                       'min_duration': 2,  # Minimal duration of malfunction
+                       'max_duration': 5  # Max duration of malfunction
+                       }
+
+    rail, rail_map = make_simple_rail2()
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=10),
+                  number_of_agents=1,
+                  malfunction_generator=malfunction_from_params(stochastic_data))
+    env.reset()
+    assert env.mean_malfunction_rate == 1000
+    assert env.min_number_of_steps_broken == 2
+    assert env.max_number_of_steps_broken == 5
+
+def test_malfanction_to_and_from_file():
+    """
+    Test loading malfunction from
+    Returns
+    -------
+
+    """
+    stochastic_data = {'malfunction_rate': 1000,  # Rate of malfunction occurence
+                       'min_duration': 2,  # Minimal duration of malfunction
+                       'max_duration': 5  # Max duration of malfunction
+                       }
+
+    rail, rail_map = make_simple_rail2()
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=10),
+                  number_of_agents=1,
+                  malfunction_generator=malfunction_from_params(stochastic_data))
+
+    env.reset()
+    env.save("./malfunction_saving_loading_tests.pkl")
+
+    env2 = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=10),
+                  number_of_agents=1,
+                  malfunction_generator=malfunction_from_file("./malfunction_saving_loading_tests.pkl"))
+
+    env2.reset()
+
+    assert env2.mean_malfunction_rate == 1000
+    assert env2.min_number_of_steps_broken == 2
+    assert env2.max_number_of_steps_broken == 5
-- 
GitLab