Skip to content

Commit

Permalink
[Refactor] Fully refactor yolact
Browse files Browse the repository at this point in the history
  • Loading branch information
BIGWangYuDong authored and ZwwWayne committed Jul 19, 2022
1 parent 5a2ef66 commit 78bab5e
Show file tree
Hide file tree
Showing 18 changed files with 963 additions and 769 deletions.
3 changes: 1 addition & 2 deletions configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
# TODO: Update after mmcv.RandomChoiceResize finish refactor
type='RandomChoiceResize',
scales=[(852, 512), (852, 480), (852, 448), (852, 416), (852, 384),
(852, 352)],
resize_cfg=dict(type='Resize', keep_ratio=True)),
keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs')
]
Expand Down
1 change: 1 addition & 0 deletions configs/solo/solo_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_mask=True,
pad_size_divisor=32),
backbone=dict(
type='ResNet',
Expand Down
3 changes: 1 addition & 2 deletions configs/solo/solo_r50_fpn_3x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
# TODO: Update after mmcv.RandomChoiceResize finish refactor
type='RandomChoiceResize',
scales=[(1333, 800), (1333, 768), (1333, 736), (1333, 704),
(1333, 672), (1333, 640)],
resize_cfg=dict(type='Resize', keep_ratio=True)),
keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs')
]
Expand Down
3 changes: 1 addition & 2 deletions configs/solov2/solov2_light_r50_fpn_mstrain_3x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
# TODO: Update after mmcv.RandomChoiceResize finish refactor
type='RandomChoiceResize',
scales=[(768, 512), (768, 480), (768, 448), (768, 416), (768, 384),
(768, 352)],
resize_cfg=dict(type='Resize', keep_ratio=True)),
keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs')
]
Expand Down
1 change: 1 addition & 0 deletions configs/solov2/solov2_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_mask=True,
pad_size_divisor=32),
backbone=dict(
type='ResNet',
Expand Down
3 changes: 1 addition & 2 deletions configs/solov2/solov2_r50_fpn_mstrain_3x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
# TODO: Update after mmcv.RandomChoiceResize finish refactor
type='RandomChoiceResize',
scales=[(1333, 800), (1333, 768), (1333, 736), (1333, 704),
(1333, 672), (1333, 640)],
resize_cfg=dict(type='Resize', keep_ratio=True)),
keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs')
]
Expand Down
122 changes: 64 additions & 58 deletions configs/yolact/yolact_r50_1x8_coco.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
_base_ = '../_base_/default_runtime.py'

_base_ = [
'../_base_/datasets/coco_instance.py', '../_base_/default_runtime.py'
]
img_norm_cfg = dict(
mean=[123.68, 116.78, 103.94], std=[58.40, 57.12, 57.38], to_rgb=True)
# model settings
img_size = 550
input_size = 550
model = dict(
type='YOLACT',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=img_norm_cfg['mean'],
std=img_norm_cfg['std'],
bgr_to_rgb=img_norm_cfg['to_rgb'],
pad_mask=True),
backbone=dict(
type='ResNet',
depth=50,
Expand Down Expand Up @@ -56,11 +65,8 @@
num_protos=32,
num_classes=80,
max_masks_to_train=100,
loss_mask_weight=6.125),
segm_head=dict(
type='YOLACTSegmHead',
num_classes=80,
in_channels=256,
loss_mask_weight=6.125,
with_seg_branch=True,
loss_segm=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
# training and testing settings
Expand All @@ -72,6 +78,7 @@
min_pos_iou=0.,
ignore_iof_thr=-1,
gt_max_assign_all=False),
sampler=dict(type='PseudoSampler'), # YOLACT should use PseudoSampler
# smoothl1_beta=1.,
allowed_border=-1,
pos_weight=-1,
Expand All @@ -81,16 +88,16 @@
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
mask_thr=0.5,
iou_thr=0.5,
top_k=200,
max_per_img=100))
max_per_img=100,
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.68, 116.78, 103.94], std=[58.40, 57.12, 57.38], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='FilterAnnotations', min_gt_bbox_wh=(4.0, 4.0)),
dict(
Expand All @@ -102,62 +109,61 @@
type='MinIoURandomCrop',
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
min_crop_size=0.3),
dict(type='Resize', img_scale=(img_size, img_size), keep_ratio=False),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Resize', scale=(input_size, input_size), keep_ratio=False),
dict(type='RandomFlip', prob=0.5),
dict(
type='PhotoMetricDistortion',
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18),
dict(type='Normalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
dict(type='PackDetInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(input_size, input_size), keep_ratio=False),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
train_dataloader = dict(
batch_size=8,
num_workers=4,
batch_sampler=None,
dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

max_epochs = 55
# training schedule for 55e
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# learning rate
param_scheduler = [
dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=500),
dict(
type='MultiScaleFlipAug',
img_scale=(img_size, img_size),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=False),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[20, 42, 49, 52],
gamma=0.1)
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))

# optimizer
optimizer = dict(type='SGD', lr=1e-3, momentum=0.9, weight_decay=5e-4)
optimizer_config = dict()
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.1,
step=[20, 42, 49, 52])
runner = dict(type='EpochBasedRunner', max_epochs=55)
cudnn_benchmark = True
evaluation = dict(metric=['bbox', 'segm'])
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=1e-3, momentum=0.9, weight_decay=5e-4))

custom_hooks = [
dict(type='CheckInvalidLossHook', interval=50, priority='VERY_LOW')
]

env_cfg = dict(cudnn_benchmark=True)

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
Expand Down
27 changes: 17 additions & 10 deletions configs/yolact/yolact_r50_8x8_coco.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
_base_ = 'yolact_r50_1x8_coco.py'

optimizer = dict(type='SGD', lr=8e-3, momentum=0.9, weight_decay=5e-4)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=1000,
warmup_ratio=0.1,
step=[20, 42, 49, 52])

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(lr=8e-3),
clip_grad=dict(max_norm=35, norm_type=2))
# learning rate
max_epochs = 55
param_scheduler = [
dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=1000),
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[20, 42, 49, 52],
gamma=0.1)
]
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (8 samples per GPU)
Expand Down
22 changes: 19 additions & 3 deletions mmdet/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,9 +659,25 @@ def transform(self, results: dict) -> Union[dict, None]:
Returns:
dict: Updated result dict.
"""
assert 'gt_bboxes' in results
gt_bboxes = results['gt_bboxes']
if gt_bboxes.shape[0] == 0:
# gt_masks may not match with gt_bboxes, because gt_masks
# will not add into instances if ignore is True
if 'gt_ignore_flags' in results and 'gt_masks' in results:
vaild_idx = np.where(results['gt_ignore_flags'] == 0)[0]
keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_ignore_flags')
for key in keys:
if key in results:
results[key] = results[key][vaild_idx]

if self.by_box:
assert 'gt_bboxes' in results
gt_bboxes = results['gt_bboxes']
instance_num = gt_bboxes.shape[0]
if self.by_mask:
assert 'gt_masks' in results
gt_masks = results['gt_masks']
instance_num = len(gt_masks)

if instance_num == 0:
return results

tests = []
Expand Down
12 changes: 6 additions & 6 deletions mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from .ssd_head import SSDHead
from .tood_head import TOODHead
from .vfnet_head import VFNetHead
from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead
from .yolact_head import YOLACTHead, YOLACTProtonet
from .yolo_head import YOLOV3Head
from .yolof_head import YOLOFHead
from .yolox_head import YOLOXHead
Expand All @@ -49,11 +49,11 @@
'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead',
'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead',
'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead',
'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead',
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead',
'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
'YOLACTProtonet', 'YOLOV3Head', 'PAAHead', 'SABLRetinaHead',
'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead', 'CascadeRPNHead',
'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', 'AutoAssignHead',
'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead',
'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead',
'Mask2FormerHead', 'SOLOV2Head', 'DDODHead', 'CenterNetUpdateHead'
]
3 changes: 3 additions & 0 deletions mmdet/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,9 @@ def get_targets(self,
# `avg_factor` is usually equal to the number of positive priors.
avg_factor = sum(
[results.avg_factor for results in sampling_results_list])
# update `_raw_positive_infos`, which will be used when calling
# `get_positive_infos`.
self._raw_positive_infos.update(sampling_results=sampling_results_list)
# split targets to a list w.r.t. multiple levels
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights,
Expand Down
29 changes: 29 additions & 0 deletions mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class BaseDenseHead(BaseModule, metaclass=ABCMeta):

def __init__(self, init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
# `_raw_positive_infos` will be used in `get_positive_infos`, which
# can get positive information.
self._raw_positive_infos = dict()

def init_weights(self) -> None:
"""Initialize the weights."""
Expand All @@ -68,6 +71,32 @@ def init_weights(self) -> None:
if hasattr(m, 'conv_offset'):
constant_init(m.conv_offset, 0)

def get_positive_infos(self) -> InstanceList:
"""Get positive information from sampling results.
Returns:
list[:obj:`InstanceData`]: Positive information of each image,
usually including positive bboxes, positive labels, positive
priors, etc.
"""
if len(self._raw_positive_infos) == 0:
return None

sampling_results = self._raw_positive_infos.get(
'sampling_results', None)
assert sampling_results is not None
positive_infos = []
for sampling_result in enumerate(sampling_results):
pos_info = InstanceData()
pos_info.bboxes = sampling_result.pos_gt_bboxes
pos_info.labels = sampling_result.pos_gt_labels
pos_info.priors = sampling_result.pos_priors
pos_info.pos_assigned_gt_inds = \
sampling_result.pos_assigned_gt_inds
pos_info.pos_inds = sampling_result.pos_inds
positive_infos.append(pos_info)
return positive_infos

def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Expand Down
Loading

0 comments on commit 78bab5e

Please sign in to comment.