Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to implement a detach operation similar to Pytorch? #1138

Closed
zaoanhh opened this issue Dec 17, 2024 · 2 comments
Closed

How to implement a detach operation similar to Pytorch? #1138

zaoanhh opened this issue Dec 17, 2024 · 2 comments

Comments

@zaoanhh
Copy link

zaoanhh commented Dec 17, 2024

I've been doing some work recently using Lux for unsupervised training, and I think Lux is pretty cool. But I encountered some problems: I used to use the output of the model to generate some labels for subsequent calculations. In Pytorch it looks like:

import torch
import torch.nn.functional as F
y = model(x)
y1 = y.detach()
softmax_y1 = F.softmax(y1, dim=1)
pred_class_indices = torch.argmax(softmax_y1, dim=1)
num_classes = 4
labels_true = F.one_hot(pred_class_indices, num_classes=num_classes)

In Lux, how should I implement the detach operation to avoid the gradient being tracked during label generation?

@avik-pal
Copy link
Member

You do ChainRulesCore.ignore_derivatives(y)

@zaoanhh
Copy link
Author

zaoanhh commented Dec 19, 2024

You do ChainRulesCore.ignore_derivatives(y)

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants