Skip to content
Snippets Groups Projects
Commit 25b3062e authored by Jyotish P's avatar Jyotish P
Browse files

feat: add script to test evaluator and starter kit locally

parent 35986721
No related branches found
No related tags found
1 merge request!4feat: add script to test evaluator and starter kit locally
......@@ -101,7 +101,7 @@ class MNISTEvaluator:
gt_labels = gt_labels[gt_labels["image_id"].isin(predictions["image_id"].values)]
gt_labels = gt_labels.sort_values(by="image_id")
assert len(gt_labels) == len(predictions)
assert len(gt_labels) == len(predictions), f"{len(gt_labels)} != {len(predictions)}"
scores = {
"score": f1_score(gt_labels["label"].values, predictions["label"].values, average="macro"),
......
......@@ -13,7 +13,7 @@ env = jinja2.Environment(autoescape=True, loader=loader)
class Constants:
SHARED_DISK = os.getenv("AICROWD_SHARED_DIR", "test_debug_data/shared")
SHARED_DISK = os.getenv("AICROWD_SHARED_DIR", "results/shared")
DATASET_DIR = os.getenv("AICROWD_DATASET_DIR", "test_debug_data/data")
GROUND_TRUTH_DIR = os.getenv(
"AICROWD_GROUND_TRUTH_DIR", "test_debug_data/ground_truth"
......@@ -40,12 +40,12 @@ class AIcrowdEvaluator:
try:
data = pd.read_csv(csv_file_path)
except:
data = []
data = pd.DataFrame()
predictions.append(data)
predictions = pd.concat(predictions, axis=0, ignore_index=True)
predictions.to_csv(Constants.MERGED_PREDICTIONS_FILE_PATH)
def render_status_update(self) -> str:
def render_current_status_as_markdown(self) -> str:
"""This method updates the GitLab issue page comment with the current evaluation
progress. The content returned should be in markdown.
......@@ -105,7 +105,7 @@ class AIcrowdEvaluator:
if __name__ == "__main__":
evaluator = AIcrowdEvaluator()
score = evaluator.evaluate()
print(score)
# score = evaluator.evaluate()
# print(score)
print(evaluator.render_status_update())
print(evaluator.render_current_status_as_markdown())
import argparse
import contextlib
import json
import os
import shutil
import subprocess
import sys
import time
import webbrowser
from functools import partial
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
import socket
from threading import Thread
import yaml
RESULTS_DIR = "results"
SHARED_DIR = f"{RESULTS_DIR}/shared"
RENDERING_DIR = f"{RESULTS_DIR}/rendering"
if os.path.exists(RESULTS_DIR):
shutil.rmtree(RESULTS_DIR)
os.mkdir(RESULTS_DIR)
os.mkdir(SHARED_DIR)
os.mkdir(RENDERING_DIR)
class EvaluationError(Exception):
pass
def build_parser():
parser = argparse.ArgumentParser(description="End to end testing for evaluator")
parser.add_argument(
"--starter-kit-path",
required=True,
dest="starter_kit_path",
help="Path to starter kit code",
)
parser.add_argument(
"--dataset-path",
required=True,
dest="dataset_path",
help="Path to dir containing extracted dataset",
)
return parser
def build_args():
parser = build_parser()
args = parser.parse_args()
validate_starter_kit(args.starter_kit_path)
validate_dataset(args.dataset_path)
return args
def validate_starter_kit(path: str):
with open(os.path.join(path, "aicrowd.json")) as fp:
aicrowd_json = json.load(fp)
assert "challenge_id" in aicrowd_json
def validate_dataset(path: str):
evaluator_cfg = load_evaluator_cfg()
assert os.path.exists(
os.path.join(
path,
evaluator_cfg["evaluation"]["debug_run"].get(
"dataset_path", "debug_test_data"
),
)
), "Debug dataset not present, make sure you have a `debug_test_data` directory"
for run in evaluator_cfg["evaluation"]["runs"]:
dataset_path = os.path.join(path, run.get("dataset_path", "test_data"))
assert os.path.exists(dataset_path), f"Dataset doesn't exist at {dataset_path}"
assert os.path.exists(
os.path.join(path, "ground_truth_data")
), "Ground truth data should be in `ground_truth_data` dir"
def load_evaluator_cfg() -> dict:
with open("aicrowd.yaml") as fp:
cfg = yaml.safe_load(fp.read())
return cfg
def run_evaluation():
args = build_args()
cfg = load_evaluator_cfg()
files_to_revert = prepare_artifacts(args, cfg["evaluation"]["global"]["files"])
Thread(target=run_live_rendering, args=(args,), daemon=True).start()
Thread(target=run_http_server, daemon=False).start()
Thread(target=open_browser_tab).start()
try:
run_predictions(
args, cfg, cfg["evaluation"]["debug_run"]
)
for run in cfg["evaluation"]["runs"]:
run_files_to_revert = prepare_artifacts(args, run.get("files", {}))
run_predictions(args, cfg, run)
revert_artifacts(run_files_to_revert)
except EvaluationError as err:
revert_artifacts(files_to_revert)
raise err
revert_artifacts(files_to_revert)
run_scoring(args)
def prepare_artifacts(args, file_list: dict):
files_to_revert = []
for src, dst in file_list.items():
src = os.path.join("data", src)
dst = os.path.join(args.starter_kit_path, dst)
if os.path.exists(dst):
shutil.move(dst, dst + ".original")
files_to_revert.append(dst)
shutil.copy2(src, dst)
os.chmod(dst, 0o777)
return files_to_revert
def revert_artifacts(file_list: list):
for file_path in file_list:
shutil.move(file_path + ".original", file_path)
def run_predictions(args, cfg, run):
dataset_path = os.path.join(args.dataset_path, run.get("dataset_path", "test_data"))
env = {
**os.environ,
"AICROWD_DATASET_PATH": os.path.abspath(dataset_path),
"AICROWD_SHARED_DIR": os.path.abspath(SHARED_DIR),
}
for env_name, env_value in (
cfg["evaluation"].get("global", {}).get("env", {}).items()
):
env[env_name] = env_value
for env_name, env_value in run.get("env", {}).items():
env[env_name] = env_value
cmd = "./" + run.get("entrypoint", "run.sh").replace("/home/aicrowd", "")
proc = subprocess.run(cmd, env=env, shell=True, cwd=args.starter_kit_path)
if proc.returncode != 0:
raise EvaluationError("Failed to run starter kit code")
def run_scoring(args):
os.environ["AICROWD_SHARED_DIR"] = SHARED_DIR
os.environ["AICROWD_GROUND_TRUTH_DIR"] = os.path.join(
args.dataset_path, "ground_truth_data"
)
from evaluator import AIcrowdEvaluator
evaluator = AIcrowdEvaluator()
scores = evaluator.evaluate()
# check if json serializable
json.dumps(scores)
print("#" * 50)
print("#" * 50)
print("Evaluation Results")
print("==================")
print(scores)
print("Evaluation completed! Press ctrl+c to stop the rendering server.")
print("#" * 50)
print("#" * 50)
def run_live_rendering(args):
os.environ["AICROWD_SHARED_DIR"] = SHARED_DIR
os.environ["AICROWD_GROUND_TRUTH_DIR"] = os.path.join(
args.dataset_path, "ground_truth_data"
)
from evaluator import AIcrowdEvaluator
evaluator = AIcrowdEvaluator()
html_headers = """
<!doctype html>
<html>
<head>
<meta charset="utf-8"/>
<title>Marked in the browser</title>
<script src="https://twemoji.maxcdn.com/v/latest/twemoji.min.js" crossorigin="anonymous"></script>
<link href="https://projects.iamcal.com/js-emoji/demo/emoji.css">
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Lato&display=swap" rel="stylesheet">
<style>
body { font-family: 'Lato', sans-serif; }
#content {
width: fit-content;
margin-left: auto;
margin-right: auto;
margin-top: 2rem;
}
table {
border-spacing: 0;
border: #bbb 0.5px solid;
border-radius: 5px;
}
table th {
background-color: #ddd;
}
table th, table td {
padding: 0.5rem 1rem;
margin: 0;
border: #bbb 0.5px solid;
}
code {
background: #eee;
padding: 0.1rem 0.6rem;
border-radius: 5px;
}
</style>
</head>
<body>
<div class="container">
<div id="content">"""
html_footers = """
</div></div><script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<script src="https://projects.iamcal.com/js-emoji/lib/emoji.js"></script>
<script>
document.getElementById('content').innerHTML = marked.parse(document.getElementById('content').innerHTML);
var emoji = new EmojiConvertor();
document.getElementById('content').innerHTML = emoji.replace_colons(document.getElementById('content').innerHTML);
setTimeout("location.reload(true);", 5000);
</script>
</body>
</html>
"""
while True:
status = evaluator.render_current_status_as_markdown()
with open(os.path.join(RENDERING_DIR, "index.html"), "w") as fp:
fp.write(html_headers)
fp.write(status)
fp.write(html_footers)
time.sleep(5)
def run_http_server():
handler_class = partial(SimpleHTTPRequestHandler, directory=RENDERING_DIR)
# ensure dual-stack is not disabled; ref #38907
class DualStackServer(ThreadingHTTPServer):
def server_bind(self):
# suppress exception when protocol is IPv4
with contextlib.suppress(Exception):
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
return super().server_bind()
def test(HandlerClass, ServerClass, protocol="HTTP/1.0", port=8000, bind=None):
infos = socket.getaddrinfo(
bind,
port,
type=socket.SOCK_STREAM,
flags=socket.AI_PASSIVE,
)
ServerClass.address_family, _, _, _, addr = next(iter(infos))
HandlerClass.protocol_version = protocol
with ServerClass(addr, HandlerClass) as httpd:
host, port = httpd.socket.getsockname()[:2]
url_host = f"[{host}]" if ":" in host else host
print(
f"Serving HTTP on {host} port {port} "
f"(http://{url_host}:{port}/) ..."
)
try:
httpd.serve_forever()
except KeyboardInterrupt:
print("\nKeyboard interrupt received, exiting.")
sys.exit(0)
test(
HandlerClass=handler_class,
ServerClass=DualStackServer,
port=20222,
bind="127.0.0.1",
)
def open_browser_tab():
time.sleep(1)
webbrowser.open("http://127.0.0.1:20222/index.html", new=2)
def main():
run_evaluation()
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment