From 2ed3eb6ade58f0da035ccb77a124de86c69e4efb Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 10 Jul 2019 17:00:15 +0200 Subject: [PATCH] #92 reward function test --- tests/test_flatland_envs_observations.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 4eb8d63..d400dc2 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -117,7 +117,7 @@ def test_reward_function_conflict(rendering=False): renderer = RenderTool(env, gl="PILSVG") renderer.renderEnv(show=True, show_observations=True) - iteration = -1 + iteration = 0 expected_positions = { 0: { 0: (5, 6), @@ -147,8 +147,7 @@ def test_reward_function_conflict(rendering=False): 1: (3, 7) }, } - while not env.dones["__all__"] and iteration + 1 < 5: - iteration += 1 + while iteration < 5: rewards = _step_along_shortest_path(env, obs_builder, rail) for agent in env.agents: @@ -161,6 +160,8 @@ def test_reward_function_conflict(rendering=False): if rendering: renderer.renderEnv(show=True, show_observations=True) + iteration += 1 + def test_reward_function_waiting(rendering=False): rail, rail_map = make_simple_rail() @@ -194,7 +195,7 @@ def test_reward_function_waiting(rendering=False): renderer = RenderTool(env, gl="PILSVG") renderer.renderEnv(show=True, show_observations=True) - iteration = -1 + iteration = 0 expectations = { 0: { 'positions': { @@ -252,20 +253,20 @@ def test_reward_function_waiting(rendering=False): 7: { 'positions': { 0: (3, 1), - 1: (5, 6), + 1: (3, 8), }, 'rewards': [1, 1], }, 8: { 'positions': { 0: (3, 1), - 1: (5, 6), + 1: (3, 8), }, 'rewards': [1, 1], }, } - while not env.dones["__all__"] and iteration + 1 < 5: - iteration += 1 + while iteration < 7: + rewards = _step_along_shortest_path(env, obs_builder, rail) if rendering: @@ -289,3 +290,4 @@ def test_reward_function_waiting(rendering=False): agent.handle, actual_reward, expected_reward) + iteration += 1 -- GitLab