From 096fd93376921fd3a442287b45b0e291149bde31 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Mon, 16 Sep 2019 19:43:16 +0200
Subject: [PATCH] #178 bugfix initial malfunction

---
 tests/test_flatland_malfunction.py |  3 +--
 tests/test_multi_speed.py          | 22 ++++++++++++++++------
 2 files changed, 17 insertions(+), 8 deletions(-)

diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index e74666e1..fb191fd9 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -153,7 +153,6 @@ def test_malfunction_process_statistically():
     assert nb_malfunction > 150
 
 
-# TODO test DO_NOTHING!
 def test_initial_malfunction(rendering=True):
     random.seed(0)
     stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
@@ -240,7 +239,7 @@ def test_initial_malfunction(rendering=True):
         _assert(agent.direction, replay.direction, 'direction')
         _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
 
-        if replay.action:
+        if replay.action is not None:
             assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
             _, _, _, info_dict = env.step({0: replay.action})
 
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 529e9412..b83f133f 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -1,5 +1,3 @@
-import time
-
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
@@ -97,6 +95,8 @@ def test_multi_speed_init():
                 old_pos[i_agent] = env.agents[i_agent].position
 
 
+# TODO test penalties!
+# TODO test invalid actions!
 def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
     """Test that actions are correctly performed on cell exit for a single agent."""
     rail, rail_map = make_simple_rail()
@@ -191,7 +191,6 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
         speed=0.5
     )
 
-    # TODO test penalties!
     agentStatic: EnvAgentStatic = env.agents_static[0]
     info_dict = {
         'action_required': [True]
@@ -216,7 +215,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
         _assert(agent.position, replay.position, 'position')
         _assert(agent.direction, replay.direction, 'direction')
 
-        if replay.action:
+        if replay.action is not None:
             assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
             _, _, _, info_dict = env.step({0: replay.action})
 
@@ -424,7 +423,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
             _assert(a, agent.position, replay.position, 'position')
             _assert(a, agent.direction, replay.direction, 'direction')
 
-            if replay.action:
+            if replay.action is not None:
                 assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
                     step, a, True)
                 action_dict[a] = replay.action
@@ -534,9 +533,20 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
                 direction=Grid4TransitionsEnum.SOUTH,
                 action=None
             ),
+            # DO_NOTHING keeps moving!
+            Replay(
+                position=(5, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=RailEnvActions.DO_NOTHING
+            ),
             Replay(
                 position=(5, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
+                action=None
+            ),
+            Replay(
+                position=(6, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
                 action=RailEnvActions.MOVE_FORWARD
             ),
 
@@ -573,7 +583,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
         if replay.malfunction:
             agent.malfunction_data['malfunction'] = 2
 
-        if replay.action:
+        if replay.action is not None:
             assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
             _, _, _, info_dict = env.step({0: replay.action})
 
-- 
GitLab