-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add T2T_ViT #2426
base: main
Are you sure you want to change the base?
Add T2T_ViT #2426
Conversation
@brianhou0208 thanks for the work, and looks like a good job getting it in shape. I took a closer look using your code but I have some doubts about this model
For speed comparisons I disabled F.sdpa in existing vit to be fair. Simpler vits with higher acccuracy (imagenet-1k pretrain also to be fair) are often 30-40% faster. So not convinced this is worth the add. Was there a particular reason you had interest in the model? |
def single_attn(self, x: torch.Tensor) -> torch.Tensor:
k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)
if not torch.jit.is_scripting():
with torch.autocast(device_type=v.device.type, enabled=False):
y = self._attn_impl(k, q, v)
else:
y = self._attn_impl(k, q, v)
# skip connection
y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection
return y
def _attn_impl(self, k, q, v):
kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m)
D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1)
kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m)
y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag
return y |
Hi @rwightman, I agree with your observation. The T2T-ViT model does not have advantages over other models. The only advantage might be that it does not use any Another issue occurs when using pre-trained weights and testing whether the structure of first_conv is adaptive to the number of input (C, H, W). If pytorch-image-models/tests/test_models.py Lines 371 to 376 in d81da93
In test_model_load_pretrained , if first_convd is like T2T-ViT without Conv, passing this parameter to nn.Linear instead of nn.Conv2d will also report an error.pytorch-image-models/timm/models/_builder.py Lines 225 to 239 in d81da93
Since this involves modifying |
@brianhou0208 I don't know if not having the input conv is a 'feature', my very first vit impl here, before the official JAX code was released that used the Conv2D trick was this: pytorch-image-models/timm/models/vision_transformer.py Lines 139 to 169 in 7613094
The conv approach was faster since it was an optimized kernel and not a chain of API calls, I suppose torch.compile would rectify most of that but still don't see the downside to the conv. Also the packed vit I started working on (have yet to pick it back up) has to push patchification further into the data pipeline, https://github.com/huggingface/pytorch-image-models/blob/379780bb6ca3304d63bf8ca789d5bbce5949d0b5/timm/models/vision_transformer_packed.py |
Hi @rwightman this PR resolved #2364 , please check.
Result
test T2T-ViT model and weight on ImageNet val dataset
test code
output log
calculate FLOPs/MACs/Params tool
report from calflops
Reference
paper: https://arxiv.org/pdf/2101.11986
code: https://github.com/yitu-opensource/T2T-ViT