Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.clamp issue #237 #238

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

CryptoSalamander
Copy link

This PR is related to #237 !
There are two options to fix this problem,

  1. just convert max tensor to scalar (this PR)
  2. make max tensor be loaded the same device with self.logit_scale like below:
device = self.logit_scale.device
max_tensor = torch.log(torch.tensor(1. / 0.01)).to(device)
logit_scale = torch.clamp(self.logit_scale, max=max_tensor).exp()

I think the first option is better due to its simplicity.
I tested both options on my datasets, it seems no difference in cuda memory allocation & inference speed.

@CryptoSalamander
Copy link
Author

@ancientmooner Could you please check issue #237 ?

@juncgu
Copy link

juncgu commented Apr 11, 2023

@CryptoSalamander, I would prefer to use the second option. I faced the same issue when using torch 2.0, and the item() method in the first option will lead torch.dynamo to break WindowAttention into two graphs when tracing the module.

@CryptoSalamander CryptoSalamander force-pushed the fix/clamp-issue branch 3 times, most recently from ad49644 to 38eba56 Compare April 11, 2023 13:57
@CryptoSalamander
Copy link
Author

CryptoSalamander commented Apr 11, 2023

@juncgu Thanks for your suggestion. I have modified the code as the second option.
Could you please take a look at this PR? @ancientmooner

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants