diff --git a/.gitignore b/.gitignore index 82c633d27c6d9c3b332b27e0e73a7eb05385934e..9cdbd1b04e3227442af0509adf7ced42411b266e 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,5 @@ ENV/ images/test/ test_save.dat + +.visualizations \ No newline at end of file diff --git a/flatland/evaluators/aicrowd_helpers.py b/flatland/evaluators/aicrowd_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..2a4458c754b1384da0b0ad61ba034e9d68a2429b --- /dev/null +++ b/flatland/evaluators/aicrowd_helpers.py @@ -0,0 +1,98 @@ +import os +import boto3 +import uuid +import subprocess + +############################################################### +# Expected Env Variables +############################################################### +# Default Values to be provided : +# S3_BUCKET : aicrowd-production +# S3_UPLOAD_PATH_TEMPLATE : misc/flatland-rl-Media/{}.mp4 +############################################################### +AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID", False) +AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY", False) +S3_BUCKET = os.getenv("S3_BUCKET", False) +S3_UPLOAD_PATH_TEMPLATE = os.getenv("S3_UPLOAD_PATH_TEMPLATE", False) + + +def get_boto_client(): + if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY: + raise Exception("AWS Credentials not provided..") + return boto3.client( + 's3', + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY + ) + + +def upload_to_s3(localpath): + s3 = get_boto_client() + if not S3_UPLOAD_PATH_TEMPLATE: + raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...") + if not S3_BUCKET: + raise Exception("S3_BUCKET not provided...") + + image_target_key = S3_UPLOAD_PATH_TEMPLATE.format(str(uuid.uuid4())) + s3.put_object( + ACL="public-read", + Bucket=S3_BUCKET, + Key=image_target_key, + Body=open(localpath, 'rb') + ) + return image_target_key + + +def make_subprocess_call(command, shell=False): + result = subprocess.run( + command.split(), + shell=shell, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + stdout = result.stdout.decode('utf-8') + stderr = result.stderr.decode('utf-8') + return result.returncode, stdout, stderr + + +def generate_movie_from_screenshot(frames_folder): + """ + Expects the frames in the frames_folder folder + and then use ffmpeg to generate the video + which writes the output to the frames_folder + """ + # Generate Thumbnail Video + print("Generating Thumbnail...") + frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png") + thumb_output_path = os.path.join(frames_folder, "out_thumb.mp4") + return_code, output, output_err = make_subprocess_call( + "/usr/bin/ffmpeg -r 25 -start_number 0 -i " + + frames_path + + " -c:v libx264 -vf fps=25 -pix_fmt yuv420p -s 320x240 " + + thumb_output_path + ) + if return_code != 0: + raise Exception(output_err) + + # Generate Normal Sized Video + print("Generating Normal Video...") + frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png") + output_path = os.path.join(frames_folder, "out.mp4") + return_code, output, output_err = make_subprocess_call( + "/usr/bin/ffmpeg -r 25 -start_number 0 -i " + + frames_path + + " -c:v libx264 -vf fps=25 -pix_fmt yuv420p -s 320x240 " + + output_path + ) + if return_code != 0: + raise Exception(output_err) + + return output_path, thumb_output_path + + +def is_grading(): + return os.getenv("CROWDAI_IS_GRADING", False) or \ + os.getenv("AICROWD_IS_GRADING", False) + + + diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 3c1fb2e9013896990204e291e239c4de4becc21c..281b8ef689b853c70abfd3bef3e1f7b23b7b35b8 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -5,10 +5,13 @@ from flatland.envs.generators import rail_from_file from flatland.envs.rail_env import RailEnv from flatland.core.env_observation_builder import DummyObservationBuilder from flatland.evaluators import messages +from flatland.evaluators import aicrowd_helpers +from flatland.utils.rendertools import RenderTool import numpy as np import msgpack import msgpack_numpy as m import os +import shutil import timeout_decorator import time import traceback @@ -77,7 +80,7 @@ class FlatlandRemoteEvaluationService: # RailEnv specific variables self.env = False - self.env_available = False + self.env_renderer = False self.reward = 0 self.simulation_count = 0 self.simulation_rewards = [] @@ -87,6 +90,16 @@ class FlatlandRemoteEvaluationService: self.begin_simulation = False self.current_step = 0 self.visualize = visualize + self.vizualization_folder_name = "./.visualizations" + self.record_frame_step = 0 + + if self.visualize: + try: + shutil.rmtree(self.vizualization_folder_name) + except Exception as e: + print(e) + + os.mkdir(self.vizualization_folder_name) def get_env_filepaths(self): """ @@ -248,12 +261,15 @@ class FlatlandRemoteEvaluationService: rail_generator=rail_from_file(test_env_file_path), obs_builder_object=DummyObservationBuilder() ) + if self.visualize: + if self.env_renderer: + del self.env_renderer + self.env_renderer = RenderTool(self.env, gl="PILSVG", ) # Set max episode steps allowed self.env._max_episode_steps = \ int(1.5 * (self.env.width + self.env.height)) - self.env_available = True self.simulation_count += 1 if self.begin_simulation: @@ -320,6 +336,16 @@ class FlatlandRemoteEvaluationService: complete += 1 percentage_complete = complete * 1.0 / self.env.get_num_agents() self.simulation_percentage_complete[-1] = percentage_complete + + # Record Frame + if self.visualize: + self.env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + self.env_renderer.gl.save_image( + os.path.join( + self.vizualization_folder_name, + "flatland_frame_{:04d}.png".format(self.record_frame_step) + )) + self.record_frame_step += 1 # Build and send response _command_response = {} @@ -460,7 +486,8 @@ if __name__ == "__main__": grader = FlatlandRemoteEvaluationService( test_env_folder=test_folder, flatland_rl_service_id=args.service_id, - verbose=True + verbose=True, + visualize=True ) result = grader.run() if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE: diff --git a/requirements_dev.txt b/requirements_dev.txt index 1f6f86a9fcd4aad08b640866f4c67f1975fa3551..3beb1cfb6ebf7ef3606d512c3ff043f98bf4967d 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -4,6 +4,7 @@ twine>=1.12.1 pytest>=3.8.2 pytest-runner>=4.2 crowdai-api>=0.1.21 +boto3>=1.9.194 numpy>=1.16.2 recordtype>=1.3 xarray>=0.11.3