diff --git a/mmdet/core/__init__.py b/mmdet/core/__init__.py index 0327750379f42e2fc43f81d8c579cc5ffd2cf4a1..1eb03f76acdfaff65c96e5316e3b4898b7a7af6a 100644 --- a/mmdet/core/__init__.py +++ b/mmdet/core/__init__.py @@ -3,5 +3,6 @@ from .bbox_ops import * from .mask_ops import * from .losses import * from .eval import * +from .parallel import * from .post_processing import * from .utils import * diff --git a/mmdet/nn/parallel/__init__.py b/mmdet/core/parallel/__init__.py similarity index 100% rename from mmdet/nn/parallel/__init__.py rename to mmdet/core/parallel/__init__.py diff --git a/mmdet/nn/parallel/_functions.py b/mmdet/core/parallel/_functions.py similarity index 100% rename from mmdet/nn/parallel/_functions.py rename to mmdet/core/parallel/_functions.py diff --git a/mmdet/nn/parallel/data_parallel.py b/mmdet/core/parallel/data_parallel.py similarity index 100% rename from mmdet/nn/parallel/data_parallel.py rename to mmdet/core/parallel/data_parallel.py diff --git a/mmdet/nn/parallel/distributed.py b/mmdet/core/parallel/distributed.py similarity index 100% rename from mmdet/nn/parallel/distributed.py rename to mmdet/core/parallel/distributed.py diff --git a/mmdet/nn/parallel/scatter_gather.py b/mmdet/core/parallel/scatter_gather.py similarity index 99% rename from mmdet/nn/parallel/scatter_gather.py rename to mmdet/core/parallel/scatter_gather.py index f5f7c588f4b137a6d36fcb49a0b520c60faa6d9a..02849dc01bc4e8af6420df69a4f5dcd5650655c6 100644 --- a/mmdet/nn/parallel/scatter_gather.py +++ b/mmdet/core/parallel/scatter_gather.py @@ -1,6 +1,7 @@ import torch -from ._functions import Scatter from torch.nn.parallel._functions import Scatter as OrigScatter + +from ._functions import Scatter from mmdet.datasets.utils import DataContainer diff --git a/mmdet/core/utils/hooks.py b/mmdet/core/utils/hooks.py index 05441601ba792e44b47d89e5f405bb5092286d3f..9772d4d64f1f5860a7646c23c950b620a052ae20 100644 --- a/mmdet/core/utils/hooks.py +++ b/mmdet/core/utils/hooks.py @@ -7,11 +7,11 @@ import mmcv import numpy as np import torch from mmcv.torchpack import Hook -from mmdet.datasets.loader import collate -from mmdet.nn.parallel import scatter from pycocotools.cocoeval import COCOeval from ..eval import eval_recalls +from ..parallel import scatter +from mmdet.datasets.loader import collate class EmptyCacheHook(Hook): diff --git a/mmdet/nn/__init__.py b/mmdet/nn/__init__.py deleted file mode 100644 index 1b627f5e7b807b1c6ae321c775c8fc8d03266238..0000000000000000000000000000000000000000 --- a/mmdet/nn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .parallel import MMDataParallel, MMDistributedDataParallel diff --git a/tools/test.py b/tools/test.py index 773136d4c8a63d1d76fc60eee7218d1a30c8bda9..0a43cdc316506ecf2b7addb8d11e3dc7dc30507b 100644 --- a/tools/test.py +++ b/tools/test.py @@ -5,10 +5,9 @@ import mmcv from mmcv.torchpack import load_checkpoint, parallel_test, obj_from_dict from mmdet import datasets -from mmdet.core import results2json, coco_eval +from mmdet.core import scatter, MMDataParallel, results2json, coco_eval from mmdet.datasets.loader import collate, build_dataloader from mmdet.models import build_detector, detectors -from mmdet.nn.parallel import scatter, MMDataParallel def single_test(model, data_loader, show=False): diff --git a/tools/train.py b/tools/train.py index 8fd43807967fef6b17695158a4f67514b0a0ab5d..fd47b1375622693e7e0fdd31c5e746df9dc5ff5a 100644 --- a/tools/train.py +++ b/tools/train.py @@ -8,10 +8,10 @@ from mmcv import Config from mmcv.torchpack import Runner, obj_from_dict from mmdet import datasets -from mmdet.core import init_dist, DistOptimizerHook, DistSamplerSeedHook +from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook, + MMDataParallel, MMDistributedDataParallel) from mmdet.datasets.loader import build_dataloader from mmdet.models import build_detector -from mmdet.nn.parallel import MMDataParallel, MMDistributedDataParallel def parse_losses(losses):