diff --git a/mmdet/core/__init__.py b/mmdet/core/__init__.py index 20e124bd21f6c99e660a08e6821cfcdbc7dcfb5a..05788f1eb9131b0b514c6aeaecd8131f85dc31e4 100644 --- a/mmdet/core/__init__.py +++ b/mmdet/core/__init__.py @@ -1,6 +1,6 @@ from .anchor import * # noqa: F401, F403 -from .bbox_ops import * # noqa: F401, F403 -from .mask_ops import * # noqa: F401, F403 +from .bbox import * # noqa: F401, F403 +from .mask import * # noqa: F401, F403 from .losses import * # noqa: F401, F403 from .eval import * # noqa: F401, F403 from .parallel import * # noqa: F401, F403 diff --git a/mmdet/core/anchor/anchor_target.py b/mmdet/core/anchor/anchor_target.py index f449507499e9f35092a2f7430c70f5c2fe0ff24c..ad81e390e6dcb2a064862818a34ea99adbe462e0 100644 --- a/mmdet/core/anchor/anchor_target.py +++ b/mmdet/core/anchor/anchor_target.py @@ -1,6 +1,6 @@ import torch -from ..bbox_ops import bbox_assign, bbox2delta, bbox_sampling +from ..bbox import bbox_assign, bbox2delta, bbox_sampling from ..utils import multi_apply diff --git a/mmdet/core/bbox_ops/__init__.py b/mmdet/core/bbox/__init__.py similarity index 100% rename from mmdet/core/bbox_ops/__init__.py rename to mmdet/core/bbox/__init__.py diff --git a/mmdet/core/bbox_ops/bbox_target.py b/mmdet/core/bbox/bbox_target.py similarity index 100% rename from mmdet/core/bbox_ops/bbox_target.py rename to mmdet/core/bbox/bbox_target.py diff --git a/mmdet/core/bbox_ops/geometry.py b/mmdet/core/bbox/geometry.py similarity index 100% rename from mmdet/core/bbox_ops/geometry.py rename to mmdet/core/bbox/geometry.py diff --git a/mmdet/core/bbox_ops/sampling.py b/mmdet/core/bbox/sampling.py similarity index 98% rename from mmdet/core/bbox_ops/sampling.py rename to mmdet/core/bbox/sampling.py index 80f8c8207cc55b37d647bef33f0486d0a49ccd4a..976cd9507f2279b663d3f5e09ed1180da5b457c1 100644 --- a/mmdet/core/bbox_ops/sampling.py +++ b/mmdet/core/bbox/sampling.py @@ -5,6 +5,11 @@ from .geometry import bbox_overlaps def random_choice(gallery, num): + """Random select some elements from the gallery. + + It seems that Pytorch's implementation is slower than numpy so we use numpy + to randperm the indices. + """ assert len(gallery) >= num if isinstance(gallery, list): gallery = np.array(gallery) @@ -12,9 +17,7 @@ def random_choice(gallery, num): np.random.shuffle(cands) rand_inds = cands[:num] if not isinstance(gallery, np.ndarray): - rand_inds = torch.from_numpy(rand_inds).long() - if gallery.is_cuda: - rand_inds = rand_inds.cuda(gallery.get_device()) + rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) return gallery[rand_inds] diff --git a/mmdet/core/bbox_ops/transforms.py b/mmdet/core/bbox/transforms.py similarity index 100% rename from mmdet/core/bbox_ops/transforms.py rename to mmdet/core/bbox/transforms.py diff --git a/mmdet/core/mask_ops/__init__.py b/mmdet/core/mask/__init__.py similarity index 100% rename from mmdet/core/mask_ops/__init__.py rename to mmdet/core/mask/__init__.py diff --git a/mmdet/core/mask_ops/mask_target.py b/mmdet/core/mask/mask_target.py similarity index 100% rename from mmdet/core/mask_ops/mask_target.py rename to mmdet/core/mask/mask_target.py diff --git a/mmdet/core/mask_ops/segms.py b/mmdet/core/mask/segms.py similarity index 100% rename from mmdet/core/mask_ops/segms.py rename to mmdet/core/mask/segms.py diff --git a/mmdet/core/mask_ops/utils.py b/mmdet/core/mask/utils.py similarity index 100% rename from mmdet/core/mask_ops/utils.py rename to mmdet/core/mask/utils.py diff --git a/mmdet/core/post_processing/merge_augs.py b/mmdet/core/post_processing/merge_augs.py index 2b8d861a6745b90dd33b77ae4bda65bfd825d9a7..00f65b049ccf2b00a0fee73cc64ac257415425ea 100644 --- a/mmdet/core/post_processing/merge_augs.py +++ b/mmdet/core/post_processing/merge_augs.py @@ -3,7 +3,7 @@ import torch import numpy as np from mmdet.ops import nms -from ..bbox_ops import bbox_mapping_back +from ..bbox import bbox_mapping_back def merge_aug_proposals(aug_proposals, img_metas, rpn_test_cfg): diff --git a/mmdet/datasets/transforms.py b/mmdet/datasets/transforms.py index a7e72e0ee48b7e0c63eae065721abe2ecad62cec..d2daad15046ba9f4b6f288ef800b0ccb30f94d97 100644 --- a/mmdet/datasets/transforms.py +++ b/mmdet/datasets/transforms.py @@ -2,7 +2,7 @@ import mmcv import numpy as np import torch -from mmdet.core.mask_ops import segms +from mmdet.core.mask import segms __all__ = [ 'ImageTransform', 'BboxTransform', 'PolyMaskTransform', 'Numpy2Tensor'