From 19654d0ea29cbb507c09e4ac11ede8fb4072a2c7 Mon Sep 17 00:00:00 2001 From: "S.P. Mohanty" <spmohanty91@gmail.com> Date: Tue, 8 Oct 2019 13:22:57 +0200 Subject: [PATCH] Instantiate the seeds early on in init --- flatland/envs/rail_env.py | 3 +-- flatland/evaluators/client.py | 14 ++++++++++++-- flatland/evaluators/service.py | 14 ++++++++++++-- tests/test_flatland_malfunction.py | 27 ++++++--------------------- 4 files changed, 31 insertions(+), 27 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index bf438a56..e45bc0ae 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -162,6 +162,7 @@ class RailEnv(Environment): self.rail: Optional[GridTransitionMap] = None self.width = width self.height = height + self._seed() self.remove_agents_at_target = remove_agents_at_target @@ -187,8 +188,6 @@ class RailEnv(Environment): self.action_space = [1] - self._seed() - # Stochastic train malfunctioning parameters if stochastic_data is not None: prop_malfunction = stochastic_data['prop_malfunction'] diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index 92f1a49d..d9c9ae99 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -166,6 +166,8 @@ class FlatlandRemoteClient(object): _request['payload'] = {} _response = self._blocking_request(_request) observation = _response['payload']['observation'] + info = _response['payload']['info'] + random_seed = _response['payload']['random_seed'] if not observation: # If the observation is False, @@ -196,10 +198,18 @@ class FlatlandRemoteClient(object): self.env._max_episode_steps = \ int(1.5 * (self.env.width + self.env.height)) - local_observation = self.env.reset() + local_observation = self.env.reset(random_seed=random_seed) + + local_observation, info = self.env.reset( + regen_rail=False, + replace_agents=False, + activate_agents=False, + random_seed=random_seed + ) + # Use the local observation # as the remote server uses a dummy observation builder - return local_observation + return local_observation, info def env_step(self, action, render=False): """ diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 4f273be4..023730dc 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -332,14 +332,22 @@ class FlatlandRemoteEvaluationService: self.simulation_steps.append(0) self.current_step = 0 - - _observation = self.env.reset() + + RANDOM_SEED = 1001 + _observation, _info = self.env.reset( + regen_rail=False, + replace_agents=False, + activate_agents=False, + random_seed=RANDOM_SEED + ) _command_response = {} _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE _command_response['payload'] = {} _command_response['payload']['observation'] = _observation _command_response['payload']['env_file_path'] = self.env_file_paths[self.simulation_count] + _command_response['payload']['info'] = _info + _command_response['payload']['random_seed'] = RANDOM_SEED else: """ All test env evaluations are complete @@ -349,6 +357,8 @@ class FlatlandRemoteEvaluationService: _command_response['payload'] = {} _command_response['payload']['observation'] = False _command_response['payload']['env_file_path'] = False + _command_response['payload']['info'] = False + _command_response['payload']['random_seed'] = RANDOM_SEED self.send_response(_command_response, command) ##################################################################### diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index d5fce5d6..e87bd93d 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -70,8 +70,8 @@ def test_malfunction_process(): 'malfunction_rate': 1000, 'min_duration': 3, 'max_duration': 3} - random.seed(0) - np.random.seed(0) + # random.seed(0) + # np.random.seed(0) stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence @@ -90,9 +90,7 @@ def test_malfunction_process(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset() - - obs = env.reset(False, False, True) + obs = env.reset(False, False, True, random_seed=0) # Check that a initial duration for malfunction was assigned assert env.agents[0].malfunction_data['next_malfunction'] > 0 @@ -147,10 +145,6 @@ def test_malfunction_process_statistically(): 'min_duration': 3, 'max_duration': 3} - random.seed(0) - np.random.seed(0) - - rail, rail_map = make_simple_rail2() env = RailEnv(width=25, @@ -162,7 +156,7 @@ def test_malfunction_process_statistically(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset(False, False, False) + env.reset(False, False, False, random_seed=0) env.agents[0].target = (0, 0) nb_malfunction = 0 for step in range(20): @@ -185,9 +179,6 @@ def test_malfunction_before_entry(): 'min_duration': 10, 'max_duration': 10} - random.seed(0) - np.random.seed(0) - rail, rail_map = make_simple_rail2() env = RailEnv(width=25, @@ -199,7 +190,7 @@ def test_malfunction_before_entry(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset(False, False, False) + env.reset(False, False, False, random_seed=0) env.agents[0].target = (0, 0) nb_malfunction = 0 for step in range(20): @@ -219,9 +210,6 @@ def test_malfunction_before_entry(): def test_initial_malfunction(): - random.seed(0) - np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction @@ -240,7 +228,7 @@ def test_initial_malfunction(): ) # reset to initialize agents_static - env.reset(False, False, True) + env.reset(False, False, True, random_seed=0) set_penalties_for_replay(env) replay_config = ReplayConfig( @@ -294,9 +282,6 @@ def test_initial_malfunction(): def test_initial_malfunction_stop_moving(): - random.seed(0) - np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction -- GitLab