Skip to content
Snippets Groups Projects
Commit 3fac7ebd authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated tests to new exponential distirbution random generator

parent 0ff89ada
No related branches found
No related tags found
No related merge requests found
...@@ -360,10 +360,8 @@ class RailEnv(Environment): ...@@ -360,10 +360,8 @@ class RailEnv(Environment):
# Next malfunction in number of stops # Next malfunction in number of stops
next_breakdown = int( next_breakdown = int(
self.np_random.exponential(scale=agent.malfunction_data['malfunction_rate'])) self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate']))
next_breakdown = self.np_random.randint(self.min_number_of_steps_broken, agent.malfunction_data['next_malfunction'] = next_breakdown
self.max_number_of_steps_broken + 1) + 1
agent.malfunction_data['next_malfunction'] = 5 # next_breakdown
# Duration of current malfunction # Duration of current malfunction
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
......
...@@ -170,7 +170,7 @@ def test_malfunction_before_entry(): ...@@ -170,7 +170,7 @@ def test_malfunction_before_entry():
"""Tests that malfunctions are produced by stochastic_data!""" """Tests that malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test # Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1., stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 2, 'malfunction_rate': 1,
'min_duration': 10, 'min_duration': 10,
'max_duration': 10} 'max_duration': 10}
...@@ -187,9 +187,17 @@ def test_malfunction_before_entry(): ...@@ -187,9 +187,17 @@ def test_malfunction_before_entry():
# reset to initialize agents_static # reset to initialize agents_static
env.reset(False, False, False, random_seed=10) env.reset(False, False, False, random_seed=10)
env.agents[0].target = (0, 0) env.agents[0].target = (0, 0)
for a in range(env.get_num_agents()):
print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[ assert env.agents[1].malfunction_data['malfunction'] == 11
'malfunction'])) assert env.agents[2].malfunction_data['malfunction'] == 11
assert env.agents[3].malfunction_data['malfunction'] == 11
assert env.agents[4].malfunction_data['malfunction'] == 11
assert env.agents[5].malfunction_data['malfunction'] == 11
assert env.agents[6].malfunction_data['malfunction'] == 11
assert env.agents[7].malfunction_data['malfunction'] == 11
assert env.agents[8].malfunction_data['malfunction'] == 11
assert env.agents[9].malfunction_data['malfunction'] == 11
for step in range(20): for step in range(20):
action_dict: Dict[int, RailEnvActions] = {} action_dict: Dict[int, RailEnvActions] = {}
...@@ -200,16 +208,16 @@ def test_malfunction_before_entry(): ...@@ -200,16 +208,16 @@ def test_malfunction_before_entry():
action_dict[agent.handle] = RailEnvActions(0) action_dict[agent.handle] = RailEnvActions(0)
env.step(action_dict) env.step(action_dict)
assert env.agents[1].malfunction_data['malfunction'] == 1 assert env.agents[1].malfunction_data['malfunction'] == 1
assert env.agents[2].malfunction_data['malfunction'] == 1 assert env.agents[2].malfunction_data['malfunction'] == 1
assert env.agents[3].malfunction_data['malfunction'] == 1 assert env.agents[3].malfunction_data['malfunction'] == 1
assert env.agents[4].malfunction_data['malfunction'] == 1 assert env.agents[4].malfunction_data['malfunction'] == 1
assert env.agents[5].malfunction_data['malfunction'] == 2 assert env.agents[5].malfunction_data['malfunction'] == 1
assert env.agents[6].malfunction_data['malfunction'] == 1 assert env.agents[6].malfunction_data['malfunction'] == 1
assert env.agents[7].malfunction_data['malfunction'] == 1 assert env.agents[7].malfunction_data['malfunction'] == 1
assert env.agents[8].malfunction_data['malfunction'] == 1 assert env.agents[8].malfunction_data['malfunction'] == 1
assert env.agents[9].malfunction_data['malfunction'] == 3 assert env.agents[9].malfunction_data['malfunction'] == 1
# Print for test generation # Print for test generation
# for a in range(env.get_num_agents()): # for a in range(env.get_num_agents()):
# print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,
...@@ -220,7 +228,7 @@ def test_malfunction_before_entry(): ...@@ -220,7 +228,7 @@ def test_malfunction_before_entry():
def test_initial_malfunction(): def test_initial_malfunction():
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence 'malfunction_rate': 100, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction 'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction 'max_duration': 5 # Max duration of malfunction
} }
...@@ -230,7 +238,7 @@ def test_initial_malfunction(): ...@@ -230,7 +238,7 @@ def test_initial_malfunction():
env = RailEnv(width=25, env = RailEnv(width=25,
height=30, height=30,
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), schedule_generator=random_schedule_generator(seed=10),
number_of_agents=1, number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs() obs_builder_object=SingleAgentNavigationObs()
...@@ -238,6 +246,7 @@ def test_initial_malfunction(): ...@@ -238,6 +246,7 @@ def test_initial_malfunction():
# reset to initialize agents_static # reset to initialize agents_static
env.reset(False, False, True, random_seed=10) env.reset(False, False, True, random_seed=10)
print(env.agents[0].malfunction_data)
env.agents[0].target = (0, 5) env.agents[0].target = (0, 5)
set_penalties_for_replay(env) set_penalties_for_replay(env)
replay_config = ReplayConfig( replay_config = ReplayConfig(
......
...@@ -168,3 +168,23 @@ def test_seeding_and_malfunction(): ...@@ -168,3 +168,23 @@ def test_seeding_and_malfunction():
assert env.agents[9].position == env2.agents[9].position assert env.agents[9].position == env2.agents[9].position
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
print("assert env.agents[{}].position == env2.agents[{}].position".format(a, a)) print("assert env.agents[{}].position == env2.agents[{}].position".format(a, a))
def tests_new_distributio():
def _exp_distirbution_synced(rate):
"""
Generates sample from exponential distribution
We need this to guarantee synchronity between different instances with same seed.
:param rate:
:return:
"""
u = np.random.rand()
x = - np.log(1 - u) * rate
return x
numbers = []
for i in range(100):
rate1 = 2
rate2 = 100
print((_exp_distirbution_synced(rate1), _exp_distirbution_synced(rate2)))
print(numbers)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment