-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: Module <class 'brevitas.proxy.float_runtime_quant.ActFloatQuantProxyFromInjector'> not supported for export #1091
Comments
Hello, Unfortunately, torch quant/dequant op that we normally use to map quantization (see https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor.html) don't support minifloat/fp8 quantization. With respect to ONNX, we can only export FP8 and not lower bit-width for similar reasons, meaning that ONNX only supports a few types of fp8 but it doesn't allow to define a custom minifloat data type with arbitrary mantissa and exponent bit-width (or other configurations). If you share with us what is your goal with this export flow, maybe we can guide you towards a custom solution while we build a generic export flow. |
Hey @Giuseppe5 - Thank you for the speedy reply. That makes a lot of sense and sounds like something that is a bit tricky to do generically given the complexity of the different ops to support it correctly. I am running a series of experiments to look at the effect quantisation has on adversarial robustness and want to quantise the models on a HPC style environment, save them somewhere sensible and then evaluate the robustness downstream. It doesn't have to be a valid torch/onnx model to load into another library necessarily. I tried to use I'd appreciate any ideas you may have to save the models to be loaded again later. |
We recently merged in dev the possibility to export minifloat to QONNX (QONNX ref and QONNX Minifloat ref). This representation allows us to represent explicitly all the various minifloat configuration that Brevitas can simulate. Would this work for you? Do you need a torch-based export for this task or do you just need to represent (even with custom ops) the computational graph? |
This sounds like something that could be very useful for my use case but does not cover everything. Ideally, it would be That being said, the features of |
Another option is to just save the state dict with Talking with @nickfraser, we might try a few ideas on how to make Brevitas compatible with serialization. It should be a quick experiment, and I'll keep the issue updated if it works out so that you can replicate it while we go through all the PR process. |
That sounds wonderful. Thank you! Re. carrying |
To be fair, that pin might be outdated. When numpy 2.0 was released, a lot of things broke and our decision was to wait until it stabilized a bit before we tried to unpin it. We don't use any specific numpy functions that are no longer available in 2.0 or things like that. I will open a PR with the unpinned version to see how many things break. Fingers crossed : #1093 |
Sounds great! Thank you 👍 |
I think there might be a reasonable optimism that Brevitas doesn't have anymore any hard-requirement on numpy. Having said that, I notice that torch keeps installing numpy 1.26 even if I don't specify any particular version. I tried manually upgrading it (with torch 2.4) and everything seems to work fine. With an older version of torch there seems to be conflicts. Let me know what you see from your side. I will still look into serialization but this could be the fastest way to get there IMHO. |
What is the best branch to install/setup to test this? and any luck with the serialisation experimentation? |
After install brevitas, just update numpy to whatever version you need and everything should work. Be sure to have
I opened a draft PR with an example #1096 It seems to work locally but there could be side effects when using the pickled model downstream. I believe all the issues can be fixed relatively easily. Wouldn't you need to carry on Brevitas downstream also with pickle? |
I would need to carry I will pull the draft PR version and have an experiment this week and report back! (Probably close of play Thursday!) |
Hey @Giuseppe5 - Sorry for being slow to get to this. I have looked at the linked PR and the code example used Edit: I've worked out that the layers get replaced with quant versions and can now see the things the injector adds like |
The unloading/saving/loading of the injectors happens within the context manager I created in that PR. When you enter the context manager, the injectors are temporarily detached, allowing you to serialize your model (the call to torch.save in the context manager would fail otherwise). After you exit the context manager, the injectors are re-attached. The serialized model won't have them, and this might cause some issue when re-loading the model with torch.load, but I believe we can address any bug that comes up because of that. To sum it up, after you generate a model with the script you mentioned, enter the context manager and save it, and then try to re-load the model and use it as you would normally. If you see bugs, let me know and I'll advise on how to proceed/update the PR. Does this answer your question? Unfortunately I am not quite sure I understand the part after the EDIT, so I hope this is sufficient to get you unblocked. |
Thanks for the speedy reply @Giuseppe5. I now understand so thank you for the response. I am having some issues setting up the draft PR. I have pulled it locally and have tried several different methods of installing (
Any ideas? |
|
Thank you @nickfraser! I managed to get it setup with starting a fresh and installing the @Giuseppe5 - I have managed to successfully get the toy example in #1096 (with a few mods - see below!) working for both saving and loading. Savingimport brevitas.nn as qnn
from brevitas.export.inference.manager import quant_inference_mode
import torch
model = qnn.QuantLinear(3, 8)
# This is needed to supress an error around cannot
# save when training
model.eval()
with quant_inference_mode(model , delete_injector=True):
b = model(torch.randn(1,3))
# Amended this to save the state_dict rather than model itself
torch.save(model.state_dict(), "test_dict.pickle") Loadingimport brevitas.nn as qnn
from brevitas.export.inference.manager import quant_inference_mode
import torch
model = qnn.QuantLinear(3, 8)
model.eval()
model.load_state_dict(torch.load(open("test_dict.pickle", "rb")))
with quant_inference_mode(model):
b = model(torch.randn(1,3))
print(b) I am time limited this week but I will try to have a go with a more complex model (such as a TorchVision ImageNet ResNet) and then report back findings! Thank you for the changes and tips to date. I really appreciate it! |
Just one comment, I believe that if you're only storing the state dict, you don't need the context manager + delete injector. That part is needed only if you try to serialize (i.e., model = qnn.QuantLinear(3, 8)
# This is needed to supress an error around cannot
# save when training
model.eval()
b = model(torch.randn(1,3))
torch.save(model.state_dict(), "test_dict.pickle") In any case, feel free to experiment and keep us posted with updates/issues :) |
Also, it is recommended always to do at least one forward pass with your quantized model before saving/exporting, because there are some quant parameters that requires a forward pass for initialization |
Describe the bug
Attempting to save PTQ
TorchVision
models using theptq_benchmark_torchvision.py
script after amending the script to save the model usingexport_torch_qcdq
as a final step.The traceback is below:
Reproducibility
To Reproduce
Steps to reproduce the behavior. For example:
export_torch_qcdq()
and associated args to the end of theptq_benchmark_torchvision.py
Expected behavior
The model should be saved.
please complete the following information:
If known:
Additional context
I have tired
torch.save()
natively and it doesn't work either.The text was updated successfully, but these errors were encountered: