diff --git a/tests/test_environments.py b/tests/test_environments.py index 210f1c76c8fd9978141a48189d5bcf2e31e68611..198104fc2c56db2dac0573939bbad2d071d5618b 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -140,8 +140,10 @@ def test_dead_end(): _ = rail_env.step({0: 1}) _ = rail_env.step({0: 3}) assert (rail_env.agents_position[0] == prev_pos) - - _, _, dones, _ = rail_env.step({0: 2}) + if rail_env.agents_position[0] != rail_env.agents_target[0]: + _, _, dones, _ = rail_env.step({0: 2}) + else: + _, _, dones, _ = rail_env.step({0: 0}) if i < 5: assert (not dones[0] and not dones['__all__'])