Skip to content
Snippets Groups Projects
Commit 1158cdc2 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated malfunction test

parent 2ec6d6f4
No related branches found
No related tags found
No related merge requests found
......@@ -238,6 +238,8 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0
agent.malfunction_data['malfunction'] = 0
self._agent_stopped(i_agent)
self.num_resets += 1
self._elapsed_steps = 0
......
......@@ -51,10 +51,11 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
def test_malfunction_process():
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 5,
'min_duration': 3,
'max_duration': 10}
'max_duration': 3}
np.random.seed(5)
env = RailEnv(width=14,
......@@ -66,23 +67,44 @@ def test_malfunction_process():
stochastic_data=stochastic_data)
obs = env.reset()
# Check that a initial duration for malfunction was assigned
assert env.agents[0].malfunction_data['next_malfunction'] > 0
agent_halts = 0
total_down_time = 0
agent_malfunctioning = False
agent_old_position = env.agents[0].position
for step in range(100):
actions = {}
for i in range(len(obs)):
actions[i] = np.argmax(obs[i]) + 1
if step % 5 == 0:
# Stop the agent and set it to be malfunctioning
actions[0] = 4
env.agents[0].malfunction_data['next_malfunction'] = 0
agent_halts += 1
if env.agents[0].malfunction_data['malfunction'] > 0:
agent_malfunctioning = True
else:
agent_malfunctioning = False
obs, all_rewards, done, _ = env.step(actions)
if done["__all__"]:
break
if agent_malfunctioning:
assert agent_old_position == env.agents[0].position
agent_old_position = env.agents[0].position
total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the agents breaks twice
assert env.agents[0].malfunction_data['nr_malfunctions'] == 2
assert env.agents[0].malfunction_data['nr_malfunctions'] == 5
# Check that 11 stops where performed
assert agent_halts == 20
# Check that 7 stops where performed
assert agent_halts == 7
# Check that malfunctioning data was standing around
assert total_down_time > 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment