From ae4648f9672cbe17f3d17e3b3fdbdd0da9750f3f Mon Sep 17 00:00:00 2001 From: Kai Chen <chenkaidev@gmail.com> Date: Sun, 28 Apr 2019 16:37:05 -0700 Subject: [PATCH] add a tool to process models to be published --- tools/publish_model.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 tools/publish_model.py diff --git a/tools/publish_model.py b/tools/publish_model.py new file mode 100644 index 0000000..39795f1 --- /dev/null +++ b/tools/publish_model.py @@ -0,0 +1,34 @@ +import argparse +import subprocess +import torch + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Process a checkpoint to be published') + parser.add_argument('in_file', help='input checkpoint filename') + parser.add_argument('out_file', help='output checkpoint filename') + args = parser.parse_args() + return args + + +def process_checkpoint(in_file, out_file): + checkpoint = torch.load(in_file, map_location='cpu') + # remove optimizer for smaller file size + if 'optimizer' in checkpoint: + del checkpoint['optimizer'] + # if it is necessary to remove some sensitive data in checkpoint['meta'], + # add the code here. + torch.save(checkpoint, out_file) + sha = subprocess.check_output(['sha256sum', out_file]).decode() + final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) + subprocess.Popen(['mv', out_file, final_file]) + + +def main(): + args = parse_args() + process_checkpoint(args.in_file, args.out_file) + + +if __name__ == '__main__': + main() -- GitLab