From 1158cdc213b69092a81b115e9274179340f2209e Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Sat, 10 Aug 2019 16:37:16 -0400 Subject: [PATCH] updated malfunction test --- flatland/envs/rail_env.py | 2 ++ tests/test_flatland_malfunction.py | 34 ++++++++++++++++++++++++------ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index d62c689a..4f016f2b 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -238,6 +238,8 @@ class RailEnv(Environment): agent.speed_data['position_fraction'] = 0.0 agent.malfunction_data['malfunction'] = 0 + self._agent_stopped(i_agent) + self.num_resets += 1 self._elapsed_steps = 0 diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 91c551db..4122d56f 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -51,10 +51,11 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): def test_malfunction_process(): + # Set fixed malfunction duration for this test stochastic_data = {'prop_malfunction': 1., 'malfunction_rate': 5, 'min_duration': 3, - 'max_duration': 10} + 'max_duration': 3} np.random.seed(5) env = RailEnv(width=14, @@ -66,23 +67,44 @@ def test_malfunction_process(): stochastic_data=stochastic_data) obs = env.reset() + + # Check that a initial duration for malfunction was assigned + assert env.agents[0].malfunction_data['next_malfunction'] > 0 + agent_halts = 0 + total_down_time = 0 + agent_malfunctioning = False + agent_old_position = env.agents[0].position for step in range(100): actions = {} for i in range(len(obs)): actions[i] = np.argmax(obs[i]) + 1 if step % 5 == 0: + # Stop the agent and set it to be malfunctioning actions[0] = 4 + env.agents[0].malfunction_data['next_malfunction'] = 0 agent_halts += 1 + if env.agents[0].malfunction_data['malfunction'] > 0: + agent_malfunctioning = True + else: + agent_malfunctioning = False + obs, all_rewards, done, _ = env.step(actions) - if done["__all__"]: - break + if agent_malfunctioning: + assert agent_old_position == env.agents[0].position + + agent_old_position = env.agents[0].position + total_down_time += env.agents[0].malfunction_data['malfunction'] + # Check that the agents breaks twice - assert env.agents[0].malfunction_data['nr_malfunctions'] == 2 + assert env.agents[0].malfunction_data['nr_malfunctions'] == 5 + + # Check that 11 stops where performed + assert agent_halts == 20 - # Check that 7 stops where performed - assert agent_halts == 7 + # Check that malfunctioning data was standing around + assert total_down_time > 0 -- GitLab