From 2ed3eb6ade58f0da035ccb77a124de86c69e4efb Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 10 Jul 2019 17:00:15 +0200
Subject: [PATCH] #92 reward function test

---
 tests/test_flatland_envs_observations.py | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 4eb8d63..d400dc2 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -117,7 +117,7 @@ def test_reward_function_conflict(rendering=False):
         renderer = RenderTool(env, gl="PILSVG")
         renderer.renderEnv(show=True, show_observations=True)
 
-    iteration = -1
+    iteration = 0
     expected_positions = {
         0: {
             0: (5, 6),
@@ -147,8 +147,7 @@ def test_reward_function_conflict(rendering=False):
             1: (3, 7)
         },
     }
-    while not env.dones["__all__"] and iteration + 1 < 5:
-        iteration += 1
+    while iteration < 5:
         rewards = _step_along_shortest_path(env, obs_builder, rail)
 
         for agent in env.agents:
@@ -161,6 +160,8 @@ def test_reward_function_conflict(rendering=False):
         if rendering:
             renderer.renderEnv(show=True, show_observations=True)
 
+        iteration += 1
+
 
 def test_reward_function_waiting(rendering=False):
     rail, rail_map = make_simple_rail()
@@ -194,7 +195,7 @@ def test_reward_function_waiting(rendering=False):
         renderer = RenderTool(env, gl="PILSVG")
         renderer.renderEnv(show=True, show_observations=True)
 
-    iteration = -1
+    iteration = 0
     expectations = {
         0: {
             'positions': {
@@ -252,20 +253,20 @@ def test_reward_function_waiting(rendering=False):
         7: {
             'positions': {
                 0: (3, 1),
-                1: (5, 6),
+                1: (3, 8),
             },
             'rewards': [1, 1],
         },
         8: {
             'positions': {
                 0: (3, 1),
-                1: (5, 6),
+                1: (3, 8),
             },
             'rewards': [1, 1],
         },
     }
-    while not env.dones["__all__"] and iteration + 1 < 5:
-        iteration += 1
+    while iteration < 7:
+
         rewards = _step_along_shortest_path(env, obs_builder, rail)
 
         if rendering:
@@ -289,3 +290,4 @@ def test_reward_function_waiting(rendering=False):
                                                                                                    agent.handle,
                                                                                                    actual_reward,
                                                                                                    expected_reward)
+        iteration += 1
-- 
GitLab