Commit 0179a4a0 authored by nilabha's avatar nilabha

Added support for saving checkpoints in w&b

parent 119356f9
Pipeline #5585 passed with stage
in 3 minutes and 53 seconds
......@@ -223,6 +223,8 @@ def run(args, parser):
del exp['config']['evaluation_config']['wandb']
if args.custom_fn:
custom_fn = globals()[exp['config'].get("env_config",{}).get("custom_fn","imitation_ppo_train_fn")]
if args.save_checkpoint:
exp['config']['env_config']['save_checkpoint'] = True
if args.config_file:
# TODO should be in exp['config'] directly
exp['config']['env_config']['yaml_config'] = args.config_file
......
......@@ -119,6 +119,11 @@ def create_parser(parser_creator=None):
help="Whether the experiment requires video recording during evaluation"
"Default evaluation config is default_render.yaml "
"Can also be done via custom evaluation config set (eval_generator:test_render) under configs")
parser.add_argument(
"-s",
"--save-checkpoint",
action="store_true",
help="Whether the experiment will save the checkpoints to weights and biases")
parser.add_argument(
"--bind-all",
action="store_true",
......
......@@ -8,6 +8,7 @@ import wandb
from ray import tune
from ray.tune.utils import flatten_dict
from glob import glob
def find(pattern, path):
......@@ -39,6 +40,7 @@ class WandbLogger(tune.logger.Logger):
'''
resume = self.config.get("env_config", {}).get("resume", False)
self.clear_folders(resume)
self.saved_checkpoints = []
def clear_folders(self, resume):
env_name = "flatland"
......@@ -95,6 +97,12 @@ class WandbLogger(tune.logger.Logger):
if self._save_folder:
metrics = self.update_video_metrics(result, metrics)
if self.config.get("env_config", {}).get("save_checkpoint", None):
_cur_checkpoints = glob(os.path.join(self.logdir, "checkpoint*"))
# Remove already saved checkpoints
_cur_checkpoints = list(set(_cur_checkpoints)-set(self.saved_checkpoints))
metrics['checkpoint'] = _cur_checkpoints
self.saved_checkpoints.extend(_cur_checkpoints)
queue.put(metrics)
def update_video_metrics(self, result, metrics):
......@@ -180,5 +188,14 @@ def wandb_process(queue, config):
if "KILL" in metrics:
break
if "checkpoint" in metrics:
_checkpoints = metrics['checkpoint']
for _checkpoint in _checkpoints:
if os.path.exists(_checkpoint):
try:
wandb.save(_checkpoint)
except Exception as e:
print("Error Occurred in saving checkpoints:",e)
del metrics['checkpoint']
run.log(metrics)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment