diff --git a/flatland/cli.py b/flatland/cli.py index 46bf5aecbb47f04bd00f2772897650aa4c47c12e..32e8d9dc786b0412795694fc985c90aa55fc2e91 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -8,6 +8,8 @@ import time from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool +from flatland.evaluators.service import FlatlandRemoteEvaluationService +import redis @click.command() @@ -49,5 +51,35 @@ def demo(args=None): return 0 +@click.command() +@click.option('--tests', + type=click.Path(exists=True), + help="Path to folder containing Flatland tests", + required=True + ) +@click.option('--service_id', + default="FLATLAND_RL_SERVICE_ID", + help="Evaluation Service ID. This has to match the service id on the client.", + required=False + ) +def evaluator(tests, service_id): + try: + redis_connection = redis.Redis() + redis_connection.ping() + except redis.exceptions.ConnectionError as e: + raise Exception( + "\nRedis server does not seem to be running on your localhost.\n" + "Please ensure that you have a redis server running on your localhost" + ) + + grader = FlatlandRemoteEvaluationService( + test_env_folder=tests, + flatland_rl_service_id=service_id, + visualize=False, + verbose=False + ) + grader.run() + + if __name__ == "__main__": sys.exit(demo()) # pragma: no cover diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index cc31caaf1f6d2f1721fac550751b85390842cc4a..a4968c0c8e827c060e0e3f7de0cf28cc0658089b 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -175,6 +175,13 @@ class FlatlandRemoteClient(object): self.test_envs_root, test_env_file_path ) + if not os.path.exists(test_env_file_path): + raise Exception( + "\nWe cannot seem to find the env file paths at the required location.\n" + "Did you remember to set the AICROWD_TESTS_FOLDER environment variable " + "to point to the location of the Tests folder ? \n" + "We are currently looking at `{}` for the tests".format(self.test_envs_root) + ) print("Current env path : ", test_env_file_path) self.env = RailEnv( width=1, diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 56fbb7a22705f5bf0d8fc213b5c1ac05cde852d6..201a3978e5dcebb01592f97116f9cdd7d8387646 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -62,7 +62,6 @@ class FlatlandRemoteEvaluationService: self.test_env_folder = test_env_folder self.video_generation_envs = video_generation_envs self.env_file_paths = self.get_env_filepaths() - print(self.env_file_paths) # Logging and Reporting related vars self.verbose = verbose @@ -277,12 +276,11 @@ class FlatlandRemoteEvaluationService: """ test_env_file_path = self.env_file_paths[self.simulation_count] - print("__ Env Path : ", test_env_file_path) + print("Evaluating : {}".format(test_env_file_path)) test_env_file_path = os.path.join( self.test_env_folder, test_env_file_path ) - print("__ Processed Env Path : ", test_env_file_path) del self.env self.env = RailEnv( width=1, @@ -475,6 +473,13 @@ class FlatlandRemoteEvaluationService: self.evaluation_state["score"]["score"] = mean_percentage_complete self.evaluation_state["score"]["score_secondary"] = mean_reward self.handle_aicrowd_success_event(self.evaluation_state) + print("#"*100) + print("EVALUATION COMPLETE !!") + print("#"*100) + print("# Mean Reward : {}".format(mean_reward)) + print("# Mean Percentage Complete : {}".format(mean_percentage_complete)) + print("#"*100) + print("#"*100) def report_error(self, error_message, command_response_channel): """ @@ -518,7 +523,7 @@ class FlatlandRemoteEvaluationService: Main runner function which waits for commands from the client and acts accordingly. """ - print("Listening for commands at : ", self.command_channel) + print("Listening at : ", self.command_channel) while True: command = self.get_next_command() @@ -604,7 +609,6 @@ if __name__ == "__main__": result = grader.run() if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE: cumulative_results = result['payload'] - print("Results : ", cumulative_results) elif result['type'] == messages.FLATLAND_RL.ERROR: error = result['payload'] raise Exception("Evaluation Failed : {}".format(str(error))) diff --git a/setup.py b/setup.py index 60fd120984176742ecf3d82ab094ff92e6f0310c..131cc983d228abffa22f7e4415edbce0ced19668 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def get_all_svg_files(directory='./svg/'): ret = [] for dirpath, subdirs, files in os.walk(directory): for f in files: - ret.append(os.path.join(dirpath,f)) + ret.append(os.path.join(dirpath, f)) return ret @@ -24,7 +24,7 @@ def get_all_images_files(directory='./images/'): ret = [] for dirpath, subdirs, files in os.walk(directory): for f in files: - ret.append(os.path.join(dirpath,f)) + ret.append(os.path.join(dirpath, f)) return ret @@ -32,7 +32,7 @@ def get_all_notebook_files(directory='./notebooks/'): ret = [] for dirpath, subdirs, files in os.walk(directory): for f in files: - ret.append(os.path.join(dirpath,f)) + ret.append(os.path.join(dirpath, f)) return ret @@ -63,6 +63,7 @@ setup( entry_points={ 'console_scripts': [ 'flatland-demo=flatland.cli:demo', + 'flatland-evaluator=flatland.cli:evaluator' ], }, install_requires=requirements,