Skip to content
Snippets Groups Projects
Commit 088919f6 authored by pangjm's avatar pangjm
Browse files

Update single base version

parent 108fc9e1
No related branches found
No related tags found
No related merge requests found
Showing
with 564 additions and 20 deletions
TDL.md 0 → 100644
### MMCV
- [ ] Implement the attr 'get' of 'Config'
- [ ] Config bugs: None type to '{}' with addict
- [ ] Default logger should be only with gpu0
- [ ] Unit Test: mmcv and mmcv.torchpack
### MMDetection
#### Basic
- [ ] Implement training function without distributed
- [ ] Verify nccl/nccl2/gloo
- [ ] Replace UGLY code: params plug in 'args' to reach a global flow
- [ ] Replace 'print' by 'logger'
#### Testing
- [ ] Implement distributed testing
- [ ] Implement single gpu testing
#### Refactor
- [ ] Re-consider params names
- [ ] Refactor functions in 'core'
- [ ] Merge single test & aug test as one function, so as other redundancy
#### New features
- [ ] Plug loss params into Config
- [ ] Multi-head communication
from .anchor_generator import * from .train_engine import *
from .test_engine import *
from .rpn_ops import *
from .bbox_ops import * from .bbox_ops import *
from .mask_ops import * from .mask_ops import *
from .losses import *
from .eval import * from .eval import *
from .nn import * from .post_processing import *
from .targets import * from .utils import *
from .geometry import bbox_overlaps from .geometry import bbox_overlaps
from .sampling import (random_choice, bbox_assign, bbox_assign_via_overlaps, from .sampling import (random_choice, bbox_assign, bbox_assign_via_overlaps,
bbox_sampling, sample_positives, sample_negatives) bbox_sampling, sample_positives, sample_negatives,
sample_proposals)
from .transforms import (bbox_transform, bbox_transform_inv, bbox_flip, from .transforms import (bbox_transform, bbox_transform_inv, bbox_flip,
bbox_mapping, bbox_mapping_back, bbox2roi, roi2bbox) bbox_mapping, bbox_mapping_back, bbox2roi, roi2bbox,
bbox2result)
from .bbox_target import bbox_target
__all__ = [ __all__ = [
'bbox_overlaps', 'random_choice', 'bbox_assign', 'bbox_overlaps', 'random_choice', 'bbox_assign',
'bbox_assign_via_overlaps', 'bbox_sampling', 'sample_positives', 'bbox_assign_via_overlaps', 'bbox_sampling', 'sample_positives',
'sample_negatives', 'bbox_transform', 'bbox_transform_inv', 'bbox_flip', 'sample_negatives', 'bbox_transform', 'bbox_transform_inv', 'bbox_flip',
'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox' 'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result',
'bbox_target', 'sample_proposals'
] ]
import mmcv
import torch
from .geometry import bbox_overlaps
from .transforms import bbox_transform, bbox_transform_inv
def bbox_target(pos_proposals_list,
neg_proposals_list,
pos_gt_bboxes_list,
pos_gt_labels_list,
cfg,
reg_num_classes=1,
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
return_list=False):
img_per_gpu = len(pos_proposals_list)
all_labels = []
all_label_weights = []
all_bbox_targets = []
all_bbox_weights = []
for img_id in range(img_per_gpu):
pos_proposals = pos_proposals_list[img_id]
neg_proposals = neg_proposals_list[img_id]
pos_gt_bboxes = pos_gt_bboxes_list[img_id]
pos_gt_labels = pos_gt_labels_list[img_id]
debug_img = debug_imgs[img_id] if cfg.debug else None
labels, label_weights, bbox_targets, bbox_weights = proposal_target_single(
pos_proposals, neg_proposals, pos_gt_bboxes, pos_gt_labels,
reg_num_classes, cfg, target_means, target_stds)
all_labels.append(labels)
all_label_weights.append(label_weights)
all_bbox_targets.append(bbox_targets)
all_bbox_weights.append(bbox_weights)
if return_list:
return all_labels, all_label_weights, all_bbox_targets, all_bbox_weights
labels = torch.cat(all_labels, 0)
label_weights = torch.cat(all_label_weights, 0)
bbox_targets = torch.cat(all_bbox_targets, 0)
bbox_weights = torch.cat(all_bbox_weights, 0)
return labels, label_weights, bbox_targets, bbox_weights
def proposal_target_single(pos_proposals,
neg_proposals,
pos_gt_bboxes,
pos_gt_labels,
reg_num_classes,
cfg,
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]):
num_pos = pos_proposals.size(0)
num_neg = neg_proposals.size(0)
num_samples = num_pos + num_neg
labels = pos_proposals.new_zeros(num_samples, dtype=torch.long)
label_weights = pos_proposals.new_zeros(num_samples)
bbox_targets = pos_proposals.new_zeros(num_samples, 4)
bbox_weights = pos_proposals.new_zeros(num_samples, 4)
if num_pos > 0:
labels[:num_pos] = pos_gt_labels
pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
label_weights[:num_pos] = pos_weight
pos_bbox_targets = bbox_transform(pos_proposals, pos_gt_bboxes,
target_means, target_stds)
bbox_targets[:num_pos, :] = pos_bbox_targets
bbox_weights[:num_pos, :] = 1
if num_neg > 0:
label_weights[-num_neg:] = 1.0
if reg_num_classes > 1:
bbox_targets, bbox_weights = expand_target(bbox_targets, bbox_weights,
labels, reg_num_classes)
return labels, label_weights, bbox_targets, bbox_weights
def expand_target(bbox_targets, bbox_weights, labels, num_classes):
bbox_targets_expand = bbox_targets.new_zeros((bbox_targets.size(0),
4 * num_classes))
bbox_weights_expand = bbox_weights.new_zeros((bbox_weights.size(0),
4 * num_classes))
for i in torch.nonzero(labels > 0).squeeze(-1):
start, end = labels[i] * 4, (labels[i] + 1) * 4
bbox_targets_expand[i, start:end] = bbox_targets[i, :]
bbox_weights_expand[i, start:end] = bbox_weights[i, :]
return bbox_targets_expand, bbox_weights_expand
...@@ -42,7 +42,7 @@ def bbox_assign(proposals, ...@@ -42,7 +42,7 @@ def bbox_assign(proposals,
min_pos_iou(float): minimum iou for a bbox to be considered as a positive bbox, min_pos_iou(float): minimum iou for a bbox to be considered as a positive bbox,
for RPN, it is usually set as 0, for Fast R-CNN, for RPN, it is usually set as 0, for Fast R-CNN,
it is usually set as pos_iou_thr it is usually set as pos_iou_thr
crowd_thr: ignore proposals which have iof(intersection over foreground) with crowd_thr: ignore proposals which have iof(intersection over foreground) with
crowd bboxes over crowd_thr crowd bboxes over crowd_thr
Returns: Returns:
tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, ) tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
...@@ -253,3 +253,43 @@ def bbox_sampling(assigned_gt_inds, ...@@ -253,3 +253,43 @@ def bbox_sampling(assigned_gt_inds,
max_overlaps, neg_balance_thr, max_overlaps, neg_balance_thr,
neg_hard_fraction) neg_hard_fraction)
return pos_inds, neg_inds return pos_inds, neg_inds
def sample_proposals(proposals_list, gt_bboxes_list, gt_crowds_list,
gt_labels_list, cfg):
cfg_list = [cfg for _ in range(len(proposals_list))]
results = map(sample_proposals_single, proposals_list, gt_bboxes_list,
gt_crowds_list, gt_labels_list, cfg_list)
# list of tuple to tuple of list
return tuple(map(list, zip(*results)))
def sample_proposals_single(proposals,
gt_bboxes,
gt_crowds,
gt_labels,
cfg):
proposals = proposals[:, :4]
assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
bbox_assign(
proposals, gt_bboxes, gt_crowds, gt_labels, cfg.pos_iou_thr,
cfg.neg_iou_thr, cfg.pos_iou_thr, cfg.crowd_thr)
if cfg.add_gt_as_proposals:
proposals = torch.cat([gt_bboxes, proposals], dim=0)
gt_assign_self = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=proposals.device)
assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
assigned_labels = torch.cat([gt_labels, assigned_labels])
pos_inds, neg_inds = bbox_sampling(
assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction, cfg.neg_pos_ub,
cfg.pos_balance_sampling, max_overlaps, cfg.neg_balance_thr)
pos_proposals = proposals[pos_inds]
neg_proposals = proposals[neg_inds]
pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
pos_gt_labels = assigned_labels[pos_inds]
return (pos_inds, neg_inds, pos_proposals, neg_proposals,
pos_assigned_gt_inds, pos_gt_bboxes, pos_gt_labels)
...@@ -126,3 +126,22 @@ def roi2bbox(rois): ...@@ -126,3 +126,22 @@ def roi2bbox(rois):
bbox = rois[inds, 1:] bbox = rois[inds, 1:]
bbox_list.append(bbox) bbox_list.append(bbox)
return bbox_list return bbox_list
def bbox2result(bboxes, labels, num_classes):
"""Convert detection results to a list of numpy arrays
Args:
bboxes (Tensor): shape (n, 5)
labels (Tensor): shape (n, )
num_classes (int): class number, including background class
Returns:
list(ndarray): bbox results of each class
"""
if bboxes.shape[0] == 0:
return [
np.zeros((0, 5), dtype=np.float32) for i in range(num_classes - 1)
]
else:
bboxes = bboxes.cpu().numpy()
labels = labels.cpu().numpy()
return [bboxes[labels == i, :] for i in range(num_classes - 1)]
from .losses import (
weighted_nll_loss, weighted_cross_entropy, weighted_binary_cross_entropy,
sigmoid_focal_loss, weighted_sigmoid_focal_loss, mask_cross_entropy,
weighted_mask_cross_entropy, smooth_l1_loss, weighted_smoothl1, accuracy)
__all__ = [
'weighted_nll_loss', 'weighted_cross_entropy',
'weighted_binary_cross_entropy', 'sigmoid_focal_loss',
'weighted_sigmoid_focal_loss', 'mask_cross_entropy',
'weighted_mask_cross_entropy', 'smooth_l1_loss', 'weighted_smoothl1',
'accuracy'
]
# TODO merge naive and weighted loss to one function.
import torch
import torch.nn.functional as F
from ..bbox_ops import bbox_transform_inv, bbox_overlaps
def weighted_nll_loss(pred, label, weight, ave_factor=None):
if ave_factor is None:
ave_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.nll_loss(pred, label, size_average=False, reduce=False)
return torch.sum(raw * weight)[None] / ave_factor
def weighted_cross_entropy(pred, label, weight, ave_factor=None):
if ave_factor is None:
ave_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.cross_entropy(pred, label, size_average=False, reduce=False)
return torch.sum(raw * weight)[None] / ave_factor
def weighted_binary_cross_entropy(pred, label, weight, ave_factor=None):
if ave_factor is None:
ave_factor = max(torch.sum(weight > 0).float().item(), 1.)
return F.binary_cross_entropy_with_logits(
pred, label.float(), weight.float(),
size_average=False)[None] / ave_factor
def sigmoid_focal_loss(pred,
target,
weight,
gamma=2.0,
alpha=0.25,
size_average=True):
pred_sigmoid = pred.sigmoid()
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma)
return F.binary_cross_entropy_with_logits(
pred, target, weight, size_average=size_average)
def weighted_sigmoid_focal_loss(pred,
target,
weight,
gamma=2.0,
alpha=0.25,
ave_factor=None,
num_classes=80):
if ave_factor is None:
ave_factor = torch.sum(weight > 0).float().item() / num_classes + 1e-6
return sigmoid_focal_loss(
pred, target, weight, gamma=gamma, alpha=alpha,
size_average=False)[None] / ave_factor
def mask_cross_entropy(pred, target, label):
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, size_average=True)[None]
def weighted_mask_cross_entropy(pred, target, weight, label):
num_rois = pred.size()[0]
num_samples = torch.sum(weight > 0).float().item() + 1e-6
assert num_samples >= 1
inds = torch.arange(0, num_rois).long().cuda()
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight, size_average=False)[None] / num_samples
def smooth_l1_loss(pred, target, beta=1.0, size_average=True, reduce=True):
assert beta > 0
assert pred.size() == target.size() and target.numel() > 0
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
if size_average:
loss /= pred.numel()
if reduce:
loss = loss.sum()
return loss
def weighted_smoothl1(pred, target, weight, beta=1.0, ave_factor=None):
if ave_factor is None:
ave_factor = torch.sum(weight > 0).float().item() / 4 + 1e-6
loss = smooth_l1_loss(pred, target, beta, size_average=False, reduce=False)
return torch.sum(loss * weight)[None] / ave_factor
def accuracy(pred, target, topk=1):
if isinstance(topk, int):
topk = (topk, )
return_single = True
maxk = max(topk)
_, pred_label = pred.topk(maxk, 1, True, True)
pred_label = pred_label.t()
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res
...@@ -2,9 +2,10 @@ from .segms import (flip_segms, polys_to_mask, mask_to_bbox, ...@@ -2,9 +2,10 @@ from .segms import (flip_segms, polys_to_mask, mask_to_bbox,
polys_to_mask_wrt_box, polys_to_boxes, rle_mask_voting, polys_to_mask_wrt_box, polys_to_boxes, rle_mask_voting,
rle_mask_nms, rle_masks_to_boxes) rle_mask_nms, rle_masks_to_boxes)
from .utils import split_combined_gt_polys from .utils import split_combined_gt_polys
from .mask_target import mask_target
__all__ = [ __all__ = [
'flip_segms', 'polys_to_mask', 'mask_to_bbox', 'polys_to_mask_wrt_box', 'flip_segms', 'polys_to_mask', 'mask_to_bbox', 'polys_to_mask_wrt_box',
'polys_to_boxes', 'rle_mask_voting', 'rle_mask_nms', 'rle_masks_to_boxes', 'polys_to_boxes', 'rle_mask_voting', 'rle_mask_nms', 'rle_masks_to_boxes',
'split_combined_gt_polys' 'split_combined_gt_polys', 'mask_target'
] ]
import torch
import numpy as np
from .segms import polys_to_mask_wrt_box
def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_polys_list,
img_meta, cfg):
cfg_list = [cfg for _ in range(len(pos_proposals_list))]
img_metas = [img_meta for _ in range(len(pos_proposals_list))]
mask_targets = map(mask_target_single, pos_proposals_list,
pos_assigned_gt_inds_list, gt_polys_list, img_metas,
cfg_list)
mask_targets = torch.cat(tuple(mask_targets), dim=0)
return mask_targets
def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_polys,
img_meta, cfg):
mask_size = cfg.mask_size
num_pos = pos_proposals.size(0)
mask_targets = pos_proposals.new_zeros((num_pos, mask_size, mask_size))
if num_pos > 0:
pos_proposals = pos_proposals.cpu().numpy()
pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
scale_factor = img_meta['scale_factor'][0].cpu().numpy()
for i in range(num_pos):
bbox = pos_proposals[i, :] / scale_factor
polys = gt_polys[pos_assigned_gt_inds[i]]
mask = polys_to_mask_wrt_box(polys, bbox, mask_size)
mask = np.array(mask > 0, dtype=np.float32)
mask_targets[i, ...] = torch.from_numpy(mask).to(
mask_targets.device)
return mask_targets
import torch import torch
from mmcv.ops import nms from mmdet.ops import nms
import numpy as np import numpy as np
from ..bbox_ops import bbox_mapping_back from ..bbox_ops import bbox_mapping_back
......
from .anchor_generator import *
from .anchor_target import *
import torch
import numpy as np
from ..bbox_ops import (bbox_assign, bbox_transform, bbox_sampling)
def anchor_target(anchor_list, valid_flag_list, featmap_sizes, gt_bboxes_list,
img_shapes, target_means, target_stds, cfg):
"""Compute anchor regression and classification targets
Args:
anchor_list(list): anchors of each feature map level
featuremap_sizes(list): feature map sizes
gt_bboxes_list(list): ground truth bbox of images in a mini-batch
img_shapes(list): shape of each image in a mini-batch
cfg(dict): configs
Returns:
tuple
"""
if len(featmap_sizes) == len(anchor_list):
all_anchors = torch.cat(anchor_list, 0)
anchor_nums = [anchors.size(0) for anchors in anchor_list]
use_isomerism_anchors = False
elif len(img_shapes) == len(anchor_list):
# using different anchors for different images
all_anchors_list = [
torch.cat(anchor_list[img_id], 0)
for img_id in range(len(img_shapes))
]
anchor_nums = [anchors.size(0) for anchors in anchor_list[0]]
use_isomerism_anchors = True
else:
raise ValueError('length of anchor_list should be equal to number of '
'feature lvls or number of images in a batch')
all_labels = []
all_label_weights = []
all_bbox_targets = []
all_bbox_weights = []
num_total_sampled = 0
for img_id in range(len(img_shapes)):
if isinstance(valid_flag_list[img_id], list):
valid_flags = torch.cat(valid_flag_list[img_id], 0)
else:
valid_flags = valid_flag_list[img_id]
if use_isomerism_anchors:
all_anchors = all_anchors_list[img_id]
inside_flags = anchor_inside_flags(all_anchors, valid_flags,
img_shapes[img_id][:2],
cfg.allowed_border)
if not inside_flags.any():
return None
gt_bboxes = gt_bboxes_list[img_id]
anchor_targets = anchor_target_single(all_anchors, inside_flags,
gt_bboxes, target_means,
target_stds, cfg)
(labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds) = anchor_targets
all_labels.append(labels)
all_label_weights.append(label_weights)
all_bbox_targets.append(bbox_targets)
all_bbox_weights.append(bbox_weights)
num_total_sampled += max(pos_inds.numel() + neg_inds.numel(), 1)
all_labels = torch.stack(all_labels, 0)
all_label_weights = torch.stack(all_label_weights, 0)
all_bbox_targets = torch.stack(all_bbox_targets, 0)
all_bbox_weights = torch.stack(all_bbox_weights, 0)
# split into different feature levels
labels_list = []
label_weights_list = []
bbox_targets_list = []
bbox_weights_list = []
start = 0
for anchor_num in anchor_nums:
end = start + anchor_num
labels_list.append(all_labels[:, start:end].squeeze(0))
label_weights_list.append(all_label_weights[:, start:end].squeeze(0))
bbox_targets_list.append(all_bbox_targets[:, start:end].squeeze(0))
bbox_weights_list.append(all_bbox_weights[:, start:end].squeeze(0))
start = end
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_sampled)
def anchor_target_single(all_anchors, inside_flags, gt_bboxes, target_means,
target_stds, cfg):
num_total_anchors = all_anchors.size(0)
anchors = all_anchors[inside_flags, :]
assigned_gt_inds, argmax_overlaps, max_overlaps = bbox_assign(
anchors,
gt_bboxes,
pos_iou_thr=cfg.pos_iou_thr,
neg_iou_thr=cfg.neg_iou_thr,
min_pos_iou=cfg.min_pos_iou)
pos_inds, neg_inds = bbox_sampling(assigned_gt_inds, cfg.anchor_batch_size,
cfg.pos_fraction, cfg.neg_pos_ub,
cfg.pos_balance_sampling, max_overlaps,
cfg.neg_balance_thr)
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = torch.zeros_like(assigned_gt_inds)
label_weights = torch.zeros_like(assigned_gt_inds, dtype=torch.float)
if len(pos_inds) > 0:
pos_inds = unique(pos_inds)
pos_anchors = anchors[pos_inds, :]
pos_gt_bbox = gt_bboxes[assigned_gt_inds[pos_inds] - 1, :]
pos_bbox_targets = bbox_transform(pos_anchors, pos_gt_bbox,
target_means, target_stds)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
labels[pos_inds] = 1
if cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = cfg.pos_weight
if len(neg_inds) > 0:
neg_inds = unique(neg_inds)
label_weights[neg_inds] = 1.0
# map up to original set of anchors
labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds)
def anchor_inside_flags(all_anchors, valid_flags, img_shape, allowed_border=0):
img_h, img_w = img_shape.float()
if allowed_border >= 0:
inside_flags = valid_flags & \
(all_anchors[:, 0] >= -allowed_border) & \
(all_anchors[:, 1] >= -allowed_border) & \
(all_anchors[:, 2] < img_w + allowed_border) & \
(all_anchors[:, 3] < img_h + allowed_border)
else:
inside_flags = valid_flags
return inside_flags
def unique(tensor):
if tensor.is_cuda:
u_tensor = np.unique(tensor.cpu().numpy())
return tensor.new_tensor(u_tensor)
else:
return torch.unique(tensor)
def unmap(data, count, inds, fill=0):
""" Unmap a subset of item (data) back to the original set of items (of
size count) """
if data.dim() == 1:
ret = data.new_full((count, ), fill)
ret[inds] = data
else:
new_size = (count, ) + data.size()[1:]
ret = data.new_full(new_size, fill)
ret[inds, :] = data
return ret
from .anchor_target import anchor_target
from .bbox_target import bbox_target
from .mask_target import mask_target
__all__ = ['anchor_target', 'bbox_target', 'mask_target']
def anchor_target():
pass
def bbox_target():
pass
def mask_target():
pass
from mmdet.datasets import collate
from mmdet.nn.parallel import scatter
__all__ = ['_data_func']
def _data_func(data, gpu_id):
imgs, img_metas = tuple(
scatter(collate([data], samples_per_gpu=1), [gpu_id])[0])
return dict(
img=imgs,
img_meta=img_metas,
return_loss=False,
return_bboxes=True,
rescale=True)
import numpy as np
import torch
from collections import OrderedDict
from mmdet.nn.parallel import scatter
def parse_losses(losses):
log_vars = OrderedDict()
for loss_key, loss_value in losses.items():
if isinstance(loss_value, dict):
for _key, _value in loss_value.items():
if isinstance(_value, list):
_value = sum([_loss.mean() for _loss in _value])
else:
_value = _value.mean()
log_vars[_keys] = _value
elif isinstance(loss_value, list):
log_vars[loss_key] = sum(_loss.mean() for _loss in loss_value)
else:
log_vars[loss_key] = loss_value.mean()
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
log_vars['loss'] = loss
for _key, _value in log_vars.items():
log_vars[_key] = _value.item()
return loss, log_vars
def batch_processor(model, data, train_mode, args=None):
data = scatter(data, [torch.cuda.current_device()])[0]
losses = model(**data)
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss / args.world_size,
log_vars=log_vars,
num_samples=len(data['img'].data))
return outputs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment