diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index d62c689aa9f6c788b54474b884d4101d98fb4ff0..4f016f2b3d4ee3e1918fbf47c9492f2c84228999 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -238,6 +238,8 @@ class RailEnv(Environment): agent.speed_data['position_fraction'] = 0.0 agent.malfunction_data['malfunction'] = 0 + self._agent_stopped(i_agent) + self.num_resets += 1 self._elapsed_steps = 0 diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 91c551db60f9d71d7aa0774ea8b6aaf42af3e35b..4122d56fc4721b820bcc252b58629cd96f6f8681 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -51,10 +51,11 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): def test_malfunction_process(): + # Set fixed malfunction duration for this test stochastic_data = {'prop_malfunction': 1., 'malfunction_rate': 5, 'min_duration': 3, - 'max_duration': 10} + 'max_duration': 3} np.random.seed(5) env = RailEnv(width=14, @@ -66,23 +67,44 @@ def test_malfunction_process(): stochastic_data=stochastic_data) obs = env.reset() + + # Check that a initial duration for malfunction was assigned + assert env.agents[0].malfunction_data['next_malfunction'] > 0 + agent_halts = 0 + total_down_time = 0 + agent_malfunctioning = False + agent_old_position = env.agents[0].position for step in range(100): actions = {} for i in range(len(obs)): actions[i] = np.argmax(obs[i]) + 1 if step % 5 == 0: + # Stop the agent and set it to be malfunctioning actions[0] = 4 + env.agents[0].malfunction_data['next_malfunction'] = 0 agent_halts += 1 + if env.agents[0].malfunction_data['malfunction'] > 0: + agent_malfunctioning = True + else: + agent_malfunctioning = False + obs, all_rewards, done, _ = env.step(actions) - if done["__all__"]: - break + if agent_malfunctioning: + assert agent_old_position == env.agents[0].position + + agent_old_position = env.agents[0].position + total_down_time += env.agents[0].malfunction_data['malfunction'] + # Check that the agents breaks twice - assert env.agents[0].malfunction_data['nr_malfunctions'] == 2 + assert env.agents[0].malfunction_data['nr_malfunctions'] == 5 + + # Check that 11 stops where performed + assert agent_halts == 20 - # Check that 7 stops where performed - assert agent_halts == 7 + # Check that malfunctioning data was standing around + assert total_down_time > 0