Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding RTDETRv2 #34773

Merged
merged 59 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
457ec3c
cookiecutter add rtdetrv2
Nov 15, 2024
53dd35b
make modular working
Nov 15, 2024
0dcc0c5
working modelgit add .
Nov 16, 2024
13ed8e5
working modelgit add .
Nov 16, 2024
f1cf051
finalize moduar inheritence
Nov 16, 2024
40e4c20
finalize moduar inheritence
Nov 16, 2024
52e9d8f
Update src/transformers/models/rtdetrv2/modular_rtdetrv2.py
jadechoghari Nov 21, 2024
1cb8414
update modular and add rename
jadechoghari Nov 21, 2024
dae9513
remove output ckpt
jadechoghari Nov 21, 2024
e940d27
define loss_kwargs
jadechoghari Nov 21, 2024
45349a6
fix CamelCase naming
jadechoghari Nov 23, 2024
cd6ffca
fix naming + files
Nov 27, 2024
d762677
fix modular and convert file
jadechoghari Dec 27, 2024
03bb2cc
additional changes
jadechoghari Dec 27, 2024
6135be4
Merge branch 'main' into rtdetrv2
jadechoghari Dec 27, 2024
350cafe
fix modular
jadechoghari Dec 29, 2024
db724d9
fix import error (switch to lazy)
jadechoghari Dec 29, 2024
8d37be4
fix autobackbone
jadechoghari Dec 29, 2024
d773a97
make style
jadechoghari Dec 29, 2024
35f8db7
add
jadechoghari Dec 30, 2024
bc5c5b0
update testing
jadechoghari Dec 30, 2024
859f367
fix loss
jadechoghari Jan 1, 2025
cfb9ea4
remove old folder
jadechoghari Jan 1, 2025
9a9e9f1
fix testing for v2
jadechoghari Jan 1, 2025
a29eed2
update docstring
jadechoghari Jan 2, 2025
7e1d4dd
fix docstring
jadechoghari Jan 2, 2025
032b916
add resnetv2 (with modular bug to fix)
jadechoghari Jan 3, 2025
5e00df2
remove resnetv2 backbone
jadechoghari Jan 6, 2025
f212548
fix changes
jadechoghari Jan 10, 2025
d868782
small fixes
jadechoghari Jan 11, 2025
b267a20
remove rtdetrv2resnetconfig
jadechoghari Jan 11, 2025
42b7ca5
Merge branch 'main' into rtdetrv2
jadechoghari Jan 11, 2025
d766434
add rtdetrv2 name to convert
jadechoghari Jan 11, 2025
d12e03a
make style
jadechoghari Jan 11, 2025
b2cea40
Update docs/source/en/model_doc/rt_detr_v2.md
jadechoghari Jan 14, 2025
8ed6f75
Update src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py
jadechoghari Jan 14, 2025
d296af1
Update src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py
jadechoghari Jan 14, 2025
73274c9
fix modular typo after review
Jan 14, 2025
60a74d3
add reviewed changes
jadechoghari Jan 18, 2025
1f4c082
add final review changes
jadechoghari Jan 20, 2025
959ee76
Update docs/source/en/model_doc/rt_detr_v2.md
jadechoghari Jan 22, 2025
cef9935
Update src/transformers/models/rt_detr_v2/__init__.py
jadechoghari Jan 22, 2025
77f455e
Update src/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_…
jadechoghari Jan 22, 2025
f56f195
add review changes
jadechoghari Jan 22, 2025
859e37a
remove rtdetrv2 resnet
jadechoghari Jan 22, 2025
2869cf3
removing this weird project change
jadechoghari Jan 31, 2025
152d7df
Merge branch 'main' into rtdetrv2
jadechoghari Jan 31, 2025
75ac690
change ckpt name from jadechoghari to author
jadechoghari Feb 1, 2025
6bfdc76
Merge branch 'main' into rtdetrv2
jadechoghari Feb 3, 2025
fed0d54
implement review and update testing
jadechoghari Feb 4, 2025
8d93109
Merge branch 'rtdetrv2' of https://github.com/jadechoghari/transforme…
jadechoghari Feb 4, 2025
cb7bf37
update naming and remove wrong ckpt
jadechoghari Feb 5, 2025
83711d4
name
jadechoghari Feb 5, 2025
d306d92
make fix-copies
jadechoghari Feb 5, 2025
f344181
Merge branch 'main' into rtdetrv2
jadechoghari Feb 6, 2025
ade4e3f
Fix RT-DETR loss
qubvel Feb 6, 2025
dd3c155
Add resources, fix name
qubvel Feb 6, 2025
aedcf12
Fix repo in docs
qubvel Feb 6, 2025
13b1c47
Fix table name
qubvel Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,8 @@
title: ResNet
- local: model_doc/rt_detr
title: RT-DETR
- local: model_doc/rt_detr_v2
title: RT-DETRv2
- local: model_doc/segformer
title: SegFormer
- local: model_doc/seggpt
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ Flax), PyTorch, and/or TensorFlow.
| [RoFormer](model_doc/roformer) | ✅ | ✅ | ✅ |
| [RT-DETR](model_doc/rt_detr) | ✅ | ❌ | ❌ |
| [RT-DETR-ResNet](model_doc/rt_detr_resnet) | ✅ | ❌ | ❌ |
| [RT-DETRv2](model_doc/rt_detr_v2) | ✅ | ❌ | ❌ |
| [RWKV](model_doc/rwkv) | ✅ | ❌ | ❌ |
| [SAM](model_doc/sam) | ✅ | ✅ | ❌ |
| [SeamlessM4T](model_doc/seamless_m4t) | ✅ | ❌ | ❌ |
Expand Down
97 changes: 97 additions & 0 deletions docs/source/en/model_doc/rt_detr_v2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# RT-DETRv2

## Overview

The RT-DETRv2 model was proposed in [RT-DETRv2: Improved Baseline with Bag-of-Freebies for Real-Time Detection Transformer](https://arxiv.org/abs/2407.17140) by Wenyu Lv, Yian Zhao, Qinyao Chang, Kui Huang, Guanzhong Wang, Yi Liu.

RT-DETRv2 refines RT-DETR by introducing selective multi-scale feature extraction, a discrete sampling operator for broader deployment compatibility, and improved training strategies like dynamic data augmentation and scale-adaptive hyperparameters. These changes enhance flexibility and practicality while maintaining real-time performance.

The abstract from the paper is the following:

*In this report, we present RT-DETRv2, an improved Real-Time DEtection TRansformer (RT-DETR). RT-DETRv2 builds upon the previous state-of-the-art real-time detector, RT-DETR, and opens up a set of bag-of-freebies for flexibility and practicality, as well as optimizing the training strategy to achieve enhanced performance. To improve the flexibility, we suggest setting a distinct number of sampling points for features at different scales in the deformable attention to achieve selective multi-scale feature extraction by the decoder. To enhance practicality, we propose an optional discrete sampling operator to replace the grid_sample operator that is specific to RT-DETR compared to YOLOs. This removes the deployment constraints typically associated with DETRs. For the training strategy, we propose dynamic data augmentation and scale-adaptive hyperparameters customization to improve performance without loss of speed.*

This model was contributed by [jadechoghari](https://huggingface.co/jadechoghari).
The original code can be found [here](https://github.com/lyuwenyu/RT-DETR).

## Usage tips

jadechoghari marked this conversation as resolved.
Show resolved Hide resolved
This second version of RT-DETR improves how the decoder finds objects in an image.

- **better sampling** – adjusts offsets so the model looks at the right areas
- **flexible attention** – can use smooth (bilinear) or fixed (discrete) sampling
- **optimized processing** – improves how attention weights mix information

```py
>>> import torch
>>> import requests

>>> from PIL import Image
>>> from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor

>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_v2_r18vd")
>>> model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/rtdetr_v2_r18vd")

>>> inputs = image_processor(images=image, return_tensors="pt")

>>> with torch.no_grad():
... outputs = model(**inputs)

>>> results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.5)

>>> for result in results:
... for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
... score, label = score.item(), label_id.item()
... box = [round(i, 2) for i in box.tolist()]
... print(f"{model.config.id2label[label]}: {score:.2f} {box}")
cat: 0.97 [341.14, 25.11, 639.98, 372.89]
cat: 0.96 [12.78, 56.35, 317.67, 471.34]
remote: 0.95 [39.96, 73.12, 175.65, 117.44]
sofa: 0.86 [-0.11, 2.97, 639.89, 473.62]
sofa: 0.82 [-0.12, 1.78, 639.87, 473.52]
remote: 0.79 [333.65, 76.38, 370.69, 187.48]
```

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with RT-DETRv2.

<PipelineTag pipeline="object-detection"/>

- Scripts for finetuning [`RTDetrV2ForObjectDetection`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection).
- See also: [Object detection task guide](../tasks/object_detection).
- Notebooks for [inference](https://github.com/qubvel/transformers-notebooks/blob/main/notebooks/RT_DETR_v2_inference.ipynb) and [fine-tuning](https://github.com/qubvel/transformers-notebooks/blob/main/notebooks/RT_DETR_v2_finetune_on_a_custom_dataset.ipynb) RT-DETRv2 on a custom dataset (🌎).


## RTDetrV2Config

[[autodoc]] RTDetrV2Config


## RTDetrV2Model

[[autodoc]] RTDetrV2Model
- forward

## RTDetrV2ForObjectDetection

[[autodoc]] RTDetrV2ForObjectDetection
- forward
6 changes: 6 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@
"RoFormerTokenizer",
],
"models.rt_detr": ["RTDetrConfig", "RTDetrResNetConfig"],
"models.rt_detr_v2": ["RTDetrV2Config"],
"models.rwkv": ["RwkvConfig"],
"models.sam": [
"SamConfig",
Expand Down Expand Up @@ -3454,6 +3455,9 @@
"RTDetrResNetPreTrainedModel",
]
)
_import_structure["models.rt_detr_v2"].extend(
["RTDetrV2ForObjectDetection", "RTDetrV2Model", "RTDetrV2PreTrainedModel"]
)
_import_structure["models.rwkv"].extend(
[
"RwkvForCausalLM",
Expand Down Expand Up @@ -5875,6 +5879,7 @@
RTDetrConfig,
RTDetrResNetConfig,
)
from .models.rt_detr_v2 import RTDetrV2Config
from .models.rwkv import RwkvConfig
from .models.sam import (
SamConfig,
Expand Down Expand Up @@ -8171,6 +8176,7 @@
RTDetrResNetBackbone,
RTDetrResNetPreTrainedModel,
)
from .models.rt_detr_v2 import RTDetrV2ForObjectDetection, RTDetrV2Model, RTDetrV2PreTrainedModel
from .models.rwkv import (
RwkvForCausalLM,
RwkvModel,
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/loss/loss_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from ..utils import is_scipy_available, is_vision_available, requires_backends
from .loss_for_object_detection import (
_set_aux_loss,
box_iou,
dice_loss,
generalized_box_iou,
Expand All @@ -35,6 +34,15 @@
from transformers.image_transforms import center_to_corners_format


# different for RT-DETR: not slicing the last element like in DETR one
@torch.jit.unused
def _set_aux_loss(outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]


class RTDetrHungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network

Expand Down
1 change: 1 addition & 0 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ def ForTokenClassification(logits, labels, config, **kwargs):
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
"RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
}
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
roc_bert,
roformer,
rt_detr,
rt_detr_v2,
rwkv,
sam,
seamless_m4t,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@
("roformer", "RoFormerConfig"),
("rt_detr", "RTDetrConfig"),
("rt_detr_resnet", "RTDetrResNetConfig"),
("rt_detr_v2", "RTDetrV2Config"),
("rwkv", "RwkvConfig"),
("sam", "SamConfig"),
("seamless_m4t", "SeamlessM4TConfig"),
Expand Down Expand Up @@ -600,6 +601,7 @@
("roformer", "RoFormer"),
("rt_detr", "RT-DETR"),
("rt_detr_resnet", "RT-DETR-ResNet"),
("rt_detr_v2", "RT-DETRv2"),
("rwkv", "RWKV"),
("sam", "SAM"),
("seamless_m4t", "SeamlessM4T"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@
("roc_bert", "RoCBertModel"),
("roformer", "RoFormerModel"),
("rt_detr", "RTDetrModel"),
("rt_detr_v2", "RTDetrV2Model"),
("rwkv", "RwkvModel"),
("sam", "SamModel"),
("seamless_m4t", "SeamlessM4TModel"),
Expand Down Expand Up @@ -897,6 +898,7 @@
("deta", "DetaForObjectDetection"),
("detr", "DetrForObjectDetection"),
("rt_detr", "RTDetrForObjectDetection"),
("rt_detr_v2", "RTDetrV2ForObjectDetection"),
("table-transformer", "TableTransformerForObjectDetection"),
("yolos", "YolosForObjectDetection"),
]
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/rt_detr/modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,9 +2115,8 @@ def forward(

loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None
if labels is not None:
if self.training and denoising_meta_values is not None:
enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5]
enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4]
loss, loss_dict, auxiliary_outputs = self.loss_function(
logits,
labels,
Expand Down
29 changes: 29 additions & 0 deletions src/transformers/models/rt_detr_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_rt_detr_v2 import *
from .modeling_rt_detr_v2 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading