From b8bcda671e742d5133c259e8db29f8622657acf7 Mon Sep 17 00:00:00 2001 From: Kai Chen Date: Thu, 4 Jul 2019 18:53:43 +0800 Subject: [PATCH] Use registry to manage datasets (#924) * use registry to manage datasets * bug fix for concat dataset * update documentation to fit the new api --- GETTING_STARTED.md | 4 +- mmdet/datasets/__init__.py | 11 +++-- mmdet/datasets/builder.py | 38 +++++++++++++++ mmdet/datasets/coco.py | 2 + mmdet/datasets/concat_dataset.py | 22 --------- mmdet/datasets/custom.py | 2 + mmdet/datasets/dataset_wrappers.py | 55 ++++++++++++++++++++++ mmdet/datasets/registry.py | 3 ++ mmdet/datasets/repeat_dataset.py | 19 -------- mmdet/datasets/utils.py | 52 +-------------------- mmdet/datasets/voc.py | 2 + mmdet/datasets/wider_face.py | 11 +++-- mmdet/datasets/xml_style.py | 2 + mmdet/models/builder.py | 27 ++--------- mmdet/models/registry.py | 38 +-------------- mmdet/utils/__init__.py | 3 ++ mmdet/utils/registry.py | 74 ++++++++++++++++++++++++++++++ tools/test.py | 4 +- tools/train.py | 4 +- 19 files changed, 210 insertions(+), 163 deletions(-) create mode 100644 mmdet/datasets/builder.py delete mode 100644 mmdet/datasets/concat_dataset.py create mode 100644 mmdet/datasets/dataset_wrappers.py create mode 100644 mmdet/datasets/registry.py delete mode 100644 mmdet/datasets/repeat_dataset.py create mode 100644 mmdet/utils/__init__.py create mode 100644 mmdet/utils/registry.py diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 7de53d0..04b2d3a 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -150,8 +150,10 @@ In `mmdet/datasets/my_dataset.py`: ```python from .coco import CocoDataset +from .registry import DATASETS +@DATASETS.register_module class MyDataset(CocoDataset): CLASSES = ('a', 'b', 'c', 'd', 'e') @@ -228,7 +230,7 @@ import torch.nn as nn from ..registry import BACKBONES -@BACKBONES.register +@BACKBONES.register_module class MobileNet(nn.Module): def __init__(self, arg1, arg2): diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 2c1612d..ab3e549 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -4,14 +4,15 @@ from .voc import VOCDataset from .wider_face import WIDERFaceDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader -from .utils import to_tensor, random_scale, show_ann, get_dataset -from .concat_dataset import ConcatDataset -from .repeat_dataset import RepeatDataset +from .utils import to_tensor, random_scale, show_ann +from .dataset_wrappers import ConcatDataset, RepeatDataset from .extra_aug import ExtraAugmentation +from .registry import DATASETS +from .builder import build_dataset __all__ = [ 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler', 'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale', - 'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset', - 'ExtraAugmentation', 'WIDERFaceDataset' + 'show_ann', 'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation', + 'WIDERFaceDataset', 'DATASETS', 'build_dataset' ] diff --git a/mmdet/datasets/builder.py b/mmdet/datasets/builder.py new file mode 100644 index 0000000..6b1ffba --- /dev/null +++ b/mmdet/datasets/builder.py @@ -0,0 +1,38 @@ +import copy + +from mmdet.utils import build_from_cfg +from .dataset_wrappers import ConcatDataset, RepeatDataset +from .registry import DATASETS + + +def _concat_dataset(cfg): + ann_files = cfg['ann_file'] + img_prefixes = cfg.get('img_prefix', None) + seg_prefixes = cfg.get('seg_prefixes', None) + proposal_files = cfg.get('proposal_file', None) + + datasets = [] + num_dset = len(ann_files) + for i in range(num_dset): + data_cfg = copy.deepcopy(cfg) + data_cfg['ann_file'] = ann_files[i] + if isinstance(img_prefixes, (list, tuple)): + data_cfg['img_prefix'] = img_prefixes[i] + if isinstance(seg_prefixes, (list, tuple)): + data_cfg['seg_prefix'] = seg_prefixes[i] + if isinstance(proposal_files, (list, tuple)): + data_cfg['proposal_file'] = proposal_files[i] + datasets.append(build_dataset(data_cfg)) + + return ConcatDataset(datasets) + + +def build_dataset(cfg): + if cfg['type'] == 'RepeatDataset': + dataset = RepeatDataset(build_dataset(cfg['dataset']), cfg['times']) + elif isinstance(cfg['ann_file'], (list, tuple)): + dataset = _concat_dataset(cfg) + else: + dataset = build_from_cfg(cfg, DATASETS) + + return dataset diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 0b3af9b..46ef709 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -2,8 +2,10 @@ from pycocotools.coco import COCO from .custom import CustomDataset +from .registry import DATASETS +@DATASETS.register_module class CocoDataset(CustomDataset): CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py deleted file mode 100644 index 195420a..0000000 --- a/mmdet/datasets/concat_dataset.py +++ /dev/null @@ -1,22 +0,0 @@ -import numpy as np -from torch.utils.data.dataset import ConcatDataset as _ConcatDataset - - -class ConcatDataset(_ConcatDataset): - """A wrapper of concatenated dataset. - - Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but - concat the group flag for image aspect ratio. - - Args: - datasets (list[:obj:`Dataset`]): A list of datasets. - """ - - def __init__(self, datasets): - super(ConcatDataset, self).__init__(datasets) - self.CLASSES = datasets[0].CLASSES - if hasattr(datasets[0], 'flag'): - flags = [] - for i in range(0, len(datasets)): - flags.append(datasets[i].flag) - self.flag = np.concatenate(flags) diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index 9bf4731..ef8a1bc 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -5,12 +5,14 @@ from mmcv.parallel import DataContainer as DC from torch.utils.data import Dataset +from .registry import DATASETS from .transforms import (ImageTransform, BboxTransform, MaskTransform, SegMapTransform, Numpy2Tensor) from .utils import to_tensor, random_scale from .extra_aug import ExtraAugmentation +@DATASETS.register_module class CustomDataset(Dataset): """Custom dataset for detection. diff --git a/mmdet/datasets/dataset_wrappers.py b/mmdet/datasets/dataset_wrappers.py new file mode 100644 index 0000000..e749cb0 --- /dev/null +++ b/mmdet/datasets/dataset_wrappers.py @@ -0,0 +1,55 @@ +import numpy as np +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + +from .registry import DATASETS + + +@DATASETS.register_module +class ConcatDataset(_ConcatDataset): + """A wrapper of concatenated dataset. + + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but + concat the group flag for image aspect ratio. + + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. + """ + + def __init__(self, datasets): + super(ConcatDataset, self).__init__(datasets) + self.CLASSES = datasets[0].CLASSES + if hasattr(datasets[0], 'flag'): + flags = [] + for i in range(0, len(datasets)): + flags.append(datasets[i].flag) + self.flag = np.concatenate(flags) + + +@DATASETS.register_module +class RepeatDataset(object): + """A wrapper of repeated dataset. + + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + + Args: + dataset (:obj:`Dataset`): The dataset to be repeated. + times (int): Repeat times. + """ + + def __init__(self, dataset, times): + self.dataset = dataset + self.times = times + self.CLASSES = dataset.CLASSES + if hasattr(self.dataset, 'flag'): + self.flag = np.tile(self.dataset.flag, times) + + self._ori_len = len(self.dataset) + + def __getitem__(self, idx): + return self.dataset[idx % self._ori_len] + + def __len__(self): + return self.times * self._ori_len diff --git a/mmdet/datasets/registry.py b/mmdet/datasets/registry.py new file mode 100644 index 0000000..e726624 --- /dev/null +++ b/mmdet/datasets/registry.py @@ -0,0 +1,3 @@ +from mmdet.utils import Registry + +DATASETS = Registry('dataset') diff --git a/mmdet/datasets/repeat_dataset.py b/mmdet/datasets/repeat_dataset.py deleted file mode 100644 index 7e99293..0000000 --- a/mmdet/datasets/repeat_dataset.py +++ /dev/null @@ -1,19 +0,0 @@ -import numpy as np - - -class RepeatDataset(object): - - def __init__(self, dataset, times): - self.dataset = dataset - self.times = times - self.CLASSES = dataset.CLASSES - if hasattr(self.dataset, 'flag'): - self.flag = np.tile(self.dataset.flag, times) - - self._ori_len = len(self.dataset) - - def __getitem__(self, idx): - return self.dataset[idx % self._ori_len] - - def __len__(self): - return self.times * self._ori_len diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py index 8fdba7f..9f4f46c 100644 --- a/mmdet/datasets/utils.py +++ b/mmdet/datasets/utils.py @@ -1,15 +1,9 @@ -import copy from collections import Sequence -import mmcv -from mmcv.runner import obj_from_dict -import torch - import matplotlib.pyplot as plt +import mmcv import numpy as np -from .concat_dataset import ConcatDataset -from .repeat_dataset import RepeatDataset -from .. import datasets +import torch def to_tensor(data): @@ -72,45 +66,3 @@ def show_ann(coco, img, ann_info): plt.axis('off') coco.showAnns(ann_info) plt.show() - - -def get_dataset(data_cfg): - if data_cfg['type'] == 'RepeatDataset': - return RepeatDataset( - get_dataset(data_cfg['dataset']), data_cfg['times']) - - if isinstance(data_cfg['ann_file'], (list, tuple)): - ann_files = data_cfg['ann_file'] - num_dset = len(ann_files) - else: - ann_files = [data_cfg['ann_file']] - num_dset = 1 - - if 'proposal_file' in data_cfg.keys(): - if isinstance(data_cfg['proposal_file'], (list, tuple)): - proposal_files = data_cfg['proposal_file'] - else: - proposal_files = [data_cfg['proposal_file']] - else: - proposal_files = [None] * num_dset - assert len(proposal_files) == num_dset - - if isinstance(data_cfg['img_prefix'], (list, tuple)): - img_prefixes = data_cfg['img_prefix'] - else: - img_prefixes = [data_cfg['img_prefix']] * num_dset - assert len(img_prefixes) == num_dset - - dsets = [] - for i in range(num_dset): - data_info = copy.deepcopy(data_cfg) - data_info['ann_file'] = ann_files[i] - data_info['proposal_file'] = proposal_files[i] - data_info['img_prefix'] = img_prefixes[i] - dset = obj_from_dict(data_info, datasets) - dsets.append(dset) - if len(dsets) > 1: - dset = ConcatDataset(dsets) - else: - dset = dsets[0] - return dset diff --git a/mmdet/datasets/voc.py b/mmdet/datasets/voc.py index ba1c772..77bffe3 100644 --- a/mmdet/datasets/voc.py +++ b/mmdet/datasets/voc.py @@ -1,6 +1,8 @@ +from .registry import DATASETS from .xml_style import XMLDataset +@DATASETS.register_module class VOCDataset(XMLDataset): CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', diff --git a/mmdet/datasets/wider_face.py b/mmdet/datasets/wider_face.py index ad52a95..b83e3d6 100644 --- a/mmdet/datasets/wider_face.py +++ b/mmdet/datasets/wider_face.py @@ -3,16 +3,18 @@ import mmcv +from .registry import DATASETS from .xml_style import XMLDataset +@DATASETS.register_module class WIDERFaceDataset(XMLDataset): """ Reader for the WIDER Face dataset in PASCAL VOC format. Conversion scripts can be found in https://github.com/sovrasov/wider-face-pascal-voc-annotations """ - CLASSES = ('face',) + CLASSES = ('face', ) def __init__(self, **kwargs): super(WIDERFaceDataset, self).__init__(**kwargs) @@ -31,7 +33,10 @@ def load_annotations(self, ann_file): height = int(size.find('height').text) folder = root.find('folder').text img_infos.append( - dict(id=img_id, filename=osp.join(folder, filename), - width=width, height=height)) + dict( + id=img_id, + filename=osp.join(folder, filename), + width=width, + height=height)) return img_infos diff --git a/mmdet/datasets/xml_style.py b/mmdet/datasets/xml_style.py index e0c6ac1..39d5704 100644 --- a/mmdet/datasets/xml_style.py +++ b/mmdet/datasets/xml_style.py @@ -5,8 +5,10 @@ import numpy as np from .custom import CustomDataset +from .registry import DATASETS +@DATASETS.register_module class XMLDataset(CustomDataset): def __init__(self, min_size=None, **kwargs): diff --git a/mmdet/models/builder.py b/mmdet/models/builder.py index 93cdb19..0c9b644 100644 --- a/mmdet/models/builder.py +++ b/mmdet/models/builder.py @@ -1,35 +1,18 @@ -import mmcv from torch import nn +from mmdet.utils import build_from_cfg from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS, LOSSES, DETECTORS) -def _build_module(cfg, registry, default_args): - assert isinstance(cfg, dict) and 'type' in cfg - assert isinstance(default_args, dict) or default_args is None - args = cfg.copy() - obj_type = args.pop('type') - if mmcv.is_str(obj_type): - if obj_type not in registry.module_dict: - raise KeyError('{} is not in the {} registry'.format( - obj_type, registry.name)) - obj_type = registry.module_dict[obj_type] - elif not isinstance(obj_type, type): - raise TypeError('type must be a str or valid type, but got {}'.format( - type(obj_type))) - if default_args is not None: - for name, value in default_args.items(): - args.setdefault(name, value) - return obj_type(**args) - - def build(cfg, registry, default_args=None): if isinstance(cfg, list): - modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg] + modules = [ + build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg + ] return nn.Sequential(*modules) else: - return _build_module(cfg, registry, default_args) + return build_from_cfg(cfg, registry, default_args) def build_backbone(cfg): diff --git a/mmdet/models/registry.py b/mmdet/models/registry.py index 533fdf8..78ef248 100644 --- a/mmdet/models/registry.py +++ b/mmdet/models/registry.py @@ -1,40 +1,4 @@ -import torch.nn as nn - - -class Registry(object): - - def __init__(self, name): - self._name = name - self._module_dict = dict() - - @property - def name(self): - return self._name - - @property - def module_dict(self): - return self._module_dict - - def _register_module(self, module_class): - """Register a module. - - Args: - module (:obj:`nn.Module`): Module to be registered. - """ - if not issubclass(module_class, nn.Module): - raise TypeError( - 'module must be a child of nn.Module, but got {}'.format( - module_class)) - module_name = module_class.__name__ - if module_name in self._module_dict: - raise KeyError('{} is already registered in {}'.format( - module_name, self.name)) - self._module_dict[module_name] = module_class - - def register_module(self, cls): - self._register_module(cls) - return cls - +from mmdet.utils import Registry BACKBONES = Registry('backbone') NECKS = Registry('neck') diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py new file mode 100644 index 0000000..c0a1244 --- /dev/null +++ b/mmdet/utils/__init__.py @@ -0,0 +1,3 @@ +from .registry import Registry, build_from_cfg + +__all__ = ['Registry', 'build_from_cfg'] diff --git a/mmdet/utils/registry.py b/mmdet/utils/registry.py new file mode 100644 index 0000000..e39552a --- /dev/null +++ b/mmdet/utils/registry.py @@ -0,0 +1,74 @@ +import inspect + +import mmcv + + +class Registry(object): + + def __init__(self, name): + self._name = name + self._module_dict = dict() + + def __repr__(self): + format_str = self.__class__.__name__ + '(name={}, items={})'.format( + self._name, list(self._module_dict.keys())) + return format_str + + @property + def name(self): + return self._name + + @property + def module_dict(self): + return self._module_dict + + def get(self, key): + return self._module_dict.get(key, None) + + def _register_module(self, module_class): + """Register a module. + + Args: + module (:obj:`nn.Module`): Module to be registered. + """ + if not inspect.isclass(module_class): + raise TypeError('module must be a class, but got {}'.format( + type(module_class))) + module_name = module_class.__name__ + if module_name in self._module_dict: + raise KeyError('{} is already registered in {}'.format( + module_name, self.name)) + self._module_dict[module_name] = module_class + + def register_module(self, cls): + self._register_module(cls) + return cls + + +def build_from_cfg(cfg, registry, default_args=None): + """Build a module from config dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + obj: The constructed object. + """ + assert isinstance(cfg, dict) and 'type' in cfg + assert isinstance(default_args, dict) or default_args is None + args = cfg.copy() + obj_type = args.pop('type') + if mmcv.is_str(obj_type): + obj_type = registry.get(obj_type) + if obj_type is None: + raise KeyError('{} is not in the {} registry'.format( + obj_type, registry.name)) + elif not inspect.isclass(obj_type): + raise TypeError('type must be a str or valid type, but got {}'.format( + type(obj_type))) + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + return obj_type(**args) diff --git a/tools/test.py b/tools/test.py index 54f074d..e33a0c4 100644 --- a/tools/test.py +++ b/tools/test.py @@ -12,7 +12,7 @@ from mmdet.apis import init_dist from mmdet.core import results2json, coco_eval, wrap_fp16_model -from mmdet.datasets import build_dataloader, get_dataset +from mmdet.datasets import build_dataloader, build_dataset from mmdet.models import build_detector @@ -147,7 +147,7 @@ def main(): # build the dataloader # TODO: support multiple images per gpu (only minor changes are needed) - dataset = get_dataset(cfg.data.test) + dataset = build_dataset(cfg.data.test) data_loader = build_dataloader( dataset, imgs_per_gpu=1, diff --git a/tools/train.py b/tools/train.py index d8bb9dc..ee2012f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -5,7 +5,7 @@ from mmcv import Config from mmdet import __version__ -from mmdet.datasets import get_dataset +from mmdet.datasets import build_dataset from mmdet.apis import (train_detector, init_dist, get_root_logger, set_random_seed) from mmdet.models import build_detector @@ -75,7 +75,7 @@ def main(): model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) - train_dataset = get_dataset(cfg.data.train) + train_dataset = build_dataset(cfg.data.train) if cfg.checkpoint_config is not None: # save mmdet version, config file content and class names in # checkpoints as meta data