From c65f00bc45f0ce64a7ba416617a80c613d39dd99 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 28 Oct 2019 16:04:36 -0400
Subject: [PATCH] added simple test for decay of next_malfunction step

---
 flatland/envs/rail_env.py          |   4 --
 tests/test_flatland_malfunction.py | 101 +++++++++++++----------------
 2 files changed, 46 insertions(+), 59 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index f8534640..9a25f954 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -430,10 +430,6 @@ class RailEnv(Environment):
 
             return True
 
-        # Decrease counter for next event only if agent is currently not broken and agent has a malfunction rate
-        if agent.malfunction_data['next_malfunction'] > 0 and agent.malfunction_data['malfunction'] < 1:
-            agent.malfunction_data['next_malfunction'] -= 1
-
 
 
     def step(self, action_dict_: Dict[int, RailEnvActions]):
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index bec8dae5..4fbfa86e 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -156,14 +156,14 @@ def test_malfunction_process_statistically():
 
     env.agents[0].target = (0, 0)
     # Next line only for test generation
-    #agent_malfunction_list = [[] for i in range(20)]
+    # agent_malfunction_list = [[] for i in range(20)]
     agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 4, 4, 0, 0, 0, 0],
-     [0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], [0, 0, 0, 0, 1, 1, 5, 0, 0, 4],
-     [0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 0, 0, 3, 4, 0, 2], [4, 5, 0, 3, 5, 5, 2, 3, 5, 1],
-     [3, 4, 0, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0], [1, 2, 4, 0, 2, 2, 0, 0, 2, 0],
-     [0, 1, 3, 0, 1, 1, 5, 0, 1, 0], [0, 0, 2, 0, 0, 0, 4, 0, 0, 0], [5, 0, 1, 0, 0, 0, 3, 5, 0, 5],
-     [4, 0, 0, 0, 5, 0, 2, 4, 0, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3], [2, 0, 0, 0, 3, 0, 0, 2, 4, 2],
-     [1, 0, 5, 5, 2, 0, 0, 1, 3, 1], [0, 5, 4, 4, 1, 0, 5, 0, 2, 0]]
+     [0, 0, 0, 0, 3, 3, 5, 0, 0, 0], [5, 0, 0, 5, 2, 2, 4, 5, 0, 5], [4, 5, 0, 4, 1, 1, 3, 4, 5, 4],
+     [3, 4, 0, 3, 0, 0, 2, 3, 4, 3], [2, 3, 5, 2, 0, 0, 1, 2, 3, 2], [1, 2, 4, 1, 5, 5, 0, 1, 2, 1],
+     [0, 1, 3, 0, 4, 4, 0, 0, 1, 0], [0, 0, 2, 0, 3, 3, 0, 0, 0, 0], [5, 0, 1, 0, 2, 2, 5, 5, 0, 5],
+     [4, 0, 0, 0, 1, 1, 4, 4, 5, 4], [3, 0, 0, 5, 0, 0, 3, 3, 4, 3], [2, 5, 0, 4, 0, 0, 2, 2, 3, 2],
+     [1, 4, 0, 3, 5, 5, 1, 1, 2, 1], [0, 3, 0, 2, 4, 4, 0, 0, 1, 0], [0, 2, 0, 1, 3, 3, 0, 0, 0, 0],
+     [5, 1, 0, 0, 2, 2, 5, 5, 0, 5], [4, 0, 5, 0, 1, 1, 4, 4, 5, 4]]
 
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}
@@ -202,16 +202,16 @@ def test_malfunction_before_entry():
     # Test initial malfunction values for all agents
     # we want some agents to be malfuncitoning already and some to be working
     # we want different next_malfunction values for the agents
-    assert env.agents[0].malfunction_data['next_malfunction'] == 5
-    assert env.agents[1].malfunction_data['next_malfunction'] == 6
-    assert env.agents[2].malfunction_data['next_malfunction'] == 6
-    assert env.agents[3].malfunction_data['next_malfunction'] == 3
+
+    assert env.agents[1].malfunction_data['next_malfunction'] == 5
+    assert env.agents[2].malfunction_data['next_malfunction'] == 5
+    assert env.agents[3].malfunction_data['next_malfunction'] == 2
     assert env.agents[4].malfunction_data['next_malfunction'] == 1
     assert env.agents[5].malfunction_data['next_malfunction'] == 1
-    assert env.agents[6].malfunction_data['next_malfunction'] == 3
-    assert env.agents[7].malfunction_data['next_malfunction'] == 4
-    assert env.agents[8].malfunction_data['next_malfunction'] == 6
-    assert env.agents[9].malfunction_data['next_malfunction'] == 0
+    assert env.agents[6].malfunction_data['next_malfunction'] == 2
+    assert env.agents[7].malfunction_data['next_malfunction'] == 3
+    assert env.agents[8].malfunction_data['next_malfunction'] == 5
+    assert env.agents[9].malfunction_data['next_malfunction'] == -1
     assert env.agents[0].malfunction_data['malfunction'] == 0
     assert env.agents[1].malfunction_data['malfunction'] == 0
     assert env.agents[2].malfunction_data['malfunction'] == 0
@@ -223,50 +223,41 @@ def test_malfunction_before_entry():
     assert env.agents[8].malfunction_data['malfunction'] == 0
     assert env.agents[9].malfunction_data['malfunction'] == 0
 
-    for a in range(env.get_num_agents()):
-        print("assert env.agents[{}].malfunction_data['next_malfunction'] == {}".format(a, env.agents[a].malfunction_data['next_malfunction']))
-    for a in range(env.get_num_agents()):
-        print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[
-                'malfunction']))
-    for step in range(20):
-        action_dict: Dict[int, RailEnvActions] = {}
-        for agent in env.agents:
-            # We randomly select an action
-            action_dict[agent.handle] = RailEnvActions(2)
-            if step < 10:
-                action_dict[agent.handle] = RailEnvActions(0)
+def test_next_malfunction_counter():
+    """
+    Test that the next malfunction occurs when desired
+    Returns
+    -------
 
+    """
+    # Set fixed malfunction duration for this test
+
+    rail, rail_map = make_simple_rail2()
+    action_dict: Dict[int, RailEnvActions] = {}
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
+                  number_of_agents=1,
+                  random_seed=1,
+                  )
+    # reset to initialize agents_static
+    env.reset(False, False, activate_agents=True, random_seed=10)
+    env.agents[0].malfunction_data['next_malfunction'] = 5
+    env.agents[0].malfunction_data['malfunction_rate'] = 5
+    env.agents[0].malfunction_data['malfunction'] = 0
+    env.agents[0].target =(0, 0), #Move the target out of range
+    print(env.agents[0].position, env.agents[0].malfunction_data['next_malfunction'])
+
+    for time_step in range(1, 6):
+        # Move in the env
         env.step(action_dict)
 
-    # We want to check that all agents are malfunctioning and that their values changed
+        # Check that next_step decreases as expected
+        assert env.agents[0].malfunction_data['next_malfunction'] == 5 - time_step
+
 
-    # Test  malfunction values for all agents after 20 steps
-    assert env.agents[0].malfunction_data['next_malfunction'] == 4
-    assert env.agents[1].malfunction_data['next_malfunction'] == 6
-    assert env.agents[2].malfunction_data['next_malfunction'] == 2
-    assert env.agents[3].malfunction_data['next_malfunction'] == 2
-    assert env.agents[4].malfunction_data['next_malfunction'] == 1
-    assert env.agents[5].malfunction_data['next_malfunction'] == 1
-    assert env.agents[6].malfunction_data['next_malfunction'] == 2
-    assert env.agents[7].malfunction_data['next_malfunction'] == 1
-    assert env.agents[8].malfunction_data['next_malfunction'] == 1
-    assert env.agents[9].malfunction_data['next_malfunction'] == 4
-    assert env.agents[0].malfunction_data['malfunction'] == 0
-    assert env.agents[1].malfunction_data['malfunction'] == 8
-    assert env.agents[2].malfunction_data['malfunction'] == 8
-    assert env.agents[3].malfunction_data['malfunction'] == 0
-    assert env.agents[4].malfunction_data['malfunction'] == 1
-    assert env.agents[5].malfunction_data['malfunction'] == 1
-    assert env.agents[6].malfunction_data['malfunction'] == 0
-    assert env.agents[7].malfunction_data['malfunction'] == 6
-    assert env.agents[8].malfunction_data['malfunction'] == 8
-    assert env.agents[9].malfunction_data['malfunction'] == 2
-    # Print for test generation
-    #for a in range(env.get_num_agents()):
-    #    print("assert env.agents[{}].malfunction_data['next_malfunction'] == {}".format(a, env.agents[a].malfunction_data['next_malfunction']))
-    #for a in range(env.get_num_agents()):
-    #    print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[
-    #            'malfunction']))
 
 def test_initial_malfunction():
     stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
-- 
GitLab