-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better Bfloat16 support #777
Conversation
src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Outdated
Show resolved
Hide resolved
src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Outdated
Show resolved
Hide resolved
I think this is good, but there are a few things that we may want to revisit down the line:
* >>> import torch
>>> x = torch.rand((1,),dtype=torch.bfloat16,device="cuda:0")
>>> y = torch.rand((1,),dtype=torch.bfloat16,device="cuda:0")
>>> r = x + y # Output type == input type
>>> r.dtype
torch.bfloat16 >>> import torch
>>> x = torch.rand((1,),dtype=torch.bfloat16,device="cuda:0")
>>> y = torch.rand((1,),dtype=torch.float32,device="cuda:0")
>>> r = x + y # Implicit upcast of x
>>> r.dtype
torch.float32 |
Regarding 1, I will rename the function to match its functionality. Regarding 2, In this current implementation, actually the output of QuantTensor.int() will always be float32 (even though the original QuantTensor was in float16, for example). |
src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Outdated
Show resolved
Hide resolved
d5f6a3c
to
024563c
Compare
024563c
to
e3f98d4
Compare
No description provided.