diff --git a/mmdet/utils/registry.py b/mmdet/utils/registry.py index a1cc87dcfdb5d954ec6fa496b4ffc52ac14dbbfa..4ad9f876ce73ba9cd9bd4eb70e9f2bbae76a6ffa 100644 --- a/mmdet/utils/registry.py +++ b/mmdet/utils/registry.py @@ -1,4 +1,5 @@ import inspect +from functools import partial import mmcv @@ -25,7 +26,7 @@ class Registry(object): def get(self, key): return self._module_dict.get(key, None) - def _register_module(self, module_class): + def _register_module(self, module_class, force=False): """Register a module. Args: @@ -35,13 +36,15 @@ class Registry(object): raise TypeError('module must be a class, but got {}'.format( type(module_class))) module_name = module_class.__name__ - if module_name in self._module_dict: + if not force and module_name in self._module_dict: raise KeyError('{} is already registered in {}'.format( module_name, self.name)) self._module_dict[module_name] = module_class - def register_module(self, cls): - self._register_module(cls) + def register_module(self, cls=None, force=False): + if cls is None: + return partial(self.register_module, force=force) + self._register_module(cls, force=force) return cls