diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index eb6108da77580313d7e9c517de6fa45e8b12a617..7e4697dd6743f2d23fe2dcaf7abe1ddebb2020f0 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -360,10 +360,8 @@ class RailEnv(Environment): # Next malfunction in number of stops next_breakdown = int( - self.np_random.exponential(scale=agent.malfunction_data['malfunction_rate'])) - next_breakdown = self.np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + 1 - agent.malfunction_data['next_malfunction'] = 5 # next_breakdown + self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate'])) + agent.malfunction_data['next_malfunction'] = next_breakdown # Duration of current malfunction num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 2f735f219031ed325e37ff46ce9969abf27f5936..fe41c2b1461fcb86554ee7b1cefbd34ce408e8dd 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -170,7 +170,7 @@ def test_malfunction_before_entry(): """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 2, + 'malfunction_rate': 1, 'min_duration': 10, 'max_duration': 10} @@ -187,9 +187,17 @@ def test_malfunction_before_entry(): # reset to initialize agents_static env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) - for a in range(env.get_num_agents()): - print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[ - 'malfunction'])) + + assert env.agents[1].malfunction_data['malfunction'] == 11 + assert env.agents[2].malfunction_data['malfunction'] == 11 + assert env.agents[3].malfunction_data['malfunction'] == 11 + assert env.agents[4].malfunction_data['malfunction'] == 11 + assert env.agents[5].malfunction_data['malfunction'] == 11 + assert env.agents[6].malfunction_data['malfunction'] == 11 + assert env.agents[7].malfunction_data['malfunction'] == 11 + assert env.agents[8].malfunction_data['malfunction'] == 11 + assert env.agents[9].malfunction_data['malfunction'] == 11 + for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -200,16 +208,16 @@ def test_malfunction_before_entry(): action_dict[agent.handle] = RailEnvActions(0) env.step(action_dict) + assert env.agents[1].malfunction_data['malfunction'] == 1 assert env.agents[2].malfunction_data['malfunction'] == 1 assert env.agents[3].malfunction_data['malfunction'] == 1 assert env.agents[4].malfunction_data['malfunction'] == 1 - assert env.agents[5].malfunction_data['malfunction'] == 2 + assert env.agents[5].malfunction_data['malfunction'] == 1 assert env.agents[6].malfunction_data['malfunction'] == 1 assert env.agents[7].malfunction_data['malfunction'] == 1 assert env.agents[8].malfunction_data['malfunction'] == 1 - assert env.agents[9].malfunction_data['malfunction'] == 3 - + assert env.agents[9].malfunction_data['malfunction'] == 1 # Print for test generation # for a in range(env.get_num_agents()): # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, @@ -220,7 +228,7 @@ def test_malfunction_before_entry(): def test_initial_malfunction(): stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents - 'malfunction_rate': 70, # Rate of malfunction occurence + 'malfunction_rate': 100, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction 'max_duration': 5 # Max duration of malfunction } @@ -230,7 +238,7 @@ def test_initial_malfunction(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + schedule_generator=random_schedule_generator(seed=10), number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=SingleAgentNavigationObs() @@ -238,6 +246,7 @@ def test_initial_malfunction(): # reset to initialize agents_static env.reset(False, False, True, random_seed=10) + print(env.agents[0].malfunction_data) env.agents[0].target = (0, 5) set_penalties_for_replay(env) replay_config = ReplayConfig( diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 389dabe7afcee52e1a716769836b914e80c080ce..b374b21dc8a17e950b48aeae85b570a0067d3f6a 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -168,3 +168,23 @@ def test_seeding_and_malfunction(): assert env.agents[9].position == env2.agents[9].position for a in range(env.get_num_agents()): print("assert env.agents[{}].position == env2.agents[{}].position".format(a, a)) + + +def tests_new_distributio(): + def _exp_distirbution_synced(rate): + """ + Generates sample from exponential distribution + We need this to guarantee synchronity between different instances with same seed. + :param rate: + :return: + """ + u = np.random.rand() + x = - np.log(1 - u) * rate + return x + + numbers = [] + for i in range(100): + rate1 = 2 + rate2 = 100 + print((_exp_distirbution_synced(rate1), _exp_distirbution_synced(rate2))) + print(numbers)