diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 235ab78b21be96635922927a9b4e607b122cc09f..5ae0580dbd10f66a70dced7a21f0c47cd4e206df 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -53,12 +53,16 @@ class FlatlandRemoteEvaluationService: remote_db=0, remote_password=None, visualize=False, + video_generation_envs=[], report=None, verbose=False): # Test Env folder Paths self.test_env_folder = test_env_folder + self.video_generation_envs = video_generation_envs + self.video_generation_indices = [] self.env_file_paths = self.get_env_filepaths() + print(self.video_generation_indices) # Logging and Reporting related vars self.verbose = verbose @@ -141,6 +145,21 @@ class FlatlandRemoteEvaluationService: env_paths.append( os.path.join(root, file) ) + env_paths = sorted(env_paths) + print(self.video_generation_envs) + for _idx, env_path in enumerate(env_paths): + """ + Here we collect the indices of the environments for which + we need to generate the videos + + We increment the simulation count on env_create + so the 1st simulation has an index of 1, when comparing in + env_step + """ + for vg_env in self.video_generation_envs: + print(vg_env, env_path) + if vg_env in env_path: + self.video_generation_indices.append(_idx+1) return sorted(env_paths) def instantiate_redis_connection_pool(self): @@ -367,13 +386,18 @@ class FlatlandRemoteEvaluationService: # Record Frame if self.visualize: - self.env_renderer.render_env(show=True, show_observations=False, show_predictions=False) - self.env_renderer.gl.save_image( - os.path.join( - self.vizualization_folder_name, - "flatland_frame_{:04d}.png".format(self.record_frame_step) - )) - self.record_frame_step += 1 + self.env_renderer.render_env(show=False, show_observations=False, show_predictions=False) + """ + Only save the frames for environments which are separately provided + in video_generation_indices param + """ + if self.simulation_count in self.video_generation_indices: + self.env_renderer.gl.save_image( + os.path.join( + self.vizualization_folder_name, + "flatland_frame_{:04d}.png".format(self.record_frame_step) + )) + self.record_frame_step += 1 # Build and send response _command_response = {} @@ -575,7 +599,8 @@ if __name__ == "__main__": test_env_folder=test_folder, flatland_rl_service_id=args.service_id, verbose=True, - visualize=True + visualize=True, + video_generation_envs=["Test_0/Level_1.pkl"] ) result = grader.run() if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE: