-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensorflow support #153
Comments
With the release of Tensorflow 2, they deprecated Model zoos with tensorflow: tfhub.dev, kaggle To load this model in c++, they added a new API. It requires manually building tensorflow. This can be done with bazel using the |
Instead of depending on TF, we should create a tf2 parse in migraphx which can handle the new format. |
The SavedModel will contain the following format: (source) assets/
assets.extra/
variables/
variables.data-?????-of-?????
variables.index
saved_model.pb An example SavedModel: https://www.kaggle.com/models/tensorflow/resnet-50 (It only has variables, no assets) Those names come from here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/constants.h The saved_model proto is here Inspecting any protobuf: Some of these proto-s are already in MGX, and the matching protos in TF and TSL. Parsing the protobuf will contain a meta_graph, which has the following: GraphDef, SaverDef, CollectionDef, SignatureDef, AssetFileDef, SavedObjectGraph. Here are some files to check for start: loader.cc, reader.cc, env.cc |
Check https://github.com/onnx/tensorflow-onnx to see if it can be used to avoid/workaround tf2 implementation. |
Tensorflow -> ONNXOfficial documentation is here. Setuppython3 -m venv tf_onnx
. tf_onnx/bin/activate pip install tensorflow tf2onnx Test with ResNet50curl -L -o ./resnet50-model.tar.gz\
https://www.kaggle.com/api/v1/models/tensorflow/resnet-50/tensorFlow2/classification/1/download
mkdir resnet50
tar xzvf resnet50-model.tar.gz -C resnet50
python -m tf2onnx.convert --saved-model resnet50 --output resnet50.onnx It should compile fine. Test with BERTcurl -L -o ./bert-model.tar.gz\
https://www.kaggle.com/api/v1/models/google/bert/tensorFlow2/answer-equivalence-bem/1/download
mkdir bert
tar xzvf bert-model.tar.gz -C bert
python -m tf2onnx.convert --saved-model bert --output bert.onnx It will fail, because Xla* operators are not supported. curl -L -o ./bert-qa-model.tar.gz\
https://www.kaggle.com/api/v1/models/seesee/bert/tensorFlow2/uncased-tf2-qa/1/download
mkdir bert-qa
tar xzvf bert-model.tar.gz -C bert-qa
python -m tf2onnx.convert --saved-model bert-qa --output bert-qa.onnx Sadly similar results :( Test with SDXLNo premade model available, the closest is: |
What is uploaded to HuggingFace, is either in safetensor or h5. No SavedModel format. In theory, all it would take to load convert it is: import tf2onnx
import onnx
from tensorflow.keras.models import load_model
model = load_model('path/to/keras_model_you_want_to_convert.h5')
onnx_model, _ = tf2onnx.convert.from_keras(model)
onnx.save(onnx_model, 'new_model.onnx') It might work for some, but in practice, sadly this is not the case. Because the model config is not saved, only the weights. |
Tracking issue for all tensorflow related tasks.
The text was updated successfully, but these errors were encountered: