diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index bf438a565bc177b347e5f0d696f2c62fcccde098..e45bc0ae6fba6586c3bf35dcd7b48c0c4b7c34ee 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -162,6 +162,7 @@ class RailEnv(Environment): self.rail: Optional[GridTransitionMap] = None self.width = width self.height = height + self._seed() self.remove_agents_at_target = remove_agents_at_target @@ -187,8 +188,6 @@ class RailEnv(Environment): self.action_space = [1] - self._seed() - # Stochastic train malfunctioning parameters if stochastic_data is not None: prop_malfunction = stochastic_data['prop_malfunction'] diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index 92f1a49d777552b2108aff3963aa5c1bc84fdbfc..d9c9ae9915da6521c39733cce649e2d64f41f1e6 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -166,6 +166,8 @@ class FlatlandRemoteClient(object): _request['payload'] = {} _response = self._blocking_request(_request) observation = _response['payload']['observation'] + info = _response['payload']['info'] + random_seed = _response['payload']['random_seed'] if not observation: # If the observation is False, @@ -196,10 +198,18 @@ class FlatlandRemoteClient(object): self.env._max_episode_steps = \ int(1.5 * (self.env.width + self.env.height)) - local_observation = self.env.reset() + local_observation = self.env.reset(random_seed=random_seed) + + local_observation, info = self.env.reset( + regen_rail=False, + replace_agents=False, + activate_agents=False, + random_seed=random_seed + ) + # Use the local observation # as the remote server uses a dummy observation builder - return local_observation + return local_observation, info def env_step(self, action, render=False): """ diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 4f273be466a1c2f95b55cafece4783bad91e0d2c..023730dce5ddce411b7aa9951e9e8aa8f2f04f8a 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -332,14 +332,22 @@ class FlatlandRemoteEvaluationService: self.simulation_steps.append(0) self.current_step = 0 - - _observation = self.env.reset() + + RANDOM_SEED = 1001 + _observation, _info = self.env.reset( + regen_rail=False, + replace_agents=False, + activate_agents=False, + random_seed=RANDOM_SEED + ) _command_response = {} _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE _command_response['payload'] = {} _command_response['payload']['observation'] = _observation _command_response['payload']['env_file_path'] = self.env_file_paths[self.simulation_count] + _command_response['payload']['info'] = _info + _command_response['payload']['random_seed'] = RANDOM_SEED else: """ All test env evaluations are complete @@ -349,6 +357,8 @@ class FlatlandRemoteEvaluationService: _command_response['payload'] = {} _command_response['payload']['observation'] = False _command_response['payload']['env_file_path'] = False + _command_response['payload']['info'] = False + _command_response['payload']['random_seed'] = RANDOM_SEED self.send_response(_command_response, command) ##################################################################### diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index d5fce5d6c8fa4cf37b4c72f6264726f6d502f77d..e87bd93d546ab2b1436f49b15a803b5dac16d0b1 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -70,8 +70,8 @@ def test_malfunction_process(): 'malfunction_rate': 1000, 'min_duration': 3, 'max_duration': 3} - random.seed(0) - np.random.seed(0) + # random.seed(0) + # np.random.seed(0) stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence @@ -90,9 +90,7 @@ def test_malfunction_process(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset() - - obs = env.reset(False, False, True) + obs = env.reset(False, False, True, random_seed=0) # Check that a initial duration for malfunction was assigned assert env.agents[0].malfunction_data['next_malfunction'] > 0 @@ -147,10 +145,6 @@ def test_malfunction_process_statistically(): 'min_duration': 3, 'max_duration': 3} - random.seed(0) - np.random.seed(0) - - rail, rail_map = make_simple_rail2() env = RailEnv(width=25, @@ -162,7 +156,7 @@ def test_malfunction_process_statistically(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset(False, False, False) + env.reset(False, False, False, random_seed=0) env.agents[0].target = (0, 0) nb_malfunction = 0 for step in range(20): @@ -185,9 +179,6 @@ def test_malfunction_before_entry(): 'min_duration': 10, 'max_duration': 10} - random.seed(0) - np.random.seed(0) - rail, rail_map = make_simple_rail2() env = RailEnv(width=25, @@ -199,7 +190,7 @@ def test_malfunction_before_entry(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset(False, False, False) + env.reset(False, False, False, random_seed=0) env.agents[0].target = (0, 0) nb_malfunction = 0 for step in range(20): @@ -219,9 +210,6 @@ def test_malfunction_before_entry(): def test_initial_malfunction(): - random.seed(0) - np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction @@ -240,7 +228,7 @@ def test_initial_malfunction(): ) # reset to initialize agents_static - env.reset(False, False, True) + env.reset(False, False, True, random_seed=0) set_penalties_for_replay(env) replay_config = ReplayConfig( @@ -294,9 +282,6 @@ def test_initial_malfunction(): def test_initial_malfunction_stop_moving(): - random.seed(0) - np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction