diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py
index 13f8d8f9d7df1808bfb69261ee25bf06f9da2d45..4038fe675b1d7cfd7a0fd8447f2f81bc1fe86f04 100644
--- a/tests/test_env_observation_builder.py
+++ b/tests/test_env_observation_builder.py
@@ -82,7 +82,7 @@ def test_global_obs():
 
     # If this assertion is wrong, it means that the observation returned
     # places the agent on an empty cell
-    assert(np.sum(rail_map * global_obs[0][1][0]) > 0)
+    assert(np.sum(rail_map * global_obs[0][1][:, :, 0]) > 0)
 
 
 def main():