wandblogger.py 7.26 KB
Newer Older
1
import multiprocessing
MasterScrat's avatar
MasterScrat committed
2
import numbers
3
from pprint import pprint
MasterScrat's avatar
MasterScrat committed
4 5 6

import wandb
from ray import tune
7
import os, fnmatch
8 9
import shutil

10 11 12 13 14 15 16 17

def find(pattern, path):
    result = []
    for root, dirs, files in os.walk(path):
        for name in files:
            if fnmatch.fnmatch(name, pattern):
                result.append(os.path.join(root, name))
    return result
MasterScrat's avatar
MasterScrat committed
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

# ray 0.8.1 reorganized ray.tune.util -> ray.tune.utils
try:
    from ray.tune.utils import flatten_dict
except ImportError:
    from ray.tune.util import flatten_dict


class WandbLogger(tune.logger.Logger):
    """Pass WandbLogger to the loggers argument of tune.run

       tune.run("PG", loggers=[WandbLogger], config={
           "monitor": True, "env_config": {
               "wandb": {"project": "my-project-name"}}})
    """

    def _init(self):
        self._config = None
36
        self.metrics_queue_dict = {}
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
        self.reset_state()
        if self.config.get("env_config", {}).get("render", None):
            '''
            Cleans or Resume folder based on resume
            To avoid overwriting current video folder
            Add `resume: True` under `env_config`
            '''
            resume = self.config.get("env_config", {}).get("resume", False)
            self.clear_folders(resume)

    def clear_folders(self, resume):
        env_name="flatland"
        self._save_folder = self.config.get("env_config", {}).get("video_dir", env_name)
        if not resume:
            if os.path.exists(self._save_folder):
                try:
                    shutil.rmtree(self._save_folder)
                except OSError as e:
                    print ("Error: %s - %s." % (e.filename, e.strerror))

    def reset_state(self):
        # Holds list of uploaded/moved files so that we dont upload them again
59
        self._upload_files = {}
60 61 62 63
        # Holds information of env state and put them in an unique file name
        # and maps it to the original video file
        self._file_map = {}
        self._save_folder = None
MasterScrat's avatar
MasterScrat committed
64 65

    def on_result(self, result):
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        experiment_tag = result.get('experiment_tag', 'no_experiment_tag')
        experiment_id = result.get('experiment_id', 'no_experiment_id')
        if experiment_tag not in self.metrics_queue_dict:
            print("=" * 50)
            print("Setting up new w&b logger")
            print("Experiment tag:", experiment_tag)
            print("Experiment id:", experiment_id)
            config = result.get("config")
            queue = multiprocessing.Queue()
            p = multiprocessing.Process(target=wandb_process, args=(queue, config,))
            p.start()
            self.metrics_queue_dict[experiment_tag] = queue
            print("=" * 50)

        queue = self.metrics_queue_dict[experiment_tag]

MasterScrat's avatar
MasterScrat committed
82 83 84 85
        tmp = result.copy()
        for k in ["done", "config", "pid", "timestamp"]:
            if k in tmp:
                del tmp[k]
86

MasterScrat's avatar
MasterScrat committed
87 88 89 90 91
        metrics = {}
        for key, value in flatten_dict(tmp, delimiter="/").items():
            if not isinstance(value, numbers.Number):
                continue
            metrics[key] = value
92 93 94 95 96 97 98 99

        if self.config.get("env_config", {}).get("render", None):
            # uploading relevant videos to wandb
            # we do this step to ensure any config changes done during training is incorporated
            # resume is set to True as we don't want to delete any older videos
            self.clear_folders(resume= True)

            if self._save_folder:
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
                metrics = self.update_video_metrics(result,metrics)

        queue.put(metrics)

    def update_video_metrics(self, result, metrics):
        iterations = result['training_iteration']
        steps = result['timesteps_total']
        perc_comp_mean = result['custom_metrics'].get('percentage_complete_mean',0)*100
        # We ignore *1.mp4 videos which just has the last frame
        _video_file = f'*0.mp4'
        _found_videos = find(_video_file,self._save_folder)
        #_found_videos = list(set(_found_videos) - set(self._upload_files))
        # Sort by create time for uploading to wandb
        _found_videos.sort(key=os.path.getctime)
        for _found_video in _found_videos:
            _splits = _found_video.split(os.sep)
            _check_file = os.stat(_found_video)
            _video_file = _splits[-1]
            _file_split = _video_file.split('.')
            _video_file_name = _file_split[0] + "-" + str(iterations)
            _original_name = ".".join(_file_split[2:-1])
            _video_file_name = ".".join([str(_video_file_name),str(steps),str(int(perc_comp_mean)),_original_name,str(_check_file.st_ctime)])
            _key = _found_video  # Use the video file path as key to identify the video_name
            if not self._file_map.get(_key):
                # Allocate steps, iteration, completion rate to the earliest case
                # when the video file was first created. Discard recent file names
                # TODO: Cannot match exact env details on which video was created
                # and hence defaulting to the env details from when the video was first created
                # To help identify we must record the video file with the env iteration or/and steps etc.
                # Using the env details when the video was created may be useful when recording video during evaluation
                # where we are more interested in the current training state
                self._file_map[_key]=_video_file_name

            # We only move videos that have been flushed out.
            # This is done by checking against a threshold size of 1000 bytes
            # Also check if file has changed from last time
            if _check_file.st_size > 1000 and _check_file.st_ctime >  self._upload_files.get(_found_video, True):
                _video_file_name = self._file_map.get(_key,"Unknown")
                # wandb.log({_video_file_name: wandb.Video(_found_video, format="mp4")})
                metrics[_video_file_name] = wandb.Video(_found_video, format="mp4")

                self._upload_files[_found_video] = _check_file.st_size

        return metrics
MasterScrat's avatar
MasterScrat committed
144 145 146

    def close(self):
        wandb.join()
147 148 149 150 151 152 153 154 155 156 157 158
        all_uploaded_videos = self._upload_files.keys()

        for _found_video in all_uploaded_videos:
            try:
                # Copy upload videos and their meta data once done to the logdir
                src = _found_video
                _video_file = _found_video.split(os.sep)[-1]
                dst = os.path.join(self.logdir,_video_file)
                shutil.copy2(src, dst)
                shutil.copy2(src.replace("mp4","meta.json"), dst.replace("mp4","meta.json"))
            except OSError as e:
                print ("Error: %s - %s." % (e.filename, e.strerror))
159
        self.reset_state()
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179


# each logger has to run in a separate process
def wandb_process(queue, config):
    run = wandb.init(reinit=True, **config.get("env_config", {}).get("wandb", {}))

    if config:
        for k in config.keys():
            if k != "callbacks":
                if wandb.config.get(k) is None:
                    wandb.config[k] = config[k]

        if 'yaml_config' in config['env_config']:
            yaml_config = config['env_config']['yaml_config']
            print("Saving full experiment config:", yaml_config)
            wandb.save(yaml_config)

    while True:
        metrics = queue.get()
        run.log(metrics)