Skip to content
Snippets Groups Projects
Commit 206a66f6 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

added new malfunction test to catch pre-active malfunction changes

parent 0725bab9
No related branches found
No related tags found
No related merge requests found
......@@ -177,6 +177,46 @@ def test_malfunction_process_statistically():
assert env.agents[0].malfunction_data["nr_malfunctions"] == 4
def test_malfunction_before_entry():
"""Tests hat malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 2,
'min_duration': 10,
'max_duration': 10}
random.seed(0)
np.random.seed(0)
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, False)
env.agents[0].target = (0, 0)
nb_malfunction = 0
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
# We randomly select an action
if step < 10:
action_dict[agent.handle] = RailEnvActions(0)
assert env.agents[0].malfunction_data['malfunction'] == 0
else:
action_dict[agent.handle] = RailEnvActions(2)
print(env.agents[0].malfunction_data)
env.step(action_dict)
assert env.agents[0].malfunction_data['malfunction'] > 0
def test_initial_malfunction():
random.seed(0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment