Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.compile()? #101

Open
lijun2005 opened this issue Aug 8, 2024 · 6 comments
Open

Support torch.compile()? #101

lijun2005 opened this issue Aug 8, 2024 · 6 comments

Comments

@lijun2005
Copy link

Hi, thank you for providing this toolbox. Is there any future work planned toward allowing torch.compile() for ptwt?

@v0lta
Copy link
Owner

v0lta commented Aug 9, 2024

We do support torch.jit to some degree. Its similar to torch.compile. See also https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/main/tests/test_jit.py .

@lijun2005
Copy link
Author

lijun2005 commented Aug 9, 2024 via email

@v0lta
Copy link
Owner

v0lta commented Aug 22, 2024

We haven't tested it, does a working torch.jit test imply that torch.compile works as well?

@lijun2005
Copy link
Author

lijun2005 commented Aug 22, 2024 via email

@lijun2005
Copy link
Author

Could you provide an example of how to apply learnable wavelet transforms to a 3D tensor? The input has the shape [B, L, C], where B: batch_size, L: length, and C: channel. I want to perform the wavelet transform along the L dimension

@v0lta
Copy link
Owner

v0lta commented Sep 25, 2024

Sure:

from tqdm import tqdm
import ptwt, torch
from ptwt.wavelets_learnable import ProductFilter

torch.manual_seed(42)
aten = torch.randn(32, 32, 32, 32)
wavelet = ProductFilter(torch.randn(4), torch.randn(4),
                        torch.randn(4), torch.randn(4))
opt = torch.optim.RMSprop(wavelet.parameters(), lr=0.01)
 
for _i in (bar := tqdm(range(5000))):
      res = ptwt.waverec3(ptwt.wavedec3(aten, wavelet, level=4), wavelet)
      cost = torch.mean((res - aten)**2) + wavelet.wavelet_loss()
      cost.backward()
      opt.step()
      opt.zero_grad()
      bar.set_description(f"cost: {cost.detach().numpy():2.4e}")

Pass [1, B, L, C] into wavedec3 if you want to transform the batch dimension.
I hope this helps.
See https://arxiv.org/pdf/2004.09569 for more information regarding learnable wavelets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants