Skip to content

Commit

Permalink
Adding RT-DETRv2 for object detection (#34773)
Browse files Browse the repository at this point in the history
* cookiecutter add rtdetrv2

* make modular working

* working modelgit add .

* working modelgit add .

* finalize moduar inheritence

* finalize moduar inheritence

* Update src/transformers/models/rtdetrv2/modular_rtdetrv2.py

Co-authored-by: Cyril Vallez <[email protected]>

* update modular and add rename

* remove output ckpt

* define loss_kwargs

* fix CamelCase naming

* fix naming + files

* fix modular and convert file

* additional changes

* fix modular

* fix import error (switch to lazy)

* fix autobackbone

* make style

* add

* update testing

* fix loss

* remove old folder

* fix testing for v2

* update docstring

* fix docstring

* add resnetv2 (with modular bug to fix)

* remove resnetv2 backbone

* fix changes

* small fixes

* remove rtdetrv2resnetconfig

* add rtdetrv2 name to convert

* make style

* Update docs/source/en/model_doc/rt_detr_v2.md

Co-authored-by: Steven Liu <[email protected]>

* Update src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py

Co-authored-by: Steven Liu <[email protected]>

* fix modular typo after review

* add reviewed changes

* add final review changes

* Update docs/source/en/model_doc/rt_detr_v2.md

Co-authored-by: Cyril Vallez <[email protected]>

* Update src/transformers/models/rt_detr_v2/__init__.py

Co-authored-by: Cyril Vallez <[email protected]>

* Update src/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_to_hf.py

Co-authored-by: Cyril Vallez <[email protected]>

* add review changes

* remove rtdetrv2 resnet

* removing this weird project change

* change ckpt name from jadechoghari to author

* implement review and update testing

* update naming and remove wrong ckpt

* name

* make fix-copies

* Fix RT-DETR loss

* Add resources, fix name

* Fix repo in docs

* Fix table name

---------

Co-authored-by: jadechoghari <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: qubvel <[email protected]>
  • Loading branch information
5 people authored Feb 6, 2025
1 parent 6246c03 commit 006d924
Show file tree
Hide file tree
Showing 19 changed files with 4,431 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,8 @@
title: ResNet
- local: model_doc/rt_detr
title: RT-DETR
- local: model_doc/rt_detr_v2
title: RT-DETRv2
- local: model_doc/segformer
title: SegFormer
- local: model_doc/seggpt
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ Flax), PyTorch, and/or TensorFlow.
| [RoFormer](model_doc/roformer) ||||
| [RT-DETR](model_doc/rt_detr) ||||
| [RT-DETR-ResNet](model_doc/rt_detr_resnet) ||||
| [RT-DETRv2](model_doc/rt_detr_v2) ||||
| [RWKV](model_doc/rwkv) ||||
| [SAM](model_doc/sam) ||||
| [SeamlessM4T](model_doc/seamless_m4t) ||||
Expand Down
97 changes: 97 additions & 0 deletions docs/source/en/model_doc/rt_detr_v2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# RT-DETRv2

## Overview

The RT-DETRv2 model was proposed in [RT-DETRv2: Improved Baseline with Bag-of-Freebies for Real-Time Detection Transformer](https://arxiv.org/abs/2407.17140) by Wenyu Lv, Yian Zhao, Qinyao Chang, Kui Huang, Guanzhong Wang, Yi Liu.

RT-DETRv2 refines RT-DETR by introducing selective multi-scale feature extraction, a discrete sampling operator for broader deployment compatibility, and improved training strategies like dynamic data augmentation and scale-adaptive hyperparameters. These changes enhance flexibility and practicality while maintaining real-time performance.

The abstract from the paper is the following:

*In this report, we present RT-DETRv2, an improved Real-Time DEtection TRansformer (RT-DETR). RT-DETRv2 builds upon the previous state-of-the-art real-time detector, RT-DETR, and opens up a set of bag-of-freebies for flexibility and practicality, as well as optimizing the training strategy to achieve enhanced performance. To improve the flexibility, we suggest setting a distinct number of sampling points for features at different scales in the deformable attention to achieve selective multi-scale feature extraction by the decoder. To enhance practicality, we propose an optional discrete sampling operator to replace the grid_sample operator that is specific to RT-DETR compared to YOLOs. This removes the deployment constraints typically associated with DETRs. For the training strategy, we propose dynamic data augmentation and scale-adaptive hyperparameters customization to improve performance without loss of speed.*

This model was contributed by [jadechoghari](https://huggingface.co/jadechoghari).
The original code can be found [here](https://github.com/lyuwenyu/RT-DETR).

## Usage tips

This second version of RT-DETR improves how the decoder finds objects in an image.

- **better sampling** – adjusts offsets so the model looks at the right areas
- **flexible attention** – can use smooth (bilinear) or fixed (discrete) sampling
- **optimized processing** – improves how attention weights mix information

```py
>>> import torch
>>> import requests

>>> from PIL import Image
>>> from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor

>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_v2_r18vd")
>>> model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/rtdetr_v2_r18vd")

>>> inputs = image_processor(images=image, return_tensors="pt")

>>> with torch.no_grad():
... outputs = model(**inputs)

>>> results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.5)

>>> for result in results:
... for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
... score, label = score.item(), label_id.item()
... box = [round(i, 2) for i in box.tolist()]
... print(f"{model.config.id2label[label]}: {score:.2f} {box}")
cat: 0.97 [341.14, 25.11, 639.98, 372.89]
cat: 0.96 [12.78, 56.35, 317.67, 471.34]
remote: 0.95 [39.96, 73.12, 175.65, 117.44]
sofa: 0.86 [-0.11, 2.97, 639.89, 473.62]
sofa: 0.82 [-0.12, 1.78, 639.87, 473.52]
remote: 0.79 [333.65, 76.38, 370.69, 187.48]
```

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with RT-DETRv2.

<PipelineTag pipeline="object-detection"/>

- Scripts for finetuning [`RTDetrV2ForObjectDetection`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection).
- See also: [Object detection task guide](../tasks/object_detection).
- Notebooks for [inference](https://github.com/qubvel/transformers-notebooks/blob/main/notebooks/RT_DETR_v2_inference.ipynb) and [fine-tuning](https://github.com/qubvel/transformers-notebooks/blob/main/notebooks/RT_DETR_v2_finetune_on_a_custom_dataset.ipynb) RT-DETRv2 on a custom dataset (🌎).


## RTDetrV2Config

[[autodoc]] RTDetrV2Config


## RTDetrV2Model

[[autodoc]] RTDetrV2Model
- forward

## RTDetrV2ForObjectDetection

[[autodoc]] RTDetrV2ForObjectDetection
- forward
6 changes: 6 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@
"RoFormerTokenizer",
],
"models.rt_detr": ["RTDetrConfig", "RTDetrResNetConfig"],
"models.rt_detr_v2": ["RTDetrV2Config"],
"models.rwkv": ["RwkvConfig"],
"models.sam": [
"SamConfig",
Expand Down Expand Up @@ -3454,6 +3455,9 @@
"RTDetrResNetPreTrainedModel",
]
)
_import_structure["models.rt_detr_v2"].extend(
["RTDetrV2ForObjectDetection", "RTDetrV2Model", "RTDetrV2PreTrainedModel"]
)
_import_structure["models.rwkv"].extend(
[
"RwkvForCausalLM",
Expand Down Expand Up @@ -5875,6 +5879,7 @@
RTDetrConfig,
RTDetrResNetConfig,
)
from .models.rt_detr_v2 import RTDetrV2Config
from .models.rwkv import RwkvConfig
from .models.sam import (
SamConfig,
Expand Down Expand Up @@ -8171,6 +8176,7 @@
RTDetrResNetBackbone,
RTDetrResNetPreTrainedModel,
)
from .models.rt_detr_v2 import RTDetrV2ForObjectDetection, RTDetrV2Model, RTDetrV2PreTrainedModel
from .models.rwkv import (
RwkvForCausalLM,
RwkvModel,
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/loss/loss_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from ..utils import is_scipy_available, is_vision_available, requires_backends
from .loss_for_object_detection import (
_set_aux_loss,
box_iou,
dice_loss,
generalized_box_iou,
Expand All @@ -35,6 +34,15 @@
from transformers.image_transforms import center_to_corners_format


# different for RT-DETR: not slicing the last element like in DETR one
@torch.jit.unused
def _set_aux_loss(outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]


class RTDetrHungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
Expand Down
1 change: 1 addition & 0 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ def ForTokenClassification(logits, labels, config, **kwargs):
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
"RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
}
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
roc_bert,
roformer,
rt_detr,
rt_detr_v2,
rwkv,
sam,
seamless_m4t,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@
("roformer", "RoFormerConfig"),
("rt_detr", "RTDetrConfig"),
("rt_detr_resnet", "RTDetrResNetConfig"),
("rt_detr_v2", "RTDetrV2Config"),
("rwkv", "RwkvConfig"),
("sam", "SamConfig"),
("seamless_m4t", "SeamlessM4TConfig"),
Expand Down Expand Up @@ -600,6 +601,7 @@
("roformer", "RoFormer"),
("rt_detr", "RT-DETR"),
("rt_detr_resnet", "RT-DETR-ResNet"),
("rt_detr_v2", "RT-DETRv2"),
("rwkv", "RWKV"),
("sam", "SAM"),
("seamless_m4t", "SeamlessM4T"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@
("roc_bert", "RoCBertModel"),
("roformer", "RoFormerModel"),
("rt_detr", "RTDetrModel"),
("rt_detr_v2", "RTDetrV2Model"),
("rwkv", "RwkvModel"),
("sam", "SamModel"),
("seamless_m4t", "SeamlessM4TModel"),
Expand Down Expand Up @@ -897,6 +898,7 @@
("deta", "DetaForObjectDetection"),
("detr", "DetrForObjectDetection"),
("rt_detr", "RTDetrForObjectDetection"),
("rt_detr_v2", "RTDetrV2ForObjectDetection"),
("table-transformer", "TableTransformerForObjectDetection"),
("yolos", "YolosForObjectDetection"),
]
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/rt_detr/modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,9 +2115,8 @@ def forward(

loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
if labels is not None:
if self.training and denoising_meta_values is not None:
enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
loss, loss_dict, auxiliary_outputs = self.loss_function(
logits,
labels,
Expand Down
29 changes: 29 additions & 0 deletions src/transformers/models/rt_detr_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_rt_detr_v2 import *
from .modeling_rt_detr_v2 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading

0 comments on commit 006d924

Please sign in to comment.