From 3fac7ebdfdc1b957cce5ef0d609ff069a4460415 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Tue, 8 Oct 2019 15:50:45 -0400 Subject: [PATCH] updated tests to new exponential distirbution random generator --- flatland/envs/rail_env.py | 6 ++---- tests/test_flatland_malfunction.py | 27 ++++++++++++++++++--------- tests/test_random_seeding.py | 20 ++++++++++++++++++++ 3 files changed, 40 insertions(+), 13 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index eb6108da..7e4697dd 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -360,10 +360,8 @@ class RailEnv(Environment): # Next malfunction in number of stops next_breakdown = int( - self.np_random.exponential(scale=agent.malfunction_data['malfunction_rate'])) - next_breakdown = self.np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + 1 - agent.malfunction_data['next_malfunction'] = 5 # next_breakdown + self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate'])) + agent.malfunction_data['next_malfunction'] = next_breakdown # Duration of current malfunction num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 2f735f21..fe41c2b1 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -170,7 +170,7 @@ def test_malfunction_before_entry(): """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 2, + 'malfunction_rate': 1, 'min_duration': 10, 'max_duration': 10} @@ -187,9 +187,17 @@ def test_malfunction_before_entry(): # reset to initialize agents_static env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) - for a in range(env.get_num_agents()): - print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[ - 'malfunction'])) + + assert env.agents[1].malfunction_data['malfunction'] == 11 + assert env.agents[2].malfunction_data['malfunction'] == 11 + assert env.agents[3].malfunction_data['malfunction'] == 11 + assert env.agents[4].malfunction_data['malfunction'] == 11 + assert env.agents[5].malfunction_data['malfunction'] == 11 + assert env.agents[6].malfunction_data['malfunction'] == 11 + assert env.agents[7].malfunction_data['malfunction'] == 11 + assert env.agents[8].malfunction_data['malfunction'] == 11 + assert env.agents[9].malfunction_data['malfunction'] == 11 + for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -200,16 +208,16 @@ def test_malfunction_before_entry(): action_dict[agent.handle] = RailEnvActions(0) env.step(action_dict) + assert env.agents[1].malfunction_data['malfunction'] == 1 assert env.agents[2].malfunction_data['malfunction'] == 1 assert env.agents[3].malfunction_data['malfunction'] == 1 assert env.agents[4].malfunction_data['malfunction'] == 1 - assert env.agents[5].malfunction_data['malfunction'] == 2 + assert env.agents[5].malfunction_data['malfunction'] == 1 assert env.agents[6].malfunction_data['malfunction'] == 1 assert env.agents[7].malfunction_data['malfunction'] == 1 assert env.agents[8].malfunction_data['malfunction'] == 1 - assert env.agents[9].malfunction_data['malfunction'] == 3 - + assert env.agents[9].malfunction_data['malfunction'] == 1 # Print for test generation # for a in range(env.get_num_agents()): # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, @@ -220,7 +228,7 @@ def test_malfunction_before_entry(): def test_initial_malfunction(): stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents - 'malfunction_rate': 70, # Rate of malfunction occurence + 'malfunction_rate': 100, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction 'max_duration': 5 # Max duration of malfunction } @@ -230,7 +238,7 @@ def test_initial_malfunction(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + schedule_generator=random_schedule_generator(seed=10), number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=SingleAgentNavigationObs() @@ -238,6 +246,7 @@ def test_initial_malfunction(): # reset to initialize agents_static env.reset(False, False, True, random_seed=10) + print(env.agents[0].malfunction_data) env.agents[0].target = (0, 5) set_penalties_for_replay(env) replay_config = ReplayConfig( diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 389dabe7..b374b21d 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -168,3 +168,23 @@ def test_seeding_and_malfunction(): assert env.agents[9].position == env2.agents[9].position for a in range(env.get_num_agents()): print("assert env.agents[{}].position == env2.agents[{}].position".format(a, a)) + + +def tests_new_distributio(): + def _exp_distirbution_synced(rate): + """ + Generates sample from exponential distribution + We need this to guarantee synchronity between different instances with same seed. + :param rate: + :return: + """ + u = np.random.rand() + x = - np.log(1 - u) * rate + return x + + numbers = [] + for i in range(100): + rate1 = 2 + rate2 = 100 + print((_exp_distirbution_synced(rate1), _exp_distirbution_synced(rate2))) + print(numbers) -- GitLab