diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 4eb8d633c5a246c5a755a922cae02484bb1ac696..d400dc226a71e1a9c185012fccae3c852bcd42aa 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -117,7 +117,7 @@ def test_reward_function_conflict(rendering=False): renderer = RenderTool(env, gl="PILSVG") renderer.renderEnv(show=True, show_observations=True) - iteration = -1 + iteration = 0 expected_positions = { 0: { 0: (5, 6), @@ -147,8 +147,7 @@ def test_reward_function_conflict(rendering=False): 1: (3, 7) }, } - while not env.dones["__all__"] and iteration + 1 < 5: - iteration += 1 + while iteration < 5: rewards = _step_along_shortest_path(env, obs_builder, rail) for agent in env.agents: @@ -161,6 +160,8 @@ def test_reward_function_conflict(rendering=False): if rendering: renderer.renderEnv(show=True, show_observations=True) + iteration += 1 + def test_reward_function_waiting(rendering=False): rail, rail_map = make_simple_rail() @@ -194,7 +195,7 @@ def test_reward_function_waiting(rendering=False): renderer = RenderTool(env, gl="PILSVG") renderer.renderEnv(show=True, show_observations=True) - iteration = -1 + iteration = 0 expectations = { 0: { 'positions': { @@ -252,20 +253,20 @@ def test_reward_function_waiting(rendering=False): 7: { 'positions': { 0: (3, 1), - 1: (5, 6), + 1: (3, 8), }, 'rewards': [1, 1], }, 8: { 'positions': { 0: (3, 1), - 1: (5, 6), + 1: (3, 8), }, 'rewards': [1, 1], }, } - while not env.dones["__all__"] and iteration + 1 < 5: - iteration += 1 + while iteration < 7: + rewards = _step_along_shortest_path(env, obs_builder, rail) if rendering: @@ -289,3 +290,4 @@ def test_reward_function_waiting(rendering=False): agent.handle, actual_reward, expected_reward) + iteration += 1