From 096fd93376921fd3a442287b45b0e291149bde31 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Mon, 16 Sep 2019 19:43:16 +0200 Subject: [PATCH] #178 bugfix initial malfunction --- tests/test_flatland_malfunction.py | 3 +-- tests/test_multi_speed.py | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index e74666e1..fb191fd9 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -153,7 +153,6 @@ def test_malfunction_process_statistically(): assert nb_malfunction > 150 -# TODO test DO_NOTHING! def test_initial_malfunction(rendering=True): random.seed(0) stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents @@ -240,7 +239,7 @@ def test_initial_malfunction(rendering=True): _assert(agent.direction, replay.direction, 'direction') _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') - if replay.action: + if replay.action is not None: assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) _, _, _, info_dict = env.step({0: replay.action}) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 529e9412..b83f133f 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,5 +1,3 @@ -import time - import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum @@ -97,6 +95,8 @@ def test_multi_speed_init(): old_pos[i_agent] = env.agents[i_agent].position +# TODO test penalties! +# TODO test invalid actions! def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() @@ -191,7 +191,6 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): speed=0.5 ) - # TODO test penalties! agentStatic: EnvAgentStatic = env.agents_static[0] info_dict = { 'action_required': [True] @@ -216,7 +215,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): _assert(agent.position, replay.position, 'position') _assert(agent.direction, replay.direction, 'direction') - if replay.action: + if replay.action is not None: assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) _, _, _, info_dict = env.step({0: replay.action}) @@ -424,7 +423,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction') - if replay.action: + if replay.action is not None: assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format( step, a, True) action_dict[a] = replay.action @@ -534,9 +533,20 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): direction=Grid4TransitionsEnum.SOUTH, action=None ), + # DO_NOTHING keeps moving! + Replay( + position=(5, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.DO_NOTHING + ), Replay( position=(5, 6), direction=Grid4TransitionsEnum.SOUTH, + action=None + ), + Replay( + position=(6, 6), + direction=Grid4TransitionsEnum.SOUTH, action=RailEnvActions.MOVE_FORWARD ), @@ -573,7 +583,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): if replay.malfunction: agent.malfunction_data['malfunction'] = 2 - if replay.action: + if replay.action is not None: assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) _, _, _, info_dict = env.step({0: replay.action}) -- GitLab