From a3d3baa16e8c5ba82c3873f1a317ab23ee898f24 Mon Sep 17 00:00:00 2001 From: SP Mohanty <spmohanty91@gmail.com> Date: Thu, 25 Jul 2019 04:07:26 +0200 Subject: [PATCH] Addresses #124 - Implements a flatland-evaluator script to easily run tests --- flatland/cli.py | 32 ++++++++++++++++++++++++++++++++ flatland/evaluators/client.py | 7 +++++++ flatland/evaluators/service.py | 14 +++++++++----- setup.py | 7 ++++--- 4 files changed, 52 insertions(+), 8 deletions(-) diff --git a/flatland/cli.py b/flatland/cli.py index 46bf5aec..32e8d9dc 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -8,6 +8,8 @@ import time from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool +from flatland.evaluators.service import FlatlandRemoteEvaluationService +import redis @click.command() @@ -49,5 +51,35 @@ def demo(args=None): return 0 +@click.command() +@click.option('--tests', + type=click.Path(exists=True), + help="Path to folder containing Flatland tests", + required=True + ) +@click.option('--service_id', + default="FLATLAND_RL_SERVICE_ID", + help="Evaluation Service ID. This has to match the service id on the client.", + required=False + ) +def evaluator(tests, service_id): + try: + redis_connection = redis.Redis() + redis_connection.ping() + except redis.exceptions.ConnectionError as e: + raise Exception( + "\nRedis server does not seem to be running on your localhost.\n" + "Please ensure that you have a redis server running on your localhost" + ) + + grader = FlatlandRemoteEvaluationService( + test_env_folder=tests, + flatland_rl_service_id=service_id, + visualize=False, + verbose=False + ) + grader.run() + + if __name__ == "__main__": sys.exit(demo()) # pragma: no cover diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index cc31caaf..a4968c0c 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -175,6 +175,13 @@ class FlatlandRemoteClient(object): self.test_envs_root, test_env_file_path ) + if not os.path.exists(test_env_file_path): + raise Exception( + "\nWe cannot seem to find the env file paths at the required location.\n" + "Did you remember to set the AICROWD_TESTS_FOLDER environment variable " + "to point to the location of the Tests folder ? \n" + "We are currently looking at `{}` for the tests".format(self.test_envs_root) + ) print("Current env path : ", test_env_file_path) self.env = RailEnv( width=1, diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 56fbb7a2..201a3978 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -62,7 +62,6 @@ class FlatlandRemoteEvaluationService: self.test_env_folder = test_env_folder self.video_generation_envs = video_generation_envs self.env_file_paths = self.get_env_filepaths() - print(self.env_file_paths) # Logging and Reporting related vars self.verbose = verbose @@ -277,12 +276,11 @@ class FlatlandRemoteEvaluationService: """ test_env_file_path = self.env_file_paths[self.simulation_count] - print("__ Env Path : ", test_env_file_path) + print("Evaluating : {}".format(test_env_file_path)) test_env_file_path = os.path.join( self.test_env_folder, test_env_file_path ) - print("__ Processed Env Path : ", test_env_file_path) del self.env self.env = RailEnv( width=1, @@ -475,6 +473,13 @@ class FlatlandRemoteEvaluationService: self.evaluation_state["score"]["score"] = mean_percentage_complete self.evaluation_state["score"]["score_secondary"] = mean_reward self.handle_aicrowd_success_event(self.evaluation_state) + print("#"*100) + print("EVALUATION COMPLETE !!") + print("#"*100) + print("# Mean Reward : {}".format(mean_reward)) + print("# Mean Percentage Complete : {}".format(mean_percentage_complete)) + print("#"*100) + print("#"*100) def report_error(self, error_message, command_response_channel): """ @@ -518,7 +523,7 @@ class FlatlandRemoteEvaluationService: Main runner function which waits for commands from the client and acts accordingly. """ - print("Listening for commands at : ", self.command_channel) + print("Listening at : ", self.command_channel) while True: command = self.get_next_command() @@ -604,7 +609,6 @@ if __name__ == "__main__": result = grader.run() if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE: cumulative_results = result['payload'] - print("Results : ", cumulative_results) elif result['type'] == messages.FLATLAND_RL.ERROR: error = result['payload'] raise Exception("Evaluation Failed : {}".format(str(error))) diff --git a/setup.py b/setup.py index 60fd1209..131cc983 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def get_all_svg_files(directory='./svg/'): ret = [] for dirpath, subdirs, files in os.walk(directory): for f in files: - ret.append(os.path.join(dirpath,f)) + ret.append(os.path.join(dirpath, f)) return ret @@ -24,7 +24,7 @@ def get_all_images_files(directory='./images/'): ret = [] for dirpath, subdirs, files in os.walk(directory): for f in files: - ret.append(os.path.join(dirpath,f)) + ret.append(os.path.join(dirpath, f)) return ret @@ -32,7 +32,7 @@ def get_all_notebook_files(directory='./notebooks/'): ret = [] for dirpath, subdirs, files in os.walk(directory): for f in files: - ret.append(os.path.join(dirpath,f)) + ret.append(os.path.join(dirpath, f)) return ret @@ -63,6 +63,7 @@ setup( entry_points={ 'console_scripts': [ 'flatland-demo=flatland.cli:demo', + 'flatland-evaluator=flatland.cli:evaluator' ], }, install_requires=requirements, -- GitLab