Skip to content
Snippets Groups Projects
Commit 2121d51e authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

fixed malfunction tests

parent b9092b25
No related branches found
No related tags found
No related merge requests found
......@@ -82,7 +82,7 @@ def test_malfunction_process():
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
obs, info = env.reset(False, False, True, random_seed=0)
obs, info = env.reset(False, False, True, random_seed=10)
print(env.agents[0].malfunction_data)
# Check that a initial duration for malfunction was assigned
assert env.agents[0].malfunction_data['next_malfunction'] > 0
......@@ -151,7 +151,7 @@ def test_malfunction_process_statistically():
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(True, True, False, random_seed=0)
env.reset(True, True, False, random_seed=10)
env.agents[0].target = (0, 0)
nb_malfunction = 0
......@@ -163,11 +163,11 @@ def test_malfunction_process_statistically():
env.step(action_dict)
# check that generation of malfunctions works as expected
assert env.agents[0].malfunction_data["nr_malfunctions"] == 4
assert env.agents[0].malfunction_data["nr_malfunctions"] == 5
def test_malfunction_before_entry():
"""Tests hat malfunctions are produced by stochastic_data!"""
"""Tests that malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 2,
......@@ -179,28 +179,48 @@ def test_malfunction_before_entry():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
schedule_generator=random_schedule_generator(seed=2), # seed 12
number_of_agents=10,
random_seed=1,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, False, random_seed=0)
env.reset(False, False, False, random_seed=10)
env.agents[0].target = (0, 0)
nb_malfunction = 0
assert env.agents[1].malfunction_data['malfunction'] == 11
assert env.agents[2].malfunction_data['malfunction'] == 11
assert env.agents[3].malfunction_data['malfunction'] == 11
assert env.agents[4].malfunction_data['malfunction'] == 11
assert env.agents[5].malfunction_data['malfunction'] == 0
assert env.agents[6].malfunction_data['malfunction'] == 11
assert env.agents[7].malfunction_data['malfunction'] == 11
assert env.agents[8].malfunction_data['malfunction'] == 11
assert env.agents[9].malfunction_data['malfunction'] == 0
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
# We randomly select an action
action_dict[agent.handle] = RailEnvActions(2)
if step < 10:
action_dict[agent.handle] = RailEnvActions(0)
assert env.agents[0].malfunction_data['malfunction'] == 0
else:
action_dict[agent.handle] = RailEnvActions(2)
print(env.agents[0].malfunction_data)
env.step(action_dict)
assert env.agents[0].malfunction_data['malfunction'] > 0
assert env.agents[1].malfunction_data['malfunction'] == 1
assert env.agents[2].malfunction_data['malfunction'] == 1
assert env.agents[3].malfunction_data['malfunction'] == 1
assert env.agents[4].malfunction_data['malfunction'] == 1
assert env.agents[5].malfunction_data['malfunction'] == 2
assert env.agents[6].malfunction_data['malfunction'] == 1
assert env.agents[7].malfunction_data['malfunction'] == 1
assert env.agents[8].malfunction_data['malfunction'] == 1
assert env.agents[9].malfunction_data['malfunction'] == 3
# Print for test generation
# for a in range(env.get_num_agents()):
# print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,
# env.agents[a].malfunction_data[
# 'malfunction']))
def test_initial_malfunction():
......@@ -223,8 +243,8 @@ def test_initial_malfunction():
)
# reset to initialize agents_static
env.reset(False, False, True, random_seed=0)
env.reset(False, False, True, random_seed=10)
env.agents[0].target = (0, 5)
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment