Skip to content
Snippets Groups Projects
Commit 1fffac8e authored by eric_a_scuccimarra's avatar eric_a_scuccimarra
Browse files

Update train.py

parent 16ec475f
No related branches found
No related tags found
No related merge requests found
......@@ -6,8 +6,8 @@ import torch
from mmcv import Config
from mmdet import __version__
from mmcv.runner import init_dist
from mmdet.apis import (get_root_logger, set_random_seed, train_detector)
from mmdet.apis import (get_root_logger, init_dist, set_random_seed,
train_detector)
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
......@@ -97,20 +97,14 @@ def main():
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
try:
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=args.validate,
logger=logger)
except Exception as error:
print("Error... saving checkpoint")
torch.save(model.state_dict(), './models/crash.pth')
raise error
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=args.validate,
logger=logger)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment