-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support torch.compile()? #101
Comments
We do support torch.jit to some degree. Its similar to torch.compile. See also https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/main/tests/test_jit.py . |
Sincere thanks for your prompt response. In my opinion, `torch.compile` is more convenient and faster compared to `torch.jit`. I believe that once `ptwt` supports `torch.compile`, it will be even more convenient for the community and user-friendly.---Original---From: "Moritz ***@***.***>Date: Fri, Aug 9, 2024 15:10 PMTo: ***@***.***>;Cc: ***@***.***>***@***.***>;Subject: Re: [v0lta/PyTorch-Wavelet-Toolbox] Support torch.compile()? (Issue#101)
We do support torch.jit to some degree. Its similar to torch.compile. See also https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/main/tests/test_jit.py .
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: ***@***.***>
|
We haven't tested it, does a working |
Torch.compile() indicates that some operations in ptwt are not supported.
…---Original---
From: "Moritz ***@***.***>
Date: Thu, Aug 22, 2024 19:54 PM
To: ***@***.***>;
Cc: ***@***.******@***.***>;
Subject: Re: [v0lta/PyTorch-Wavelet-Toolbox] Support torch.compile()? (Issue#101)
We haven't tested it, does a working torch.jit test imply that torch.compile works as well?
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
Could you provide an example of how to apply learnable wavelet transforms to a 3D tensor? The input has the shape [B, L, C], where B: batch_size, L: length, and C: channel. I want to perform the wavelet transform along the L dimension |
Sure: from tqdm import tqdm
import ptwt, torch
from ptwt.wavelets_learnable import ProductFilter
torch.manual_seed(42)
aten = torch.randn(32, 32, 32, 32)
wavelet = ProductFilter(torch.randn(4), torch.randn(4),
torch.randn(4), torch.randn(4))
opt = torch.optim.RMSprop(wavelet.parameters(), lr=0.01)
for _i in (bar := tqdm(range(5000))):
res = ptwt.waverec3(ptwt.wavedec3(aten, wavelet, level=4), wavelet)
cost = torch.mean((res - aten)**2) + wavelet.wavelet_loss()
cost.backward()
opt.step()
opt.zero_grad()
bar.set_description(f"cost: {cost.detach().numpy():2.4e}") Pass [1, B, L, C] into wavedec3 if you want to transform the batch dimension. |
Hi, thank you for providing this toolbox. Is there any future work planned toward allowing torch.compile() for ptwt?
The text was updated successfully, but these errors were encountered: