From 6328547fdbc8d792807aecfe16804f1e76d45790 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Wed, 9 Oct 2019 09:15:05 -0400
Subject: [PATCH] updated stochastic malfunction test

---
 flatland/envs/rail_env.py          |  2 +-
 tests/test_flatland_malfunction.py | 30 +++++++++++++++++++++---------
 2 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 945771aa..86987a56 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -354,7 +354,7 @@ class RailEnv(Environment):
         # If counter has come to zero --> Agent has malfunction
         # set next malfunction time and duration of current malfunction
         if agent.malfunction_data['malfunction_rate'] >= 1 and 1 > agent.malfunction_data['malfunction'] and \
-            agent.malfunction_data['next_malfunction'] <= 0:
+            agent.malfunction_data['next_malfunction'] < 1:
             # Increase number of malfunctions
             agent.malfunction_data['nr_malfunctions'] += 1
 
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 8e56a6a1..35e41b7e 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -136,9 +136,9 @@ def test_malfunction_process_statistically():
     """Tests hat malfunctions are produced by stochastic_data!"""
     # Set fixed malfunction duration for this test
     stochastic_data = {'prop_malfunction': 1.,
-                       'malfunction_rate': 2,
-                       'min_duration': 3,
-                       'max_duration': 3}
+                       'malfunction_rate': 5,
+                       'min_duration': 5,
+                       'max_duration': 5}
 
     rail, rail_map = make_simple_rail2()
 
@@ -146,7 +146,7 @@ def test_malfunction_process_statistically():
                   height=30,
                   rail_generator=rail_from_grid_transition_map(rail),
                   schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
+                  number_of_agents=10,
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   obs_builder_object=SingleAgentNavigationObs()
                   )
@@ -155,15 +155,27 @@ def test_malfunction_process_statistically():
 
     env.agents[0].target = (0, 0)
     nb_malfunction = 0
+    agent_malfunction_list = [[6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6],
+                              [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0],
+                              [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2],
+                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0],
+                              [6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5],
+                              [6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5],
+                              [6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1],
+                              [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4],
+                              [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6],
+                              [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2]]
+
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}
-        for agent in env.agents:
+        for agent_idx in range(env.get_num_agents()):
             # We randomly select an action
-            action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
-
+            action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
+            # For generating tests only:
+            # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
+            assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
         env.step(action_dict)
-    # check that generation of malfunctions works as expected
-    assert env.agents[0].malfunction_data["nr_malfunctions"] == 4
+
 
 
 def test_malfunction_before_entry():
-- 
GitLab