From c2352d5260868e6ad119d751914424923a16c95a Mon Sep 17 00:00:00 2001
From: Cao Yuhang <yhcao6@gmail.com>
Date: Wed, 15 May 2019 16:21:53 +0800
Subject: [PATCH] Add tool to analyze log (#648)

* add tool to analyze log

* minor fix

* plot multi keys, use subparser, modify code

* add two method add_plot_parser and add_time_parser, minor fix

* minor fix

* move add parser func outside main

* move subparser.add inside add parser func
---
 tools/analyze_logs.py | 178 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 178 insertions(+)
 create mode 100644 tools/analyze_logs.py

diff --git a/tools/analyze_logs.py b/tools/analyze_logs.py
new file mode 100644
index 0000000..b56f3de
--- /dev/null
+++ b/tools/analyze_logs.py
@@ -0,0 +1,178 @@
+import argparse
+import json
+from collections import defaultdict
+
+import matplotlib.pyplot as plt
+import numpy as np
+import seaborn as sns
+
+
+def cal_train_time(log_dicts, args):
+    for i, log_dict in enumerate(log_dicts):
+        print('{}Analyze train time of {}{}'.format('-' * 5, args.json_logs[i],
+                                                    '-' * 5))
+        all_times = []
+        for epoch in log_dict.keys():
+            if args.include_outliers:
+                all_times.append(log_dict[epoch]['time'])
+            else:
+                all_times.append(log_dict[epoch]['time'][1:])
+        all_times = np.array(all_times)
+        epoch_ave_time = all_times.mean(-1)
+        slowest_epoch = epoch_ave_time.argmax()
+        fastest_epoch = epoch_ave_time.argmin()
+        std_over_epoch = epoch_ave_time.std()
+        print('slowest epoch {}, average time is {:.4f}'.format(
+            slowest_epoch + 1, epoch_ave_time[slowest_epoch]))
+        print('fastest epoch {}, average time is {:.4f}'.format(
+            fastest_epoch + 1, epoch_ave_time[fastest_epoch]))
+        print('time std over epochs is {:.4f}'.format(std_over_epoch))
+        print('average iter time: {:.4f} s/iter'.format(np.mean(all_times)))
+        print()
+
+
+def plot_curve(log_dicts, args):
+    if args.backend is not None:
+        plt.switch_backend(args.backend)
+    sns.set_style(args.style)
+    # if legend is None, use {filename}_{key} as legend
+    legend = args.legend
+    if legend is None:
+        legend = []
+        for json_log in args.json_logs:
+            for metric in args.keys:
+                legend.append('{}_{}'.format(json_log, metric))
+    assert len(legend) == (len(args.json_logs) * len(args.keys))
+    metrics = args.keys
+
+    num_metrics = len(metrics)
+    for i, log_dict in enumerate(log_dicts):
+        epochs = list(log_dict.keys())
+        for j, metric in enumerate(metrics):
+            print('plot curve of {}, metric is {}'.format(
+                args.json_logs[i], metric))
+            assert metric in log_dict[
+                epochs[0]], '{} does not contain metric {}'.format(
+                    args.json_logs[i], metric)
+
+            if 'mAP' in metric:
+                xs = np.arange(1, max(epochs) + 1)
+                ys = []
+                for epoch in epochs:
+                    ys += log_dict[epoch][metric]
+                ax = plt.gca()
+                ax.set_xticks(xs)
+                plt.xlabel('epoch')
+                plt.plot(xs, ys, label=legend[i * num_metrics + j], marker='o')
+            else:
+                xs = []
+                ys = []
+                num_iters_per_epoch = log_dict[epochs[0]]['iter'][-1]
+                for epoch in epochs:
+                    iters = log_dict[epoch]['iter']
+                    if log_dict[epoch]['mode'][-1] == 'val':
+                        iters = iters[:-1]
+                    xs.append(
+                        np.array(iters) + (epoch - 1) * num_iters_per_epoch)
+                    ys.append(np.array(log_dict[epoch][metric][:len(iters)]))
+                xs = np.concatenate(xs)
+                ys = np.concatenate(ys)
+                plt.xlabel('iter')
+                plt.plot(
+                    xs, ys, label=legend[i * num_metrics + j], linewidth=0.5)
+            plt.legend()
+        if args.title is not None:
+            plt.title(args.title)
+    if args.out is None:
+        plt.show()
+    else:
+        print('save curve to: {}'.format(args.out))
+        plt.savefig(args.out)
+        plt.cla()
+
+
+def add_plot_parser(subparsers):
+    parser_plt = subparsers.add_parser(
+        'plot_curve', help='parser for plotting curves')
+    parser_plt.add_argument(
+        'json_logs',
+        type=str,
+        nargs='+',
+        help='path of train log in json format')
+    parser_plt.add_argument(
+        '--keys',
+        type=str,
+        nargs='+',
+        default=['bbox_mAP'],
+        help='the metric that you want to plot')
+    parser_plt.add_argument('--title', type=str, help='title of figure')
+    parser_plt.add_argument(
+        '--legend',
+        type=str,
+        nargs='+',
+        default=None,
+        help='legend of each plot')
+    parser_plt.add_argument(
+        '--backend', type=str, default=None, help='backend of plt')
+    parser_plt.add_argument(
+        '--style', type=str, default='dark', help='style of plt')
+    parser_plt.add_argument('--out', type=str, default=None)
+
+
+def add_time_parser(subparsers):
+    parser_time = subparsers.add_parser(
+        'cal_train_time',
+        help='parser for computing the average time per training iteration')
+    parser_time.add_argument(
+        'json_logs',
+        type=str,
+        nargs='+',
+        help='path of train log in json format')
+    parser_time.add_argument(
+        '--include-outliers',
+        action='store_true',
+        help='include the first value of every epoch when computing '
+        'the average time')
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Analyze Json Log')
+    # currently only support plot curve and calculate average train time
+    subparsers = parser.add_subparsers(dest='task', help='task parser')
+    add_plot_parser(subparsers)
+    add_time_parser(subparsers)
+    args = parser.parse_args()
+    return args
+
+
+def load_json_logs(json_logs):
+    # load and convert json_logs to log_dict, key is epoch, value is a sub dict
+    # keys of sub dict is different metrics, e.g. memory, bbox_mAP
+    # value of sub dict is a list of corresponding values of all iterations
+    log_dicts = [dict() for _ in json_logs]
+    for json_log, log_dict in zip(json_logs, log_dicts):
+        with open(json_log, 'r') as log_file:
+            for l in log_file:
+                log = json.loads(l.strip())
+                epoch = log.pop('epoch')
+                if epoch not in log_dict:
+                    log_dict[epoch] = defaultdict(list)
+                for k, v in log.items():
+                    log_dict[epoch][k].append(v)
+    return log_dicts
+
+
+def main():
+    args = parse_args()
+
+    json_logs = args.json_logs
+    for json_log in json_logs:
+        assert json_log.endswith('.json')
+
+    log_dicts = load_json_logs(json_logs)
+
+    eval(args.task)(log_dicts, args)
+
+
+if __name__ == '__main__':
+    main()
-- 
GitLab