Skip to content
Snippets Groups Projects
Commit 2ed3eb6a authored by u214892's avatar u214892
Browse files

#92 reward function test

parent bf23a556
No related branches found
No related tags found
No related merge requests found
...@@ -117,7 +117,7 @@ def test_reward_function_conflict(rendering=False): ...@@ -117,7 +117,7 @@ def test_reward_function_conflict(rendering=False):
renderer = RenderTool(env, gl="PILSVG") renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=True) renderer.renderEnv(show=True, show_observations=True)
iteration = -1 iteration = 0
expected_positions = { expected_positions = {
0: { 0: {
0: (5, 6), 0: (5, 6),
...@@ -147,8 +147,7 @@ def test_reward_function_conflict(rendering=False): ...@@ -147,8 +147,7 @@ def test_reward_function_conflict(rendering=False):
1: (3, 7) 1: (3, 7)
}, },
} }
while not env.dones["__all__"] and iteration + 1 < 5: while iteration < 5:
iteration += 1
rewards = _step_along_shortest_path(env, obs_builder, rail) rewards = _step_along_shortest_path(env, obs_builder, rail)
for agent in env.agents: for agent in env.agents:
...@@ -161,6 +160,8 @@ def test_reward_function_conflict(rendering=False): ...@@ -161,6 +160,8 @@ def test_reward_function_conflict(rendering=False):
if rendering: if rendering:
renderer.renderEnv(show=True, show_observations=True) renderer.renderEnv(show=True, show_observations=True)
iteration += 1
def test_reward_function_waiting(rendering=False): def test_reward_function_waiting(rendering=False):
rail, rail_map = make_simple_rail() rail, rail_map = make_simple_rail()
...@@ -194,7 +195,7 @@ def test_reward_function_waiting(rendering=False): ...@@ -194,7 +195,7 @@ def test_reward_function_waiting(rendering=False):
renderer = RenderTool(env, gl="PILSVG") renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=True) renderer.renderEnv(show=True, show_observations=True)
iteration = -1 iteration = 0
expectations = { expectations = {
0: { 0: {
'positions': { 'positions': {
...@@ -252,20 +253,20 @@ def test_reward_function_waiting(rendering=False): ...@@ -252,20 +253,20 @@ def test_reward_function_waiting(rendering=False):
7: { 7: {
'positions': { 'positions': {
0: (3, 1), 0: (3, 1),
1: (5, 6), 1: (3, 8),
}, },
'rewards': [1, 1], 'rewards': [1, 1],
}, },
8: { 8: {
'positions': { 'positions': {
0: (3, 1), 0: (3, 1),
1: (5, 6), 1: (3, 8),
}, },
'rewards': [1, 1], 'rewards': [1, 1],
}, },
} }
while not env.dones["__all__"] and iteration + 1 < 5: while iteration < 7:
iteration += 1
rewards = _step_along_shortest_path(env, obs_builder, rail) rewards = _step_along_shortest_path(env, obs_builder, rail)
if rendering: if rendering:
...@@ -289,3 +290,4 @@ def test_reward_function_waiting(rendering=False): ...@@ -289,3 +290,4 @@ def test_reward_function_waiting(rendering=False):
agent.handle, agent.handle,
actual_reward, actual_reward,
expected_reward) expected_reward)
iteration += 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment