Skip to content
Snippets Groups Projects
Commit d7b5e131 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

fixed divergence test

parent d5b16a52
No related branches found
No related tags found
No related merge requests found
......@@ -398,8 +398,8 @@ class RailEnv(Environment):
Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run
"""
if np.random.random() < self._malfunction_prob(rate):
breaking_agent = random.choice(self.agents)
if self.np_random.randn() < self._malfunction_prob(rate):
breaking_agent = self.np_random.choice(self.agents)
if breaking_agent.malfunction_data['malfunction'] < 1:
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1)
......
......@@ -120,8 +120,7 @@ def test_malfunction_process():
def test_malfunction_process_statistically():
"""Tests hat malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 5,
stochastic_data = {'malfunction_rate': 5,
'min_duration': 5,
'max_duration': 5}
......@@ -142,19 +141,19 @@ def test_malfunction_process_statistically():
env.agents[0].target = (0, 0)
# Next line only for test generation
# agent_malfunction_list = [[] for i in range(20)]
agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0],
[0, 0, 0, 0, 4, 4, 0, 0, 0, 0],
[0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5],
[0, 0, 0, 0, 1, 1, 5, 0, 0, 4],
[0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 0, 0, 3, 4, 0, 2],
[4, 5, 0, 3, 5, 5, 2, 3, 5, 1],
[3, 4, 0, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0],
[1, 2, 4, 0, 2, 2, 0, 0, 2, 0],
[0, 1, 3, 0, 1, 1, 5, 0, 1, 0], [0, 0, 2, 0, 0, 0, 4, 0, 0, 0],
[5, 0, 1, 0, 0, 0, 3, 5, 0, 5],
[4, 0, 0, 0, 5, 0, 2, 4, 0, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3],
[2, 0, 0, 0, 3, 0, 0, 2, 4, 2],
[1, 0, 5, 5, 2, 0, 0, 1, 3, 1], [0, 5, 4, 4, 1, 0, 5, 0, 2, 0]]
agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0, 3],
[4, 0, 0, 0, 0, 0, 0, 0, 0, 2],
[3, 0, 0, 0, 0, 0, 0, 0, 0, 1], [2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 4, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 3, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 2, 4, 0, 0], [0, 0, 0, 0, 0, 0, 1, 3, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 2, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0, 3],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [4, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
......@@ -170,10 +169,9 @@ def test_malfunction_process_statistically():
def test_malfunction_before_entry():
"""Tests that malfunctions are working properlz for agents before entering the environment!"""
"""Tests that malfunctions are working properly for agents before entering the environment!"""
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 5,
stochastic_data = {'malfunction_rate': 1,
'min_duration': 10,
'max_duration': 10}
......@@ -193,51 +191,15 @@ def test_malfunction_before_entry():
# Test initial malfunction values for all agents
# we want some agents to be malfuncitoning already and some to be working
# we want different next_malfunction values for the agents
assert env.agents[0].malfunction_data['malfunction'] == 0
assert env.agents[1].malfunction_data['malfunction'] == 0
assert env.agents[2].malfunction_data['malfunction'] == 0
assert env.agents[3].malfunction_data['malfunction'] == 0
assert env.agents[4].malfunction_data['malfunction'] == 10
assert env.agents[5].malfunction_data['malfunction'] == 10
assert env.agents[4].malfunction_data['malfunction'] == 0
assert env.agents[5].malfunction_data['malfunction'] == 0
assert env.agents[6].malfunction_data['malfunction'] == 0
assert env.agents[7].malfunction_data['malfunction'] == 0
assert env.agents[8].malfunction_data['malfunction'] == 0
assert env.agents[9].malfunction_data['malfunction'] == 0
def test_next_malfunction_counter():
"""
Test that the next malfunction occurs when desired
Returns
-------
"""
# Set fixed malfunction duration for this test
rail, rail_map = make_simple_rail2()
action_dict: Dict[int, RailEnvActions] = {}
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=2), # seed 12
number_of_agents=1,
random_seed=1,
)
# reset to initialize agents_static
env.reset(False, False, activate_agents=True, random_seed=10)
env.agents[0].malfunction_data['next_malfunction'] = 5
env.agents[0].malfunction_data['malfunction_rate'] = 5
env.agents[0].malfunction_data['malfunction'] = 0
env.agents[0].target = (0, 0), # Move the target out of range
print(env.agents[0].position, env.agents[0].malfunction_data['next_malfunction'])
for time_step in range(1, 6):
# Move in the env
env.step(action_dict)
# Check that next_step decreases as expected
assert env.agents[0].malfunction_data['next_malfunction'] == 5 - time_step
assert env.agents[9].malfunction_data['malfunction'] == 9
def test_malfunction_values_and_behavior():
......@@ -251,8 +213,7 @@ def test_malfunction_values_and_behavior():
rail, rail_map = make_simple_rail2()
action_dict: Dict[int, RailEnvActions] = {}
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 5,
stochastic_data = {'malfunction_rate': 5,
'min_duration': 10,
'max_duration': 10}
env = RailEnv(width=25,
......@@ -263,23 +224,18 @@ def test_malfunction_values_and_behavior():
number_of_agents=1,
random_seed=1,
)
# reset to initialize agents_static
env.reset(False, False, activate_agents=True, random_seed=10)
env.agents[0].malfunction_data['next_malfunction'] = 5
env.agents[0].malfunction_data['malfunction_rate'] = 50
env.agents[0].malfunction_data['malfunction'] = 0
env.agents[0].target = (0, 0), # Move the target out of range
print(env.agents[0].position, env.agents[0].malfunction_data['next_malfunction'])
for time_step in range(1, 16):
# Assertions
assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 0, 9, 8, 7, 6]
print("[")
for time_step in range(15):
# Move in the env
env.step(action_dict)
print(time_step)
# Check that next_step decreases as expected
if env.agents[0].malfunction_data['malfunction'] < 1:
assert env.agents[0].malfunction_data['next_malfunction'] == np.clip(5 - time_step, 0, 100)
else:
assert env.agents[0].malfunction_data['malfunction'] == np.clip(10 - (time_step - 6), 0, 100)
assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
def test_initial_malfunction():
......@@ -529,45 +485,14 @@ def test_initial_malfunction_do_nothing():
run_replay_config(env, [replay_config], activate_agents=False)
def test_initial_nextmalfunction_not_below_zero():
random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset()
agent = env.agents[0]
env.step({})
# was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
assert agent.malfunction_data['next_malfunction'] >= 0, \
"next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction'])
def tests_random_interference_from_outside():
"""Tests that malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 1,
stochastic_data = {'malfunction_rate': 1,
'min_duration': 10,
'max_duration': 10}
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
......@@ -579,9 +504,7 @@ def tests_random_interference_from_outside():
env.reset()
# reset to initialize agents_static
env.agents[0].speed_data['speed'] = 0.33
env.agents[0].initial_position = (3, 0)
env.agents[0].target = (3, 9)
env.reset(False, False, False)
env.reset(False, False, False, random_seed=10)
env_data = []
for step in range(200):
......@@ -612,11 +535,8 @@ def tests_random_interference_from_outside():
env.reset()
# reset to initialize agents_static
env.agents[0].speed_data['speed'] = 0.33
env.agents[0].initial_position = (3, 0)
env.agents[0].target = (3, 9)
env.reset(False, False, False)
env.reset(False, False, False, random_seed=10)
# Print for test generation
dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
for step in range(200):
action_dict: Dict[int, RailEnvActions] = {}
......
......@@ -118,8 +118,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
# recognizes the agent as potentially malfuncitoning
# We also set next malfunction to infitiy to avoid interference with our tests
agent.malfunction_data['malfunction'] = replay.set_malfunction
agent.malfunction_data['malfunction_rate'] = max(agent.malfunction_data['malfunction_rate'], 1)
agent.malfunction_data['next_malfunction'] = np.inf
agent.malfunction_data['moving_before_malfunction'] = agent.moving
_assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
print(step)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment