Commit 0725bab9 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

fixed malfunction tests

parent 285e714f
Pipeline #2387 passed with stages
in 36 minutes and 12 seconds
......@@ -283,8 +283,8 @@ class RailEnv(Environment):
self.set_agent_active(i_agent)
for i_agent, agent in enumerate(self.agents):
if agent.status != RailAgentStatus.ACTIVE:
continue
# if agent.status != RailAgentStatus.ACTIVE:
# continue
# A proportion of agent in the environment will receive a positive malfunction rate
if np.random.random() < self.proportion_malfunctioning_trains:
......
......@@ -129,7 +129,7 @@ def test_malfunction_process():
total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the appropriate number of malfunctions is achieved
assert env.agents[0].malfunction_data['nr_malfunctions'] == 11, "Actual {}".format(
assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
env.agents[0].malfunction_data['nr_malfunctions'])
# Check that 20 stops where performed
......@@ -150,11 +150,6 @@ def test_malfunction_process_statistically():
random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
rail, rail_map = make_simple_rail2()
......@@ -167,20 +162,19 @@ def test_malfunction_process_statistically():
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, True)
env.reset(False, False, False)
env.agents[0].target = (0, 0)
nb_malfunction = 0
for step in range(100):
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
if agent.malfunction_data['malfunction'] > 0:
nb_malfunction += 1
# We randomly select an action
action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
env.step(action_dict)
# check that generation of malfunctions works as expected
assert nb_malfunction == 3, "nb_malfunction={}".format(nb_malfunction)
assert env.agents[0].malfunction_data["nr_malfunctions"] == 4
def test_initial_malfunction():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment