Skip to content
Snippets Groups Projects
Unverified Commit d95727b2 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

add a field to support the evaluation interval (#849)

parent e4917130
No related branches found
No related tags found
No related merge requests found
...@@ -171,6 +171,7 @@ log_config = dict( ...@@ -171,6 +171,7 @@ log_config = dict(
# dict(type='TensorboardLoggerHook') # dict(type='TensorboardLoggerHook')
]) ])
# yapf:enable # yapf:enable
evaluation = dict(interval=1)
# runtime settings # runtime settings
total_epochs = 12 total_epochs = 12
dist_params = dict(backend='nccl') dist_params = dict(backend='nccl')
......
...@@ -91,8 +91,8 @@ def build_optimizer(model, optimizer_cfg): ...@@ -91,8 +91,8 @@ def build_optimizer(model, optimizer_cfg):
paramwise_options = optimizer_cfg.pop('paramwise_options', None) paramwise_options = optimizer_cfg.pop('paramwise_options', None)
# if no paramwise option is specified, just use the global setting # if no paramwise option is specified, just use the global setting
if paramwise_options is None: if paramwise_options is None:
return obj_from_dict( return obj_from_dict(optimizer_cfg, torch.optim,
optimizer_cfg, torch.optim, dict(params=model.parameters())) dict(params=model.parameters()))
else: else:
assert isinstance(paramwise_options, dict) assert isinstance(paramwise_options, dict)
# get base lr and weight decay # get base lr and weight decay
...@@ -154,15 +154,19 @@ def _dist_train(model, dataset, cfg, validate=False): ...@@ -154,15 +154,19 @@ def _dist_train(model, dataset, cfg, validate=False):
# register eval hooks # register eval hooks
if validate: if validate:
val_dataset_cfg = cfg.data.val val_dataset_cfg = cfg.data.val
eval_cfg = cfg.get('evaluation', {})
if isinstance(model.module, RPN): if isinstance(model.module, RPN):
# TODO: implement recall hooks for other datasets # TODO: implement recall hooks for other datasets
runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg)) runner.register_hook(
CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
else: else:
dataset_type = getattr(datasets, val_dataset_cfg.type) dataset_type = getattr(datasets, val_dataset_cfg.type)
if issubclass(dataset_type, datasets.CocoDataset): if issubclass(dataset_type, datasets.CocoDataset):
runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg)) runner.register_hook(
CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
else: else:
runner.register_hook(DistEvalmAPHook(val_dataset_cfg)) runner.register_hook(
DistEvalmAPHook(val_dataset_cfg, **eval_cfg))
if cfg.resume_from: if cfg.resume_from:
runner.resume(cfg.resume_from) runner.resume(cfg.resume_from)
......
...@@ -116,9 +116,11 @@ class CocoDistEvalRecallHook(DistEvalHook): ...@@ -116,9 +116,11 @@ class CocoDistEvalRecallHook(DistEvalHook):
def __init__(self, def __init__(self,
dataset, dataset,
interval=1,
proposal_nums=(100, 300, 1000), proposal_nums=(100, 300, 1000),
iou_thrs=np.arange(0.5, 0.96, 0.05)): iou_thrs=np.arange(0.5, 0.96, 0.05)):
super(CocoDistEvalRecallHook, self).__init__(dataset) super(CocoDistEvalRecallHook, self).__init__(
dataset, interval=interval)
self.proposal_nums = np.array(proposal_nums, dtype=np.int32) self.proposal_nums = np.array(proposal_nums, dtype=np.int32)
self.iou_thrs = np.array(iou_thrs, dtype=np.float32) self.iou_thrs = np.array(iou_thrs, dtype=np.float32)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment