Skip to content

Commit

Permalink
Refactor pisa_roi_head
Browse files Browse the repository at this point in the history
  • Loading branch information
jbwang1997 authored and ZwwWayne committed Jul 19, 2022
1 parent 950497e commit 5a2ef66
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 170 deletions.
87 changes: 66 additions & 21 deletions mmdet/models/roi_heads/bbox_heads/bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,19 @@ def __init__(self,
# TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
@property
def custom_cls_channels(self) -> bool:
"""get custom_cls_channels from loss_cls."""
return getattr(self.loss_cls, 'custom_cls_channels', False)

# TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
@property
def custom_activation(self) -> bool:
"""get custom_activation from loss_cls."""
return getattr(self.loss_cls, 'custom_activation', False)

# TODO: Create a SeasawBBoxHead to simplified logic in BBoxHead
@property
def custom_accuracy(self) -> bool:
"""get custom_accuracy from loss_cls."""
return getattr(self.loss_cls, 'custom_accuracy', False)

def forward(self, x: Tuple[Tensor]) -> tuple:
Expand Down Expand Up @@ -233,24 +236,24 @@ def get_targets(self,
Tuple[Tensor]: Ground truth for proposals in a single image.
Containing the following list of Tensors:
- labels (list[Tensor],Tensor): Gt_labels for all
proposals in a batch, each tensor in list has
shape (num_proposals,) when `concat=False`, otherwise
just a single tensor has shape (num_all_proposals,).
- label_weights (list[Tensor]): Labels_weights for
all proposals in a batch, each tensor in list has
shape (num_proposals,) when `concat=False`, otherwise
just a single tensor has shape (num_all_proposals,).
- bbox_targets (list[Tensor],Tensor): Regression target
for all proposals in a batch, each tensor in list
has shape (num_proposals, 4) when `concat=False`,
otherwise just a single tensor has shape
(num_all_proposals, 4), the last dimension 4 represents
[tl_x, tl_y, br_x, br_y].
- bbox_weights (list[tensor],Tensor): Regression weights for
all proposals in a batch, each tensor in list has shape
(num_proposals, 4) when `concat=False`, otherwise just a
single tensor has shape (num_all_proposals, 4).
- labels (list[Tensor],Tensor): Gt_labels for all
proposals in a batch, each tensor in list has
shape (num_proposals,) when `concat=False`, otherwise
just a single tensor has shape (num_all_proposals,).
- label_weights (list[Tensor]): Labels_weights for
all proposals in a batch, each tensor in list has
shape (num_proposals,) when `concat=False`, otherwise
just a single tensor has shape (num_all_proposals,).
- bbox_targets (list[Tensor],Tensor): Regression target
for all proposals in a batch, each tensor in list
has shape (num_proposals, 4) when `concat=False`,
otherwise just a single tensor has shape
(num_all_proposals, 4), the last dimension 4 represents
[tl_x, tl_y, br_x, br_y].
- bbox_weights (list[tensor],Tensor): Regression weights for
all proposals in a batch, each tensor in list has shape
(num_proposals, 4) when `concat=False`, otherwise just a
single tensor has shape (num_all_proposals, 4).
"""
pos_priors_list = [res.pos_priors for res in sampling_results]
neg_priors_list = [res.neg_priors for res in sampling_results]
Expand Down Expand Up @@ -305,7 +308,50 @@ def loss_and_target(self,
"""

cls_reg_targets = self.get_targets(sampling_results, rcnn_train_cfg)
(labels, label_weights, bbox_targets, bbox_weights) = cls_reg_targets
losses = self.loss(cls_score, bbox_pred, rois, *cls_reg_targets)

# cls_reg_targets is only for cascade rcnn
return dict(loss_bbox=losses, bbox_targets=cls_reg_targets)

def loss(self,
cls_score: Tensor,
bbox_pred: Tensor,
rois: Tensor,
labels: Tensor,
label_weights: Tensor,
bbox_targets: Tensor,
bbox_weights: Tensor,
reduction_override: Optional[str] = None) -> dict:
"""Calculate the loss based on the network predictions and targets.
Args:
cls_score (Tensor): Classification prediction
results of all class, has shape
(batch_size * num_proposals_single_image, num_classes)
bbox_pred (Tensor): Regression prediction results,
has shape
(batch_size * num_proposals_single_image, 4), the last
dimension 4 represents [tl_x, tl_y, br_x, br_y].
rois (Tensor): RoIs with the shape
(batch_size * num_proposals_single_image, 5) where the first
column indicates batch id of each RoI.
labels (Tensor): Gt_labels for all proposals in a batch, has
shape (batch_size * num_proposals_single_image, ).
label_weights (Tensor): Labels_weights for all proposals in a
batch, has shape (batch_size * num_proposals_single_image, ).
bbox_targets (Tensor): Regression target for all proposals in a
batch, has shape (batch_size * num_proposals_single_image, 4),
the last dimension 4 represents [tl_x, tl_y, br_x, br_y].
bbox_weights (Tensor): Regression weights for all proposals in a
batch, has shape (batch_size * num_proposals_single_image, 4).
reduction_override (str, optional): The reduction
method used to override the original reduction
method of the loss. Options are "none",
"mean" and "sum". Defaults to None,
Returns:
dict: A dictionary of loss.
"""

losses = dict()

Expand Down Expand Up @@ -356,8 +402,7 @@ def loss_and_target(self,
else:
losses['loss_bbox'] = bbox_pred[pos_inds].sum()

# cls_reg_targets is only for cascade rcnn
return dict(loss_bbox=losses, bbox_targets=cls_reg_targets)
return losses

def predict_by_feat(self,
rois: Tuple[Tensor],
Expand Down
13 changes: 7 additions & 6 deletions mmdet/models/roi_heads/dynamic_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,12 @@ def bbox_loss(self, x: Tuple[Tensor],
pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1)
num_pos = len(pos_inds)
num_imgs = len(sampling_results)
cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1)
beta_topk = min(self.train_cfg.dynamic_rcnn.beta_topk * num_imgs,
num_pos)
cur_target = torch.kthvalue(cur_target, beta_topk)[0].item()
self.beta_history.append(cur_target)
if num_pos > 0:
cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1)
beta_topk = min(self.train_cfg.dynamic_rcnn.beta_topk * num_imgs,
num_pos)
cur_target = torch.kthvalue(cur_target, beta_topk)[0].item()
self.beta_history.append(cur_target)

return bbox_results

Expand All @@ -151,7 +152,7 @@ def update_hyperparameters(self):
self.bbox_assigner.pos_iou_thr = new_iou_thr
self.bbox_assigner.neg_iou_thr = new_iou_thr
self.bbox_assigner.min_pos_iou = new_iou_thr
if (np.median(self.beta_history) < EPS):
if (not self.beta_history) or (np.median(self.beta_history) < EPS):
# avoid 0 or too small value for new_beta
new_beta = self.bbox_head.loss_bbox.beta
else:
Expand Down
148 changes: 68 additions & 80 deletions mmdet/models/roi_heads/pisa_roi_head.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

from torch import Tensor

from mmdet.data_elements import DetDataSample
from mmdet.data_elements.bbox import bbox2roi
from mmdet.models.task_modules import SamplingResult
from mmdet.registry import MODELS
from mmdet.utils import InstanceList
from ..losses.pisa_loss import carl_loss, isr_p
from ..utils import unpack_gt_instances
from .standard_roi_head import StandardRoIHead


Expand All @@ -10,107 +18,87 @@ class PISARoIHead(StandardRoIHead):
r"""The RoI head for `Prime Sample Attention in Object Detection
<https://arxiv.org/abs/1904.04821>`_."""

def forward_train(self,
x,
img_metas,
proposal_list,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None):
"""Forward function for training.
def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
batch_data_samples: List[DetDataSample]) -> dict:
"""Perform forward propagation and loss calculation of the detection
roi on the features of the upstream network.
Args:
x (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
proposals (list[Tensors]): List of region proposals.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes_ignore (list[Tensor], optional): Specify which bounding
boxes can be ignored when computing the loss.
gt_masks (None | Tensor) : True segmentation masks for each box
used if the architecture supports a segmentation task.
x (tuple[Tensor]): List of multi-level img features.
rpn_results_list (list[:obj:`InstanceData`]): List of region
proposals.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
dict[str, Tensor]: A dictionary of loss components
"""
assert len(rpn_results_list) == len(batch_data_samples)
outputs = unpack_gt_instances(batch_data_samples)
batch_gt_instances, batch_gt_instances_ignore, _ = outputs

# assign gts and sample proposals
if self.with_bbox or self.with_mask:
num_imgs = len(img_metas)
if gt_bboxes_ignore is None:
gt_bboxes_ignore = [None for _ in range(num_imgs)]
sampling_results = []
neg_label_weights = []
for i in range(num_imgs):
assign_result = self.bbox_assigner.assign(
proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
gt_labels[i])
sampling_result = self.bbox_sampler.sample(
assign_result,
proposal_list[i],
gt_bboxes[i],
gt_labels[i],
feats=[lvl_feat[i][None] for lvl_feat in x])
# neg label weight is obtained by sampling when using ISR-N
neg_label_weight = None
if isinstance(sampling_result, tuple):
sampling_result, neg_label_weight = sampling_result
sampling_results.append(sampling_result)
neg_label_weights.append(neg_label_weight)
num_imgs = len(batch_data_samples)
sampling_results = []
neg_label_weights = []
for i in range(num_imgs):
# rename rpn_results.bboxes to rpn_results.priors
rpn_results = rpn_results_list[i]
rpn_results.priors = rpn_results.pop('bboxes')

assign_result = self.bbox_assigner.assign(
rpn_results, batch_gt_instances[i],
batch_gt_instances_ignore[i])
sampling_result = self.bbox_sampler.sample(
assign_result,
rpn_results,
batch_gt_instances[i],
feats=[lvl_feat[i][None] for lvl_feat in x])
if isinstance(sampling_result, tuple):
sampling_result, neg_label_weight = sampling_result
sampling_results.append(sampling_result)
neg_label_weights.append(neg_label_weight)

losses = dict()
# bbox head forward and loss
if self.with_bbox:
bbox_results = self._bbox_forward_train(
x,
sampling_results,
gt_bboxes,
gt_labels,
img_metas,
neg_label_weights=neg_label_weights)
bbox_results = self.bbox_loss(
x, sampling_results, neg_label_weights=neg_label_weights)
losses.update(bbox_results['loss_bbox'])

# mask head forward and loss
if self.with_mask:
mask_results = self._mask_forward_train(x, sampling_results,
bbox_results['bbox_feats'],
gt_masks, img_metas)
mask_results = self.mask_loss(x, sampling_results,
bbox_results['bbox_feats'],
batch_gt_instances)
losses.update(mask_results['loss_mask'])

return losses

def _bbox_forward(self, x, rois):
"""Box forward function used in both training and testing."""
# TODO: a more flexible way to decide which feature maps to use
bbox_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)
cls_score, bbox_pred = self.bbox_head(bbox_feats)

bbox_results = dict(
cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
return bbox_results
def bbox_loss(self,
x: Tuple[Tensor],
sampling_results: List[SamplingResult],
neg_label_weights: List[Tensor] = None) -> dict:
"""Perform forward propagation and loss calculation of the bbox head on
the features of the upstream network.
def _bbox_forward_train(self,
x,
sampling_results,
gt_bboxes,
gt_labels,
img_metas,
neg_label_weights=None):
"""Run forward function and calculate loss for box head in training."""
rois = bbox2roi([res.bboxes for res in sampling_results])
Args:
x (tuple[Tensor]): List of multi-level img features.
sampling_results (list["obj:`SamplingResult`]): Sampling results.
bbox_results = self._bbox_forward(x, rois)
Returns:
dict[str, Tensor]: Usually returns a dictionary with keys:
bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
gt_labels, self.train_cfg)
- `cls_score` (Tensor): Classification scores.
- `bbox_pred` (Tensor): Box energies / deltas.
- `bbox_feats` (Tensor): Extract bbox RoI features.
- `loss_bbox` (dict): A dictionary of bbox loss components.
"""
rois = bbox2roi([res.priors for res in sampling_results])
bbox_results = self._bbox_forward(x, rois)
bbox_targets = self.bbox_head.get_targets(sampling_results,
self.train_cfg)

# neg_label_weights obtained by sampler is image-wise, mapping back to
# the corresponding location in label weights
Expand Down
Loading

0 comments on commit 5a2ef66

Please sign in to comment.