diff --git a/mmdet/apis/env.py b/mmdet/apis/env.py
index 20cd26dee8fbc258ffd4c50fef6e8468bf4ba094..57348da6681e02b122cadc977dd36e0d63223fd6 100644
--- a/mmdet/apis/env.py
+++ b/mmdet/apis/env.py
@@ -35,7 +35,33 @@ def _init_dist_mpi(backend, **kwargs):
 
 
 def _init_dist_slurm(backend, **kwargs):
-    raise NotImplementedError
+    proc_id = int(os.environ['SLURM_PROCID'])
+    ntasks = int(os.environ['SLURM_NTASKS'])
+    node_list = os.environ['SLURM_NODELIST']
+    num_gpus = torch.cuda.device_count()
+    torch.cuda.set_device(proc_id % num_gpus)
+    if '[' in node_list:
+        beg = node_list.find('[')
+        pos1 = node_list.find('-', beg)
+        if pos1 < 0:
+            pos1 = 1000
+        pos2 = node_list.find(',', beg)
+        if pos2 < 0:
+            pos2 = 1000
+        node_list = node_list[:min(pos1, pos2)].replace('[', '')
+    addr = node_list[8:].replace('-', '.')
+    os.environ['MASTER_PORT'] = str(kwargs['port'])
+    os.environ['MASTER_ADDR'] = addr
+    os.environ['WORLD_SIZE'] = str(ntasks)
+    os.environ['RANK'] = str(proc_id)
+    if backend == 'nccl':
+        dist.init_process_group(backend='nccl')
+    else:
+        dist.init_process_group(
+            backend='gloo', rank=proc_id, world_size=ntasks)
+    rank = dist.get_rank()
+    world_size = dist.get_world_size()
+    return rank, world_size
 
 
 def set_random_seed(seed):
diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py
index 0f82f92aad10ed86b6528f0554615d7e9589ce1c..b109ae7d79a9ab8e25013d8a576b3815560f2d5e 100644
--- a/mmdet/models/backbones/__init__.py
+++ b/mmdet/models/backbones/__init__.py
@@ -1,3 +1,4 @@
 from .resnet import ResNet
+from .resnext import ResNeXt
 
-__all__ = ['ResNet']
+__all__ = ['ResNet', 'ResNeXt']
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index 66684b154b5aea3364789495b43c8b31ab97745b..c1b400ea2e39346167e5df393a53ad646d4e1ba4 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -219,9 +219,13 @@ class ResNet(nn.Module):
         super(ResNet, self).__init__()
         if depth not in self.arch_settings:
             raise KeyError('invalid depth {} for resnet'.format(depth))
+        self.depth = depth,
+        self.num_stages = num_stages,
+        self.strides = strides,
+        self.dilations = dilations,
         assert num_stages >= 1 and num_stages <= 4
-        block, stage_blocks = self.arch_settings[depth]
-        stage_blocks = stage_blocks[:num_stages]
+        self.block, self.stage_blocks = self.arch_settings[depth]
+        self.stage_blocks = self.stage_blocks[:num_stages]
         assert len(strides) == len(dilations) == num_stages
         assert max(out_indices) < num_stages
 
@@ -240,12 +244,12 @@ class ResNet(nn.Module):
         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
 
         self.res_layers = []
-        for i, num_blocks in enumerate(stage_blocks):
+        for i, num_blocks in enumerate(self.stage_blocks):
             stride = strides[i]
             dilation = dilations[i]
             planes = 64 * 2**i
             res_layer = make_res_layer(
-                block,
+                self.block,
                 self.inplanes,
                 planes,
                 num_blocks,
@@ -253,12 +257,13 @@ class ResNet(nn.Module):
                 dilation=dilation,
                 style=self.style,
                 with_cp=with_cp)
-            self.inplanes = planes * block.expansion
+            self.inplanes = planes * self.block.expansion
             layer_name = 'layer{}'.format(i + 1)
             self.add_module(layer_name, res_layer)
             self.res_layers.append(layer_name)
 
-        self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
+        self.feat_dim = self.block.expansion * 64 * 2**(
+            len(self.stage_blocks) - 1)
 
     def init_weights(self, pretrained=None):
         if isinstance(pretrained, str):
diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..9de95ef3321714ddf660b1cbe3f5f2477a73f9d1
--- /dev/null
+++ b/mmdet/models/backbones/resnext.py
@@ -0,0 +1,205 @@
+import math
+import logging
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+
+from mmcv.cnn import constant_init, kaiming_init
+from mmcv.runner import load_checkpoint
+
+from .resnet import ResNet
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 stride=1,
+                 dilation=1,
+                 downsample=None,
+                 groups=1,
+                 base_width=4,
+                 style='pytorch',
+                 with_cp=False):
+        """Bottleneck block.
+        If style is "pytorch", the stride-two layer is the 3x3 conv layer,
+        if it is "caffe", the stride-two layer is the first 1x1 conv layer.
+        """
+        super(Bottleneck, self).__init__()
+        assert style in ['pytorch', 'caffe']
+
+        width = planes if groups == 1 else math.floor(
+            planes * (base_width / 64)) * groups
+
+        if style == 'pytorch':
+            conv1_stride = 1
+            conv2_stride = stride
+        else:
+            conv1_stride = stride
+            conv2_stride = 1
+        self.conv1 = nn.Conv2d(
+            inplanes, width, kernel_size=1, stride=conv1_stride, bias=False)
+        self.bn1 = nn.BatchNorm2d(width)
+        self.conv2 = nn.Conv2d(
+            width,
+            width,
+            kernel_size=3,
+            stride=conv2_stride,
+            padding=dilation,
+            dilation=dilation,
+            groups=groups,
+            bias=False)
+        self.bn2 = nn.BatchNorm2d(width)
+        self.conv3 = nn.Conv2d(
+            width, planes * self.expansion, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+        self.dilation = dilation
+        self.with_cp = with_cp
+
+    def forward(self, x):
+
+        def _inner_forward(x):
+            residual = x
+
+            out = self.conv1(x)
+            out = self.bn1(out)
+            out = self.relu(out)
+
+            out = self.conv2(out)
+            out = self.bn2(out)
+            out = self.relu(out)
+
+            out = self.conv3(out)
+            out = self.bn3(out)
+
+            if self.downsample is not None:
+                residual = self.downsample(x)
+
+            out += residual
+
+            return out
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(_inner_forward, x)
+        else:
+            out = _inner_forward(x)
+
+        out = self.relu(out)
+
+        return out
+
+
+def make_res_layer(block,
+                   inplanes,
+                   planes,
+                   blocks,
+                   stride=1,
+                   dilation=1,
+                   groups=1,
+                   base_width=4,
+                   style='pytorch',
+                   with_cp=False):
+    downsample = None
+    if stride != 1 or inplanes != planes * block.expansion:
+        downsample = nn.Sequential(
+            nn.Conv2d(
+                inplanes,
+                planes * block.expansion,
+                kernel_size=1,
+                stride=stride,
+                bias=False),
+            nn.BatchNorm2d(planes * block.expansion),
+        )
+
+    layers = []
+    layers.append(
+        block(
+            inplanes,
+            planes,
+            stride,
+            dilation,
+            downsample,
+            groups=groups,
+            base_width=base_width,
+            style=style,
+            with_cp=with_cp))
+    inplanes = planes * block.expansion
+    for i in range(1, blocks):
+        layers.append(
+            block(
+                inplanes,
+                planes,
+                1,
+                dilation,
+                groups=groups,
+                base_width=base_width,
+                style=style,
+                with_cp=with_cp))
+
+    return nn.Sequential(*layers)
+
+
+class ResNeXt(ResNet):
+    """ResNeXt backbone.
+
+    Args:
+        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+        num_stages (int): Resnet stages, normally 4.
+        groups (int): Group of resnext.
+        base_width (int): Base width of resnext.
+        strides (Sequence[int]): Strides of the first block of each stage.
+        dilations (Sequence[int]): Dilation of each stage.
+        out_indices (Sequence[int]): Output from which stages.
+        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+            layer is the 3x3 conv layer, otherwise the stride-two layer is
+            the first 1x1 conv layer.
+        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+            not freezing any parameters.
+        bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
+            running stats (mean and var).
+        bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed.
+    """
+
+    arch_settings = {
+        50: (Bottleneck, (3, 4, 6, 3)),
+        101: (Bottleneck, (3, 4, 23, 3)),
+        152: (Bottleneck, (3, 8, 36, 3))
+    }
+
+    def __init__(self,
+                 groups=1,
+                 base_width=4,
+                 *args,
+                 **kwargs):
+        super(ResNeXt, self).__init__(*args, **kwargs)
+        self.groups = groups
+        self.base_width = base_width
+
+        self.inplanes = 64
+        self.res_layers = []
+        for i, num_blocks in enumerate(self.stage_blocks):
+            stride = self.strides[0][i]
+            dilation = self.dilations[0][i]
+            planes = 64 * 2**i
+            res_layer = make_res_layer(
+                self.block,
+                self.inplanes,
+                planes,
+                num_blocks,
+                stride=stride,
+                dilation=dilation,
+                groups=self.groups,
+                base_width=self.base_width,
+                style=self.style,
+                with_cp=self.with_cp)
+            self.inplanes = planes * self.block.expansion
+            layer_name = 'layer{}'.format(i + 1)
+            self.add_module(layer_name, res_layer)
+            self.res_layers.append(layer_name)
diff --git a/tools/train.py b/tools/train.py
index 8e03628db5ea28d027ccdc3939c72bace482be93..ee52aea1083bebb739b86e60c1e75f195a3d50d9 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -1,4 +1,7 @@
 from __future__ import division
+import sys
+sys.path.insert(0, '/mnt/lustre/pangjiangmiao/codebase/mmcv')
+sys.path.insert(0, '/mnt/lustre/pangjiangmiao/codebase/mmdet')
 
 import argparse
 from mmcv import Config
@@ -14,6 +17,7 @@ def parse_args():
     parser = argparse.ArgumentParser(description='Train a detector')
     parser.add_argument('config', help='train config file path')
     parser.add_argument('--work_dir', help='the dir to save logs and models')
+    parser.add_argument('--resume_from', help='the checkpoint to resume from')
     parser.add_argument(
         '--validate',
         action='store_true',
@@ -43,6 +47,8 @@ def main():
     # update configs according to CLI args
     if args.work_dir is not None:
         cfg.work_dir = args.work_dir
+    if args.resume_from is not None:
+        cfg.resume_from = args.resume_from
     cfg.gpus = args.gpus
     if cfg.checkpoint_config is not None:
         # save mmdet version in checkpoints as meta data
@@ -67,6 +73,13 @@ def main():
 
     model = build_detector(
         cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
+
+    import torch.distributed as dist
+    if dist.get_rank() == 0:
+        with open('/mnt/lustre/pangjiangmiao/r50_32x4d_mmdet.txt', 'w') as f:
+            for k in model.state_dict().keys():
+                if 'num_batches_tracked' in k: continue
+                f.writelines('{}\n'.format(k))
     train_dataset = obj_from_dict(cfg.data.train, datasets)
     train_detector(
         model,