From 78bab5e7bf0af8a62d77747c2bd4f4795c558b6b Mon Sep 17 00:00:00 2001 From: BIGWangYuDong Date: Sun, 17 Jul 2022 16:03:04 +0000 Subject: [PATCH] [Refactor] Fully refactor yolact --- .../decoupled_solo_light_r50_fpn_3x_coco.py | 3 +- configs/solo/solo_r50_fpn_1x_coco.py | 1 + configs/solo/solo_r50_fpn_3x_coco.py | 3 +- .../solov2_light_r50_fpn_mstrain_3x_coco.py | 3 +- configs/solov2/solov2_r50_fpn_1x_coco.py | 1 + .../solov2/solov2_r50_fpn_mstrain_3x_coco.py | 3 +- configs/yolact/yolact_r50_1x8_coco.py | 122 +- configs/yolact/yolact_r50_8x8_coco.py | 27 +- mmdet/datasets/transforms/loading.py | 22 +- mmdet/models/dense_heads/__init__.py | 12 +- mmdet/models/dense_heads/anchor_head.py | 3 + mmdet/models/dense_heads/base_dense_head.py | 29 + mmdet/models/dense_heads/base_mask_head.py | 28 +- mmdet/models/dense_heads/yolact_head.py | 1236 ++++++++++------- .../detectors/single_stage_instance_seg.py | 69 +- mmdet/models/detectors/yolact.py | 131 +- .../roi_heads/mask_heads/fcn_mask_head.py | 2 +- .../test_single_stage_instance_seg.py | 37 +- 18 files changed, 963 insertions(+), 769 deletions(-) diff --git a/configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py b/configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py index 7d6b5ed1b2a..e40a7cba7c0 100644 --- a/configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py +++ b/configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py @@ -30,11 +30,10 @@ file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict( - # TODO: Update after mmcv.RandomChoiceResize finish refactor type='RandomChoiceResize', scales=[(852, 512), (852, 480), (852, 448), (852, 416), (852, 384), (852, 352)], - resize_cfg=dict(type='Resize', keep_ratio=True)), + keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs') ] diff --git a/configs/solo/solo_r50_fpn_1x_coco.py b/configs/solo/solo_r50_fpn_1x_coco.py index ae6c8eda283..595e9ffe148 100644 --- a/configs/solo/solo_r50_fpn_1x_coco.py +++ b/configs/solo/solo_r50_fpn_1x_coco.py @@ -10,6 +10,7 @@ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], bgr_to_rgb=True, + pad_mask=True, pad_size_divisor=32), backbone=dict( type='ResNet', diff --git a/configs/solo/solo_r50_fpn_3x_coco.py b/configs/solo/solo_r50_fpn_3x_coco.py index 49dd7a564bf..c30d41f6d92 100644 --- a/configs/solo/solo_r50_fpn_3x_coco.py +++ b/configs/solo/solo_r50_fpn_3x_coco.py @@ -6,11 +6,10 @@ file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict( - # TODO: Update after mmcv.RandomChoiceResize finish refactor type='RandomChoiceResize', scales=[(1333, 800), (1333, 768), (1333, 736), (1333, 704), (1333, 672), (1333, 640)], - resize_cfg=dict(type='Resize', keep_ratio=True)), + keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs') ] diff --git a/configs/solov2/solov2_light_r50_fpn_mstrain_3x_coco.py b/configs/solov2/solov2_light_r50_fpn_mstrain_3x_coco.py index b86139c7d55..2e3ce5fa103 100644 --- a/configs/solov2/solov2_light_r50_fpn_mstrain_3x_coco.py +++ b/configs/solov2/solov2_light_r50_fpn_mstrain_3x_coco.py @@ -15,11 +15,10 @@ file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict( - # TODO: Update after mmcv.RandomChoiceResize finish refactor type='RandomChoiceResize', scales=[(768, 512), (768, 480), (768, 448), (768, 416), (768, 384), (768, 352)], - resize_cfg=dict(type='Resize', keep_ratio=True)), + keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs') ] diff --git a/configs/solov2/solov2_r50_fpn_1x_coco.py b/configs/solov2/solov2_r50_fpn_1x_coco.py index 37dba2cfc39..138ca010b5f 100644 --- a/configs/solov2/solov2_r50_fpn_1x_coco.py +++ b/configs/solov2/solov2_r50_fpn_1x_coco.py @@ -11,6 +11,7 @@ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], bgr_to_rgb=True, + pad_mask=True, pad_size_divisor=32), backbone=dict( type='ResNet', diff --git a/configs/solov2/solov2_r50_fpn_mstrain_3x_coco.py b/configs/solov2/solov2_r50_fpn_mstrain_3x_coco.py index c072f9ddd00..ca7a6a16772 100644 --- a/configs/solov2/solov2_r50_fpn_mstrain_3x_coco.py +++ b/configs/solov2/solov2_r50_fpn_mstrain_3x_coco.py @@ -6,11 +6,10 @@ file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict( - # TODO: Update after mmcv.RandomChoiceResize finish refactor type='RandomChoiceResize', scales=[(1333, 800), (1333, 768), (1333, 736), (1333, 704), (1333, 672), (1333, 640)], - resize_cfg=dict(type='Resize', keep_ratio=True)), + keep_ratio=True), dict(type='RandomFlip', prob=0.5), dict(type='PackDetInputs') ] diff --git a/configs/yolact/yolact_r50_1x8_coco.py b/configs/yolact/yolact_r50_1x8_coco.py index dbced5a1a69..592c631fab1 100644 --- a/configs/yolact/yolact_r50_1x8_coco.py +++ b/configs/yolact/yolact_r50_1x8_coco.py @@ -1,9 +1,18 @@ -_base_ = '../_base_/default_runtime.py' - +_base_ = [ + '../_base_/datasets/coco_instance.py', '../_base_/default_runtime.py' +] +img_norm_cfg = dict( + mean=[123.68, 116.78, 103.94], std=[58.40, 57.12, 57.38], to_rgb=True) # model settings -img_size = 550 +input_size = 550 model = dict( type='YOLACT', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=img_norm_cfg['mean'], + std=img_norm_cfg['std'], + bgr_to_rgb=img_norm_cfg['to_rgb'], + pad_mask=True), backbone=dict( type='ResNet', depth=50, @@ -56,11 +65,8 @@ num_protos=32, num_classes=80, max_masks_to_train=100, - loss_mask_weight=6.125), - segm_head=dict( - type='YOLACTSegmHead', - num_classes=80, - in_channels=256, + loss_mask_weight=6.125, + with_seg_branch=True, loss_segm=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), # training and testing settings @@ -72,6 +78,7 @@ min_pos_iou=0., ignore_iof_thr=-1, gt_max_assign_all=False), + sampler=dict(type='PseudoSampler'), # YOLACT should use PseudoSampler # smoothl1_beta=1., allowed_border=-1, pos_weight=-1, @@ -81,16 +88,16 @@ nms_pre=1000, min_bbox_size=0, score_thr=0.05, + mask_thr=0.5, iou_thr=0.5, top_k=200, - max_per_img=100)) + max_per_img=100, + mask_thr_binary=0.5)) # dataset settings -dataset_type = 'CocoDataset' -data_root = 'data/coco/' -img_norm_cfg = dict( - mean=[123.68, 116.78, 103.94], std=[58.40, 57.12, 57.38], to_rgb=True) train_pipeline = [ - dict(type='LoadImageFromFile'), + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), dict(type='LoadAnnotations', with_bbox=True, with_mask=True), dict(type='FilterAnnotations', min_gt_bbox_wh=(4.0, 4.0)), dict( @@ -102,62 +109,61 @@ type='MinIoURandomCrop', min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3), - dict(type='Resize', img_scale=(img_size, img_size), keep_ratio=False), - dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Resize', scale=(input_size, input_size), keep_ratio=False), + dict(type='RandomFlip', prob=0.5), dict( type='PhotoMetricDistortion', brightness_delta=32, contrast_range=(0.5, 1.5), saturation_range=(0.5, 1.5), hue_delta=18), - dict(type='Normalize', **img_norm_cfg), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), + dict(type='PackDetInputs') ] test_pipeline = [ dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(input_size, input_size), keep_ratio=False), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=8, + num_workers=4, + batch_sampler=None, + dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +max_epochs = 55 +# training schedule for 55e +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# learning rate +param_scheduler = [ + dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=500), dict( - type='MultiScaleFlipAug', - img_scale=(img_size, img_size), - flip=False, - transforms=[ - dict(type='Resize', keep_ratio=False), - dict(type='Normalize', **img_norm_cfg), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']), - ]) + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[20, 42, 49, 52], + gamma=0.1) ] -data = dict( - samples_per_gpu=8, - workers_per_gpu=4, - train=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_train2017.json', - img_prefix=data_root + 'train2017/', - pipeline=train_pipeline), - val=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=test_pipeline), - test=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=test_pipeline)) + # optimizer -optimizer = dict(type='SGD', lr=1e-3, momentum=0.9, weight_decay=5e-4) -optimizer_config = dict() -# learning policy -lr_config = dict( - policy='step', - warmup='linear', - warmup_iters=500, - warmup_ratio=0.1, - step=[20, 42, 49, 52]) -runner = dict(type='EpochBasedRunner', max_epochs=55) -cudnn_benchmark = True -evaluation = dict(metric=['bbox', 'segm']) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=1e-3, momentum=0.9, weight_decay=5e-4)) + +custom_hooks = [ + dict(type='CheckInvalidLossHook', interval=50, priority='VERY_LOW') +] + +env_cfg = dict(cudnn_benchmark=True) # NOTE: `auto_scale_lr` is for automatically scaling LR, # USER SHOULD NOT CHANGE ITS VALUES. diff --git a/configs/yolact/yolact_r50_8x8_coco.py b/configs/yolact/yolact_r50_8x8_coco.py index 41003ab42bf..899f77cb18a 100644 --- a/configs/yolact/yolact_r50_8x8_coco.py +++ b/configs/yolact/yolact_r50_8x8_coco.py @@ -1,15 +1,22 @@ _base_ = 'yolact_r50_1x8_coco.py' -optimizer = dict(type='SGD', lr=8e-3, momentum=0.9, weight_decay=5e-4) -optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) -# learning policy -lr_config = dict( - policy='step', - warmup='linear', - warmup_iters=1000, - warmup_ratio=0.1, - step=[20, 42, 49, 52]) - +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(lr=8e-3), + clip_grad=dict(max_norm=35, norm_type=2)) +# learning rate +max_epochs = 55 +param_scheduler = [ + dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=1000), + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[20, 42, 49, 52], + gamma=0.1) +] # NOTE: `auto_scale_lr` is for automatically scaling LR, # USER SHOULD NOT CHANGE ITS VALUES. # base_batch_size = (8 GPUs) x (8 samples per GPU) diff --git a/mmdet/datasets/transforms/loading.py b/mmdet/datasets/transforms/loading.py index 1db34a3f55b..8fff5ea25c4 100644 --- a/mmdet/datasets/transforms/loading.py +++ b/mmdet/datasets/transforms/loading.py @@ -659,9 +659,25 @@ def transform(self, results: dict) -> Union[dict, None]: Returns: dict: Updated result dict. """ - assert 'gt_bboxes' in results - gt_bboxes = results['gt_bboxes'] - if gt_bboxes.shape[0] == 0: + # gt_masks may not match with gt_bboxes, because gt_masks + # will not add into instances if ignore is True + if 'gt_ignore_flags' in results and 'gt_masks' in results: + vaild_idx = np.where(results['gt_ignore_flags'] == 0)[0] + keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_ignore_flags') + for key in keys: + if key in results: + results[key] = results[key][vaild_idx] + + if self.by_box: + assert 'gt_bboxes' in results + gt_bboxes = results['gt_bboxes'] + instance_num = gt_bboxes.shape[0] + if self.by_mask: + assert 'gt_masks' in results + gt_masks = results['gt_masks'] + instance_num = len(gt_masks) + + if instance_num == 0: return results tests = [] diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index b775293b1df..01e90c9ea0e 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -38,7 +38,7 @@ from .ssd_head import SSDHead from .tood_head import TOODHead from .vfnet_head import VFNetHead -from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead +from .yolact_head import YOLACTHead, YOLACTProtonet from .yolo_head import YOLOV3Head from .yolof_head import YOLOFHead from .yolox_head import YOLOXHead @@ -49,11 +49,11 @@ 'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead', 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead', - 'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead', - 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead', - 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', - 'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', - 'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead', + 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead', 'SABLRetinaHead', + 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead', 'CascadeRPNHead', + 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', 'AutoAssignHead', + 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead', + 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead', 'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'SOLOV2Head', 'DDODHead', 'CenterNetUpdateHead' ] diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py index 8893a394b70..99346e28b69 100644 --- a/mmdet/models/dense_heads/anchor_head.py +++ b/mmdet/models/dense_heads/anchor_head.py @@ -383,6 +383,9 @@ def get_targets(self, # `avg_factor` is usually equal to the number of positive priors. avg_factor = sum( [results.avg_factor for results in sampling_results_list]) + # update `_raw_positive_infos`, which will be used when calling + # `get_positive_infos`. + self._raw_positive_infos.update(sampling_results=sampling_results_list) # split targets to a list w.r.t. multiple levels labels_list = images_to_levels(all_labels, num_level_anchors) label_weights_list = images_to_levels(all_label_weights, diff --git a/mmdet/models/dense_heads/base_dense_head.py b/mmdet/models/dense_heads/base_dense_head.py index dbe341495c7..cdb87c141cc 100644 --- a/mmdet/models/dense_heads/base_dense_head.py +++ b/mmdet/models/dense_heads/base_dense_head.py @@ -58,6 +58,9 @@ class BaseDenseHead(BaseModule, metaclass=ABCMeta): def __init__(self, init_cfg: OptMultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) + # `_raw_positive_infos` will be used in `get_positive_infos`, which + # can get positive information. + self._raw_positive_infos = dict() def init_weights(self) -> None: """Initialize the weights.""" @@ -68,6 +71,32 @@ def init_weights(self) -> None: if hasattr(m, 'conv_offset'): constant_init(m.conv_offset, 0) + def get_positive_infos(self) -> InstanceList: + """Get positive information from sampling results. + + Returns: + list[:obj:`InstanceData`]: Positive information of each image, + usually including positive bboxes, positive labels, positive + priors, etc. + """ + if len(self._raw_positive_infos) == 0: + return None + + sampling_results = self._raw_positive_infos.get( + 'sampling_results', None) + assert sampling_results is not None + positive_infos = [] + for sampling_result in enumerate(sampling_results): + pos_info = InstanceData() + pos_info.bboxes = sampling_result.pos_gt_bboxes + pos_info.labels = sampling_result.pos_gt_labels + pos_info.priors = sampling_result.pos_priors + pos_info.pos_assigned_gt_inds = \ + sampling_result.pos_assigned_gt_inds + pos_info.pos_inds = sampling_result.pos_inds + positive_infos.append(pos_info) + return positive_infos + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict: """Perform forward propagation and loss calculation of the detection head on the features of the upstream network. diff --git a/mmdet/models/dense_heads/base_mask_head.py b/mmdet/models/dense_heads/base_mask_head.py index 2093702c8ab..fcb95c91ce5 100644 --- a/mmdet/models/dense_heads/base_mask_head.py +++ b/mmdet/models/dense_heads/base_mask_head.py @@ -1,13 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union -from mmcv.runner import BaseModule +from mmengine.model import BaseModule from torch import Tensor from mmdet.data_elements import SampleList from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig -from ..task_modules.samplers import SamplingResult from ..utils import unpack_gt_instances @@ -32,7 +31,7 @@ def predict_by_feat(self, *args, **kwargs): def loss(self, x: Union[List[Tensor], Tuple[Tensor]], batch_data_samples: SampleList, - positive_infos: Optional[List[SamplingResult]] = None, + positive_infos: OptInstanceList = None, **kwargs) -> dict: """Perform forward propagation and loss calculation of the mask head on the features of the upstream network. @@ -43,20 +42,20 @@ def loss(self, batch_data_samples (list[:obj:`DetDataSample`]): Each item contains the meta information of each image and corresponding annotations. - positive_infos (list[:obj:``], optional): Information + positive_infos (list[:obj:`InstanceData`], optional): Information of positive samples. Used when the label assignment is - done outside the MaskHead, e.g., in BboxHead in + done outside the MaskHead, e.g., BboxHead in YOLACT or CondInst, etc. When the label assignment is done in - MaskHead, it would be None, like SOLO. All values + MaskHead, it would be None, like SOLO or SOLOv2. All values in it should have shape (num_positive_samples, *). + Returns: - dict[str, Tensor]: A dictionary of loss components. + dict: A dictionary of loss components. """ if positive_infos is None: outs = self(x) else: - # TODO: Currently not checked outs = self(x, positive_infos) assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \ @@ -84,7 +83,7 @@ def predict(self, x: Tuple[Tensor], batch_data_samples: SampleList, rescale: bool = False, - bbox_results_list: OptInstanceList = None, + results_list: OptInstanceList = None, **kwargs) -> InstanceList: """Test function without test-time augmentation. @@ -96,7 +95,7 @@ def predict(self, `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool, optional): Whether to rescale the results. Defaults to False. - bbox_results_list (list[obj:`InstanceData`], optional): Detection + results_list (list[obj:`InstanceData`], optional): Detection results of each image after the post process. Only exist if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc. @@ -114,13 +113,16 @@ def predict(self, batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] + if results_list is None: + outs = self(x) + else: + outs = self(x, results_list) - outs = self(x) results_list = self.predict_by_feat( *outs, batch_img_metas=batch_img_metas, rescale=rescale, - bbox_results_list=bbox_results_list, + results_list=results_list, **kwargs) return results_list diff --git a/mmdet/models/dense_heads/yolact_head.py b/mmdet/models/dense_heads/yolact_head.py index 4a7ddd2b939..fee0a57fedc 100644 --- a/mmdet/models/dense_heads/yolact_head.py +++ b/mmdet/models/dense_heads/yolact_head.py @@ -1,16 +1,24 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional + import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule +from mmengine.data import InstanceData from mmengine.model import BaseModule, ModuleList +from torch import Tensor -from mmdet.registry import MODELS, TASK_UTILS -from ..builder import build_loss +from mmdet.registry import MODELS +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList, OptMultiConfig) from ..layers import fast_nms from ..utils import images_to_levels, multi_apply, select_single_mlvl +from ..utils.misc import empty_instances from .anchor_head import AnchorHead +from .base_mask_head import BaseMaskHead @MODELS.register_module() @@ -29,65 +37,65 @@ class YOLACTHead(AnchorHead): num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. - anchor_generator (dict): Config dict for anchor generator - loss_cls (dict): Config of classification loss. - loss_bbox (dict): Config of localization loss. + anchor_generator (:obj:`ConfigDict` or dict): Config dict for + anchor generator + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. num_head_convs (int): Number of the conv layers shared by box and cls branches. num_protos (int): Number of the mask coefficients. use_ohem (bool): If true, ``loss_single_OHEM`` will be used for cls loss calculation. If false, ``loss_single`` will be used. - conv_cfg (dict): Dictionary to construct and config conv layer. - norm_cfg (dict): Dictionary to construct and config norm layer. - init_cfg (dict or list[dict], optional): Initialization config dict. + conv_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to + construct and config conv layer. + norm_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to + construct and config norm layer. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. """ def __init__(self, - num_classes, - in_channels, - anchor_generator=dict( + num_classes: int, + in_channels: int, + anchor_generator: ConfigType = dict( type='AnchorGenerator', octave_base_scale=3, scales_per_octave=1, ratios=[0.5, 1.0, 2.0], strides=[8, 16, 32, 64, 128]), - loss_cls=dict( + loss_cls: ConfigType = dict( type='CrossEntropyLoss', use_sigmoid=False, reduction='none', loss_weight=1.0), - loss_bbox=dict( + loss_bbox: ConfigType = dict( type='SmoothL1Loss', beta=1.0, loss_weight=1.5), - num_head_convs=1, - num_protos=32, - use_ohem=True, - conv_cfg=None, - norm_cfg=None, - init_cfg=dict( + num_head_convs: int = 1, + num_protos: int = 32, + use_ohem: bool = True, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = dict( type='Xavier', distribution='uniform', bias=0, layer='Conv2d'), - **kwargs): + **kwargs) -> None: self.num_head_convs = num_head_convs self.num_protos = num_protos self.use_ohem = use_ohem self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg - super(YOLACTHead, self).__init__( - num_classes, - in_channels, + super().__init__( + num_classes=num_classes, + in_channels=in_channels, loss_cls=loss_cls, loss_bbox=loss_bbox, anchor_generator=anchor_generator, init_cfg=init_cfg, **kwargs) - if self.use_ohem: - sampler_cfg = dict(type='PseudoSampler') - self.sampler = TASK_UTILS.build(sampler_cfg, context=self) - self.sampling = False - def _init_layers(self): + def _init_layers(self) -> None: """Initialize layers of the head.""" self.relu = nn.ReLU(inplace=True) self.head_convs = ModuleList() @@ -115,7 +123,7 @@ def _init_layers(self): 3, padding=1) - def forward_single(self, x): + def forward_single(self, x: Tensor) -> tuple: """Forward feature of a single scale level. Args: @@ -123,12 +131,13 @@ def forward_single(self, x): Returns: tuple: - cls_score (Tensor): Cls scores for a single scale level \ - the channels number is num_anchors * num_classes. - bbox_pred (Tensor): Box energies / deltas for a single scale \ - level, the channels number is num_anchors * 4. - coeff_pred (Tensor): Mask coefficients for a single scale \ - level, the channels number is num_anchors * num_protos. + + - cls_score (Tensor): Cls scores for a single scale level + the channels number is num_anchors * num_classes. + - bbox_pred (Tensor): Box energies / deltas for a single scale + level, the channels number is num_anchors * 4. + - coeff_pred (Tensor): Mask coefficients for a single scale + level, the channels number is num_anchors * num_protos. """ for head_conv in self.head_convs: x = head_conv(x) @@ -137,37 +146,38 @@ def forward_single(self, x): coeff_pred = self.conv_coeff(x).tanh() return cls_score, bbox_pred, coeff_pred - def loss(self, - cls_scores, - bbox_preds, - gt_bboxes, - gt_labels, - img_metas, - gt_bboxes_ignore=None): - """A combination of the func:``AnchorHead.loss`` and - func:``SSDHead.loss``. + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + coeff_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the bbox head. When ``self.use_ohem == True``, it functions like ``SSDHead.loss``, - otherwise, it follows ``AnchorHead.loss``. Besides, it additionally - returns ``sampling_results``. + otherwise, it follows ``AnchorHead.loss``. Args: cls_scores (list[Tensor]): Box scores for each scale level - Has shape (N, num_anchors * num_classes, H, W) + has shape (N, num_anchors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for each scale - level with shape (N, num_anchors * 4, H, W) - gt_bboxes (list[Tensor]): Ground truth bboxes for each image with - shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. - gt_labels (list[Tensor]): Class indices corresponding to each box - img_metas (list[dict]): Meta information of each image, e.g., + level with shape (N, num_anchors * 4, H, W). + coeff_preds (list[Tensor]): Mask coefficients for each scale + level with shape (N, num_anchors * num_protos, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. - gt_bboxes_ignore (None | list[Tensor]): Specify which bounding - boxes can be ignored when computing the loss. Default: None + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. Returns: - tuple: - dict[str, Tensor]: A dictionary of loss components. - List[:obj:``SamplingResult``]: Sampler results for each image. + dict: A dictionary of loss components. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels @@ -175,25 +185,20 @@ def loss(self, device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( - featmap_sizes, img_metas, device=device) - label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + featmap_sizes, batch_img_metas, device=device) cls_reg_targets = self.get_targets( anchor_list, valid_flag_list, - gt_bboxes, - img_metas, - gt_bboxes_ignore_list=gt_bboxes_ignore, - gt_labels_list=gt_labels, - label_channels=label_channels, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore, unmap_outputs=not self.use_ohem, return_sampling_results=True) - if cls_reg_targets is None: - return None (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, - num_total_pos, num_total_neg, sampling_results) = cls_reg_targets + avg_factor, sampling_results) = cls_reg_targets if self.use_ohem: - num_images = len(img_metas) + num_images = len(batch_img_metas) all_cls_scores = torch.cat([ s.permute(0, 2, 3, 1).reshape( num_images, -1, self.cls_out_channels) for s in cls_scores @@ -222,7 +227,7 @@ def loss(self, 'bbox predications become infinite or NaN!' losses_cls, losses_bbox = multi_apply( - self.loss_single_OHEM, + self.OHEMloss_by_feat_single, all_cls_scores, all_bbox_preds, all_anchors, @@ -230,12 +235,8 @@ def loss(self, all_label_weights, all_bbox_targets, all_bbox_weights, - num_total_samples=num_total_pos) + avg_factor=avg_factor) else: - num_total_samples = ( - num_total_pos + - num_total_neg if self.sampling else num_total_pos) - # anchor number of multi levels num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] # concat all level anchors and flags to a single tensor @@ -245,7 +246,7 @@ def loss(self, all_anchor_list = images_to_levels(concat_anchor_list, num_level_anchors) losses_cls, losses_bbox = multi_apply( - self.loss_single, + self.loss_by_feat_single, cls_scores, bbox_preds, all_anchor_list, @@ -253,15 +254,47 @@ def loss(self, label_weights_list, bbox_targets_list, bbox_weights_list, - num_total_samples=num_total_samples) + avg_factor=avg_factor) + losses = dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + # update `_raw_positive_infos`, which will be used when calling + # `get_positive_infos`. + self._raw_positive_infos.update(coeff_preds=coeff_preds) + return losses + + def OHEMloss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, + anchors: Tensor, labels: Tensor, + label_weights: Tensor, bbox_targets: Tensor, + bbox_weights: Tensor, + avg_factor: int) -> tuple: + """Compute loss of a single image. Similar to + func:``SSDHead.loss_by_feat_single`` - return dict( - loss_cls=losses_cls, loss_bbox=losses_bbox), sampling_results + Args: + cls_score (Tensor): Box scores for eachimage + Has shape (num_total_anchors, num_classes). + bbox_pred (Tensor): Box energies / deltas for each image + level with shape (num_total_anchors, 4). + anchors (Tensor): Box reference for each scale level with shape + (num_total_anchors, 4). + labels (Tensor): Labels of each anchors with shape + (num_total_anchors,). + label_weights (Tensor): Label weights of each anchor with shape + (num_total_anchors,) + bbox_targets (Tensor): BBox regression targets of each anchor + weight shape (num_total_anchors, 4). + bbox_weights (Tensor): BBox regression loss weights of each anchor + with shape (num_total_anchors, 4). + avg_factor (int): Average factor that is used to average + the loss. When using sampling method, avg_factor is usually + the sum of positive and negative priors. When using + `PseudoSampler`, `avg_factor` is usually equal to the number + of positive priors. + + Returns: + Tuple[Tensor, Tensor]: A tuple of cls loss and bbox loss of one + feature map. + """ - def loss_single_OHEM(self, cls_score, bbox_pred, anchors, labels, - label_weights, bbox_targets, bbox_weights, - num_total_samples): - """"See func:``SSDHead.loss``.""" loss_cls_all = self.loss_cls(cls_score, labels, label_weights) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes @@ -274,33 +307,64 @@ def loss_single_OHEM(self, cls_score, bbox_pred, anchors, labels, if num_pos_samples == 0: num_neg_samples = neg_inds.size(0) else: - num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples + num_neg_samples = self.train_cfg['neg_pos_ratio'] * \ + num_pos_samples if num_neg_samples > neg_inds.size(0): num_neg_samples = neg_inds.size(0) topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples) loss_cls_pos = loss_cls_all[pos_inds].sum() loss_cls_neg = topk_loss_cls_neg.sum() - loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples + loss_cls = (loss_cls_pos + loss_cls_neg) / avg_factor if self.reg_decoded_bbox: # When the regression loss (e.g. `IouLoss`, `GIouLoss`) # is applied directly on the decoded bounding boxes, it # decodes the already encoded coordinates to absolute format. bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) loss_bbox = self.loss_bbox( - bbox_pred, - bbox_targets, - bbox_weights, - avg_factor=num_total_samples) + bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor) return loss_cls[None], loss_bbox - def get_bboxes(self, - cls_scores, - bbox_preds, - coeff_preds, - img_metas, - cfg=None, - rescale=False): - """"Similar to func:``AnchorHead.get_bboxes``, but additionally + def get_positive_infos(self) -> InstanceList: + """Get positive information from sampling results. + + Returns: + list[:obj:`InstanceData`]: Positive Information of each image, + usually including positive bboxes, positive labels, positive + priors, positive coeffs, etc. + """ + assert len(self._raw_positive_infos) > 0 + sampling_results = self._raw_positive_infos['sampling_results'] + num_imgs = len(sampling_results) + + coeff_pred_list = [] + for coeff_pred_per_level in self._raw_positive_infos['coeff_preds']: + coeff_pred_per_level = \ + coeff_pred_per_level.permute( + 0, 2, 3, 1).reshape(num_imgs, -1, self.num_protos) + coeff_pred_list.append(coeff_pred_per_level) + coeff_preds = torch.cat(coeff_pred_list, dim=1) + + pos_info_list = [] + for idx, sampling_result in enumerate(sampling_results): + pos_info = InstanceData() + coeff_preds_single = coeff_preds[idx] + pos_info.pos_assigned_gt_inds = \ + sampling_result.pos_assigned_gt_inds + pos_info.pos_inds = sampling_result.pos_inds + pos_info.coeffs = coeff_preds_single[sampling_result.pos_inds] + pos_info.bboxes = sampling_result.pos_gt_bboxes + pos_info_list.append(pos_info) + return pos_info_list + + def predict_by_feat(self, + cls_scores, + bbox_preds, + coeff_preds, + batch_img_metas, + cfg=None, + rescale=True, + **kwargs): + """Similar to func:``AnchorHead.get_bboxes``, but additionally processes coeff_preds. Args: @@ -310,99 +374,108 @@ def get_bboxes(self, level with shape (N, num_anchors * 4, H, W) coeff_preds (list[Tensor]): Mask coefficients for each scale level with shape (N, num_anchors * num_protos, H, W) - img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. + batch_img_metas (list[dict]): Batch image meta info. cfg (mmcv.Config | None): Test / postprocessing configuration, if None, test_cfg would be used rescale (bool): If True, return boxes in original image space. - Default: False. + Defaults to True. Returns: - list[tuple[Tensor, Tensor, Tensor]]: Each item in result_list is - a 3-tuple. The first item is an (n, 5) tensor, where the - first 4 columns are bounding box positions - (tl_x, tl_y, br_x, br_y) and the 5-th column is a score - between 0 and 1. The second item is an (n,) tensor where each - item is the predicted class label of the corresponding box. - The third item is an (n, num_protos) tensor where each item - is the predicted mask coefficients of instance inside the - corresponding box. + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - coeffs (Tensor): the predicted mask coefficients of + instance inside the corresponding box has a shape + (n, num_protos). """ assert len(cls_scores) == len(bbox_preds) num_levels = len(cls_scores) device = cls_scores[0].device featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] - mlvl_anchors = self.prior_generator.grid_priors( + mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, device=device) - det_bboxes = [] - det_labels = [] - det_coeffs = [] - for img_id in range(len(img_metas)): + result_list = [] + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] cls_score_list = select_single_mlvl(cls_scores, img_id) bbox_pred_list = select_single_mlvl(bbox_preds, img_id) coeff_pred_list = select_single_mlvl(coeff_preds, img_id) - img_shape = img_metas[img_id]['img_shape'] - scale_factor = img_metas[img_id]['scale_factor'] - bbox_res = self._get_bboxes_single(cls_score_list, bbox_pred_list, - coeff_pred_list, mlvl_anchors, - img_shape, scale_factor, cfg, - rescale) - det_bboxes.append(bbox_res[0]) - det_labels.append(bbox_res[1]) - det_coeffs.append(bbox_res[2]) - return det_bboxes, det_labels, det_coeffs - - def _get_bboxes_single(self, - cls_score_list, - bbox_pred_list, - coeff_preds_list, - mlvl_anchors, - img_shape, - scale_factor, - cfg, - rescale=False): - """"Similar to func:``AnchorHead._get_bboxes_single``, but additionally - processes coeff_preds_list and uses fast NMS instead of traditional - NMS. + results = self._predict_by_feat_single( + cls_score_list=cls_score_list, + bbox_pred_list=bbox_pred_list, + coeff_preds_list=coeff_pred_list, + mlvl_priors=mlvl_priors, + img_meta=img_meta, + cfg=cfg, + rescale=rescale) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + coeff_preds_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigType, + rescale: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. Similar to func:``AnchorHead._predict_by_feat_single``, + but additionally processes coeff_preds_list and uses fast NMS instead + of traditional NMS. Args: cls_score_list (list[Tensor]): Box scores for a single scale level - Has shape (num_anchors * num_classes, H, W). + Has shape (num_priors * num_classes, H, W). bbox_pred_list (list[Tensor]): Box energies / deltas for a single - scale level with shape (num_anchors * 4, H, W). + scale level with shape (num_priors * 4, H, W). coeff_preds_list (list[Tensor]): Mask coefficients for a single - scale level with shape (num_anchors * num_protos, H, W). - mlvl_anchors (list[Tensor]): Box reference for a single scale level - with shape (num_total_anchors, 4). - img_shape (tuple[int]): Shape of the input image, - (height, width, 3). - scale_factor (ndarray): Scale factor of the image arange as - (w_scale, h_scale, w_scale, h_scale). - cfg (mmcv.Config): Test / postprocessing configuration, + scale level with shape (num_priors * num_protos, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid, + has shape (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. + Defaults to False. Returns: - tuple[Tensor, Tensor, Tensor]: The first item is an (n, 5) tensor, - where the first 4 columns are bounding box positions - (tl_x, tl_y, br_x, br_y) and the 5-th column is a score between - 0 and 1. The second item is an (n,) tensor where each item is - the predicted class label of the corresponding box. The third - item is an (n, num_protos) tensor where each item is the - predicted mask coefficients of instance inside the - corresponding box. + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - coeffs (Tensor): the predicted mask coefficients of + instance inside the corresponding box has a shape + (n, num_protos). """ + assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_priors) + cfg = self.test_cfg if cfg is None else cfg - assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] nms_pre = cfg.get('nms_pre', -1) - mlvl_bboxes = [] + + mlvl_bbox_preds = [] + mlvl_valid_priors = [] mlvl_scores = [] mlvl_coeffs = [] - for cls_score, bbox_pred, coeff_pred, anchors in \ + for cls_score, bbox_pred, coeff_pred, priors in \ zip(cls_score_list, bbox_pred_list, - coeff_preds_list, mlvl_anchors): + coeff_preds_list, mlvl_priors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] cls_score = cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels) @@ -424,158 +497,96 @@ def _get_bboxes_single(self, # BG cat_id: num_class max_scores, _ = scores[:, :-1].max(dim=1) _, topk_inds = max_scores.topk(nms_pre) - anchors = anchors[topk_inds, :] + priors = priors[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] coeff_pred = coeff_pred[topk_inds, :] - bboxes = self.bbox_coder.decode( - anchors, bbox_pred, max_shape=img_shape) - mlvl_bboxes.append(bboxes) + + mlvl_bbox_preds.append(bbox_pred) + mlvl_valid_priors.append(priors) mlvl_scores.append(scores) mlvl_coeffs.append(coeff_pred) - mlvl_bboxes = torch.cat(mlvl_bboxes) - if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - mlvl_scores = torch.cat(mlvl_scores) - mlvl_coeffs = torch.cat(mlvl_coeffs) - if self.use_sigmoid_cls: - # Add a dummy background class to the backend when using sigmoid - # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 - # BG cat_id: num_class - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) - mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) - det_bboxes, det_labels, det_coeffs = fast_nms(mlvl_bboxes, mlvl_scores, - mlvl_coeffs, - cfg.score_thr, - cfg.iou_thr, cfg.top_k, - cfg.max_per_img) - return det_bboxes, det_labels, det_coeffs - - -@MODELS.register_module() -class YOLACTSegmHead(BaseModule): - """YOLACT segmentation head used in https://arxiv.org/abs/1904.02689. - - Apply a semantic segmentation loss on feature space using layers that are - only evaluated during training to increase performance with no speed - penalty. - - Args: - in_channels (int): Number of channels in the input feature map. - num_classes (int): Number of categories excluding the background - category. - loss_segm (dict): Config of semantic segmentation loss. - init_cfg (dict or list[dict], optional): Initialization config dict. - """ - - def __init__(self, - num_classes, - in_channels=256, - loss_segm=dict( - type='CrossEntropyLoss', - use_sigmoid=True, - loss_weight=1.0), - init_cfg=dict( - type='Xavier', - distribution='uniform', - override=dict(name='segm_conv'))): - super(YOLACTSegmHead, self).__init__(init_cfg) - self.in_channels = in_channels - self.num_classes = num_classes - self.loss_segm = build_loss(loss_segm) - self._init_layers() - self.fp16_enabled = False - - def _init_layers(self): - """Initialize layers of the head.""" - self.segm_conv = nn.Conv2d( - self.in_channels, self.num_classes, kernel_size=1) - def forward(self, x): - """Forward feature from the upstream network. + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = torch.cat(mlvl_valid_priors) + multi_bboxes = self.bbox_coder.decode( + priors, bbox_pred, max_shape=img_shape) + + multi_scores = torch.cat(mlvl_scores) + multi_coeffs = torch.cat(mlvl_coeffs) + + return self._bbox_post_process( + multi_bboxes=multi_bboxes, + multi_scores=multi_scores, + multi_coeffs=multi_coeffs, + cfg=cfg, + rescale=rescale, + img_meta=img_meta) + + def _bbox_post_process(self, + multi_bboxes: Tensor, + multi_scores: Tensor, + multi_coeffs: Tensor, + cfg: ConfigType, + rescale: bool = False, + img_meta: Optional[dict] = None, + **kwargs) -> InstanceData: + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually `with_nms` is False is used for aug test. Args: - x (Tensor): Feature from the upstream network, which is - a 4D-tensor. - - Returns: - Tensor: Predicted semantic segmentation map with shape - (N, num_classes, H, W). - """ - return self.segm_conv(x) - - def loss(self, segm_pred, gt_masks, gt_labels): - """Compute loss of the head. - - Args: - segm_pred (list[Tensor]): Predicted semantic segmentation map - with shape (N, num_classes, H, W). - gt_masks (list[Tensor]): Ground truth masks for each image with - the same shape of the input image. - gt_labels (list[Tensor]): Class indices corresponding to each box. + multi_bboxes (Tensor): Predicted bbox that concat all levels. + multi_scores (Tensor): Bbox scores that concat all levels. + multi_coeffs (Tensor): Mask coefficients that concat all levels. + cfg (ConfigDict): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default to False. + img_meta (dict, optional): Image meta info. Defaults to None. Returns: - dict[str, Tensor]: A dictionary of loss components. + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - coeffs (Tensor): the predicted mask coefficients of + instance inside the corresponding box has a shape + (n, num_protos). """ - loss_segm = [] - num_imgs, num_classes, mask_h, mask_w = segm_pred.size() - for idx in range(num_imgs): - cur_segm_pred = segm_pred[idx] - cur_gt_masks = gt_masks[idx].float() - cur_gt_labels = gt_labels[idx] - segm_targets = self.get_targets(cur_segm_pred, cur_gt_masks, - cur_gt_labels) - if segm_targets is None: - loss = self.loss_segm(cur_segm_pred, - torch.zeros_like(cur_segm_pred), - torch.zeros_like(cur_segm_pred)) - else: - loss = self.loss_segm( - cur_segm_pred, - segm_targets, - avg_factor=num_imgs * mask_h * mask_w) - loss_segm.append(loss) - return dict(loss_segm=loss_segm) - - def get_targets(self, segm_pred, gt_masks, gt_labels): - """Compute semantic segmentation targets for each image. + if rescale: + assert img_meta.get('scale_factor') is not None + multi_bboxes /= multi_bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + # mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - Args: - segm_pred (Tensor): Predicted semantic segmentation map - with shape (num_classes, H, W). - gt_masks (Tensor): Ground truth masks for each image with - the same shape of the input image. - gt_labels (Tensor): Class indices corresponding to each box. + if self.use_sigmoid_cls: + # Add a dummy background class to the backend when using sigmoid + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class - Returns: - Tensor: Semantic segmentation targets with shape - (num_classes, H, W). - """ - if gt_masks.size(0) == 0: - return None - num_classes, mask_h, mask_w = segm_pred.size() - with torch.no_grad(): - downsampled_masks = F.interpolate( - gt_masks.unsqueeze(0), (mask_h, mask_w), - mode='bilinear', - align_corners=False).squeeze(0) - downsampled_masks = downsampled_masks.gt(0.5).float() - segm_targets = torch.zeros_like(segm_pred, requires_grad=False) - for obj_idx in range(downsampled_masks.size(0)): - segm_targets[gt_labels[obj_idx] - 1] = torch.max( - segm_targets[gt_labels[obj_idx] - 1], - downsampled_masks[obj_idx]) - return segm_targets - - def simple_test(self, feats, img_metas, rescale=False): - """Test function without test-time augmentation.""" - raise NotImplementedError( - 'simple_test of YOLACTSegmHead is not implemented ' - 'because this head is only evaluated during training') + padding = multi_scores.new_zeros(multi_scores.shape[0], 1) + multi_scores = torch.cat([multi_scores, padding], dim=1) + det_bboxes, det_labels, det_coeffs = fast_nms( + multi_bboxes, multi_scores, multi_coeffs, cfg.score_thr, + cfg.iou_thr, cfg.top_k, cfg.max_per_img) + results = InstanceData() + results.bboxes = det_bboxes[:, :4] + results.scores = det_bboxes[:, -1] + results.labels = det_labels + results.coeffs = det_coeffs + return results @MODELS.register_module() -class YOLACTProtonet(BaseModule): +class YOLACTProtonet(BaseMaskHead): """YOLACT mask head used in https://arxiv.org/abs/1904.02689. This head outputs the mask prototypes for YOLACT. @@ -584,45 +595,69 @@ class YOLACTProtonet(BaseModule): in_channels (int): Number of channels in the input feature map. proto_channels (tuple[int]): Output channels of protonet convs. proto_kernel_sizes (tuple[int]): Kernel sizes of protonet convs. - include_last_relu (Bool): If keep the last relu of protonet. + include_last_relu (bool): If keep the last relu of protonet. num_protos (int): Number of prototypes. num_classes (int): Number of categories excluding the background category. loss_mask_weight (float): Reweight the mask loss by this factor. max_masks_to_train (int): Maximum number of masks to train for each image. - init_cfg (dict or list[dict], optional): Initialization config dict. + with_seg_branch (bool): Whether to apply a semantic segmentation + branch and calculate loss during training to increase + performance with no speed penalty. Defaults to True. + loss_segm (:obj:`ConfigDict` or dict, optional): Config of + semantic segmentation loss. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config + of head. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + head. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. """ - def __init__(self, - num_classes, - in_channels=256, - proto_channels=(256, 256, 256, None, 256, 32), - proto_kernel_sizes=(3, 3, 3, -2, 3, 1), - include_last_relu=True, - num_protos=32, - loss_mask_weight=1.0, - max_masks_to_train=100, - init_cfg=dict( - type='Xavier', - distribution='uniform', - override=dict(name='protonet'))): - super(YOLACTProtonet, self).__init__(init_cfg) + def __init__( + self, + num_classes: int, + in_channels: int = 256, + proto_channels: tuple = (256, 256, 256, None, 256, 32), + proto_kernel_sizes: tuple = (3, 3, 3, -2, 3, 1), + include_last_relu: bool = True, + num_protos: int = 32, + loss_mask_weight: float = 1.0, + max_masks_to_train: int = 100, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + with_seg_branch: bool = True, + loss_segm: ConfigType = dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + init_cfg=dict( + type='Xavier', + distribution='uniform', + override=dict(name='protonet')) + ) -> None: + super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.proto_channels = proto_channels self.proto_kernel_sizes = proto_kernel_sizes self.include_last_relu = include_last_relu - self.protonet = self._init_layers() + + # Segmentation branch + self.with_seg_branch = with_seg_branch + self.segm_branch = SegmentationModule( + num_classes=num_classes, in_channels=in_channels) \ + if with_seg_branch else None + self.loss_segm = MODELS.build(loss_segm) if with_seg_branch else None self.loss_mask_weight = loss_mask_weight self.num_protos = num_protos self.num_classes = num_classes self.max_masks_to_train = max_masks_to_train - self.fp16_enabled = False + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self._init_layers() - def _init_layers(self): - """A helper function to take a config setting and turn it into a - network.""" + def _init_layers(self) -> None: + """Initialize layers of the head.""" # Possible patterns: # ( 256, 3) -> conv # ( 256,-2) -> deconv @@ -655,146 +690,118 @@ def _init_layers(self): else in_channels if not self.include_last_relu: protonets = protonets[:-1] - return nn.Sequential(*protonets) - - def forward_dummy(self, x): - prototypes = self.protonet(x) - return prototypes + self.protonet = nn.Sequential(*protonets) - def forward(self, x, coeff_pred, bboxes, img_meta, sampling_results=None): + def forward(self, x: tuple, positive_infos: InstanceList) -> tuple: """Forward feature from the upstream network to get prototypes and linearly combine the prototypes, using masks coefficients, into instance masks. Finally, crop the instance masks with given bboxes. Args: - x (Tensor): Feature from the upstream network, which is + x (Tuple[Tensor]): Feature from the upstream network, which is a 4D-tensor. - coeff_pred (list[Tensor]): Mask coefficients for each scale - level with shape (N, num_anchors * num_protos, H, W). - bboxes (list[Tensor]): Box used for cropping with shape - (N, num_anchors * 4, H, W). During training, they are - ground truth boxes. During testing, they are predicted - boxes. - img_meta (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - sampling_results (List[:obj:``SamplingResult``]): Sampler results - for each image. + positive_infos (List[:obj:``InstanceData``]): Positive information + that calculate from detect head. Returns: - list[Tensor]: Predicted instance segmentation masks. + tuple: Predicted instance segmentation masks and + semantic segmentation map. """ - prototypes = self.protonet(x) - prototypes = prototypes.permute(0, 2, 3, 1).contiguous() - - num_imgs = x.size(0) + # YOLACT used single feature map to get segmentation masks + single_x = x[0] - # The reason for not using self.training is that - # val workflow will have a dimension mismatch error. - # Note that this writing method is very tricky. - # Fix https://github.com/open-mmlab/mmdetection/issues/5978 - is_train_or_val_workflow = (coeff_pred[0].dim() == 4) + # YOLACT segmentation branch, if not training or segmentation branch + # is None, will not process the forward function. + if self.segm_branch is not None and self.training: + segm_preds = self.segm_branch(single_x) + else: + segm_preds = None + # YOLACT mask head + prototypes = self.protonet(single_x) + prototypes = prototypes.permute(0, 2, 3, 1).contiguous() - # Train or val workflow - if is_train_or_val_workflow: - coeff_pred_list = [] - for coeff_pred_per_level in coeff_pred: - coeff_pred_per_level = \ - coeff_pred_per_level.permute( - 0, 2, 3, 1).reshape(num_imgs, -1, self.num_protos) - coeff_pred_list.append(coeff_pred_per_level) - coeff_pred = torch.cat(coeff_pred_list, dim=1) + num_imgs = single_x.size(0) mask_pred_list = [] for idx in range(num_imgs): cur_prototypes = prototypes[idx] - cur_coeff_pred = coeff_pred[idx] - cur_bboxes = bboxes[idx] - cur_img_meta = img_meta[idx] - - # Testing state - if not is_train_or_val_workflow: - bboxes_for_cropping = cur_bboxes - else: - cur_sampling_results = sampling_results[idx] - pos_assigned_gt_inds = \ - cur_sampling_results.pos_assigned_gt_inds - bboxes_for_cropping = cur_bboxes[pos_assigned_gt_inds].clone() - pos_inds = cur_sampling_results.pos_inds - cur_coeff_pred = cur_coeff_pred[pos_inds] + pos_coeffs = positive_infos[idx].coeffs # Linearly combine the prototypes with the mask coefficients - mask_pred = cur_prototypes @ cur_coeff_pred.t() + mask_pred = cur_prototypes @ pos_coeffs.t() mask_pred = torch.sigmoid(mask_pred) - - h, w = cur_img_meta['img_shape'][:2] - bboxes_for_cropping[:, 0] /= w - bboxes_for_cropping[:, 1] /= h - bboxes_for_cropping[:, 2] /= w - bboxes_for_cropping[:, 3] /= h - - mask_pred = self.crop(mask_pred, bboxes_for_cropping) - mask_pred = mask_pred.permute(2, 0, 1).contiguous() mask_pred_list.append(mask_pred) - return mask_pred_list + return mask_pred_list, segm_preds - def loss(self, mask_pred, gt_masks, gt_bboxes, img_meta, sampling_results): - """Compute loss of the head. + def loss_by_feat(self, mask_preds: List[Tensor], segm_preds: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], positive_infos: InstanceList, + **kwargs) -> dict: + """Calculate the loss based on the features extracted by the mask head. Args: - mask_pred (list[Tensor]): Predicted prototypes with shape - (num_classes, H, W). - gt_masks (list[Tensor]): Ground truth masks for each image with - the same shape of the input image. - gt_bboxes (list[Tensor]): Ground truth bboxes for each image with - shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. - img_meta (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - sampling_results (List[:obj:``SamplingResult``]): Sampler results - for each image. + mask_preds (list[Tensor]): List of predicted prototypes, each has + shape (num_classes, H, W). + segm_preds (Tensor): Predicted semantic segmentation map with + shape (N, num_classes, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``masks``, + and ``labels`` attributes. + batch_img_metas (list[dict]): Meta information of multiple images. + positive_infos (List[:obj:``InstanceData``]): Information of + positive samples of each image that are assigned in detection + head. Returns: dict[str, Tensor]: A dictionary of loss components. """ + assert positive_infos is not None, \ + 'positive_infos should not be None in `YOLACTProtonet`' + losses = dict() + + # crop + croped_mask_pred = self.crop_mask_preds(mask_preds, batch_img_metas, + positive_infos) + loss_mask = [] - num_imgs = len(mask_pred) + loss_segm = [] + num_imgs, _, mask_h, mask_w = segm_preds.size() + assert num_imgs == len(croped_mask_pred) + segm_avg_factor = num_imgs * mask_h * mask_w total_pos = 0 + + if self.segm_branch is not None: + assert segm_preds is not None + for idx in range(num_imgs): - cur_mask_pred = mask_pred[idx] - cur_gt_masks = gt_masks[idx].float() - cur_gt_bboxes = gt_bboxes[idx] - cur_img_meta = img_meta[idx] - cur_sampling_results = sampling_results[idx] - - pos_assigned_gt_inds = cur_sampling_results.pos_assigned_gt_inds - num_pos = pos_assigned_gt_inds.size(0) - # Since we're producing (near) full image masks, - # it'd take too much vram to backprop on every single mask. - # Thus we select only a subset. - if num_pos > self.max_masks_to_train: - perm = torch.randperm(num_pos) - select = perm[:self.max_masks_to_train] - cur_mask_pred = cur_mask_pred[select] - pos_assigned_gt_inds = pos_assigned_gt_inds[select] - num_pos = self.max_masks_to_train - total_pos += num_pos + img_meta = batch_img_metas[idx] - gt_bboxes_for_reweight = cur_gt_bboxes[pos_assigned_gt_inds] + (mask_pred, pos_mask_targets, segm_targets, num_pos, + gt_bboxes_for_reweight) = self._get_targets_single( + croped_mask_pred[idx], segm_preds[idx], + batch_gt_instances[idx], positive_infos[idx]) - mask_targets = self.get_targets(cur_mask_pred, cur_gt_masks, - pos_assigned_gt_inds) - if num_pos == 0: - loss = cur_mask_pred.sum() * 0. - elif mask_targets is None: - loss = F.binary_cross_entropy(cur_mask_pred, - torch.zeros_like(cur_mask_pred), - torch.zeros_like(cur_mask_pred)) + # segmentation loss + if self.with_seg_branch: + if segm_targets is None: + loss = segm_preds[idx].sum() * 0. + else: + loss = self.loss_segm( + segm_preds[idx], + segm_targets, + avg_factor=segm_avg_factor) + loss_segm.append(loss) + # mask loss + total_pos += num_pos + if num_pos == 0 or pos_mask_targets is None: + loss = mask_pred.sum() * 0. else: - cur_mask_pred = torch.clamp(cur_mask_pred, 0, 1) + mask_pred = torch.clamp(mask_pred, 0, 1) loss = F.binary_cross_entropy( - cur_mask_pred, mask_targets, + mask_pred, pos_mask_targets, reduction='none') * self.loss_mask_weight - h, w = cur_img_meta['img_shape'][:2] + h, w = img_meta['img_shape'][:2] gt_bboxes_width = (gt_bboxes_for_reweight[:, 2] - gt_bboxes_for_reweight[:, 0]) / w gt_bboxes_height = (gt_bboxes_for_reweight[:, 3] - @@ -808,76 +815,141 @@ def loss(self, mask_pred, gt_masks, gt_bboxes, img_meta, sampling_results): total_pos += 1 # avoid nan loss_mask = [x / total_pos for x in loss_mask] - return dict(loss_mask=loss_mask) + losses.update(loss_mask=loss_mask) + if self.with_seg_branch: + losses.update(loss_segm=loss_segm) + + return losses - def get_targets(self, mask_pred, gt_masks, pos_assigned_gt_inds): - """Compute instance segmentation targets for each image. + def _get_targets_single(self, mask_pred: Tensor, segm_pred: Tensor, + gt_instances: InstanceData, + positive_info: InstanceData): + """Compute targets for predictions of single image. Args: mask_pred (Tensor): Predicted prototypes with shape (num_classes, H, W). - gt_masks (Tensor): Ground truth masks for each image with - the same shape of the input image. - pos_assigned_gt_inds (Tensor): GT indices of the corresponding - positive samples. + segm_pred (Tensor): Predicted semantic segmentation map + with shape (num_classes, H, W). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes``, ``labels``, + and ``masks`` attributes. + positive_info (:obj:`InstanceData`): Information of positive + samples that are assigned in detection head. It usually + contains following keys. + + - pos_assigned_gt_inds (Tensor): Assigner GT indexes of + positive proposals, has shape (num_pos, ) + - pos_inds (Tensor): Positive index of image, has + shape (num_pos, ). + - coeffs (Tensor): Positive mask coefficients + with shape (num_pos, num_protos). + - bboxes (Tensor): Positive bboxes with shape + (num_pos, 4) + Returns: - Tensor: Instance segmentation targets with shape - (num_instances, H, W). + tuple: Usually returns a tuple containing learning targets. + + - mask_pred (Tensor): Positive predicted mask with shape + (num_pos, mask_h, mask_w). + - pos_mask_targets (Tensor): Positive mask targets with shape + (num_pos, mask_h, mask_w). + - segm_targets (Tensor): Semantic segmentation targets with shape + (num_classes, segm_h, segm_w). + - num_pos (int): Positive numbers. + - gt_bboxes_for_reweight (Tensor): GT bboxes that match to the + positive priors has shape (num_pos, 4). """ + gt_bboxes = gt_instances.bboxes + gt_labels = gt_instances.labels + device = gt_bboxes.device + gt_masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=device).float() if gt_masks.size(0) == 0: - return None + return mask_pred, None, None, 0, None + + # process with semantic segmentation targets + if segm_pred is not None: + num_classes, segm_h, segm_w = segm_pred.size() + with torch.no_grad(): + downsampled_masks = F.interpolate( + gt_masks.unsqueeze(0), (segm_h, segm_w), + mode='bilinear', + align_corners=False).squeeze(0) + downsampled_masks = downsampled_masks.gt(0.5).float() + segm_targets = torch.zeros_like(segm_pred, requires_grad=False) + for obj_idx in range(downsampled_masks.size(0)): + segm_targets[gt_labels[obj_idx] - 1] = torch.max( + segm_targets[gt_labels[obj_idx] - 1], + downsampled_masks[obj_idx]) + else: + segm_targets = None + # process with mask targets + pos_assigned_gt_inds = positive_info.pos_assigned_gt_inds + num_pos = pos_assigned_gt_inds.size(0) + # Since we're producing (near) full image masks, + # it'd take too much vram to backprop on every single mask. + # Thus we select only a subset. + if num_pos > self.max_masks_to_train: + perm = torch.randperm(num_pos) + select = perm[:self.max_masks_to_train] + mask_pred = mask_pred[select] + pos_assigned_gt_inds = pos_assigned_gt_inds[select] + num_pos = self.max_masks_to_train + + gt_bboxes_for_reweight = gt_bboxes[pos_assigned_gt_inds] + mask_h, mask_w = mask_pred.shape[-2:] gt_masks = F.interpolate( gt_masks.unsqueeze(0), (mask_h, mask_w), mode='bilinear', align_corners=False).squeeze(0) gt_masks = gt_masks.gt(0.5).float() - mask_targets = gt_masks[pos_assigned_gt_inds] - return mask_targets + pos_mask_targets = gt_masks[pos_assigned_gt_inds] + + return (mask_pred, pos_mask_targets, segm_targets, num_pos, + gt_bboxes_for_reweight) - def get_seg_masks(self, mask_pred, label_pred, img_meta, rescale): - """Resize, binarize, and format the instance mask predictions. + def crop_mask_preds(self, mask_preds: List[Tensor], + batch_img_metas: List[dict], + positive_infos: InstanceList) -> list: + """Crop predicted masks by zeroing out everything not in the predicted + bbox. Args: - mask_pred (Tensor): shape (N, H, W). - label_pred (Tensor): shape (N, ). - img_meta (dict): Meta information of each image, e.g., - image size, scaling factor, etc. - rescale (bool): If rescale is False, then returned masks will - fit the scale of imgs[0]. + mask_preds (list[Tensor]): Predicted prototypes with shape + (num_classes, H, W). + batch_img_metas (list[dict]): Meta information of multiple images. + positive_infos (List[:obj:``InstanceData``]): Positive + information that calculate from detect head. + Returns: - list[ndarray]: Mask predictions grouped by their predicted classes. + list: The cropped masks. """ - ori_shape = img_meta['ori_shape'] - scale_factor = img_meta['scale_factor'] - if rescale: - img_h, img_w = ori_shape[:2] - else: - img_h = np.round(ori_shape[0] * scale_factor[1]).astype(np.int32) - img_w = np.round(ori_shape[1] * scale_factor[0]).astype(np.int32) - - cls_segms = [[] for _ in range(self.num_classes)] - if mask_pred.size(0) == 0: - return cls_segms - - mask_pred = F.interpolate( - mask_pred.unsqueeze(0), (img_h, img_w), - mode='bilinear', - align_corners=False).squeeze(0) > 0.5 - mask_pred = mask_pred.cpu().numpy().astype(np.uint8) - - for m, l in zip(mask_pred, label_pred): - cls_segms[l].append(m) - return cls_segms + croped_mask_preds = [] + for img_meta, mask_pred, cur_info in zip(batch_img_metas, mask_preds, + positive_infos): + bboxes_for_cropping = copy.deepcopy(cur_info.bboxes) + h, w = img_meta['img_shape'][:2] + bboxes_for_cropping[:, 0::2] /= w + bboxes_for_cropping[:, 1::2] /= h + mask_pred = self.crop_single(mask_pred, bboxes_for_cropping) + mask_pred = mask_pred.permute(2, 0, 1).contiguous() + croped_mask_preds.append(mask_pred) + return croped_mask_preds - def crop(self, masks, boxes, padding=1): - """Crop predicted masks by zeroing out everything not in the predicted - bbox. + def crop_single(self, + masks: Tensor, + boxes: Tensor, + padding: int = 1) -> Tensor: + """Crop single predicted masks by zeroing out everything not in the + predicted bbox. Args: - masks (Tensor): shape [H, W, N]. - boxes (Tensor): bbox coords in relative point form with + masks (Tensor): Predicted prototypes, has shape [H, W, N]. + boxes (Tensor): Bbox coords in relative point form with shape [N, 4]. + padding (int): Image padding size. Return: Tensor: The cropped masks. @@ -904,7 +976,12 @@ def crop(self, masks, boxes, padding=1): return masks * crop_mask.float() - def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True): + def sanitize_coordinates(self, + x1: Tensor, + x2: Tensor, + img_size: int, + padding: int = 0, + cast: bool = True) -> tuple: """Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0, and x2 <= image_size. Also converts from relative to absolute coordinates and casts the results to long tensors. @@ -913,16 +990,17 @@ def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True): copy if necessary. Args: - _x1 (Tensor): shape (N, ). - _x2 (Tensor): shape (N, ). + x1 (Tensor): shape (N, ). + x2 (Tensor): shape (N, ). img_size (int): Size of the input image. padding (int): x1 >= padding, x2 <= image_size-padding. cast (bool): If cast is false, the result won't be cast to longs. Returns: tuple: - x1 (Tensor): Sanitized _x1. - x2 (Tensor): Sanitized _x2. + + - x1 (Tensor): Sanitized _x1. + - x2 (Tensor): Sanitized _x2. """ x1 = x1 * img_size x2 = x2 * img_size @@ -935,67 +1013,160 @@ def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True): x2 = torch.clamp(x2 + padding, max=img_size) return x1, x2 - def simple_test(self, - feats, - det_bboxes, - det_labels, - det_coeffs, - img_metas, - rescale=False): - """Test function without test-time augmentation. + def predict_by_feat(self, + mask_preds: List[Tensor], + segm_preds: Tensor, + results_list: InstanceList, + batch_img_metas: List[dict], + rescale: bool = True, + **kwargs) -> InstanceList: + """Transform a batch of output features extracted from the head into + mask results. Args: - feats (tuple[torch.Tensor]): Multi-level features from the - upstream network, each is a 4D-tensor. - det_bboxes (list[Tensor]): BBox results of each image. each - element is (n, 5) tensor, where 5 represent - (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. - det_labels (list[Tensor]): BBox results of each image. each - element is (n, ) tensor, each element represents the class - label of the corresponding box. - det_coeffs (list[Tensor]): BBox coefficient of each image. each - element is (n, m) tensor, m is vector length. - img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. + mask_preds (list[Tensor]): Predicted prototypes with shape + (num_classes, H, W). + results_list (List[:obj:``InstanceData``]): BBoxHead results. + batch_img_metas (list[dict]): Meta information of all images. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: - list[list]: encoded masks. The c-th item in the outer list - corresponds to the c-th class. Given the c-th outer list, the - i-th item in that inner list is the mask for the i-th box with - class label c. + list[:obj:`InstanceData`]: Processed results of multiple + images.Each :obj:`InstanceData` usually contains + following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + assert len(mask_preds) == len(results_list) == len(batch_img_metas) + + croped_mask_pred = self.crop_mask_preds(mask_preds, batch_img_metas, + results_list) + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + results = results_list[img_id] + bboxes = results.bboxes + mask_pred = croped_mask_pred[img_id] + if bboxes.shape[0] == 0 or mask_pred.shape[0] == 0: + results_list[img_id] = empty_instances( + [img_meta], + bboxes.device, + task_type='mask', + instance_results=[results])[0] + else: + im_mask = self._predict_by_feat_single( + mask_pred=croped_mask_pred[img_id], + bboxes=bboxes, + img_meta=img_meta, + rescale=rescale) + results.masks = im_mask + return results_list + + def _predict_by_feat_single(self, + mask_pred: Tensor, + bboxes: Tensor, + img_meta: dict, + rescale: bool, + cfg: OptConfigType = None): + """Transform a single image's features extracted from the head into + mask results. + + Args: + mask_pred (Tensor): Predicted prototypes, has shape [H, W, N]. + bboxes (Tensor): Bbox coords in relative point form with + shape [N, 4]. + img_meta (dict): Meta information of each image, e.g., + image size, scaling factor, etc. + rescale (bool): If rescale is False, then returned masks will + fit the scale of imgs[0]. + cfg (dict, optional): Config used in test phase. + Defaults to None. + + Returns: + :obj:`InstanceData`: Processed results of single image. + it usually contains following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). """ - num_imgs = len(img_metas) - scale_factors = tuple(meta['scale_factor'] for meta in img_metas) - if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): - segm_results = [[[] for _ in range(self.num_classes)] - for _ in range(num_imgs)] + cfg = self.test_cfg if cfg is None else cfg + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + img_h, img_w = img_meta['ori_shape'][:2] + if rescale: # in-placed rescale the bboxes + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + bboxes /= scale_factor else: - # if det_bboxes is rescaled to the original image size, we need to - # rescale it back to the testing scale to obtain RoIs. - if rescale and not isinstance(scale_factors[0], float): - scale_factors = [ - torch.from_numpy(scale_factor).to(det_bboxes[0].device) - for scale_factor in scale_factors - ] - _bboxes = [ - det_bboxes[i][:, :4] * - scale_factors[i] if rescale else det_bboxes[i][:, :4] - for i in range(len(det_bboxes)) - ] - mask_preds = self.forward(feats[0], det_coeffs, _bboxes, img_metas) - # apply mask post-processing to each image individually - segm_results = [] - for i in range(num_imgs): - if det_bboxes[i].shape[0] == 0: - segm_results.append([[] for _ in range(self.num_classes)]) - else: - segm_result = self.get_seg_masks(mask_preds[i], - det_labels[i], - img_metas[i], rescale) - segm_results.append(segm_result) - return segm_results + w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1] + img_h = np.round(img_h * h_scale.item()).astype(np.int32) + img_w = np.round(img_w * w_scale.item()).astype(np.int32) + + masks = F.interpolate( + mask_pred.unsqueeze(0), (img_h, img_w), + mode='bilinear', + align_corners=False).squeeze(0) > cfg.mask_thr + + if cfg.mask_thr_binary < 0: + # for visualization and debugging + masks = (masks * 255).to(dtype=torch.uint8) + + return masks + + +class SegmentationModule(BaseModule): + """YOLACT segmentation branch used in `_ + + In mmdet v2.x `segm_loss` is calculated in YOLACTSegmHead, while in + mmdet v3.x `SegmentationModule` is used to obtain the predicted semantic + segmentation map and `segm_loss` is calculated in YOLACTProtonet. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + num_classes: int, + in_channels: int = 256, + init_cfg: ConfigType = dict( + type='Xavier', + distribution='uniform', + override=dict(name='segm_conv')) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.num_classes = num_classes + self._init_layers() + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self.segm_conv = nn.Conv2d( + self.in_channels, self.num_classes, kernel_size=1) + + def forward(self, x: Tensor) -> Tensor: + """Forward feature from the upstream network. + + Args: + x (Tensor): Feature from the upstream network, which is + a 4D-tensor. + + Returns: + Tensor: Predicted semantic segmentation map with shape + (N, num_classes, H, W). + """ + return self.segm_conv(x) class InterpolateModule(BaseModule): @@ -1004,12 +1175,19 @@ class InterpolateModule(BaseModule): Any arguments you give it just get passed along for the ride. """ - def __init__(self, *args, init_cfg=None, **kwargs): - super().__init__(init_cfg) - + def __init__(self, *args, init_cfg=None, **kwargs) -> None: + super().__init__(init_cfg=init_cfg) self.args = args self.kwargs = kwargs - def forward(self, x): - """Forward features from the upstream network.""" + def forward(self, x: Tensor) -> Tensor: + """Forward features from the upstream network. + + Args: + x (Tensor): Feature from the upstream network, which is + a 4D-tensor. + + Returns: + Tensor: A 4D-tensor feature map. + """ return F.interpolate(x, *self.args, **self.kwargs) diff --git a/mmdet/models/detectors/single_stage_instance_seg.py b/mmdet/models/detectors/single_stage_instance_seg.py index 26b72605792..c89c0f1b562 100644 --- a/mmdet/models/detectors/single_stage_instance_seg.py +++ b/mmdet/models/detectors/single_stage_instance_seg.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import List, Tuple +from typing import Tuple from torch import Tensor -from mmdet.data_elements import SampleList +from mmdet.data_elements import OptSampleList, SampleList from mmdet.registry import MODELS from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig from .base import BaseDetector @@ -63,8 +63,10 @@ def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: x = self.neck(x) return x - def _forward(self, batch_inputs: Tensor, *args, **kwargs) \ - -> Tuple[List[Tensor]]: + def _forward(self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None, + **kwargs) -> tuple: """Network forward process. Usually includes backbone, neck and head forward without any post-processing. @@ -72,17 +74,27 @@ def _forward(self, batch_inputs: Tensor, *args, **kwargs) \ batch_inputs (Tensor): Inputs with shape (N, C, H, W). Returns: - tuple[list]: A tuple of features from ``bbox_head`` forward. + tuple: A tuple of features from ``bbox_head`` forward. """ outs = () # backbone x = self.extract_feat(batch_inputs) # bbox_head + positive_infos = None if self.with_bbox: - # TODO: current not supported - pass + assert batch_data_samples is not None + bbox_outs = self.bbox_head.forward(x) + outs = outs + (bbox_outs, ) + # It is necessary to use `bbox_head.loss` to update + # `_raw_positive_infos` which will be used in `get_positive_infos` + # positive_infos will be used in the following mask head. + _ = self.bbox_head.loss(x, batch_data_samples, **kwargs) + positive_infos = self.bbox_head.get_positive_infos() # mask_head - mask_outs = self.mask_head.forward(x) + if positive_infos is None: + mask_outs = self.mask_head.forward(x) + else: + mask_outs = self.mask_head.forward(x, positive_infos) outs = outs + (mask_outs, ) return outs @@ -97,20 +109,19 @@ def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList, as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Returns: - dict[str, Tensor]: A dictionary of loss components. + dict: A dictionary of loss components. """ x = self.extract_feat(batch_inputs) losses = dict() - # TODO: Check the logic in CondInst and YOLACT positive_infos = None # CondInst and YOLACT have bbox_head if self.with_bbox: bbox_losses = self.bbox_head.loss(x, batch_data_samples, **kwargs) - # TODO: enhance the logic when refactor YOLACT - if bbox_losses.get('positive_infos', None) is not None: - positive_infos = bbox_losses.pop('positive_infos') losses.update(bbox_losses) + # get positive information from bbox head, which will be used + # in the following mask head. + positive_infos = self.bbox_head.get_positive_infos() mask_loss = self.mask_head.loss( x, batch_data_samples, positive_infos=positive_infos, **kwargs) @@ -123,7 +134,7 @@ def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList, def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, - rescale: bool = False, + rescale: bool = True, **kwargs) -> SampleList: """Perform forward propagation of the mask head and predict mask results on the features of the upstream network. @@ -142,27 +153,27 @@ def predict(self, 'pred_instances'. And the ``pred_instances`` usually contains following keys. - - scores (Tensor): Classification scores, has a shape - (num_instance, ) - - labels (Tensor): Labels of bboxes, has a shape - (num_instances, ). - - bboxes (Tensor): Has a shape (num_instances, 4), - the last dimension 4 arrange as (x1, y1, x2, y2). - - masks (Tensor): Has a shape (num_instances, H, W). + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). """ x = self.extract_feat(batch_inputs) if self.with_bbox: - # TODO: currently not checked - bbox_results_list = self.bbox_head.predict( - x, batch_data_samples, rescale=rescale) + # the bbox branch does not need to be scaled to the original + # image scale, because the mask branch will scale both bbox + # and mask at the same time. + bbox_rescale = rescale if not self.with_mask else False + results_list = self.bbox_head.predict( + x, batch_data_samples, rescale=bbox_rescale) else: - bbox_results_list = None + results_list = None results_list = self.mask_head.predict( - x, - batch_data_samples, - rescale=rescale, - bbox_results_list=bbox_results_list) + x, batch_data_samples, rescale=rescale, results_list=results_list) # connvert to DetDataSample results_list = self.convert_to_datasample(results_list) diff --git a/mmdet/models/detectors/yolact.py b/mmdet/models/detectors/yolact.py index 139573931ea..116638266f8 100644 --- a/mmdet/models/detectors/yolact.py +++ b/mmdet/models/detectors/yolact.py @@ -1,120 +1,29 @@ # Copyright (c) OpenMMLab. All rights reserved. -import torch -from mmdet.data_elements.bbox import bbox2result from mmdet.registry import MODELS -from .single_stage import SingleStageDetector +from mmdet.utils.typing import ConfigType, OptConfigType, OptMultiConfig +from .single_stage_instance_seg import SingleStageInstanceSegmentor @MODELS.register_module() -class YOLACT(SingleStageDetector): +class YOLACT(SingleStageInstanceSegmentor): """Implementation of `YOLACT `_""" def __init__(self, - backbone, - neck, - bbox_head, - segm_head, - mask_head, - train_cfg=None, - test_cfg=None, - pretrained=None, - init_cfg=None): - super(YOLACT, self).__init__(backbone, neck, bbox_head, train_cfg, - test_cfg, pretrained, init_cfg) - self.segm_head = MODELS.build(segm_head) - self.mask_head = MODELS.build(mask_head) - - def forward_dummy(self, img): - """Used for computing network flops. - - See `mmdetection/tools/analysis_tools/get_flops.py` - """ - feat = self.extract_feat(img) - bbox_outs = self.bbox_head(feat) - prototypes = self.mask_head.forward_dummy(feat[0]) - return (bbox_outs, prototypes) - - def forward_train(self, - img, - img_metas, - gt_bboxes, - gt_labels, - gt_bboxes_ignore=None, - gt_masks=None): - """ - Args: - img (Tensor): of shape (N, C, H, W) encoding input images. - Typically these should be mean centered and std scaled. - img_metas (list[dict]): list of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys see - `mmdet/datasets/pipelines/formatting.py:Collect`. - gt_bboxes (list[Tensor]): Ground truth bboxes for each image with - shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. - gt_labels (list[Tensor]): class indices corresponding to each box - gt_bboxes_ignore (None | list[Tensor]): specify which bounding - boxes can be ignored when computing the loss. - gt_masks (None | Tensor) : true segmentation masks for each box - used if the architecture supports a segmentation task. - - Returns: - dict[str, Tensor]: a dictionary of loss components - """ - # convert Bitmap mask or Polygon Mask to Tensor here - gt_masks = [ - gt_mask.to_tensor(dtype=torch.uint8, device=img.device) - for gt_mask in gt_masks - ] - - x = self.extract_feat(img) - - cls_score, bbox_pred, coeff_pred = self.bbox_head(x) - bbox_head_loss_inputs = (cls_score, bbox_pred) + (gt_bboxes, gt_labels, - img_metas) - losses, sampling_results = self.bbox_head.loss( - *bbox_head_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) - - segm_head_outs = self.segm_head(x[0]) - loss_segm = self.segm_head.loss(segm_head_outs, gt_masks, gt_labels) - losses.update(loss_segm) - - mask_pred = self.mask_head(x[0], coeff_pred, gt_bboxes, img_metas, - sampling_results) - loss_mask = self.mask_head.loss(mask_pred, gt_masks, gt_bboxes, - img_metas, sampling_results) - losses.update(loss_mask) - - # check NaN and Inf - for loss_name in losses.keys(): - assert torch.isfinite(torch.stack(losses[loss_name]))\ - .all().item(), '{} becomes infinite or NaN!'\ - .format(loss_name) - - return losses - - def simple_test(self, img, img_metas, rescale=False): - """Test function without test-time augmentation.""" - feat = self.extract_feat(img) - det_bboxes, det_labels, det_coeffs = self.bbox_head.simple_test( - feat, img_metas, rescale=rescale) - bbox_results = [ - bbox2result(det_bbox, det_label, self.bbox_head.num_classes) - for det_bbox, det_label in zip(det_bboxes, det_labels) - ] - - segm_results = self.mask_head.simple_test( - feat, - det_bboxes, - det_labels, - det_coeffs, - img_metas, - rescale=rescale) - - return list(zip(bbox_results, segm_results)) - - def aug_test(self, imgs, img_metas, rescale=False): - """Test with augmentations.""" - raise NotImplementedError( - 'YOLACT does not support test-time augmentation') + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + mask_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) diff --git a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py index 5de5f8a81ba..723929cfe4d 100644 --- a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py +++ b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py @@ -337,7 +337,7 @@ def _predict_by_feat_single(self, # In AugTest, has been activated before mask_pred = bboxes.new_tensor(mask_pred) - if rescale: + if rescale: # in-placed rescale the bboxes bboxes /= scale_factor else: w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1] diff --git a/tests/test_models/test_detectors/test_single_stage_instance_seg.py b/tests/test_models/test_detectors/test_single_stage_instance_seg.py index 8b3e49fe773..8398019ae15 100644 --- a/tests/test_models/test_detectors/test_single_stage_instance_seg.py +++ b/tests/test_models/test_detectors/test_single_stage_instance_seg.py @@ -20,7 +20,8 @@ def setUp(self): 'solo/decoupled_solo_r50_fpn_1x_coco.py', 'solo/decoupled_solo_light_r50_fpn_3x_coco.py', 'solov2/solov2_r50_fpn_1x_coco.py', - 'solov2/solov2_light_r18_fpn_mstrain_3x_coco.py' + 'solov2/solov2_light_r18_fpn_mstrain_3x_coco.py', + 'yolact/yolact_r50_1x8_coco.py', ]) def test_init(self, cfg_file): model = get_detector_cfg(cfg_file) @@ -31,6 +32,8 @@ def test_init(self, cfg_file): self.assertTrue(detector.backbone) self.assertTrue(detector.neck) self.assertTrue(detector.mask_head) + if detector.with_bbox: + self.assertTrue(detector.bbox_head) @parameterized.expand([ ('solo/solo_r50_fpn_1x_coco.py', ('cpu', 'cuda')), @@ -38,6 +41,7 @@ def test_init(self, cfg_file): ('solo/decoupled_solo_light_r50_fpn_3x_coco.py', ('cpu', 'cuda')), ('solov2/solov2_r50_fpn_1x_coco.py', ('cpu', 'cuda')), ('solov2/solov2_light_r18_fpn_mstrain_3x_coco.py', ('cpu', 'cuda')), + ('yolact/yolact_r50_1x8_coco.py', ('cpu', 'cuda')), ]) def test_single_stage_forward_loss_mode(self, cfg_file, devices): model = get_detector_cfg(cfg_file) @@ -71,6 +75,7 @@ def test_single_stage_forward_loss_mode(self, cfg_file, devices): ('solo/decoupled_solo_light_r50_fpn_3x_coco.py', ('cpu', 'cuda')), ('solov2/solov2_r50_fpn_1x_coco.py', ('cpu', 'cuda')), ('solov2/solov2_light_r18_fpn_mstrain_3x_coco.py', ('cpu', 'cuda')), + ('yolact/yolact_r50_1x8_coco.py', ('cpu', 'cuda')), ]) def test_single_stage_forward_predict_mode(self, cfg_file, devices): model = get_detector_cfg(cfg_file) @@ -101,3 +106,33 @@ def test_single_stage_forward_predict_mode(self, cfg_file, devices): batch_inputs, data_samples, mode='predict') self.assertEqual(len(batch_results), 2) self.assertIsInstance(batch_results[0], DetDataSample) + + @parameterized.expand([ + ('solo/solo_r50_fpn_1x_coco.py', ('cpu', 'cuda')), + ('solo/decoupled_solo_r50_fpn_1x_coco.py', ('cpu', 'cuda')), + ('solo/decoupled_solo_light_r50_fpn_3x_coco.py', ('cpu', 'cuda')), + ('solov2/solov2_r50_fpn_1x_coco.py', ('cpu', 'cuda')), + ('solov2/solov2_light_r18_fpn_mstrain_3x_coco.py', ('cpu', 'cuda')), + ('yolact/yolact_r50_1x8_coco.py', ('cpu', 'cuda')), + ]) + def test_single_stage_forward_tensor_mode(self, cfg_file, devices): + model = get_detector_cfg(cfg_file) + model.backbone.init_cfg = None + + from mmdet.models import build_detector + assert all([device in ['cpu', 'cuda'] for device in devices]) + + for device in devices: + detector = build_detector(model) + + if device == 'cuda': + if not torch.cuda.is_available(): + return unittest.skip('test requires GPU and torch+cuda') + detector = detector.cuda() + + packed_inputs = demo_mm_inputs(2, [[3, 128, 128], [3, 125, 130]]) + batch_inputs, data_samples = detector.data_preprocessor( + packed_inputs, False) + batch_results = detector.forward( + batch_inputs, data_samples, mode='tensor') + self.assertIsInstance(batch_results, tuple)