Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T2T_ViT #2426

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Add T2T_ViT #2426

wants to merge 2 commits into from

Conversation

brianhou0208
Copy link
Contributor

@brianhou0208 brianhou0208 commented Jan 22, 2025

Hi @rwightman this PR resolved #2364 , please check.

Result

test T2T-ViT model and weight on ImageNet val dataset

Model Acc@1 Acc@5 FLOPs#G MACs#G Params#M
t2t_vit_7 71.6760 90.8860 2.0261 0.9755 4.2557
t2t_vit_10 75.1500 92.8060 2.6476 1.2854 5.8347
t2t_vit_12 76.4800 93.4840 3.0620 1.492 6.8874
t2t_vit_14 81.5000 95.6660 8.7526 4.334 21.4658
t2t_vit_19 81.9320 95.7440 15.6663 7.7868 39.0851
t2t_vit_24 82.2760 95.8860 25.4543 12.6759 64.0010
t2t_vit_t_14 81.6880 95.8520 8.6881 4.334 21.4654
t2t_vit_t_19 82.4420 96.0820 15.6018 7.7868 39.0847
t2t_vit_t_24 82.5540 96.0640 25.3898 12.6759 64.0006
test code
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.utils.metrics import AverageMeter, accuracy

device = torch.device('cuda:0')

if __name__ == "__main__":
    val_dataset = datasets.ImageFolder(
        './data/val',
        transforms.Compose([
            transforms.Resize(int(224 / 0.9), interpolation=3),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])
    )

    val_loader = DataLoader(
        val_dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True)
    
    for name in timm.list_models('t2t_vit*'):
        model = timm.create_model(name, pretrained=True).eval()
        model.to(device)
        top1 = AverageMeter()
        top5 = AverageMeter()

        with torch.no_grad():
            for images, target in tqdm(val_loader):
                images = images.to(device)
                target = target.to(device)
                output = model(images)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                top1.update(acc1, images.size(0))
                top5.update(acc5, images.size(0))
        print(f"Model {name} ACC@1 {top1.avg:.4f} ACC@5 {top5.avg:.4f}")
output log
100%|██████████████████████████████████████████████| 196/196 [00:39<00:00,  4.92it/s]
Model t2t_vit_7 ACC@1 71.6760 ACC@5 90.8860
FLOPs 2.0261 GFLOPS / MACs 975.534 MMACs / Params 4.2557 M

100%|██████████████████████████████████████████████| 196/196 [00:39<00:00,  4.96it/s]
Model t2t_vit_10 ACC@1 75.1500 ACC@5 92.8060
FLOPs 2.6476 GFLOPS / MACs 1.2854 GMACs / Params 5.8347 M

100%|██████████████████████████████████████████████| 196/196 [00:40<00:00,  4.88it/s]
Model t2t_vit_12 ACC@1 76.4800 ACC@5 93.4840
FLOPs 3.062 GFLOPS / MACs 1.492 GMACs / Params 6.8874 M

100%|██████████████████████████████████████████████| 196/196 [01:08<00:00,  2.87it/s]
Model t2t_vit_14 ACC@1 81.5000 ACC@5 95.6660
FLOPs 8.7526 GFLOPS / MACs 4.334 GMACs / Params 21.4658 M

100%|██████████████████████████████████████████████| 196/196 [01:45<00:00,  1.86it/s]
Model t2t_vit_19 ACC@1 81.9320 ACC@5 95.7440
FLOPs 15.6663 GFLOPS / MACs 7.7868 GMACs / Params 39.0851 M

100%|██████████████████████████████████████████████| 196/196 [02:31<00:00,  1.30it/s]
Model t2t_vit_24 ACC@1 82.2760 ACC@5 95.8860
FLOPs 25.4543 GFLOPS / MACs 12.6759 GMACs / Params 64.001 M

100%|██████████████████████████████████████████████| 196/196 [01:28<00:00,  2.20it/s]
Model t2t_vit_t_14 ACC@1 81.6880 ACC@5 95.8520
FLOPs 8.6881 GFLOPS / MACs 4.334 GMACs / Params 21.4654 M

100%|██████████████████████████████████████████████| 196/196 [02:04<00:00,  1.57it/s]
Model t2t_vit_t_19 ACC@1 82.4420 ACC@5 96.0820
FLOPs 15.6018 GFLOPS / MACs 7.7868 GMACs / Params 39.0847 M

100%|██████████████████████████████████████████████| 196/196 [02:51<00:00,  1.15it/s]
Model t2t_vit_t_24 ACC@1 82.5540 ACC@5 96.0640
FLOPs 25.3898 GFLOPS / MACs 12.6759 GMACs / Params 64.0006 M
calculate FLOPs/MACs/Params tool

report from calflops

from calflops import calculate_flops
def flops_param(model):
    flops, macs, params = calculate_flops(
        model=model,
        input_shape=(1, 3, 224, 224),
        output_as_string=True,
        output_precision=4,
        print_detailed=False,
        print_results=False
    )
    print(f"FLOPs {flops} / MACs {macs} / Params {params}")

Reference

paper: https://arxiv.org/pdf/2101.11986
code: https://github.com/yitu-opensource/T2T-ViT

@brianhou0208 brianhou0208 marked this pull request as draft January 23, 2025 16:06
@rwightman
Copy link
Collaborator

@brianhou0208 thanks for the work, and looks like a good job getting it in shape. I took a closer look using your code but I have some doubts about this model

  1. it requires a workaround w/ AMP + float16 to avoid NaN (see next post)
  2. compared to simpler models it's really not performing better givent the speed, especially comparing these https://huggingface.co/collections/timm/searching-for-better-vit-baselines-663eb74f64f847d2f35a9c19 they are faster and better accuracy at a fraction of the param count and they have fewer macs/activations. Even comparing some models that have been there longer like deit3 (e.g. deit3_medium_patch16_224) they are faster/simpler/smaller than these.

For speed comparisons I disabled F.sdpa in existing vit to be fair. Simpler vits with higher acccuracy (imagenet-1k pretrain also to be fair) are often 30-40% faster.

So not convinced this is worth the add. Was there a particular reason you had interest in the model?

@rwightman
Copy link
Collaborator

    def single_attn(self, x: torch.Tensor) -> torch.Tensor:
        k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)

        if not torch.jit.is_scripting():
            with torch.autocast(device_type=v.device.type, enabled=False):
                y = self._attn_impl(k, q, v)
        else:
            y = self._attn_impl(k, q, v)

        # skip connection
        y = v + self.dp(self.proj(y))  # same as token_transformer in T2T layer, use v as skip connection
        return y

    def _attn_impl(self, k, q, v):
        kp, qp = self.prm_exp(k), self.prm_exp(q)  # (B, T, m), (B, T, m)
        D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2)  # (B, T, m) * (B, m) -> (B, T, 1)
        kptv = torch.einsum('bin,bim->bnm', v.float(), kp)  # (B, emb, m)
        y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon)  # (B, T, emb)/Diag
        return y

@brianhou0208
Copy link
Contributor Author

Hi @rwightman, I agree with your observation. The T2T-ViT model does not have advantages over other models. The only advantage might be that it does not use any nn.Conv2d at all, relying instead on the nn.Unfold method to extract patches.
Most ViT-based models require some form of convolution for input processing, but the T2T-ViT architecture can completely bypass convolution, maybe this architecture can be further explored...

Another issue occurs when using pre-trained weights and testing whether the structure of first_conv is adaptive to the number of input (C, H, W). If first_conv is set to None, the test_model_default_cfgs_non_std test will fail.

first_conv = cfg['first_conv']
if isinstance(first_conv, str):
first_conv = (first_conv,)
assert isinstance(first_conv, (tuple, list))
for fc in first_conv:
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'

In test_model_load_pretrained , if first_convd is like T2T-ViT without Conv, passing this parameter to nn.Linear instead of nn.Conv2d will also report an error.
input_convs = pretrained_cfg.get('first_conv', None)
if input_convs is not None and in_chans != 3:
if isinstance(input_convs, str):
input_convs = (input_convs,)
for input_conv_name in input_convs:
weight_name = input_conv_name + '.weight'
try:
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
_logger.info(
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
except NotImplementedError as e:
del state_dict[weight_name]
strict = False
_logger.warning(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')

Since this involves modifying test_models, and adding T2T-ViT is not worth the effort, I should probably close this PR.

@rwightman
Copy link
Collaborator

rwightman commented Jan 24, 2025

@brianhou0208 I don't know if not having the input conv is a 'feature', my very first vit impl here, before the official JAX code was released that used the Conv2D trick was this:

class PatchEmbed(nn.Module):
""" Image to Patch Embedding
Unfold image into fixed size patches, flatten into seq, project to embedding dim.
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, flatten_channels_last=False):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert img_size[0] % patch_size[0] == 0, 'image height must be divisible by the patch height'
assert img_size[1] % patch_size[1] == 0, 'image width must be divisible by the patch width'
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
patch_dim = in_chans * patch_size[0] * patch_size[1]
self.img_size = img_size
self.patch_size = patch_size
self.flatten_channels_last = flatten_channels_last
self.num_patches = num_patches
self.proj = nn.Linear(patch_dim, embed_dim)
def forward(self, x):
B, C, H, W = x.shape
Ph, Pw = self.patch_size
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
if self.flatten_channels_last:
# flatten patches with channels last like the paper (likely using TF)
x = x.unfold(2, Ph, Ph).unfold(3, Pw, Pw).permute(0, 2, 3, 4, 5, 1).reshape(B, -1, Ph * Pw * C)
else:
x = x.permute(0, 2, 3, 1).unfold(1, Ph, Ph).unfold(2, Pw, Pw).reshape(B, -1, C * Ph * Pw)
x = self.proj(x)
return x

The conv approach was faster since it was an optimized kernel and not a chain of API calls, I suppose torch.compile would rectify most of that but still don't see the downside to the conv.

Also the packed vit I started working on (have yet to pick it back up) has to push patchification further into the data pipeline, https://github.com/huggingface/pytorch-image-models/blob/379780bb6ca3304d63bf8ca789d5bbce5949d0b5/timm/models/vision_transformer_packed.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEATURE] add t2t_vit
2 participants