diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index d2663916a17a70597d10e489da7aead4f8932dc4..0d6d309765690b1f95c681d7d109a13071d7f86b 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -41,7 +41,9 @@ def test_global_obs(): # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell - assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0) + obs_agents_state = global_obs[0][1] + obs_agents_state = obs_agents_state + 1 + assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0) def _step_along_shortest_path(env, obs_builder, rail):