diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index bbda700abd68f21dbada56d020937ab69b2d3f96..5e1e907612e480bfd967e8d2728c535c23ad541d 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -98,8 +98,8 @@ class FlatlandRemoteEvaluationService: video_generation_envs=[], report=None, verbose=False, - actionDir=None, - episodeDir=None, + actionDir=None, + episodeDir=None, mergeDir=None, use_pickle=False, shuffle=True, @@ -184,6 +184,7 @@ class FlatlandRemoteEvaluationService: self.env = False self.env_renderer = False self.reward = 0 + self.simulation_done = True self.simulation_count = -1 self.simulation_env_file_paths = [] self.simulation_rewards = [] @@ -286,7 +287,7 @@ class FlatlandRemoteEvaluationService: # if requested, only generate actions for those envs which don't already have them if self.mergeDir and self.missing_only: existing_paths = (itertools.chain.from_iterable( - [ glob.glob(os.path.join(self.mergeDir, f"envs/*.{ext}")) + [ glob.glob(os.path.join(self.mergeDir, f"envs/*.{ext}")) for ext in ["pkl", "mpk"] ])) existing_paths = [os.path.relpath(sPath, self.mergeDir) for sPath in existing_paths] env_paths = sorted(set(env_paths) - set(existing_paths)) @@ -474,7 +475,7 @@ class FlatlandRemoteEvaluationService: # self.current_step, # self.simulation_count # )) - + # print("Timeout of {}s in step {} of simulation {}".format( # COMMAND_TIMEOUT, # self.current_step, @@ -516,7 +517,7 @@ class FlatlandRemoteEvaluationService: _redis.rpush(command_response_channel, sResponse) def send_error(self, error_dict, suppress_logs=False): - """ For out-of-band errors like timeouts, + """ For out-of-band errors like timeouts, where we do not have a command, so we have no response channel! """ _redis = self.get_redis_connection() @@ -565,7 +566,14 @@ class FlatlandRemoteEvaluationService: Handles a ENV_CREATE command from the client TODO: Add a high level summary of everything thats happening here. """ + if not self.simulation_done: + # trying to reset a simulation before finishing the previous one + _command_response = self._error_template("CAN'T CREATE NEW ENV BEFORE PREVIOUS IS DONE") + self.send_response(_command_response, command) + raise Exception(_command_response['payload']) + self.simulation_count += 1 + self.simulation_done = False # reset the timeout flag / state. self.state_env_timed_out = False @@ -718,6 +726,8 @@ class FlatlandRemoteEvaluationService: if self.actionDir is not None: self.lActions.append(action) if done["__all__"]: + self.simulation_done = True + # Compute percentage complete complete = 0 for i_agent in range(self.env.get_num_agents()): @@ -727,9 +737,15 @@ class FlatlandRemoteEvaluationService: percentage_complete = complete * 1.0 / self.env.get_num_agents() self.simulation_percentage_complete[-1] = percentage_complete + print("Evaluation finished in {} timesteps. Percentage agents done: {:.3f}. Normalized reward: {:.3f}.".format( + self.simulation_steps[-1], + self.simulation_percentage_complete[-1], + self.simulation_rewards_normalized[-1] + )) + if self.actionDir is not None: self.save_actions() - + if self.episodeDir is not None: self.save_episode() @@ -759,26 +775,26 @@ class FlatlandRemoteEvaluationService: def save_actions(self): sfEnv = self.env_file_paths[self.simulation_count] - + sfActions = self.actionDir + "/" + sfEnv.replace(".pkl", ".json") print("env path: ", sfEnv, " sfActions:", sfActions) if not os.path.exists(os.path.dirname(sfActions)): os.makedirs(os.path.dirname(sfActions)) - + with open(sfActions, "w") as fOut: json.dump(self.lActions, fOut) self.lActions = [] - + def save_episode(self): sfEnv = self.env_file_paths[self.simulation_count] sfEpisode = self.episodeDir + "/" + sfEnv print("env path: ", sfEnv, " sfEpisode:", sfEpisode) RailEnvPersister.save_episode(self.env, sfEpisode) #self.env.save_episode(sfEpisode) - + def save_merged_env(self): sfEnv = self.env_file_paths[self.simulation_count] sfMergeEnv = self.mergeDir + "/" + sfEnv @@ -984,7 +1000,7 @@ class FlatlandRemoteEvaluationService: except timeout_decorator.timeout_decorator.TimeoutError: if self.previous_command['type'] == messages.FLATLAND_RL.ENV_STEP: self.send_error({"type":messages.FLATLAND_RL.ENV_STEP_TIMEOUT}) - + elif self.previous_command['type'] == messages.FLATLAND_RL.ENV_CREATE: self.send_error({"type":messages.FLATLAND_RL.ENV_RESET_TIMEOUT}) @@ -1044,7 +1060,7 @@ class FlatlandRemoteEvaluationService: print("client env_step timeout") self.handle_env_step_timeout(command) - + else: _error = self._error_template( "UNKNOWN_REQUEST:{}".format( @@ -1095,13 +1111,13 @@ if __name__ == "__main__": default=None, help="deprecated - use mergeDir. Folder containing the files for the test envs", required=False) - + parser.add_argument('--episodeDir', dest='episodeDir', default=None, help="deprecated - use mergeDir. Folder containing the files for the test envs", required=False) - + parser.add_argument('--mergeDir', dest='mergeDir', default=None,