diff --git a/README.md b/README.md
index 2b666b2de461aba476be62d257b55e7c17291d63..0a77858939a8e1de0bac9d2656fa215cf38c629e 100644
--- a/README.md
+++ b/README.md
@@ -70,6 +70,7 @@ Results and models are available in the [Model zoo](docs/MODEL_ZOO.md).
 | Foveabox           | ✓        | ✓        | ☐        | ✗        | ✓     |
 | FreeAnchor         | ✓        | ✓        | ☐        | ✗        | ✓     |
 | NAS-FPN            | ✓        | ✓        | ☐        | ✗        | ✓     |
+| ATSS               | ✓        | ✓        | ☐        | ✗        | ✓     |
 
 Other features
 - [x] DCNv2
diff --git a/configs/atss/README.md b/configs/atss/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2457358ccce01bc3457f6d958b3da04456678211
--- /dev/null
+++ b/configs/atss/README.md
@@ -0,0 +1,20 @@
+# Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
+
+
+## Introduction
+
+```
+@article{zhang2019bridging,
+  title   =  {Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection},
+  author  =  {Zhang, Shifeng and Chi, Cheng and Yao, Yongqiang and Lei, Zhen and Li, Stan Z.},
+  journal =  {arXiv preprint arXiv:1912.02424},
+  year    =  {2019}
+}
+```
+
+
+## Results and Models
+
+| Backbone  | Style   | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
+|:---------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
+| R-50      | pytorch | 1x      | 3.6      | 0.357               | 12.8           |  39.2  | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/atss/atss_r50_fpn_1x_20200113-a7aa251e.pth)|
diff --git a/configs/atss/atss_r50_fpn_1x.py b/configs/atss/atss_r50_fpn_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..e19a39da967e588072692505048bf7cc9d9ff128
--- /dev/null
+++ b/configs/atss/atss_r50_fpn_1x.py
@@ -0,0 +1,127 @@
+# model settings
+model = dict(
+    type='ATSS',
+    pretrained='torchvision://resnet50',
+    backbone=dict(
+        type='ResNet',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=1,
+        style='pytorch'),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        start_level=1,
+        add_extra_convs=True,
+        extra_convs_on_inputs=False,
+        num_outs=5),
+    bbox_head=dict(
+        type='ATSSHead',
+        num_classes=81,
+        in_channels=256,
+        stacked_convs=4,
+        feat_channels=256,
+        octave_base_scale=8,
+        scales_per_octave=1,
+        anchor_ratios=[1.0],
+        anchor_strides=[8, 16, 32, 64, 128],
+        target_means=[.0, .0, .0, .0],
+        target_stds=[0.1, 0.1, 0.2, 0.2],
+        loss_cls=dict(
+            type='FocalLoss',
+            use_sigmoid=True,
+            gamma=2.0,
+            alpha=0.25,
+            loss_weight=1.0),
+        loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
+        loss_centerness=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))
+# training and testing settings
+train_cfg = dict(
+    assigner=dict(type='ATSSAssigner', topk=9),
+    allowed_border=-1,
+    pos_weight=-1,
+    debug=False)
+test_cfg = dict(
+    nms_pre=1000,
+    min_bbox_size=0,
+    score_thr=0.05,
+    nms=dict(type='nms', iou_thr=0.6),
+    max_per_img=100)
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', with_bbox=True),
+    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+    dict(type='RandomFlip', flip_ratio=0.5),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size_divisor=32),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(1333, 800),
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=32),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    imgs_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_train2017.json',
+        img_prefix=data_root + 'train2017/',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        ann_file=data_root + 'annotations/instances_val2017.json',
+        img_prefix=data_root + 'val2017/',
+        pipeline=test_pipeline))
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+    policy='step',
+    warmup='linear',
+    warmup_iters=500,
+    warmup_ratio=1.0 / 3,
+    step=[8, 11])
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook'),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+# runtime settings
+total_epochs = 12
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+work_dir = './work_dirs/atss_r50_fpn_1x'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md
index 6b93186b988de8e3c5424d9d8c0ad27217a1d941..954a4772a2114b8983d8a3fa656104a1f64b8bdb 100644
--- a/docs/MODEL_ZOO.md
+++ b/docs/MODEL_ZOO.md
@@ -284,6 +284,9 @@ Please refer to [Rethinking ImageNet Pre-training](https://github.com/open-mmlab
 ### NAS-FPN
 Please refer to [NAS-FPN](https://github.com/open-mmlab/mmdetection/blob/master/configs/nas_fpn) for details.
 
+### ATSS
+Please refer to [ATSS](https://github.com/open-mmlab/mmdetection/blob/master/configs/atss) for details.
+
 ### Other datasets
 
 We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).
diff --git a/mmdet/core/anchor/__init__.py b/mmdet/core/anchor/__init__.py
index dfeb3b407300beee4d530c111b4187693b190e2b..06e2d12323bbf9dc95ba84a6aaddfa1ac6d5f5bd 100644
--- a/mmdet/core/anchor/__init__.py
+++ b/mmdet/core/anchor/__init__.py
@@ -1,10 +1,12 @@
 from .anchor_generator import AnchorGenerator
-from .anchor_target import anchor_inside_flags, anchor_target
+from .anchor_target import (anchor_inside_flags, anchor_target,
+                            images_to_levels, unmap)
 from .guided_anchor_target import ga_loc_target, ga_shape_target
 from .point_generator import PointGenerator
 from .point_target import point_target
 
 __all__ = [
     'AnchorGenerator', 'anchor_target', 'anchor_inside_flags', 'ga_loc_target',
-    'ga_shape_target', 'PointGenerator', 'point_target'
+    'ga_shape_target', 'PointGenerator', 'point_target', 'images_to_levels',
+    'unmap'
 ]
diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py
index 93eebb775be7720f232f122050d5f753117f7731..4ed1d5643184e4c9b700bd9dd1ca39072f1e740a 100644
--- a/mmdet/core/bbox/assigners/__init__.py
+++ b/mmdet/core/bbox/assigners/__init__.py
@@ -1,10 +1,11 @@
 from .approx_max_iou_assigner import ApproxMaxIoUAssigner
 from .assign_result import AssignResult
+from .atss_assigner import ATSSAssigner
 from .base_assigner import BaseAssigner
 from .max_iou_assigner import MaxIoUAssigner
 from .point_assigner import PointAssigner
 
 __all__ = [
     'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
-    'PointAssigner'
+    'PointAssigner', 'ATSSAssigner'
 ]
diff --git a/mmdet/core/bbox/assigners/atss_assigner.py b/mmdet/core/bbox/assigners/atss_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..e442ac709d854a43883387ae88246776b5744743
--- /dev/null
+++ b/mmdet/core/bbox/assigners/atss_assigner.py
@@ -0,0 +1,159 @@
+import torch
+
+from ..geometry import bbox_overlaps
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+class ATSSAssigner(BaseAssigner):
+    """Assign a corresponding gt bbox or background to each bbox.
+
+    Each proposals will be assigned with `0` or a positive integer
+    indicating the ground truth index.
+
+    - 0: negative sample, no assigned gt
+    - positive integer: positive sample, index (1-based) of assigned gt
+
+    Args:
+        topk (float): number of bbox selected in each level
+    """
+
+    def __init__(self, topk):
+        self.topk = topk
+
+    # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
+
+    def assign(self,
+               bboxes,
+               num_level_bboxes,
+               gt_bboxes,
+               gt_bboxes_ignore=None,
+               gt_labels=None):
+        """Assign gt to bboxes.
+
+        The assignment is done in following steps
+
+        1. compute iou between all bbox (bbox of all pyramid levels) and gt
+        2. compute center distance between all bbox and gt
+        3. on each pyramid level, for each gt, select k bbox whose center
+           are closest to the gt center, so we total select k*l bbox as
+           candidates for each gt
+        4. get corresponding iou for the these candidates, and compute the
+           mean and std, set mean + std as the iou threshold
+        5. select these candidates whose iou are greater than or equal to
+           the threshold as postive
+        6. limit the positive sample's center in gt
+
+
+        Args:
+            bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+            num_level_bboxes (List): num of bboxes in each level
+            gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+                labelled as `ignored`, e.g., crowd boxes in COCO.
+            gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+        Returns:
+            :obj:`AssignResult`: The assign result.
+        """
+        INF = 100000000
+        bboxes = bboxes[:, :4]
+        num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
+
+        # compute iou between all bbox and gt
+        overlaps = bbox_overlaps(bboxes, gt_bboxes)
+
+        # assign 0 by default
+        assigned_gt_inds = overlaps.new_full((num_bboxes, ),
+                                             0,
+                                             dtype=torch.long)
+
+        if num_gt == 0 or num_bboxes == 0:
+            # No ground truth or boxes, return empty assignment
+            max_overlaps = overlaps.new_zeros((num_bboxes, ))
+            if num_gt == 0:
+                # No truth, assign everything to background
+                assigned_gt_inds[:] = 0
+            if gt_labels is None:
+                assigned_labels = None
+            else:
+                assigned_labels = overlaps.new_zeros((num_bboxes, ),
+                                                     dtype=torch.long)
+            return AssignResult(
+                num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+        # compute center distance between all bbox and gt
+        gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
+        gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
+        gt_points = torch.stack((gt_cx, gt_cy), dim=1)
+
+        bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
+        bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
+        bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1)
+
+        distances = (bboxes_points[:, None, :] -
+                     gt_points[None, :, :]).pow(2).sum(-1).sqrt()
+
+        # Selecting candidates based on the center distance
+        candidate_idxs = []
+        start_idx = 0
+        for level, bboxes_per_level in enumerate(num_level_bboxes):
+            # on each pyramid level, for each gt,
+            # select k bbox whose center are closest to the gt center
+            end_idx = start_idx + bboxes_per_level
+            distances_per_level = distances[start_idx:end_idx, :]
+            _, topk_idxs_per_level = distances_per_level.topk(
+                self.topk, dim=0, largest=False)
+            candidate_idxs.append(topk_idxs_per_level + start_idx)
+            start_idx = end_idx
+        candidate_idxs = torch.cat(candidate_idxs, dim=0)
+
+        # get corresponding iou for the these candidates, and compute the
+        # mean and std, set mean + std as the iou threshold
+        candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
+        overlaps_mean_per_gt = candidate_overlaps.mean(0)
+        overlaps_std_per_gt = candidate_overlaps.std(0)
+        overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
+
+        is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
+
+        # limit the positive sample's center in gt
+        for gt_idx in range(num_gt):
+            candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
+        ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
+            num_gt, num_bboxes).contiguous().view(-1)
+        ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
+            num_gt, num_bboxes).contiguous().view(-1)
+        candidate_idxs = candidate_idxs.view(-1)
+
+        # calculate the left, top, right, bottom distance between positive
+        # bbox center and gt side
+        l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
+        t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
+        r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
+        b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
+        is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
+        is_pos = is_pos & is_in_gts
+
+        # if an anchor box is assigned to multiple gts,
+        # the one with the highest IoU will be selected.
+        overlaps_inf = torch.full_like(overlaps,
+                                       -INF).t().contiguous().view(-1)
+        index = candidate_idxs.view(-1)[is_pos.view(-1)]
+        overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
+        overlaps_inf = overlaps_inf.view(num_gt, -1).t()
+
+        max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
+        assigned_gt_inds[
+            max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
+
+        if gt_labels is not None:
+            assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, ))
+            pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
+            if pos_inds.numel() > 0:
+                assigned_labels[pos_inds] = gt_labels[
+                    assigned_gt_inds[pos_inds] - 1]
+        else:
+            assigned_labels = None
+        return AssignResult(
+            num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
diff --git a/mmdet/models/anchor_heads/__init__.py b/mmdet/models/anchor_heads/__init__.py
index c693c0fd209cf34eaf01a6ce5b28e40d883b6ea6..3364c4c9113a7f9f816bd6a56f515d186a9821ce 100644
--- a/mmdet/models/anchor_heads/__init__.py
+++ b/mmdet/models/anchor_heads/__init__.py
@@ -1,4 +1,5 @@
 from .anchor_head import AnchorHead
+from .atss_head import ATSSHead
 from .fcos_head import FCOSHead
 from .fovea_head import FoveaHead
 from .free_anchor_retina_head import FreeAnchorRetinaHead
@@ -14,5 +15,6 @@ from .ssd_head import SSDHead
 __all__ = [
     'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
     'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
-    'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead'
+    'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
+    'ATSSHead'
 ]
diff --git a/mmdet/models/anchor_heads/atss_head.py b/mmdet/models/anchor_heads/atss_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0f2e0abc2703beef4a33d12d53b78393cb7bd5c
--- /dev/null
+++ b/mmdet/models/anchor_heads/atss_head.py
@@ -0,0 +1,487 @@
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from mmcv.cnn import normal_init
+
+from mmdet.core import (PseudoSampler, anchor_inside_flags, bbox2delta,
+                        build_assigner, delta2bbox, force_fp32,
+                        images_to_levels, multi_apply, multiclass_nms, unmap)
+from ..builder import build_loss
+from ..registry import HEADS
+from ..utils import ConvModule, Scale, bias_init_with_prob
+from .anchor_head import AnchorHead
+
+
+def reduce_mean(tensor):
+    if not (dist.is_available() and dist.is_initialized()):
+        return tensor
+    tensor = tensor.clone()
+    dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.reduce_op.SUM)
+    return tensor
+
+
+@HEADS.register_module
+class ATSSHead(AnchorHead):
+    """
+    Bridging the Gap Between Anchor-based and Anchor-free Detection via
+    Adaptive Training Sample Selection
+
+    ATSS head structure is similar with FCOS, however ATSS use anchor boxes
+    and assign label by Adaptive Training Sample Selection instead max-iou.
+
+    https://arxiv.org/abs/1912.02424
+    """
+
+    def __init__(self,
+                 num_classes,
+                 in_channels,
+                 stacked_convs=4,
+                 octave_base_scale=4,
+                 scales_per_octave=1,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+                 loss_centerness=dict(
+                     type='CrossEntropyLoss',
+                     use_sigmoid=True,
+                     loss_weight=1.0),
+                 **kwargs):
+        self.stacked_convs = stacked_convs
+        self.octave_base_scale = octave_base_scale
+        self.scales_per_octave = scales_per_octave
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+
+        octave_scales = np.array(
+            [2**(i / scales_per_octave) for i in range(scales_per_octave)])
+        anchor_scales = octave_scales * octave_base_scale
+        super(ATSSHead, self).__init__(
+            num_classes, in_channels, anchor_scales=anchor_scales, **kwargs)
+
+        self.loss_centerness = build_loss(loss_centerness)
+
+    def _init_layers(self):
+        self.relu = nn.ReLU(inplace=True)
+        self.cls_convs = nn.ModuleList()
+        self.reg_convs = nn.ModuleList()
+        for i in range(self.stacked_convs):
+            chn = self.in_channels if i == 0 else self.feat_channels
+            self.cls_convs.append(
+                ConvModule(
+                    chn,
+                    self.feat_channels,
+                    3,
+                    stride=1,
+                    padding=1,
+                    conv_cfg=self.conv_cfg,
+                    norm_cfg=self.norm_cfg))
+            self.reg_convs.append(
+                ConvModule(
+                    chn,
+                    self.feat_channels,
+                    3,
+                    stride=1,
+                    padding=1,
+                    conv_cfg=self.conv_cfg,
+                    norm_cfg=self.norm_cfg))
+        self.atss_cls = nn.Conv2d(
+            self.feat_channels,
+            self.num_anchors * self.cls_out_channels,
+            3,
+            padding=1)
+        self.atss_reg = nn.Conv2d(
+            self.feat_channels, self.num_anchors * 4, 3, padding=1)
+        self.atss_centerness = nn.Conv2d(
+            self.feat_channels, self.num_anchors * 1, 3, padding=1)
+        self.scales = nn.ModuleList([Scale(1.0) for _ in self.anchor_strides])
+
+    def init_weights(self):
+        for m in self.cls_convs:
+            normal_init(m.conv, std=0.01)
+        for m in self.reg_convs:
+            normal_init(m.conv, std=0.01)
+        bias_cls = bias_init_with_prob(0.01)
+        normal_init(self.atss_cls, std=0.01, bias=bias_cls)
+        normal_init(self.atss_reg, std=0.01)
+        normal_init(self.atss_centerness, std=0.01)
+
+    def forward(self, feats):
+        return multi_apply(self.forward_single, feats, self.scales)
+
+    def forward_single(self, x, scale):
+        cls_feat = x
+        reg_feat = x
+        for cls_conv in self.cls_convs:
+            cls_feat = cls_conv(cls_feat)
+        for reg_conv in self.reg_convs:
+            reg_feat = reg_conv(reg_feat)
+        cls_score = self.atss_cls(cls_feat)
+        # we just follow atss, not apply exp in bbox_pred
+        bbox_pred = scale(self.atss_reg(reg_feat)).float()
+        centerness = self.atss_centerness(reg_feat)
+        return cls_score, bbox_pred, centerness
+
+    def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels,
+                    label_weights, bbox_targets, num_total_samples, cfg):
+
+        anchors = anchors.reshape(-1, 4)
+        cls_score = cls_score.permute(0, 2, 3,
+                                      1).reshape(-1, self.cls_out_channels)
+        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+        centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
+        bbox_targets = bbox_targets.reshape(-1, 4)
+        labels = labels.reshape(-1)
+        label_weights = label_weights.reshape(-1)
+
+        # classification loss
+        loss_cls = self.loss_cls(
+            cls_score, labels, label_weights, avg_factor=num_total_samples)
+
+        pos_inds = torch.nonzero(labels).squeeze(1)
+
+        if len(pos_inds) > 0:
+            pos_bbox_targets = bbox_targets[pos_inds]
+            pos_bbox_pred = bbox_pred[pos_inds]
+            pos_anchors = anchors[pos_inds]
+            pos_centerness = centerness[pos_inds]
+
+            centerness_targets = self.centerness_target(
+                pos_anchors, pos_bbox_targets)
+            pos_decode_bbox_pred = delta2bbox(pos_anchors, pos_bbox_pred,
+                                              self.target_means,
+                                              self.target_stds)
+            pos_decode_bbox_targets = delta2bbox(pos_anchors, pos_bbox_targets,
+                                                 self.target_means,
+                                                 self.target_stds)
+
+            # regression loss
+            loss_bbox = self.loss_bbox(
+                pos_decode_bbox_pred,
+                pos_decode_bbox_targets,
+                weight=centerness_targets,
+                avg_factor=1.0)
+
+            # centerness loss
+            loss_centerness = self.loss_centerness(
+                pos_centerness,
+                centerness_targets,
+                avg_factor=num_total_samples)
+
+        else:
+            loss_bbox = loss_cls * 0
+            loss_centerness = loss_bbox * 0
+            centerness_targets = torch.tensor(0).cuda()
+
+        return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
+
+    @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
+    def loss(self,
+             cls_scores,
+             bbox_preds,
+             centernesses,
+             gt_bboxes,
+             gt_labels,
+             img_metas,
+             cfg,
+             gt_bboxes_ignore=None):
+
+        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+        assert len(featmap_sizes) == len(self.anchor_generators)
+
+        device = cls_scores[0].device
+        anchor_list, valid_flag_list = self.get_anchors(
+            featmap_sizes, img_metas, device=device)
+        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+        cls_reg_targets = self.atss_target(
+            anchor_list,
+            valid_flag_list,
+            gt_bboxes,
+            img_metas,
+            cfg,
+            gt_bboxes_ignore_list=gt_bboxes_ignore,
+            gt_labels_list=gt_labels,
+            label_channels=label_channels)
+        if cls_reg_targets is None:
+            return None
+
+        (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+         bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
+
+        num_total_samples = reduce_mean(
+            torch.tensor(num_total_pos).cuda()).item()
+        num_total_samples = max(num_total_samples, 1.0)
+
+        losses_cls, losses_bbox, loss_centerness,\
+            bbox_avg_factor = multi_apply(
+                self.loss_single,
+                anchor_list,
+                cls_scores,
+                bbox_preds,
+                centernesses,
+                labels_list,
+                label_weights_list,
+                bbox_targets_list,
+                num_total_samples=num_total_samples,
+                cfg=cfg)
+
+        bbox_avg_factor = sum(bbox_avg_factor)
+        bbox_avg_factor = reduce_mean(bbox_avg_factor).item()
+        losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
+        return dict(
+            loss_cls=losses_cls,
+            loss_bbox=losses_bbox,
+            loss_centerness=loss_centerness)
+
+    def centerness_target(self, anchors, bbox_targets):
+        # only calculate pos centerness targets, otherwise there may be nan
+        gts = delta2bbox(anchors, bbox_targets, self.target_means,
+                         self.target_stds)
+        anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
+        anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
+        l_ = anchors_cx - gts[:, 0]
+        t_ = anchors_cy - gts[:, 1]
+        r_ = gts[:, 2] - anchors_cx
+        b_ = gts[:, 3] - anchors_cy
+
+        left_right = torch.stack([l_, r_], dim=1)
+        top_bottom = torch.stack([t_, b_], dim=1)
+        centerness = torch.sqrt(
+            (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
+            (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
+        assert not torch.isnan(centerness).any()
+        return centerness
+
+    @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
+    def get_bboxes(self,
+                   cls_scores,
+                   bbox_preds,
+                   centernesses,
+                   img_metas,
+                   cfg,
+                   rescale=False):
+
+        assert len(cls_scores) == len(bbox_preds)
+        num_levels = len(cls_scores)
+        device = cls_scores[0].device
+        mlvl_anchors = [
+            self.anchor_generators[i].grid_anchors(
+                cls_scores[i].size()[-2:],
+                self.anchor_strides[i],
+                device=device) for i in range(num_levels)
+        ]
+
+        result_list = []
+        for img_id in range(len(img_metas)):
+            cls_score_list = [
+                cls_scores[i][img_id].detach() for i in range(num_levels)
+            ]
+            bbox_pred_list = [
+                bbox_preds[i][img_id].detach() for i in range(num_levels)
+            ]
+            centerness_pred_list = [
+                centernesses[i][img_id].detach() for i in range(num_levels)
+            ]
+            img_shape = img_metas[img_id]['img_shape']
+            scale_factor = img_metas[img_id]['scale_factor']
+            proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
+                                               centerness_pred_list,
+                                               mlvl_anchors, img_shape,
+                                               scale_factor, cfg, rescale)
+            result_list.append(proposals)
+        return result_list
+
+    def get_bboxes_single(self,
+                          cls_scores,
+                          bbox_preds,
+                          centernesses,
+                          mlvl_anchors,
+                          img_shape,
+                          scale_factor,
+                          cfg,
+                          rescale=False):
+        assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
+        mlvl_bboxes = []
+        mlvl_scores = []
+        mlvl_centerness = []
+        for cls_score, bbox_pred, centerness, anchors in zip(
+                cls_scores, bbox_preds, centernesses, mlvl_anchors):
+            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+
+            scores = cls_score.permute(1, 2, 0).reshape(
+                -1, self.cls_out_channels).sigmoid()
+            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()
+
+            nms_pre = cfg.get('nms_pre', -1)
+            if nms_pre > 0 and scores.shape[0] > nms_pre:
+                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
+                _, topk_inds = max_scores.topk(nms_pre)
+                anchors = anchors[topk_inds, :]
+                bbox_pred = bbox_pred[topk_inds, :]
+                scores = scores[topk_inds, :]
+                centerness = centerness[topk_inds]
+
+            bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
+                                self.target_stds, img_shape)
+            mlvl_bboxes.append(bboxes)
+            mlvl_scores.append(scores)
+            mlvl_centerness.append(centerness)
+
+        mlvl_bboxes = torch.cat(mlvl_bboxes)
+        if rescale:
+            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+
+        mlvl_scores = torch.cat(mlvl_scores)
+        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
+        mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
+        mlvl_centerness = torch.cat(mlvl_centerness)
+
+        det_bboxes, det_labels = multiclass_nms(
+            mlvl_bboxes,
+            mlvl_scores,
+            cfg.score_thr,
+            cfg.nms,
+            cfg.max_per_img,
+            score_factors=mlvl_centerness)
+        return det_bboxes, det_labels
+
+    def atss_target(self,
+                    anchor_list,
+                    valid_flag_list,
+                    gt_bboxes_list,
+                    img_metas,
+                    cfg,
+                    gt_bboxes_ignore_list=None,
+                    gt_labels_list=None,
+                    label_channels=1,
+                    unmap_outputs=True):
+        """
+        almost the same with anchor_target, with a little modification,
+        here we need return the anchor
+        """
+        num_imgs = len(img_metas)
+        assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+        # anchor number of multi levels
+        num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+        num_level_anchors_list = [num_level_anchors] * num_imgs
+
+        # concat all level anchors and flags to a single tensor
+        for i in range(num_imgs):
+            assert len(anchor_list[i]) == len(valid_flag_list[i])
+            anchor_list[i] = torch.cat(anchor_list[i])
+            valid_flag_list[i] = torch.cat(valid_flag_list[i])
+
+        # compute targets for each image
+        if gt_bboxes_ignore_list is None:
+            gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+        if gt_labels_list is None:
+            gt_labels_list = [None for _ in range(num_imgs)]
+        (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+         all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+             self.atss_target_single,
+             anchor_list,
+             valid_flag_list,
+             num_level_anchors_list,
+             gt_bboxes_list,
+             gt_bboxes_ignore_list,
+             gt_labels_list,
+             img_metas,
+             cfg=cfg,
+             label_channels=label_channels,
+             unmap_outputs=unmap_outputs)
+        # no valid anchors
+        if any([labels is None for labels in all_labels]):
+            return None
+        # sampled anchors of all images
+        num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+        num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+        # split targets to a list w.r.t. multiple levels
+        anchors_list = images_to_levels(all_anchors, num_level_anchors)
+        labels_list = images_to_levels(all_labels, num_level_anchors)
+        label_weights_list = images_to_levels(all_label_weights,
+                                              num_level_anchors)
+        bbox_targets_list = images_to_levels(all_bbox_targets,
+                                             num_level_anchors)
+        bbox_weights_list = images_to_levels(all_bbox_weights,
+                                             num_level_anchors)
+        return (anchors_list, labels_list, label_weights_list,
+                bbox_targets_list, bbox_weights_list, num_total_pos,
+                num_total_neg)
+
+    def atss_target_single(self,
+                           flat_anchors,
+                           valid_flags,
+                           num_level_anchors,
+                           gt_bboxes,
+                           gt_bboxes_ignore,
+                           gt_labels,
+                           img_meta,
+                           cfg,
+                           label_channels=1,
+                           unmap_outputs=True):
+        inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+                                           img_meta['img_shape'][:2],
+                                           cfg.allowed_border)
+        if not inside_flags.any():
+            return (None, ) * 6
+        # assign gt and sample anchors
+        anchors = flat_anchors[inside_flags, :]
+
+        num_level_anchors_inside = self.get_num_level_anchors_inside(
+            num_level_anchors, inside_flags)
+        bbox_assigner = build_assigner(cfg.assigner)
+        assign_result = bbox_assigner.assign(anchors, num_level_anchors_inside,
+                                             gt_bboxes, gt_bboxes_ignore,
+                                             gt_labels)
+
+        bbox_sampler = PseudoSampler()
+        sampling_result = bbox_sampler.sample(assign_result, anchors,
+                                              gt_bboxes)
+
+        num_valid_anchors = anchors.shape[0]
+        bbox_targets = torch.zeros_like(anchors)
+        bbox_weights = torch.zeros_like(anchors)
+        labels = anchors.new_zeros(num_valid_anchors, dtype=torch.long)
+        label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+        pos_inds = sampling_result.pos_inds
+        neg_inds = sampling_result.neg_inds
+        if len(pos_inds) > 0:
+            pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes,
+                                          sampling_result.pos_gt_bboxes,
+                                          self.target_means, self.target_stds)
+            bbox_targets[pos_inds, :] = pos_bbox_targets
+            bbox_weights[pos_inds, :] = 1.0
+            if gt_labels is None:
+                labels[pos_inds] = 1
+            else:
+                labels[pos_inds] = gt_labels[
+                    sampling_result.pos_assigned_gt_inds]
+            if cfg.pos_weight <= 0:
+                label_weights[pos_inds] = 1.0
+            else:
+                label_weights[pos_inds] = cfg.pos_weight
+        if len(neg_inds) > 0:
+            label_weights[neg_inds] = 1.0
+
+        # map up to original set of anchors
+        if unmap_outputs:
+            num_total_anchors = flat_anchors.size(0)
+            anchors = unmap(anchors, num_total_anchors, inside_flags)
+            labels = unmap(labels, num_total_anchors, inside_flags)
+            label_weights = unmap(label_weights, num_total_anchors,
+                                  inside_flags)
+            bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+            bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+        return (anchors, labels, label_weights, bbox_targets, bbox_weights,
+                pos_inds, neg_inds)
+
+    def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
+        split_inside_flags = torch.split(inside_flags, num_level_anchors)
+        num_level_anchors_inside = [
+            int(flags.sum()) for flags in split_inside_flags
+        ]
+        return num_level_anchors_inside
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index f9f83c621def00b84052ff41266180cfa8dec788..70a53d876952bce9a0916785c43e5c5dcd7c73ba 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -1,3 +1,4 @@
+from .atss import ATSS
 from .base import BaseDetector
 from .cascade_rcnn import CascadeRCNN
 from .double_head_rcnn import DoubleHeadRCNN
@@ -16,7 +17,7 @@ from .single_stage import SingleStageDetector
 from .two_stage import TwoStageDetector
 
 __all__ = [
-    'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
+    'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
     'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
     'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN',
     'RepPointsDetector', 'FOVEA'
diff --git a/mmdet/models/detectors/atss.py b/mmdet/models/detectors/atss.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac22bf928673334c52a67e653faadb63b8df9ea8
--- /dev/null
+++ b/mmdet/models/detectors/atss.py
@@ -0,0 +1,16 @@
+from ..registry import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module
+class ATSS(SingleStageDetector):
+
+    def __init__(self,
+                 backbone,
+                 neck,
+                 bbox_head,
+                 train_cfg=None,
+                 test_cfg=None,
+                 pretrained=None):
+        super(ATSS, self).__init__(backbone, neck, bbox_head, train_cfg,
+                                   test_cfg, pretrained)
diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py
index 992741ae4b8cde5897782ecb4b420c6971ba8b8b..07731d71053c92bf8542180f20ca613d45acbed6 100644
--- a/mmdet/models/losses/__init__.py
+++ b/mmdet/models/losses/__init__.py
@@ -4,7 +4,8 @@ from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
                                  cross_entropy, mask_cross_entropy)
 from .focal_loss import FocalLoss, sigmoid_focal_loss
 from .ghm_loss import GHMC, GHMR
-from .iou_loss import BoundedIoULoss, IoULoss, bounded_iou_loss, iou_loss
+from .iou_loss import (BoundedIoULoss, GIoULoss, IoULoss, bounded_iou_loss,
+                       iou_loss)
 from .mse_loss import MSELoss, mse_loss
 from .smooth_l1_loss import SmoothL1Loss, smooth_l1_loss
 from .utils import reduce_loss, weight_reduce_loss, weighted_loss
@@ -14,6 +15,6 @@ __all__ = [
     'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss',
     'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss',
     'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss',
-    'IoULoss', 'BoundedIoULoss', 'GHMC', 'GHMR', 'reduce_loss',
+    'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'GHMC', 'GHMR', 'reduce_loss',
     'weight_reduce_loss', 'weighted_loss'
 ]
diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py
index cf2499463d94a557e9ca85844c1717f186be3833..c19c1d1d6af86926c31c9a28e974ebfc046c2b33 100644
--- a/mmdet/models/losses/iou_loss.py
+++ b/mmdet/models/losses/iou_loss.py
@@ -69,6 +69,51 @@ def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3):
     return loss
 
 
+@weighted_loss
+def giou_loss(pred, target, eps=1e-7):
+    """
+    Generalized Intersection over Union: A Metric and A Loss for
+    Bounding Box Regression
+    https://arxiv.org/abs/1902.09630
+
+    code refer to:
+    https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py#L36
+
+    Args:
+        pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+            shape (n, 4).
+        target (Tensor): Corresponding gt bboxes, shape (n, 4).
+        eps (float): Eps to avoid log(0).
+
+    Return:
+        Tensor: Loss tensor.
+    """
+    # overlap
+    lt = torch.max(pred[:, :2], target[:, :2])
+    rb = torch.min(pred[:, 2:], target[:, 2:])
+    wh = (rb - lt + 1).clamp(min=0)
+    overlap = wh[:, 0] * wh[:, 1]
+
+    # union
+    ap = (pred[:, 2] - pred[:, 0] + 1) * (pred[:, 3] - pred[:, 1] + 1)
+    ag = (target[:, 2] - target[:, 0] + 1) * (target[:, 3] - target[:, 1] + 1)
+    union = ap + ag - overlap + eps
+
+    # IoU
+    ious = overlap / union
+
+    # enclose area
+    enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
+    enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
+    enclose_wh = (enclose_x2y2 - enclose_x1y1 + 1).clamp(min=0)
+    enclose_area = enclose_wh[:, 0] * enclose_wh[:, 1] + eps
+
+    # GIoU
+    gious = ious - (enclose_area - union) / enclose_area
+    loss = 1 - gious
+    return loss
+
+
 @LOSSES.register_module
 class IoULoss(nn.Module):
 
@@ -133,3 +178,35 @@ class BoundedIoULoss(nn.Module):
             avg_factor=avg_factor,
             **kwargs)
         return loss
+
+
+@LOSSES.register_module
+class GIoULoss(nn.Module):
+
+    def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
+        super(GIoULoss, self).__init__()
+        self.eps = eps
+        self.reduction = reduction
+        self.loss_weight = loss_weight
+
+    def forward(self,
+                pred,
+                target,
+                weight=None,
+                avg_factor=None,
+                reduction_override=None,
+                **kwargs):
+        if weight is not None and not torch.any(weight > 0):
+            return (pred * weight).sum()  # 0
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
+        loss = self.loss_weight * giou_loss(
+            pred,
+            target,
+            weight,
+            eps=self.eps,
+            reduction=reduction,
+            avg_factor=avg_factor,
+            **kwargs)
+        return loss