Skip to content
Snippets Groups Projects
Forked from nikhil_rayaprolu / food-round2
737 commits behind the upstream repository.
sampling.py 12.70 KiB
import numpy as np
import torch

from .geometry import bbox_overlaps


def random_choice(gallery, num):
    assert len(gallery) >= num
    if isinstance(gallery, list):
        gallery = np.array(gallery)
    cands = np.arange(len(gallery))
    np.random.shuffle(cands)
    rand_inds = cands[:num]
    if not isinstance(gallery, np.ndarray):
        rand_inds = torch.from_numpy(rand_inds).long()
        if gallery.is_cuda:
            rand_inds = rand_inds.cuda(gallery.get_device())
    return gallery[rand_inds]


def bbox_assign(proposals,
                gt_bboxes,
                gt_crowd_bboxes=None,
                gt_labels=None,
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=.0,
                crowd_thr=-1):
    """Assign a corresponding gt bbox or background to each proposal/anchor
    This function assign a gt bbox to every proposal, each proposals will be
    assigned with -1, 0, or a positive number. -1 means don't care, 0 means
    negative sample, positive number is the index (1-based) of assigned gt.
    If gt_crowd_bboxes is not None, proposals which have iof(intersection over foreground)
    with crowd bboxes over crowd_thr will be ignored
    Args:
        proposals(Tensor): proposals or RPN anchors, shape (n, 4)
        gt_bboxes(Tensor): shape (k, 4)
        gt_crowd_bboxes(Tensor): shape(m, 4)
        gt_labels(Tensor, optional): shape (k, )
        pos_iou_thr(float): iou threshold for positive bboxes
        neg_iou_thr(float or tuple): iou threshold for negative bboxes
        min_pos_iou(float): minimum iou for a bbox to be considered as a positive bbox,
                            for RPN, it is usually set as 0, for Fast R-CNN,
                            it is usually set as pos_iou_thr
        crowd_thr: ignore proposals which have iof(intersection over foreground) with
        crowd bboxes over crowd_thr
    Returns:
        tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
    """

    # calculate overlaps between the proposals and the gt boxes
    overlaps = bbox_overlaps(proposals, gt_bboxes)
    if overlaps.numel() == 0:
        raise ValueError('No gt bbox or proposals')

    # ignore proposals according to crowd bboxes
    if (crowd_thr > 0) and (gt_crowd_bboxes is
                            not None) and (gt_crowd_bboxes.numel() > 0):
        crowd_overlaps = bbox_overlaps(proposals, gt_crowd_bboxes, mode='iof')
        crowd_max_overlaps, _ = crowd_overlaps.max(dim=1)
        crowd_bboxes_inds = torch.nonzero(
            crowd_max_overlaps > crowd_thr).long()
        if crowd_bboxes_inds.numel() > 0:
            overlaps[crowd_bboxes_inds, :] = -1

    return bbox_assign_via_overlaps(overlaps, gt_labels, pos_iou_thr,
                                    neg_iou_thr, min_pos_iou)


def bbox_assign_via_overlaps(overlaps,
                             gt_labels=None,
                             pos_iou_thr=0.5,
                             neg_iou_thr=0.5,
                             min_pos_iou=.0):
    """Assign a corresponding gt bbox or background to each proposal/anchor
    This function assign a gt bbox to every proposal, each proposals will be
    assigned with -1, 0, or a positive number. -1 means don't care, 0 means
    negative sample, positive number is the index (1-based) of assigned gt.
    The assignment is done in following steps, the order matters:
    1. assign every anchor to -1
    2. assign proposals whose iou with all gts < neg_iou_thr to 0
    3. for each anchor, if the iou with its nearest gt >= pos_iou_thr,
    assign it to that bbox
    4. for each gt bbox, assign its nearest proposals(may be more than one)
    to itself
    Args:
        overlaps(Tensor): overlaps between n proposals and k gt_bboxes, shape(n, k)
        gt_labels(Tensor, optional): shape (k, )
        pos_iou_thr(float): iou threshold for positive bboxes
        neg_iou_thr(float or tuple): iou threshold for negative bboxes
        min_pos_iou(float): minimum iou for a bbox to be considered as a positive bbox,
                            for RPN, it is usually set as 0, for Fast R-CNN,
                            it is usually set as pos_iou_thr
    Returns:
        tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
    """
    num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
    # 1. assign -1 by default
    assigned_gt_inds = overlaps.new(num_bboxes).long().fill_(-1)

    if overlaps.numel() == 0:
        raise ValueError('No gt bbox or proposals')

    assert overlaps.size() == (num_bboxes, num_gts)
    # for each anchor, which gt best overlaps with it
    # for each anchor, the max iou of all gts
    max_overlaps, argmax_overlaps = overlaps.max(dim=1)
    # for each gt, which anchor best overlaps with it
    # for each gt, the max iou of all proposals
    gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0)

    # 2. assign negative: below
    if isinstance(neg_iou_thr, float):
        assigned_gt_inds[(max_overlaps >= 0)
                         & (max_overlaps < neg_iou_thr)] = 0
    elif isinstance(neg_iou_thr, tuple):
        assert len(neg_iou_thr) == 2
        assigned_gt_inds[(max_overlaps >= neg_iou_thr[0])
                         & (max_overlaps < neg_iou_thr[1])] = 0

    # 3. assign positive: above positive IoU threshold
    pos_inds = max_overlaps >= pos_iou_thr
    assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1

    # 4. assign fg: for each gt, proposals with highest IoU
    for i in range(num_gts):
        if gt_max_overlaps[i] >= min_pos_iou:
            assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1

    if gt_labels is None:
        return assigned_gt_inds, argmax_overlaps, max_overlaps
    else:
        assigned_labels = assigned_gt_inds.new(num_bboxes).fill_(0)
        pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
        if pos_inds.numel() > 0:
            assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
                                                  1]
        return assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps


def sample_positives(assigned_gt_inds, num_expected, balance_sampling=True):
    """Balance sampling for positive bboxes/anchors
    1. calculate average positive num for each gt: num_per_gt
    2. sample at most num_per_gt positives for each gt
    3. random sampling from rest anchors if not enough fg
    """
    pos_inds = torch.nonzero(assigned_gt_inds > 0)
    if pos_inds.numel() != 0:
        pos_inds = pos_inds.squeeze(1)
    if pos_inds.numel() <= num_expected:
        return pos_inds
    elif not balance_sampling:
        return random_choice(pos_inds, num_expected)
    else:
        unique_gt_inds = torch.unique(assigned_gt_inds[pos_inds].cpu())
        num_gts = len(unique_gt_inds)
        num_per_gt = int(round(num_expected / float(num_gts)) + 1)
        sampled_inds = []
        for i in unique_gt_inds:
            inds = torch.nonzero(assigned_gt_inds == i.item())
            if inds.numel() != 0:
                inds = inds.squeeze(1)
            else:
                continue
            if len(inds) > num_per_gt:
                inds = random_choice(inds, num_per_gt)
            sampled_inds.append(inds)
        sampled_inds = torch.cat(sampled_inds)
        if len(sampled_inds) < num_expected:
            num_extra = num_expected - len(sampled_inds)
            extra_inds = np.array(
                list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
            if len(extra_inds) > num_extra:
                extra_inds = random_choice(extra_inds, num_extra)
            extra_inds = torch.from_numpy(extra_inds).to(
                assigned_gt_inds.device).long()
            sampled_inds = torch.cat([sampled_inds, extra_inds])
        elif len(sampled_inds) > num_expected:
            sampled_inds = random_choice(sampled_inds, num_expected)
        return sampled_inds


def sample_negatives(assigned_gt_inds,
                     num_expected,
                     max_overlaps=None,
                     balance_thr=0,
                     hard_fraction=0.5):
    """Balance sampling for negative bboxes/anchors
    negative samples are split into 2 set: hard(balance_thr <= iou < neg_iou_thr)
    and easy(iou < balance_thr), around equal number of bg are sampled
    from each set.
    """
    neg_inds = torch.nonzero(assigned_gt_inds == 0)
    if neg_inds.numel() != 0:
        neg_inds = neg_inds.squeeze(1)
    if len(neg_inds) <= num_expected:
        return neg_inds
    elif balance_thr <= 0:
        # uniform sampling among all negative samples
        return random_choice(neg_inds, num_expected)
    else:
        assert max_overlaps is not None
        max_overlaps = max_overlaps.cpu().numpy()
        # balance sampling for negative samples
        neg_set = set(neg_inds.cpu().numpy())
        easy_set = set(
            np.where(
                np.logical_and(max_overlaps >= 0,
                               max_overlaps < balance_thr))[0])
        hard_set = set(np.where(max_overlaps >= balance_thr)[0])
        easy_neg_inds = list(easy_set & neg_set)
        hard_neg_inds = list(hard_set & neg_set)

        num_expected_hard = int(num_expected * hard_fraction)
        if len(hard_neg_inds) > num_expected_hard:
            sampled_hard_inds = random_choice(hard_neg_inds, num_expected_hard)
        else:
            sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
        num_expected_easy = num_expected - len(sampled_hard_inds)
        if len(easy_neg_inds) > num_expected_easy:
            sampled_easy_inds = random_choice(easy_neg_inds, num_expected_easy)
        else:
            sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
        sampled_inds = np.concatenate((sampled_easy_inds, sampled_hard_inds))
        if len(sampled_inds) < num_expected:
            num_extra = num_expected - len(sampled_inds)
            extra_inds = np.array(list(neg_set - set(sampled_inds)))
            if len(extra_inds) > num_extra:
                extra_inds = random_choice(extra_inds, num_extra)
            sampled_inds = np.concatenate((sampled_inds, extra_inds))
        sampled_inds = torch.from_numpy(sampled_inds).long().to(
            assigned_gt_inds.device)
        return sampled_inds


def bbox_sampling(assigned_gt_inds,
                  num_expected,
                  pos_fraction,
                  neg_pos_ub,
                  pos_balance_sampling=True,
                  max_overlaps=None,
                  neg_balance_thr=0,
                  neg_hard_fraction=0.5):
    num_expected_pos = int(num_expected * pos_fraction)
    pos_inds = sample_positives(assigned_gt_inds, num_expected_pos,
                                pos_balance_sampling)
    num_sampled_pos = pos_inds.numel()
    num_neg_max = int(
        neg_pos_ub *
        num_sampled_pos) if num_sampled_pos > 0 else int(neg_pos_ub)
    num_expected_neg = min(num_neg_max, num_expected - num_sampled_pos)
    neg_inds = sample_negatives(assigned_gt_inds, num_expected_neg,
                                max_overlaps, neg_balance_thr,
                                neg_hard_fraction)
    return pos_inds, neg_inds



def sample_proposals(proposals_list, gt_bboxes_list, gt_crowds_list,
                     gt_labels_list, cfg):
    cfg_list = [cfg for _ in range(len(proposals_list))]
    results = map(sample_proposals_single, proposals_list, gt_bboxes_list,
                  gt_crowds_list, gt_labels_list, cfg_list)
    # list of tuple to tuple of list
    return tuple(map(list, zip(*results)))


def sample_proposals_single(proposals,
                            gt_bboxes,
                            gt_crowds,
                            gt_labels,
                            cfg):
    proposals = proposals[:, :4]
    assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
        bbox_assign(
            proposals, gt_bboxes, gt_crowds, gt_labels, cfg.pos_iou_thr,
            cfg.neg_iou_thr, cfg.pos_iou_thr, cfg.crowd_thr)
    if cfg.add_gt_as_proposals:
        proposals = torch.cat([gt_bboxes, proposals], dim=0)
        gt_assign_self = torch.arange(
            1, len(gt_labels) + 1, dtype=torch.long, device=proposals.device)
        assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
        assigned_labels = torch.cat([gt_labels, assigned_labels])

    pos_inds, neg_inds = bbox_sampling(
        assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction, cfg.neg_pos_ub,
        cfg.pos_balance_sampling, max_overlaps, cfg.neg_balance_thr)
    pos_proposals = proposals[pos_inds]
    neg_proposals = proposals[neg_inds]
    pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
    pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
    pos_gt_labels = assigned_labels[pos_inds]

    return (pos_inds, neg_inds, pos_proposals, neg_proposals,
            pos_assigned_gt_inds, pos_gt_bboxes, pos_gt_labels)