Skip to content
Snippets Groups Projects
Commit a3d3baa1 authored by spmohanty's avatar spmohanty
Browse files

Addresses #124 - Implements a flatland-evaluator script to easily run tests

parent 5b5d979a
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,8 @@ import time
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from flatland.evaluators.service import FlatlandRemoteEvaluationService
import redis
@click.command()
......@@ -49,5 +51,35 @@ def demo(args=None):
return 0
@click.command()
@click.option('--tests',
type=click.Path(exists=True),
help="Path to folder containing Flatland tests",
required=True
)
@click.option('--service_id',
default="FLATLAND_RL_SERVICE_ID",
help="Evaluation Service ID. This has to match the service id on the client.",
required=False
)
def evaluator(tests, service_id):
try:
redis_connection = redis.Redis()
redis_connection.ping()
except redis.exceptions.ConnectionError as e:
raise Exception(
"\nRedis server does not seem to be running on your localhost.\n"
"Please ensure that you have a redis server running on your localhost"
)
grader = FlatlandRemoteEvaluationService(
test_env_folder=tests,
flatland_rl_service_id=service_id,
visualize=False,
verbose=False
)
grader.run()
if __name__ == "__main__":
sys.exit(demo()) # pragma: no cover
......@@ -175,6 +175,13 @@ class FlatlandRemoteClient(object):
self.test_envs_root,
test_env_file_path
)
if not os.path.exists(test_env_file_path):
raise Exception(
"\nWe cannot seem to find the env file paths at the required location.\n"
"Did you remember to set the AICROWD_TESTS_FOLDER environment variable "
"to point to the location of the Tests folder ? \n"
"We are currently looking at `{}` for the tests".format(self.test_envs_root)
)
print("Current env path : ", test_env_file_path)
self.env = RailEnv(
width=1,
......
......@@ -62,7 +62,6 @@ class FlatlandRemoteEvaluationService:
self.test_env_folder = test_env_folder
self.video_generation_envs = video_generation_envs
self.env_file_paths = self.get_env_filepaths()
print(self.env_file_paths)
# Logging and Reporting related vars
self.verbose = verbose
......@@ -277,12 +276,11 @@ class FlatlandRemoteEvaluationService:
"""
test_env_file_path = self.env_file_paths[self.simulation_count]
print("__ Env Path : ", test_env_file_path)
print("Evaluating : {}".format(test_env_file_path))
test_env_file_path = os.path.join(
self.test_env_folder,
test_env_file_path
)
print("__ Processed Env Path : ", test_env_file_path)
del self.env
self.env = RailEnv(
width=1,
......@@ -475,6 +473,13 @@ class FlatlandRemoteEvaluationService:
self.evaluation_state["score"]["score"] = mean_percentage_complete
self.evaluation_state["score"]["score_secondary"] = mean_reward
self.handle_aicrowd_success_event(self.evaluation_state)
print("#"*100)
print("EVALUATION COMPLETE !!")
print("#"*100)
print("# Mean Reward : {}".format(mean_reward))
print("# Mean Percentage Complete : {}".format(mean_percentage_complete))
print("#"*100)
print("#"*100)
def report_error(self, error_message, command_response_channel):
"""
......@@ -518,7 +523,7 @@ class FlatlandRemoteEvaluationService:
Main runner function which waits for commands from the client
and acts accordingly.
"""
print("Listening for commands at : ", self.command_channel)
print("Listening at : ", self.command_channel)
while True:
command = self.get_next_command()
......@@ -604,7 +609,6 @@ if __name__ == "__main__":
result = grader.run()
if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE:
cumulative_results = result['payload']
print("Results : ", cumulative_results)
elif result['type'] == messages.FLATLAND_RL.ERROR:
error = result['payload']
raise Exception("Evaluation Failed : {}".format(str(error)))
......
......@@ -16,7 +16,7 @@ def get_all_svg_files(directory='./svg/'):
ret = []
for dirpath, subdirs, files in os.walk(directory):
for f in files:
ret.append(os.path.join(dirpath,f))
ret.append(os.path.join(dirpath, f))
return ret
......@@ -24,7 +24,7 @@ def get_all_images_files(directory='./images/'):
ret = []
for dirpath, subdirs, files in os.walk(directory):
for f in files:
ret.append(os.path.join(dirpath,f))
ret.append(os.path.join(dirpath, f))
return ret
......@@ -32,7 +32,7 @@ def get_all_notebook_files(directory='./notebooks/'):
ret = []
for dirpath, subdirs, files in os.walk(directory):
for f in files:
ret.append(os.path.join(dirpath,f))
ret.append(os.path.join(dirpath, f))
return ret
......@@ -63,6 +63,7 @@ setup(
entry_points={
'console_scripts': [
'flatland-demo=flatland.cli:demo',
'flatland-evaluator=flatland.cli:evaluator'
],
},
install_requires=requirements,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment