Skip to content

Commit

Permalink
Use registry to manage datasets (#924)
Browse files Browse the repository at this point in the history
* use registry to manage datasets

* bug fix for concat dataset

* update documentation to fit the new api
  • Loading branch information
hellock authored Jul 4, 2019
1 parent 2fb1531 commit b8bcda6
Show file tree
Hide file tree
Showing 19 changed files with 210 additions and 163 deletions.
4 changes: 3 additions & 1 deletion GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,10 @@ In `mmdet/datasets/my_dataset.py`:

```python
from .coco import CocoDataset
from .registry import DATASETS


@DATASETS.register_module
class MyDataset(CocoDataset):

CLASSES = ('a', 'b', 'c', 'd', 'e')
Expand Down Expand Up @@ -228,7 +230,7 @@ import torch.nn as nn
from ..registry import BACKBONES


@BACKBONES.register
@BACKBONES.register_module
class MobileNet(nn.Module):

def __init__(self, arg1, arg2):
Expand Down
11 changes: 6 additions & 5 deletions mmdet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from .voc import VOCDataset
from .wider_face import WIDERFaceDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset
from .utils import to_tensor, random_scale, show_ann
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .extra_aug import ExtraAugmentation
from .registry import DATASETS
from .builder import build_dataset

__all__ = [
'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler',
'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale',
'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset',
'ExtraAugmentation', 'WIDERFaceDataset'
'show_ann', 'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation',
'WIDERFaceDataset', 'DATASETS', 'build_dataset'
]
38 changes: 38 additions & 0 deletions mmdet/datasets/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import copy

from mmdet.utils import build_from_cfg
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .registry import DATASETS


def _concat_dataset(cfg):
ann_files = cfg['ann_file']
img_prefixes = cfg.get('img_prefix', None)
seg_prefixes = cfg.get('seg_prefixes', None)
proposal_files = cfg.get('proposal_file', None)

datasets = []
num_dset = len(ann_files)
for i in range(num_dset):
data_cfg = copy.deepcopy(cfg)
data_cfg['ann_file'] = ann_files[i]
if isinstance(img_prefixes, (list, tuple)):
data_cfg['img_prefix'] = img_prefixes[i]
if isinstance(seg_prefixes, (list, tuple)):
data_cfg['seg_prefix'] = seg_prefixes[i]
if isinstance(proposal_files, (list, tuple)):
data_cfg['proposal_file'] = proposal_files[i]
datasets.append(build_dataset(data_cfg))

return ConcatDataset(datasets)


def build_dataset(cfg):
if cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(build_dataset(cfg['dataset']), cfg['times'])
elif isinstance(cfg['ann_file'], (list, tuple)):
dataset = _concat_dataset(cfg)
else:
dataset = build_from_cfg(cfg, DATASETS)

return dataset
2 changes: 2 additions & 0 deletions mmdet/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from pycocotools.coco import COCO

from .custom import CustomDataset
from .registry import DATASETS


@DATASETS.register_module
class CocoDataset(CustomDataset):

CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
Expand Down
22 changes: 0 additions & 22 deletions mmdet/datasets/concat_dataset.py

This file was deleted.

2 changes: 2 additions & 0 deletions mmdet/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from mmcv.parallel import DataContainer as DC
from torch.utils.data import Dataset

from .registry import DATASETS
from .transforms import (ImageTransform, BboxTransform, MaskTransform,
SegMapTransform, Numpy2Tensor)
from .utils import to_tensor, random_scale
from .extra_aug import ExtraAugmentation


@DATASETS.register_module
class CustomDataset(Dataset):
"""Custom dataset for detection.
Expand Down
55 changes: 55 additions & 0 deletions mmdet/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset

from .registry import DATASETS


@DATASETS.register_module
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""

def __init__(self, datasets):
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
if hasattr(datasets[0], 'flag'):
flags = []
for i in range(0, len(datasets)):
flags.append(datasets[i].flag)
self.flag = np.concatenate(flags)


@DATASETS.register_module
class RepeatDataset(object):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""

def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, times)

self._ori_len = len(self.dataset)

def __getitem__(self, idx):
return self.dataset[idx % self._ori_len]

def __len__(self):
return self.times * self._ori_len
3 changes: 3 additions & 0 deletions mmdet/datasets/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from mmdet.utils import Registry

DATASETS = Registry('dataset')
19 changes: 0 additions & 19 deletions mmdet/datasets/repeat_dataset.py

This file was deleted.

52 changes: 2 additions & 50 deletions mmdet/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import copy
from collections import Sequence

import mmcv
from mmcv.runner import obj_from_dict
import torch

import matplotlib.pyplot as plt
import mmcv
import numpy as np
from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset
from .. import datasets
import torch


def to_tensor(data):
Expand Down Expand Up @@ -72,45 +66,3 @@ def show_ann(coco, img, ann_info):
plt.axis('off')
coco.showAnns(ann_info)
plt.show()


def get_dataset(data_cfg):
if data_cfg['type'] == 'RepeatDataset':
return RepeatDataset(
get_dataset(data_cfg['dataset']), data_cfg['times'])

if isinstance(data_cfg['ann_file'], (list, tuple)):
ann_files = data_cfg['ann_file']
num_dset = len(ann_files)
else:
ann_files = [data_cfg['ann_file']]
num_dset = 1

if 'proposal_file' in data_cfg.keys():
if isinstance(data_cfg['proposal_file'], (list, tuple)):
proposal_files = data_cfg['proposal_file']
else:
proposal_files = [data_cfg['proposal_file']]
else:
proposal_files = [None] * num_dset
assert len(proposal_files) == num_dset

if isinstance(data_cfg['img_prefix'], (list, tuple)):
img_prefixes = data_cfg['img_prefix']
else:
img_prefixes = [data_cfg['img_prefix']] * num_dset
assert len(img_prefixes) == num_dset

dsets = []
for i in range(num_dset):
data_info = copy.deepcopy(data_cfg)
data_info['ann_file'] = ann_files[i]
data_info['proposal_file'] = proposal_files[i]
data_info['img_prefix'] = img_prefixes[i]
dset = obj_from_dict(data_info, datasets)
dsets.append(dset)
if len(dsets) > 1:
dset = ConcatDataset(dsets)
else:
dset = dsets[0]
return dset
2 changes: 2 additions & 0 deletions mmdet/datasets/voc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .registry import DATASETS
from .xml_style import XMLDataset


@DATASETS.register_module
class VOCDataset(XMLDataset):

CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
Expand Down
11 changes: 8 additions & 3 deletions mmdet/datasets/wider_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@

import mmcv

from .registry import DATASETS
from .xml_style import XMLDataset


@DATASETS.register_module
class WIDERFaceDataset(XMLDataset):
"""
Reader for the WIDER Face dataset in PASCAL VOC format.
Conversion scripts can be found in
https://github.com/sovrasov/wider-face-pascal-voc-annotations
"""
CLASSES = ('face',)
CLASSES = ('face', )

def __init__(self, **kwargs):
super(WIDERFaceDataset, self).__init__(**kwargs)
Expand All @@ -31,7 +33,10 @@ def load_annotations(self, ann_file):
height = int(size.find('height').text)
folder = root.find('folder').text
img_infos.append(
dict(id=img_id, filename=osp.join(folder, filename),
width=width, height=height))
dict(
id=img_id,
filename=osp.join(folder, filename),
width=width,
height=height))

return img_infos
2 changes: 2 additions & 0 deletions mmdet/datasets/xml_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import numpy as np

from .custom import CustomDataset
from .registry import DATASETS


@DATASETS.register_module
class XMLDataset(CustomDataset):

def __init__(self, min_size=None, **kwargs):
Expand Down
27 changes: 5 additions & 22 deletions mmdet/models/builder.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,18 @@
import mmcv
from torch import nn

from mmdet.utils import build_from_cfg
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
LOSSES, DETECTORS)


def _build_module(cfg, registry, default_args):
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if obj_type not in registry.module_dict:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
obj_type = registry.module_dict[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)


def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg]
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return _build_module(cfg, registry, default_args)
return build_from_cfg(cfg, registry, default_args)


def build_backbone(cfg):
Expand Down
Loading

0 comments on commit b8bcda6

Please sign in to comment.