diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index aa8e48023694c96941b59d6f78b3fd93edf81e9e..eb6108da77580313d7e9c517de6fa45e8b12a617 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -188,7 +188,7 @@ class RailEnv(Environment): self.distance_map = DistanceMap(self.agents, self.height, self.width) self.action_space = [1] - + self._seed() self._seed() @@ -361,7 +361,9 @@ class RailEnv(Environment): # Next malfunction in number of stops next_breakdown = int( self.np_random.exponential(scale=agent.malfunction_data['malfunction_rate'])) - agent.malfunction_data['next_malfunction'] = next_breakdown + next_breakdown = self.np_random.randint(self.min_number_of_steps_broken, + self.max_number_of_steps_broken + 1) + 1 + agent.malfunction_data['next_malfunction'] = 5 # next_breakdown # Duration of current malfunction num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, @@ -754,3 +756,14 @@ class RailEnv(Environment): from importlib_resources import read_binary load_data = read_binary(package, resource) self.set_full_state_msg(load_data) + + def _exp_distirbution_synced(self, rate): + """ + Generates sample from exponential distribution + We need this to guarantee synchronity between different instances with same seed. + :param rate: + :return: + """ + u = self.np_random.rand() + x = - np.log(1 - u) * rate + return x diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 73c831a426d55abf45217d66cba57481f6cc63ea..2f735f219031ed325e37ff46ce9969abf27f5936 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -187,15 +187,9 @@ def test_malfunction_before_entry(): # reset to initialize agents_static env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) - assert env.agents[1].malfunction_data['malfunction'] == 11 - assert env.agents[2].malfunction_data['malfunction'] == 11 - assert env.agents[3].malfunction_data['malfunction'] == 11 - assert env.agents[4].malfunction_data['malfunction'] == 11 - assert env.agents[5].malfunction_data['malfunction'] == 0 - assert env.agents[6].malfunction_data['malfunction'] == 11 - assert env.agents[7].malfunction_data['malfunction'] == 11 - assert env.agents[8].malfunction_data['malfunction'] == 11 - assert env.agents[9].malfunction_data['malfunction'] == 0 + for a in range(env.get_num_agents()): + print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[ + 'malfunction'])) for step in range(20): action_dict: Dict[int, RailEnvActions] = {} diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index fa4c6b975f792072815dba83a616ae89185731d6..389dabe7afcee52e1a716769836b914e80c080ce 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -61,8 +61,8 @@ def test_seeding_and_observations(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) ) - env.reset(False, False, False) - env2.reset(False, False, False) + env.reset(False, False, False, random_seed=12) + env2.reset(False, False, False, random_seed=12) # Check that both environments produce the same initial start positions assert env.agents[0].initial_position == env2.agents[0].initial_position @@ -129,8 +129,8 @@ def test_seeding_and_malfunction(): stochastic_data=stochastic_data, # Malfunction data generator ) - env.reset(False, False, False) - env2.reset(False, False, False) + env.reset(False, False, False, random_seed=12) + env2.reset(False, False, False, random_seed=12) # Check that both environments produce the same initial start positions assert env.agents[0].initial_position == env2.agents[0].initial_position @@ -149,9 +149,11 @@ def test_seeding_and_malfunction(): for a in range(env.get_num_agents()): action = np.random.randint(4) action_dict[a] = action + print(env.agents[a].malfunction_data['malfunction'], env2.agents[a].malfunction_data['malfunction']) env.step(action_dict) env2.step(action_dict) + # Check that both environments end up in the same position assert env.agents[0].position == env2.agents[0].position