diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py index 4038fe675b1d7cfd7a0fd8447f2f81bc1fe86f04..ad393634aa4ab156e3a305dd8654d1d588127805 100644 --- a/tests/test_env_observation_builder.py +++ b/tests/test_env_observation_builder.py @@ -82,7 +82,7 @@ def test_global_obs(): # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell - assert(np.sum(rail_map * global_obs[0][1][:, :, 0]) > 0) + assert(np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0) def main():