Skip to content
Snippets Groups Projects
Unverified Commit 9df04d54 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Potential bug fix for GuidedAnchorHead (#754)

* code formatting for guided_anchor_head.py

* bug fix for using multi_apply
parent 726ebdc9
No related branches found
No related tags found
No related merge requests found
......@@ -36,15 +36,14 @@ class FeatureAdaption(nn.Module):
deformable_groups=4):
super(FeatureAdaption, self).__init__()
offset_channels = kernel_size * kernel_size * 2
self.conv_offset = nn.Conv2d(2,
deformable_groups * offset_channels,
1,
bias=False)
self.conv_adaption = DeformConv(in_channels,
out_channels,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
deformable_groups=deformable_groups)
self.conv_offset = nn.Conv2d(
2, deformable_groups * offset_channels, 1, bias=False)
self.conv_adaption = DeformConv(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
deformable_groups=deformable_groups)
self.relu = nn.ReLU(inplace=True)
def init_weights(self):
......@@ -109,20 +108,23 @@ class GuidedAnchorHead(AnchorHead):
target_stds=(1.0, 1.0, 1.0, 1.0),
deformable_groups=4,
loc_filter_thr=0.01,
loss_loc=dict(type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_shape=dict(type='IoULoss',
style='bounded',
beta=0.2,
loss_weight=1.0),
loss_cls=dict(type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)):
loss_loc=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_shape=dict(
type='IoULoss',
style='bounded',
beta=0.2,
loss_weight=1.0),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0, loss_weight=1.0)):
super(AnchorHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
......@@ -258,8 +260,8 @@ class GuidedAnchorHead(AnchorHead):
inside_flags_list.append(inside_flags)
# inside_flag for a position is true if any anchor in this
# position is true
inside_flags = (torch.stack(inside_flags_list, 0).sum(dim=0) >
0)
inside_flags = (
torch.stack(inside_flags_list, 0).sum(dim=0) > 0)
multi_level_flags.append(inside_flags)
inside_flag_list.append(multi_level_flags)
return approxs_list, inside_flag_list
......@@ -347,11 +349,12 @@ class GuidedAnchorHead(AnchorHead):
-1, 2).detach()[mask]
bbox_deltas = anchor_deltas.new_full(squares.size(), 0)
bbox_deltas[:, 2:] = anchor_deltas
guided_anchors = delta2bbox(squares,
bbox_deltas,
self.anchoring_means,
self.anchoring_stds,
wh_ratio_clip=1e-6)
guided_anchors = delta2bbox(
squares,
bbox_deltas,
self.anchoring_means,
self.anchoring_stds,
wh_ratio_clip=1e-6)
return guided_anchors, mask
def loss_shape_single(self, shape_pred, bbox_anchors, bbox_gts,
......@@ -368,23 +371,26 @@ class GuidedAnchorHead(AnchorHead):
bbox_anchors_ = bbox_anchors[inds]
bbox_gts_ = bbox_gts[inds]
anchor_weights_ = anchor_weights[inds]
pred_anchors_ = delta2bbox(bbox_anchors_,
bbox_deltas_,
self.anchoring_means,
self.anchoring_stds,
wh_ratio_clip=1e-6)
loss_shape = self.loss_shape(pred_anchors_,
bbox_gts_,
anchor_weights_,
avg_factor=anchor_total_num)
pred_anchors_ = delta2bbox(
bbox_anchors_,
bbox_deltas_,
self.anchoring_means,
self.anchoring_stds,
wh_ratio_clip=1e-6)
loss_shape = self.loss_shape(
pred_anchors_,
bbox_gts_,
anchor_weights_,
avg_factor=anchor_total_num)
return loss_shape
def loss_loc_single(self, loc_pred, loc_target, loc_weight, loc_avg_factor,
cfg):
loss_loc = self.loss_loc(loc_pred.reshape(-1, 1),
loc_target.reshape(-1, 1).long(),
loc_weight.reshape(-1, 1),
avg_factor=loc_avg_factor)
loss_loc = self.loss_loc(
loc_pred.reshape(-1, 1),
loc_target.reshape(-1, 1).long(),
loc_weight.reshape(-1, 1),
avg_factor=loc_avg_factor)
return loss_loc
def loss(self,
......@@ -418,41 +424,44 @@ class GuidedAnchorHead(AnchorHead):
# get shape targets
sampling = False if not hasattr(cfg, 'ga_sampler') else True
shape_targets = ga_shape_target(approxs_list,
inside_flag_list,
squares_list,
gt_bboxes,
img_metas,
self.approxs_per_octave,
cfg,
sampling=sampling)
shape_targets = ga_shape_target(
approxs_list,
inside_flag_list,
squares_list,
gt_bboxes,
img_metas,
self.approxs_per_octave,
cfg,
sampling=sampling)
if shape_targets is None:
return None
(bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
anchor_bg_num) = shape_targets
anchor_total_num = (anchor_fg_num if not sampling else anchor_fg_num +
anchor_bg_num)
anchor_total_num = (
anchor_fg_num if not sampling else anchor_fg_num + anchor_bg_num)
# get anchor targets
sampling = False if self.cls_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(guided_anchors_list,
inside_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=sampling)
cls_reg_targets = anchor_target(
guided_anchors_list,
inside_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (num_total_pos if self.cls_focal_loss else
num_total_pos + num_total_neg)
num_total_samples = (
num_total_pos if self.cls_focal_loss else num_total_pos +
num_total_neg)
# get classification and bbox regression losses
losses_cls, losses_bbox = multi_apply(
......@@ -467,24 +476,32 @@ class GuidedAnchorHead(AnchorHead):
cfg=cfg)
# get anchor location loss
losses_loc, = multi_apply(self.loss_loc_single,
loc_preds,
loc_targets,
loc_weights,
loc_avg_factor=loc_avg_factor,
cfg=cfg)
losses_loc = []
for i in range(len(loc_preds)):
loss_loc = self.loss_loc_single(
loc_preds[i],
loc_targets[i],
loc_weights[i],
loc_avg_factor=loc_avg_factor,
cfg=cfg)
losses_loc.append(loss_loc)
# get anchor shape loss
losses_shape, = multi_apply(self.loss_shape_single,
shape_preds,
bbox_anchors_list,
bbox_gts_list,
anchor_weights_list,
anchor_total_num=anchor_total_num)
return dict(loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_shape=losses_shape,
loss_loc=losses_loc)
losses_shape = []
for i in range(len(shape_preds)):
loss_shape = self.loss_shape_single(
shape_preds[i],
bbox_anchors_list[i],
bbox_gts_list[i],
anchor_weights_list[i],
anchor_total_num=anchor_total_num)
losses_shape.append(loss_shape)
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_shape=losses_shape,
loss_loc=losses_loc)
def get_bboxes(self,
cls_scores,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment