Skip to content
Snippets Groups Projects
Commit cab970d7 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

added simple test for decay of next_malfunction step

parent 2ebc1dbe
No related branches found
No related tags found
No related merge requests found
......@@ -437,10 +437,6 @@ class RailEnv(Environment):
return True
# Decrease counter for next event only if agent is currently not broken and agent has a malfunction rate
if agent.malfunction_data['next_malfunction'] > 0 and agent.malfunction_data['malfunction'] < 1:
agent.malfunction_data['next_malfunction'] -= 1
def step(self, action_dict_: Dict[int, RailEnvActions]):
......
......@@ -156,14 +156,14 @@ def test_malfunction_process_statistically():
env.agents[0].target = (0, 0)
# Next line only for test generation
#agent_malfunction_list = [[] for i in range(20)]
# agent_malfunction_list = [[] for i in range(20)]
agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 4, 4, 0, 0, 0, 0],
[0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], [0, 0, 0, 0, 1, 1, 5, 0, 0, 4],
[0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 0, 0, 3, 4, 0, 2], [4, 5, 0, 3, 5, 5, 2, 3, 5, 1],
[3, 4, 0, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0], [1, 2, 4, 0, 2, 2, 0, 0, 2, 0],
[0, 1, 3, 0, 1, 1, 5, 0, 1, 0], [0, 0, 2, 0, 0, 0, 4, 0, 0, 0], [5, 0, 1, 0, 0, 0, 3, 5, 0, 5],
[4, 0, 0, 0, 5, 0, 2, 4, 0, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3], [2, 0, 0, 0, 3, 0, 0, 2, 4, 2],
[1, 0, 5, 5, 2, 0, 0, 1, 3, 1], [0, 5, 4, 4, 1, 0, 5, 0, 2, 0]]
[0, 0, 0, 0, 3, 3, 5, 0, 0, 0], [5, 0, 0, 5, 2, 2, 4, 5, 0, 5], [4, 5, 0, 4, 1, 1, 3, 4, 5, 4],
[3, 4, 0, 3, 0, 0, 2, 3, 4, 3], [2, 3, 5, 2, 0, 0, 1, 2, 3, 2], [1, 2, 4, 1, 5, 5, 0, 1, 2, 1],
[0, 1, 3, 0, 4, 4, 0, 0, 1, 0], [0, 0, 2, 0, 3, 3, 0, 0, 0, 0], [5, 0, 1, 0, 2, 2, 5, 5, 0, 5],
[4, 0, 0, 0, 1, 1, 4, 4, 5, 4], [3, 0, 0, 5, 0, 0, 3, 3, 4, 3], [2, 5, 0, 4, 0, 0, 2, 2, 3, 2],
[1, 4, 0, 3, 5, 5, 1, 1, 2, 1], [0, 3, 0, 2, 4, 4, 0, 0, 1, 0], [0, 2, 0, 1, 3, 3, 0, 0, 0, 0],
[5, 1, 0, 0, 2, 2, 5, 5, 0, 5], [4, 0, 5, 0, 1, 1, 4, 4, 5, 4]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
......@@ -202,16 +202,16 @@ def test_malfunction_before_entry():
# Test initial malfunction values for all agents
# we want some agents to be malfuncitoning already and some to be working
# we want different next_malfunction values for the agents
assert env.agents[0].malfunction_data['next_malfunction'] == 5
assert env.agents[1].malfunction_data['next_malfunction'] == 6
assert env.agents[2].malfunction_data['next_malfunction'] == 6
assert env.agents[3].malfunction_data['next_malfunction'] == 3
assert env.agents[1].malfunction_data['next_malfunction'] == 5
assert env.agents[2].malfunction_data['next_malfunction'] == 5
assert env.agents[3].malfunction_data['next_malfunction'] == 2
assert env.agents[4].malfunction_data['next_malfunction'] == 1
assert env.agents[5].malfunction_data['next_malfunction'] == 1
assert env.agents[6].malfunction_data['next_malfunction'] == 3
assert env.agents[7].malfunction_data['next_malfunction'] == 4
assert env.agents[8].malfunction_data['next_malfunction'] == 6
assert env.agents[9].malfunction_data['next_malfunction'] == 0
assert env.agents[6].malfunction_data['next_malfunction'] == 2
assert env.agents[7].malfunction_data['next_malfunction'] == 3
assert env.agents[8].malfunction_data['next_malfunction'] == 5
assert env.agents[9].malfunction_data['next_malfunction'] == -1
assert env.agents[0].malfunction_data['malfunction'] == 0
assert env.agents[1].malfunction_data['malfunction'] == 0
assert env.agents[2].malfunction_data['malfunction'] == 0
......@@ -223,50 +223,41 @@ def test_malfunction_before_entry():
assert env.agents[8].malfunction_data['malfunction'] == 0
assert env.agents[9].malfunction_data['malfunction'] == 0
for a in range(env.get_num_agents()):
print("assert env.agents[{}].malfunction_data['next_malfunction'] == {}".format(a, env.agents[a].malfunction_data['next_malfunction']))
for a in range(env.get_num_agents()):
print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[
'malfunction']))
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
# We randomly select an action
action_dict[agent.handle] = RailEnvActions(2)
if step < 10:
action_dict[agent.handle] = RailEnvActions(0)
def test_next_malfunction_counter():
"""
Test that the next malfunction occurs when desired
Returns
-------
"""
# Set fixed malfunction duration for this test
rail, rail_map = make_simple_rail2()
action_dict: Dict[int, RailEnvActions] = {}
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=2), # seed 12
number_of_agents=1,
random_seed=1,
)
# reset to initialize agents_static
env.reset(False, False, activate_agents=True, random_seed=10)
env.agents[0].malfunction_data['next_malfunction'] = 5
env.agents[0].malfunction_data['malfunction_rate'] = 5
env.agents[0].malfunction_data['malfunction'] = 0
env.agents[0].target =(0, 0), #Move the target out of range
print(env.agents[0].position, env.agents[0].malfunction_data['next_malfunction'])
for time_step in range(1, 6):
# Move in the env
env.step(action_dict)
# We want to check that all agents are malfunctioning and that their values changed
# Check that next_step decreases as expected
assert env.agents[0].malfunction_data['next_malfunction'] == 5 - time_step
# Test malfunction values for all agents after 20 steps
assert env.agents[0].malfunction_data['next_malfunction'] == 4
assert env.agents[1].malfunction_data['next_malfunction'] == 6
assert env.agents[2].malfunction_data['next_malfunction'] == 2
assert env.agents[3].malfunction_data['next_malfunction'] == 2
assert env.agents[4].malfunction_data['next_malfunction'] == 1
assert env.agents[5].malfunction_data['next_malfunction'] == 1
assert env.agents[6].malfunction_data['next_malfunction'] == 2
assert env.agents[7].malfunction_data['next_malfunction'] == 1
assert env.agents[8].malfunction_data['next_malfunction'] == 1
assert env.agents[9].malfunction_data['next_malfunction'] == 4
assert env.agents[0].malfunction_data['malfunction'] == 0
assert env.agents[1].malfunction_data['malfunction'] == 8
assert env.agents[2].malfunction_data['malfunction'] == 8
assert env.agents[3].malfunction_data['malfunction'] == 0
assert env.agents[4].malfunction_data['malfunction'] == 1
assert env.agents[5].malfunction_data['malfunction'] == 1
assert env.agents[6].malfunction_data['malfunction'] == 0
assert env.agents[7].malfunction_data['malfunction'] == 6
assert env.agents[8].malfunction_data['malfunction'] == 8
assert env.agents[9].malfunction_data['malfunction'] == 2
# Print for test generation
#for a in range(env.get_num_agents()):
# print("assert env.agents[{}].malfunction_data['next_malfunction'] == {}".format(a, env.agents[a].malfunction_data['next_malfunction']))
#for a in range(env.get_num_agents()):
# print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[
# 'malfunction']))
def test_initial_malfunction():
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment