From 00431259aa0cd53085bebc4221c6d6b6c3db63e3 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Fri, 25 Oct 2019 17:37:13 -0400
Subject: [PATCH] updated tests to new malfunction behavior

---
 flatland/envs/rail_env.py          |  2 +-
 tests/test_flatland_malfunction.py | 31 +++++++++++++++++-------------
 2 files changed, 19 insertions(+), 14 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index fa2301b8..6b3721ac 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -429,7 +429,7 @@ class RailEnv(Environment):
             agent.malfunction_data['next_malfunction'] = max(next_breakdown, 1)
             # Duration of current malfunction
             num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
-                                                      self.max_number_of_steps_broken + 1) + 1
+                                                      self.max_number_of_steps_broken + 1)
             agent.malfunction_data['malfunction'] = num_broken_steps
             agent.malfunction_data['moving_before_malfunction'] = agent.moving
 
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index d3dc6fa7..bec8dae5 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -156,14 +156,14 @@ def test_malfunction_process_statistically():
 
     env.agents[0].target = (0, 0)
     # Next line only for test generation
-    # agent_malfunction_list = [[] for i in range(20)]
+    #agent_malfunction_list = [[] for i in range(20)]
     agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 4, 4, 0, 0, 0, 0],
      [0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], [0, 0, 0, 0, 1, 1, 5, 0, 0, 4],
-     [0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 5, 5, 3, 4, 0, 2], [4, 5, 0, 3, 4, 4, 2, 3, 5, 1],
-     [3, 4, 0, 2, 3, 3, 1, 2, 4, 0], [2, 3, 5, 1, 2, 2, 0, 1, 3, 0], [1, 2, 4, 0, 1, 1, 5, 0, 2, 0],
-     [0, 1, 3, 0, 0, 0, 4, 0, 1, 0], [5, 0, 2, 0, 0, 5, 3, 5, 0, 5], [4, 0, 1, 0, 0, 4, 2, 4, 0, 4],
-     [3, 0, 0, 0, 0, 3, 1, 3, 5, 3], [2, 0, 0, 0, 0, 2, 0, 2, 4, 2], [1, 0, 5, 5, 5, 1, 5, 1, 3, 1],
-     [0, 0, 4, 4, 4, 0, 4, 0, 2, 0], [5, 0, 3, 3, 3, 5, 3, 5, 1, 5]]
+     [0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 0, 0, 3, 4, 0, 2], [4, 5, 0, 3, 5, 5, 2, 3, 5, 1],
+     [3, 4, 0, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0], [1, 2, 4, 0, 2, 2, 0, 0, 2, 0],
+     [0, 1, 3, 0, 1, 1, 5, 0, 1, 0], [0, 0, 2, 0, 0, 0, 4, 0, 0, 0], [5, 0, 1, 0, 0, 0, 3, 5, 0, 5],
+     [4, 0, 0, 0, 5, 0, 2, 4, 0, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3], [2, 0, 0, 0, 3, 0, 0, 2, 4, 2],
+     [1, 0, 5, 5, 2, 0, 0, 1, 3, 1], [0, 5, 4, 4, 1, 0, 5, 0, 2, 0]]
 
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}
@@ -223,6 +223,11 @@ def test_malfunction_before_entry():
     assert env.agents[8].malfunction_data['malfunction'] == 0
     assert env.agents[9].malfunction_data['malfunction'] == 0
 
+    for a in range(env.get_num_agents()):
+        print("assert env.agents[{}].malfunction_data['next_malfunction'] == {}".format(a, env.agents[a].malfunction_data['next_malfunction']))
+    for a in range(env.get_num_agents()):
+        print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[
+                'malfunction']))
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}
         for agent in env.agents:
@@ -309,18 +314,18 @@ def test_initial_malfunction():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=1,
-                reward=env.start_penalty + env.step_penalty * 1.0
-                # malfunctioning ends: starting and running at speed 1.0
-            ),
+                reward= env.step_penalty * 1.0
+
+            ),# malfunctioning ends: starting and running at speed 1.0
             Replay(
-                position=(3, 3),
+                position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
-                reward=env.step_penalty * 1.0  # running at speed 1.0
+                reward=env.start_penalty +env.step_penalty * 1.0  # running at speed 1.0
             ),
             Replay(
-                position=(3, 4),
+                position=(3, 3),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
@@ -654,7 +659,7 @@ def test_last_malfunction_step():
             action_dict[agent.handle] = RailEnvActions(2)
 
         # Check if the agent is still allowed to move in this step
-        if env.agents[0].malfunction_data['malfunction'] > 1 or env.agents[0].malfunction_data['next_malfunction'] < 1:
+        if env.agents[0].malfunction_data['malfunction'] > 0 or env.agents[0].malfunction_data['next_malfunction'] < 1:
             agent_can_move = False
         else:
             agent_can_move = True
-- 
GitLab