From f391218c2714481cd2e964b18492a490fed23a79 Mon Sep 17 00:00:00 2001 From: Yifan Shen Date: Mon, 10 Jun 2024 15:00:16 -0400 Subject: [PATCH] 8.0b1 Release (#2232) * 8.0b1 release * auto rerun flaky tests --------- Co-authored-by: Yifan Shen --- NOTICE.txt | 23 + coremlpython/CoreMLPython.h | 32 +- coremlpython/CoreMLPython.mm | 46 +- coremltools/__init__.py | 6 +- coremltools/_deps/__init__.py | 2 +- coremltools/converters/__init__.py | 9 +- coremltools/converters/_converters_entry.py | 79 +- coremltools/converters/mil/Makefile | 3 +- coremltools/converters/mil/__init__.py | 2 +- .../mil/_deployment_compatibility.py | 5 + .../converters/mil/backend/mil/helper.py | 32 +- .../converters/mil/backend/mil/load.py | 280 +- .../passes/adjust_io_to_supported_types.py | 7 +- .../mil/passes/sanitize_name_strings.py | 11 +- .../mil/backend/mil/passes/test_passes.py | 26 +- .../converters/mil/backend/mil/test_load.py | 1098 ++++++++ .../backend/mil/test_model_input_params.py | 195 -- coremltools/converters/mil/backend/nn/load.py | 7 +- coremltools/converters/mil/converter.py | 2 +- coremltools/converters/mil/debugging_utils.py | 10 +- coremltools/converters/mil/frontend/_utils.py | 114 +- .../mil/frontend/milproto/helper.py | 5 + .../converters/mil/frontend/milproto/load.py | 104 +- .../mil/frontend/milproto/test_load.py | 208 +- .../converters/mil/frontend/tensorflow/ops.py | 1 + .../mil/frontend/tensorflow/test/test_ops.py | 6 - .../test/test_tf2_conversion_api.py | 15 +- .../mil/frontend/torch/converter.py | 670 ++++- .../mil/frontend/torch/internal_graph.py | 6 +- .../converters/mil/frontend/torch/load.py | 4 +- .../converters/mil/frontend/torch/ops.py | 297 ++- .../mil/frontend/torch/quantization_ops.py | 4 +- .../torch/test/test_executorch_e2e.py | 678 +++-- .../torch/test/test_torch_conversion_api.py | 46 +- .../mil/frontend/torch/test/test_torch_ops.py | 534 +++- .../torch/test/test_torch_quantization_ops.py | 518 +++- .../torch/test/test_torch_stateful_model.py | 1181 +++++++++ .../mil/frontend/torch/test/testing_utils.py | 8 +- coremltools/converters/mil/input_types.py | 89 +- coremltools/converters/mil/mil/block.py | 2 + coremltools/converters/mil/mil/builder.py | 18 +- coremltools/converters/mil/mil/input_type.py | 12 +- coremltools/converters/mil/mil/operation.py | 8 +- .../converters/mil/mil/ops/defs/__init__.py | 2 +- .../converters/mil/mil/ops/defs/_utils.py | 87 - .../mil/mil/ops/defs/complex_dialect_ops.py | 45 +- .../mil/ops/defs/coreml_dialect/__init__.py | 6 + .../mil/mil/ops/defs/coreml_dialect/ops.py | 67 + .../mil/mil/ops/defs/iOS15/control_flow.py | 22 + .../mil/mil/ops/defs/iOS15/recurrent.py | 3 - .../mil/ops/defs/iOS15/tensor_operation.py | 27 +- .../mil/mil/ops/defs/iOS16/constexpr_ops.py | 6 +- .../mil/mil/ops/defs/iOS18/__init__.py | 22 + .../mil/mil/ops/defs/iOS18/compression.py | 791 ++++++ .../mil/mil/ops/defs/iOS18/recurrent.py | 35 + .../mil/mil/ops/defs/iOS18/states.py | 43 + .../ops/defs/iOS18/tensor_transformation.py | 152 ++ .../mil/mil/ops/defs/iOS18/transformers.py | 166 ++ .../converters/mil/mil/ops/registry.py | 1 + .../mil/ops/tests/coreml_dialect/__init__.py | 4 + .../coreml_dialect/test_coreml_dialect.py | 67 + .../mil/ops/tests/iOS14/test_control_flow.py | 25 + .../mil/mil/ops/tests/iOS14/test_conv.py | 18 +- .../ops/tests/iOS14/test_image_resizing.py | 26 +- .../ops/tests/iOS14/test_tensor_operation.py | 13 + .../tests/iOS14/test_tensor_transformation.py | 2 +- .../mil/ops/tests/iOS16/test_constexpr_ops.py | 10 +- .../mil/ops/tests/iOS17/test_quantization.py | 155 +- .../ops/tests/iOS17/test_scatter_gather.py | 69 +- .../mil/mil/ops/tests/iOS18/__init__.py | 9 + .../mil/ops/tests/iOS18/test_compression.py | 1961 ++++++++++++++ .../mil/mil/ops/tests/iOS18/test_recurrent.py | 178 ++ .../mil/mil/ops/tests/iOS18/test_states.py | 342 +++ .../tests/iOS18/test_tensor_transformation.py | 545 ++++ .../mil/ops/tests/iOS18/test_transformers.py | 419 +++ .../mil/mil/ops/tests/test_utils.py | 48 - .../mil/mil/ops/tests/testing_utils.py | 54 +- .../converters/mil/mil/passes/__init__.py | 1 + .../defs/cleanup/const_deduplication.py | 66 +- .../defs/cleanup/dead_code_elimination.py | 5 + .../defs/cleanup/expand_dynamic_linear.py | 11 +- .../defs/cleanup/topological_reorder.py | 4 + .../defs/optimize_activation_quantization.py | 419 +++ .../mil/passes/defs/optimize_normalization.py | 7 +- .../mil/passes/defs/optimize_quantization.py | 140 +- .../mil/mil/passes/defs/preprocess.py | 17 +- .../mil/mil/passes/defs/quantization.py | 2 + .../mil/mil/passes/defs/randomize.py | 59 + .../converters/mil/mil/passes/graph_pass.py | 2 +- .../mil/mil/passes/pass_pipeline.py | 9 +- .../mil/passes/tests/test_cleanup_passes.py | 141 +- .../mil/passes/tests/test_pass_pipeline.py | 7 + .../mil/mil/passes/tests/test_passes.py | 134 +- .../passes/tests/test_quantization_passes.py | 216 +- coremltools/converters/mil/mil/program.py | 54 +- coremltools/converters/mil/mil/scope.py | 7 +- .../converters/mil/mil/tests/test_programs.py | 21 +- .../converters/mil/mil/tests/test_types.py | 90 + .../converters/mil/mil/types/__init__.py | 30 +- .../converters/mil/mil/types/type_int.py | 40 +- .../converters/mil/mil/types/type_mapping.py | 75 +- .../converters/mil/mil/types/type_state.py | 48 + .../converters/mil/mil/types/type_tensor.py | 6 + coremltools/converters/mil/mil/var.py | 37 +- coremltools/converters/mil/testing_reqs.py | 8 +- coremltools/converters/mil/testing_utils.py | 101 +- coremltools/models/_compiled_model.py | 69 +- coremltools/models/_deprecation.py | 2 +- .../models/ml_program/compression_utils.py | 2 +- coremltools/models/model.py | 166 +- .../neural_network/flexible_shape_utils.py | 78 +- coremltools/models/utils.py | 535 +++- coremltools/optimize/__init__.py | 2 - coremltools/optimize/coreml/__init__.py | 1 + coremltools/optimize/coreml/_config.py | 277 +- .../coreml/_post_training_quantization.py | 191 +- .../optimize/coreml/_quantization_passes.py | 875 ++++-- coremltools/optimize/coreml/_utils.py | 422 ++- .../optimize/coreml/experimental/__init__.py | 7 + .../optimize/coreml/experimental/_config.py | 97 + .../coreml/experimental/_model_debugger.py | 332 +++ .../_post_training_quantization.py | 237 ++ .../experimental/_quantization_passes.py | 251 ++ coremltools/optimize/torch/__init__.py | 3 +- coremltools/optimize/torch/_logging.py | 2 +- coremltools/optimize/torch/_typing.py | 2 +- coremltools/optimize/torch/_utils/__init__.py | 2 +- .../optimize/torch/_utils/dist_utils.py | 36 + .../optimize/torch/_utils/fsdp_utils.py | 70 + coremltools/optimize/torch/_utils/k_means.py | 921 +++++++ .../optimize/torch/_utils/math_utils.py | 2 +- .../optimize/torch/_utils/metadata_utils.py | 138 + .../optimize/torch/_utils/python_utils.py | 64 + coremltools/optimize/torch/_utils/registry.py | 109 + .../optimize/torch/_utils/report_utils.py | 102 + .../optimize/torch/_utils/state_dict_utils.py | 2 +- .../optimize/torch/_utils/torch_utils.py | 112 +- .../optimize/torch/_utils/validation_utils.py | 172 ++ .../optimize/torch/_utils/version_utils.py | 6 +- .../optimize/torch/base_model_optimizer.py | 86 +- .../torch/layerwise_compression/__init__.py | 98 + .../torch/layerwise_compression/_quant.py | 211 ++ .../torch/layerwise_compression/algorithms.py | 686 +++++ .../layerwise_compression/input_cacher.py | 184 ++ .../layerwise_compressor.py | 424 +++ .../optimize/torch/optimization_config.py | 113 +- .../optimize/torch/palettization/__init__.py | 35 +- .../torch/palettization/_custom_conversion.py | 112 +- .../torch/palettization/_efficient_kmeans.py | 348 +-- .../_fake_palettizer_tensor_hook.py | 338 ++- .../torch/palettization/_partitioner.py | 368 ++- .../torch/palettization/_supported_modules.py | 40 +- .../optimize/torch/palettization/_utils.py | 52 + .../torch/palettization/fake_palettize.py | 688 +++-- .../palettization/palettization_config.py | 175 +- .../torch/palettization/palettizer.py | 49 +- .../post_training_palettization.py | 327 +++ .../torch/palettization/sensitive_k_means.py | 680 +++++ .../optimize/torch/pruning/__init__.py | 5 +- .../optimize/torch/pruning/_base_pruner.py | 18 +- .../torch/pruning/_base_pruning_method.py | 2 +- coremltools/optimize/torch/pruning/_utils.py | 76 +- .../torch/pruning/magnitude_pruner.py | 10 +- .../torch/pruning/pruning_scheduler.py | 2 +- .../optimize/torch/quantization/__init__.py | 19 +- .../quantization/_annotation_handler_utils.py | 726 +++++ .../torch/quantization/_backend_config.py | 4 +- .../quantization/_backend_config_utils.py | 20 +- .../optimize/torch/quantization/_configure.py | 5 +- .../quantization/_coreml_quantizer_utils.py | 11 +- .../torch/quantization/_qconfig_mapping.py | 15 +- .../optimize/torch/quantization/_utils.py | 29 +- .../torch/quantization/modules/__init__.py | 2 +- .../quantization/modules/fused_modules.py | 3 +- .../torch/quantization/modules/qat_modules.py | 29 +- .../quantization/modules/quantized_modules.py | 14 +- .../post_training_quantization.py | 460 ++++ .../torch/quantization/quantization_config.py | 67 +- .../optimize/torch/quantization/quantizer.py | 38 +- coremltools/proto/FeatureTypes_pb2.py | 493 ++-- coremltools/proto/MIL_pb2.py | 2334 ++++++++++------- coremltools/proto/Model_pb2.py | 1864 +++++++++---- coremltools/test/api/test_api_visibilities.py | 7 + coremltools/test/blob/test_weights.py | 236 +- coremltools/test/ml_program/test_utils.py | 894 +++++++ .../test/modelpackage/test_modelpackage.py | 74 +- .../test/optimize/coreml/test_passes.py | 926 ++++++- .../coreml/test_post_training_quantization.py | 1201 ++++++++- .../test/optimize/coreml/test_utils.py | 173 ++ coremltools/test/optimize/torch/__init__.py | 2 +- coremltools/test/optimize/torch/conftest.py | 65 +- .../optimize/torch/conversion/__init__.py | 4 + .../torch/conversion/conversion_utils.py | 109 + .../torch/conversion/joint/__init__.py | 4 + .../test_joint_compression_conversion.py | 99 + .../conversion/palettization/__init__.py | 4 + .../test_palettization_conversion.py | 399 +++ .../torch/conversion/pruning/__init__.py | 4 + .../pruning/test_pruning_conversion.py | 88 + .../torch/conversion/quantization/__init__.py | 4 + .../test_quantization_conversion.py | 173 ++ .../torch/layerwise_compression/__init__.py | 4 + .../layerwise_compression/test_algorithms.py | 285 ++ .../test/optimize/torch/models/__init__.py | 2 +- .../test/optimize/torch/models/mnist.py | 129 +- .../optimize/torch/palettization/__init__.py | 2 +- .../palettization/palettization_utils.py | 13 +- .../palettization/test_palettization_api.py | 120 +- .../torch/palettization/test_palettizer.py | 109 + .../test_post_training_palettization.py | 229 ++ .../palettization/test_sensitive_k_means.py | 358 +++ .../test/optimize/torch/pruning/__init__.py | 2 +- .../optimize/torch/pruning/pruning_utils.py | 61 +- .../torch/pruning/test_magnitude_pruner.py | 103 +- .../torch/pruning/test_pruning_scheduler.py | 2 +- .../optimize/torch/quantization/__init__.py | 3 +- .../torch/quantization/test_configure.py | 81 +- .../test_post_training_quantization.py | 235 ++ .../torch/quantization/test_quantizer.py | 122 +- .../optimize/torch/quantization/test_utils.py | 39 + coremltools/test/optimize/torch/smoke_test.py | 31 + .../test/optimize/torch/test_api_surface.py | 49 +- .../optimize/torch/test_base_optimizer.py | 41 +- .../optimize/torch/test_utils/__init__.py | 4 + .../torch/test_utils/test_fsdp_utils.py | 28 + .../optimize/torch/test_utils/test_k_means.py | 262 ++ .../torch/test_utils/test_metadata_utils.py | 158 ++ .../torch/test_utils/test_report_utils.py | 314 +++ .../torch/test_utils/test_validation_utils.py | 134 + coremltools/test/optimize/torch/utils.py | 83 +- coremltools/version.py | 2 +- ...coremltools.converters.mil.input_types.rst | 6 + ...oremltools.converters.mil.mil.ops.defs.rst | 28 + milstoragepython/MilStorage.cpp | 100 +- milstoragepython/MilStorage.hpp | 16 + milstoragepython/MilStoragePython.cpp | 18 + mlmodel/CMakeLists.txt | 3 + mlmodel/build/format/CategoricalMapping.pb.h | 3 + .../format/ClassConfidenceThresholding.pb.h | 3 + mlmodel/build/format/DataStructures.pb.h | 43 +- mlmodel/build/format/DictVectorizer.pb.h | 3 + mlmodel/build/format/FeatureTypes.pb.cc | 413 ++- mlmodel/build/format/FeatureTypes.pb.h | 267 +- mlmodel/build/format/FeatureTypes_enums.h | 19 + mlmodel/build/format/GLMClassifier.pb.h | 7 +- mlmodel/build/format/Gazetteer.pb.h | 37 +- mlmodel/build/format/Imputer.pb.h | 3 + .../format/ItemSimilarityRecommender.pb.h | 123 +- mlmodel/build/format/LinkedModel.pb.h | 19 +- mlmodel/build/format/MIL.pb.cc | 520 +++- mlmodel/build/format/MIL.pb.h | 423 ++- mlmodel/build/format/MIL_enums.h | 11 + mlmodel/build/format/Model.pb.cc | 1183 +++++++-- mlmodel/build/format/Model.pb.h | 853 +++++- mlmodel/build/format/NearestNeighbors.pb.h | 23 +- mlmodel/build/format/NeuralNetwork.pb.h | 1429 +++++----- .../build/format/NonMaximumSuppression.pb.h | 105 +- mlmodel/build/format/OneHotEncoder.pb.h | 7 +- mlmodel/build/format/Parameters.pb.h | 25 +- mlmodel/build/format/SVM.pb.h | 45 +- mlmodel/build/format/TextClassifier.pb.h | 37 +- mlmodel/build/format/TreeEnsemble.pb.h | 47 +- mlmodel/build/format/WordEmbedding.pb.h | 37 +- mlmodel/build/format/WordTagger.pb.h | 101 +- mlmodel/format/DataStructures.proto | 2 +- mlmodel/format/FeatureTypes.proto | 8 +- mlmodel/format/Gazetteer.proto | 6 +- mlmodel/format/LinkedModel.proto | 2 - mlmodel/format/MIL.proto | 14 + mlmodel/format/Model.proto | 142 +- mlmodel/format/NeuralNetwork.proto | 48 +- mlmodel/format/NonMaximumSuppression.proto | 2 +- mlmodel/format/README.rst | 6 +- mlmodel/src/Comparison.cpp | 242 +- mlmodel/src/DataType.cpp | 49 +- mlmodel/src/Globals.hpp | 5 +- mlmodel/src/MILBlob/Blob/BlobDataType.hpp | 68 +- mlmodel/src/MILBlob/Blob/StorageFormat.hpp | 25 +- mlmodel/src/MILBlob/Blob/StorageReader.cpp | 114 +- mlmodel/src/MILBlob/Blob/StorageReader.hpp | 30 +- mlmodel/src/MILBlob/Blob/StorageWriter.cpp | 75 +- mlmodel/src/MILBlob/Blob/StorageWriter.hpp | 22 + mlmodel/src/MILBlob/Fp8.cpp | 188 ++ mlmodel/src/MILBlob/Fp8.hpp | 107 + mlmodel/src/MILBlob/SubByteTypeList.hpp | 13 + mlmodel/src/MILBlob/SubByteTypes.cpp | 209 ++ mlmodel/src/MILBlob/SubByteTypes.hpp | 134 + mlmodel/src/MILBlob/Util/Span.hpp | 192 +- mlmodel/src/MILBlob/Util/SpanCast.hpp | 37 + .../MILBlob/Util/SubByteConversionUtils.hpp | 41 + mlmodel/src/Model.cpp | 97 +- mlmodel/src/Model.hpp | 38 +- mlmodel/src/ResultType.hpp | 11 +- mlmodel/src/TreeEnsembleCommon.cpp | 52 +- mlmodel/src/Utils.cpp | 74 +- mlmodel/src/Utils.hpp | 5 +- .../CategoricalMappingValidator.cpp | 16 +- .../Validation/FeatureVectorizerValidator.cpp | 16 +- .../src/Validation/InterfaceValidators.cpp | 210 +- .../KNearestNeighborsClassifierValidator.cpp | 16 +- .../NeuralNetworkLayerValidator.cpp | 18 +- .../NeuralNetwork/NeuralNetworkValidator.cpp | 150 +- .../NeuralNetworkValidatorUtils.hpp | 31 +- mlmodel/src/Validation/ScalarValidator.cpp | 50 +- mlmodel/src/Validation/ValidatorUtils-inl.hpp | 51 +- mlmodel/src/Validation/Validators.hpp | 126 +- mlmodel/src/transforms/LinearModel.hpp | 6 +- mlmodel/src/transforms/LogisticModel.hpp | 2 +- mlmodel/src/transforms/TreeEnsemble.hpp | 14 +- mlmodel/tests/InterfaceTests.cpp | 351 ++- mlmodel/tests/MILBlob/AutoDeleteTempFile.cpp | 4 +- mlmodel/tests/MILBlob/AutoDeleteTempFile.hpp | 3 +- mlmodel/tests/MILBlob/BlobUtils.cpp | 192 ++ mlmodel/tests/MILBlob/FileWriterTests.cpp | 1 - mlmodel/tests/MILBlob/MMapFileReaderTests.cpp | 1 - mlmodel/tests/MILBlob/SpanCastTests.cpp | 34 + mlmodel/tests/MILBlob/SpanTests.cpp | 216 ++ mlmodel/tests/MILBlob/StorageReaderTests.cpp | 168 +- mlmodel/tests/MILBlob/StorageWriterTests.cpp | 253 +- mlmodel/tests/MLModelTests.hpp | 28 + mlmodel/tests/NNValidatorTests.cpp | 220 +- mlmodel/tests/UtilsTests.cpp | 84 + mlmodel/tests/framework/TestUtils.hpp | 4 +- reqs/test.pip | 15 +- scripts/test.sh | 8 +- 325 files changed, 41191 insertions(+), 6710 deletions(-) create mode 100644 coremltools/converters/mil/backend/mil/test_load.py delete mode 100644 coremltools/converters/mil/backend/mil/test_model_input_params.py create mode 100644 coremltools/converters/mil/frontend/torch/test/test_torch_stateful_model.py create mode 100644 coremltools/converters/mil/mil/ops/defs/coreml_dialect/__init__.py create mode 100644 coremltools/converters/mil/mil/ops/defs/coreml_dialect/ops.py create mode 100644 coremltools/converters/mil/mil/ops/defs/iOS18/__init__.py create mode 100644 coremltools/converters/mil/mil/ops/defs/iOS18/compression.py create mode 100644 coremltools/converters/mil/mil/ops/defs/iOS18/recurrent.py create mode 100644 coremltools/converters/mil/mil/ops/defs/iOS18/states.py create mode 100644 coremltools/converters/mil/mil/ops/defs/iOS18/tensor_transformation.py create mode 100644 coremltools/converters/mil/mil/ops/defs/iOS18/transformers.py create mode 100644 coremltools/converters/mil/mil/ops/tests/coreml_dialect/__init__.py create mode 100644 coremltools/converters/mil/mil/ops/tests/coreml_dialect/test_coreml_dialect.py create mode 100644 coremltools/converters/mil/mil/ops/tests/iOS18/__init__.py create mode 100644 coremltools/converters/mil/mil/ops/tests/iOS18/test_compression.py create mode 100644 coremltools/converters/mil/mil/ops/tests/iOS18/test_recurrent.py create mode 100644 coremltools/converters/mil/mil/ops/tests/iOS18/test_states.py create mode 100644 coremltools/converters/mil/mil/ops/tests/iOS18/test_tensor_transformation.py create mode 100644 coremltools/converters/mil/mil/ops/tests/iOS18/test_transformers.py create mode 100644 coremltools/converters/mil/mil/passes/defs/optimize_activation_quantization.py create mode 100644 coremltools/converters/mil/mil/passes/defs/randomize.py create mode 100644 coremltools/converters/mil/mil/types/type_state.py create mode 100644 coremltools/optimize/coreml/experimental/__init__.py create mode 100644 coremltools/optimize/coreml/experimental/_config.py create mode 100644 coremltools/optimize/coreml/experimental/_model_debugger.py create mode 100644 coremltools/optimize/coreml/experimental/_post_training_quantization.py create mode 100644 coremltools/optimize/coreml/experimental/_quantization_passes.py create mode 100644 coremltools/optimize/torch/_utils/dist_utils.py create mode 100644 coremltools/optimize/torch/_utils/fsdp_utils.py create mode 100644 coremltools/optimize/torch/_utils/k_means.py create mode 100644 coremltools/optimize/torch/_utils/metadata_utils.py create mode 100644 coremltools/optimize/torch/_utils/registry.py create mode 100644 coremltools/optimize/torch/_utils/report_utils.py create mode 100644 coremltools/optimize/torch/_utils/validation_utils.py create mode 100644 coremltools/optimize/torch/layerwise_compression/__init__.py create mode 100644 coremltools/optimize/torch/layerwise_compression/_quant.py create mode 100644 coremltools/optimize/torch/layerwise_compression/algorithms.py create mode 100644 coremltools/optimize/torch/layerwise_compression/input_cacher.py create mode 100644 coremltools/optimize/torch/layerwise_compression/layerwise_compressor.py create mode 100644 coremltools/optimize/torch/palettization/_utils.py create mode 100644 coremltools/optimize/torch/palettization/post_training_palettization.py create mode 100644 coremltools/optimize/torch/palettization/sensitive_k_means.py create mode 100644 coremltools/optimize/torch/quantization/_annotation_handler_utils.py create mode 100644 coremltools/optimize/torch/quantization/post_training_quantization.py create mode 100644 coremltools/test/ml_program/test_utils.py create mode 100644 coremltools/test/optimize/coreml/test_utils.py create mode 100644 coremltools/test/optimize/torch/conversion/__init__.py create mode 100644 coremltools/test/optimize/torch/conversion/conversion_utils.py create mode 100644 coremltools/test/optimize/torch/conversion/joint/__init__.py create mode 100644 coremltools/test/optimize/torch/conversion/joint/test_joint_compression_conversion.py create mode 100644 coremltools/test/optimize/torch/conversion/palettization/__init__.py create mode 100644 coremltools/test/optimize/torch/conversion/palettization/test_palettization_conversion.py create mode 100644 coremltools/test/optimize/torch/conversion/pruning/__init__.py create mode 100644 coremltools/test/optimize/torch/conversion/pruning/test_pruning_conversion.py create mode 100644 coremltools/test/optimize/torch/conversion/quantization/__init__.py create mode 100644 coremltools/test/optimize/torch/conversion/quantization/test_quantization_conversion.py create mode 100644 coremltools/test/optimize/torch/layerwise_compression/__init__.py create mode 100644 coremltools/test/optimize/torch/layerwise_compression/test_algorithms.py create mode 100644 coremltools/test/optimize/torch/palettization/test_palettizer.py create mode 100644 coremltools/test/optimize/torch/palettization/test_post_training_palettization.py create mode 100644 coremltools/test/optimize/torch/palettization/test_sensitive_k_means.py create mode 100644 coremltools/test/optimize/torch/quantization/test_post_training_quantization.py create mode 100644 coremltools/test/optimize/torch/quantization/test_utils.py create mode 100644 coremltools/test/optimize/torch/smoke_test.py create mode 100644 coremltools/test/optimize/torch/test_utils/__init__.py create mode 100644 coremltools/test/optimize/torch/test_utils/test_fsdp_utils.py create mode 100644 coremltools/test/optimize/torch/test_utils/test_k_means.py create mode 100644 coremltools/test/optimize/torch/test_utils/test_metadata_utils.py create mode 100644 coremltools/test/optimize/torch/test_utils/test_report_utils.py create mode 100644 coremltools/test/optimize/torch/test_utils/test_validation_utils.py create mode 100644 mlmodel/src/MILBlob/Fp8.cpp create mode 100644 mlmodel/src/MILBlob/Fp8.hpp create mode 100644 mlmodel/src/MILBlob/SubByteTypeList.hpp create mode 100644 mlmodel/src/MILBlob/SubByteTypes.cpp create mode 100644 mlmodel/src/MILBlob/SubByteTypes.hpp create mode 100644 mlmodel/src/MILBlob/Util/SubByteConversionUtils.hpp diff --git a/NOTICE.txt b/NOTICE.txt index ad9356512..df5f1b310 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -23,3 +23,26 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + + +This project contains content in the files coremltools/optimize/torch/layerwise_compression/_quant.py, +coremltools/optimize/torch/layerwise_compression/algorithms.py, +and coremltools/optimize/torch/layerwise_compression/layerwise_compressor.py which are adapted from +gtpq (https://github.com/IST-DASLab/gptq/). It also contains content in the file coremltools/optimize/torch/layerwise_compression/algorithms.py which is adapted from sparsegpt (https://github.com/IST-DASLab/sparsegpt). The license for these follows: + +Apache License 2.0 + +Copyright 2023 IST Austria Distributed Algorithms and Systems Lab + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/coremlpython/CoreMLPython.h b/coremlpython/CoreMLPython.h index 320ef6f6e..6bd6554f5 100644 --- a/coremlpython/CoreMLPython.h +++ b/coremlpython/CoreMLPython.h @@ -15,11 +15,33 @@ #import + +#ifndef BUILT_WITH_MACOS15_SDK +#define BUILT_WITH_MACOS15_SDK \ + !(TARGET_OS_OSX && (!defined(__MAC_15_0) || __MAC_OS_X_VERSION_MAX_ALLOWED < __MAC_15_0)) +#endif + +// Print BUILT_WITH_MACOS15_SDK value +#if BUILT_WITH_MACOS15_SDK +#pragma message ("Building with macOS 15+ SDK") +#else +#pragma message ("Building without macOS 15 SDK") +#endif + + namespace py = pybind11; namespace CoreML { namespace Python { + + struct State { +#if BUILT_WITH_MACOS15_SDK + // MLState must be wrapped in a C++ class for PyBind. + MLState* m_state = nil; +#endif + }; + class Model { private: MLModel *m_model = nil; @@ -35,13 +57,19 @@ namespace CoreML { Model(const Model&) = delete; Model& operator=(const Model&) = delete; ~Model(); - explicit Model(const std::string& urlStr, const std::string& computeUnits); + explicit Model(const std::string& urlStr, const std::string& computeUnits, const std::string& functionName); explicit Model(MLModel* m_model, NSURL* compiledUrl, bool deleteCompiledModelOnExit); - py::dict predict(const py::dict& input) const; py::list batchPredict(const py::list& batch) const; py::str getCompiledModelPath() const; + + py::dict predict(const py::dict& input, State* state=NULL) const; + +#if BUILT_WITH_MACOS15_SDK + State newState() const; +#endif + }; } } diff --git a/coremlpython/CoreMLPython.mm b/coremlpython/CoreMLPython.mm index 7f65f3af1..f818f4985 100644 --- a/coremlpython/CoreMLPython.mm +++ b/coremlpython/CoreMLPython.mm @@ -42,7 +42,7 @@ bool usingMacOS13OrHigher() { } } -Model::Model(const std::string& urlStr, const std::string& computeUnits) { +Model::Model(const std::string& urlStr, const std::string& computeUnits, const std::string& functionName) { @autoreleasepool { NSError *error = nil; @@ -80,6 +80,12 @@ bool usingMacOS13OrHigher() { MLModelConfiguration *configuration = [MLModelConfiguration new]; setComputeUnit(configuration, computeUnits); + if (!functionName.empty()) { +#if BUILT_WITH_MACOS15_SDK + configuration.functionName = [NSString stringWithUTF8String:functionName.c_str()]; +#endif + } + // Create MLModel m_model = [MLModel modelWithContentsOfURL:compiledUrl configuration:configuration error:&error]; Utils::handleError(error); @@ -94,13 +100,28 @@ bool usingMacOS13OrHigher() { { } -py::dict Model::predict(const py::dict& input) const { + +py::dict Model::predict(const py::dict& input, State* state) const { @autoreleasepool { NSError *error = nil; MLDictionaryFeatureProvider *inFeatures = Utils::dictToFeatures(input, &error); Utils::handleError(error); - id outFeatures = [m_model predictionFromFeatures:static_cast(inFeatures) - error:&error]; + + id outFeatures; +#if BUILT_WITH_MACOS15_SDK + if (state == NULL) { + outFeatures = [m_model predictionFromFeatures:static_cast(inFeatures) + error:&error]; + } else { + outFeatures = [m_model predictionFromFeatures:static_cast(inFeatures) + usingState:state->m_state + error:&error]; + } +#else + outFeatures = [m_model predictionFromFeatures:static_cast(inFeatures) + error:&error]; +#endif + Utils::handleError(error); return Utils::featuresToDict(outFeatures); } @@ -163,6 +184,15 @@ bool usingMacOS13OrHigher() { } +#if BUILT_WITH_MACOS15_SDK +State Model::newState() const { + State result; + result.m_state = [m_model newState]; + return result; +} +#endif + + py::bytes Model::autoSetSpecificationVersion(const py::bytes& modelBytes) { CoreML::Specification::Model model; @@ -207,14 +237,20 @@ bool usingMacOS13OrHigher() { py::module m("libcoremlpython", "CoreML.Framework Python bindings"); py::class_(m, "_MLModelProxy") - .def(py::init()) + .def(py::init()) .def("predict", &Model::predict) .def("batchPredict", &Model::batchPredict) .def("get_compiled_model_path", &Model::getCompiledModelPath) .def_static("auto_set_specification_version", &Model::autoSetSpecificationVersion) .def_static("maximum_supported_specification_version", &Model::maximumSupportedSpecificationVersion) +#if BUILT_WITH_MACOS15_SDK + .def("newState", &Model::newState) +#endif .def_static("compileModel", &Model::compileModel); + + py::class_(m, "_State", py::module_local()); + return m.ptr(); } diff --git a/coremltools/__init__.py b/coremltools/__init__.py index 30130e1ba..db16e8bf6 100644 --- a/coremltools/__init__.py +++ b/coremltools/__init__.py @@ -64,6 +64,9 @@ # New versions for iOS 17.0 _SPECIFICATION_VERSION_IOS_17 = 8 +# New versions for iOS 18.0 +_SPECIFICATION_VERSION_IOS_18 = 9 + class ComputeUnit(_Enum): ''' @@ -82,6 +85,7 @@ class ComputeUnit(_Enum): _SPECIFICATION_VERSION_IOS_15: "CoreML5", _SPECIFICATION_VERSION_IOS_16: "CoreML6", _SPECIFICATION_VERSION_IOS_17: "CoreML7", + _SPECIFICATION_VERSION_IOS_18: "CoreML8", } # Default specification version for each backend @@ -94,7 +98,7 @@ class ComputeUnit(_Enum): # expose unified converter in coremltools package level from .converters import ClassifierConfig from .converters import ColorLayout as colorlayout -from .converters import EnumeratedShapes, ImageType, RangeDim, Shape, TensorType, convert +from .converters import EnumeratedShapes, ImageType, RangeDim, Shape, StateType, TensorType, convert from .converters.mil._deployment_compatibility import AvailableTarget as target from .converters.mil.mil.passes.defs import quantization as transform from .converters.mil.mil.passes.defs.quantization import ComputePrecision as precision diff --git a/coremltools/_deps/__init__.py b/coremltools/_deps/__init__.py index 304a0a8df..a191f1700 100644 --- a/coremltools/_deps/__init__.py +++ b/coremltools/_deps/__init__.py @@ -153,7 +153,7 @@ def __get_sklearn_version(version): # --------------------------------------------------------------------------------------- _HAS_TORCH = True -_TORCH_MAX_VERSION = "2.2.0" +_TORCH_MAX_VERSION = "2.3.0" _HAS_TORCH_EXPORT_API = False try: import torch diff --git a/coremltools/converters/__init__.py b/coremltools/converters/__init__.py index bca49bbbc..3fc7f5d74 100644 --- a/coremltools/converters/__init__.py +++ b/coremltools/converters/__init__.py @@ -4,16 +4,15 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause # expose directories as imports -from . import libsvm -from . import sklearn -from . import xgboost +from . import libsvm, sklearn, xgboost from ._converters_entry import convert from .mil import ( ClassifierConfig, ColorLayout, - TensorType, + EnumeratedShapes, ImageType, RangeDim, Shape, - EnumeratedShapes, + StateType, + TensorType, ) diff --git a/coremltools/converters/_converters_entry.py b/coremltools/converters/_converters_entry.py index 7507c70e9..cc0d505dd 100644 --- a/coremltools/converters/_converters_entry.py +++ b/coremltools/converters/_converters_entry.py @@ -29,6 +29,7 @@ InputType, RangeDim, Shape, + StateType, TensorType, ) from coremltools.converters.mil.mil import Program, types @@ -73,6 +74,7 @@ def convert( package_dir=None, debug=False, pass_pipeline: Optional[PassPipeline] = None, + states=None, ): """ Convert a TensorFlow or PyTorch model to the Core ML model format as either @@ -403,7 +405,7 @@ def skip_real_div_ops(op): returned. An enum with the following possible values: - + * ``coremltools.ComputeUnit.ALL``: Use all compute units available, including the neural engine. * ``coremltools.ComputeUnit.CPU_ONLY``: Limit the model to only use the CPU. * ``coremltools.ComputeUnit.CPU_AND_GPU``: Use both the CPU and GPU, but not the neural engine. @@ -477,6 +479,50 @@ def skip_real_div_ops(op): mlmodel = ct.convert(model, pass_pipeline=ct.PassPipeline.DEFAULT_PALETTIZATION) + states: + Create a stateful ``mlprogram`` model + by providing the ``StateType`` in the ``states`` argument (for details see `MIL Input Types `_). + The stateful model is useful when converting a large language model with KV-Cache. + The name of ``StateType`` must match the key of the PyTorch ``named_buffers()`` method in the source traced model. + + The following example converts a torch model with a buffer called ``state_1``. + + .. sourcecode:: python + + class UpdateBufferModel(torch.nn.Module): + def __init__(self): + super(UpdateBufferModel, self).__init__() + self.register_buffer( + "state_1", torch.tensor(np.array([0, 0, 0], dtype=np.float32)) + ) + + def forward(self, x): + # In place update of the model state + self.state_1.add_(x) + return self.state_1 + + + model = UpdateBufferModel() + traced_model = torch.jit.trace(model, torch.tensor([1, 2, 3], dtype=torch.float32)) + + inputs = [ + ct.TensorType(shape=(1, 2)), + ] + states = [ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(1, 2), + ), + name="state_1", + ), + ] + mlmodel = ct.convert( + traced_model, + inputs=inputs, + states=states, + minimum_deployment_target=ct.target.iOS18, + ) + Returns ------- @@ -526,8 +572,7 @@ def skip_real_div_ops(op): >>> results = mlmodel.predict({"input": example_input.numpy()}) >>> print(results['1651']) # 1651 is the node name given by PyTorch's JIT - See `Conversion Options `_ for - more advanced options. + For more options see `Conversion Options `_. """ _check_deployment_target(minimum_deployment_target) outputs_as_strings, outputs_as_tensor_or_image_types = _validate_outputs_argument(outputs) @@ -578,6 +623,15 @@ def skip_real_div_ops(op): and need_fp16_cast_pass ) + # Verify the inputs cannot contains state + if states is None: + states = [] + _verify_inputs_doesnot_contains_states(inputs) + + # states can only passed if the source is pytorch + if len(states) > 0 and exact_source != "pytorch": + raise ValueError("'states' can only be passed with pytorch source model.") + mlmodel = mil_convert( model, convert_from=exact_source, @@ -592,6 +646,7 @@ def skip_real_div_ops(op): specification_version=specification_version, main_pipeline=pass_pipeline, use_default_fp16_io=use_default_fp16_io, + states=states, ) if exact_target == "mlprogram" and mlmodel._input_has_infinite_upper_bound(): @@ -658,6 +713,20 @@ def _check_deployment_target(minimum_deployment_target): raise TypeError(msg.format(minimum_deployment_target)) +def _verify_inputs_doesnot_contains_states( + inputs: List[InputType], +) -> None: + """ + Verify that StateType is not present in the inputs. + """ + if inputs is None: + return + + for val in inputs: + if isinstance(val, StateType): + raise ValueError("'inputs' cannot contain an instance of StateType.") + + def _validate_outputs_argument(outputs): """ - validate properties that the "outputs" argument must satisfy, for instance, it should either be a list @@ -848,9 +917,9 @@ def _flatten_list(_inputs): elif exact_source == "pytorch": if _HAS_TORCH_EXPORT_API and isinstance(model, ExportedProgram): - if model.dialect != "EDGE": + if model.dialect not in ("ATEN", "EDGE"): raise NotImplementedError( - f"Conversion for models with only EDGE dialect is supported/tested. Provided Dialect: {model.dialect}" + f"Conversion for models with only ATEN or EDGE dialect is supported/tested. Provided Dialect: {model.dialect}" ) # TODO: rdar://115845792 ([Executorch] Handle user provided inputs/outputs in the convert API) diff --git a/coremltools/converters/mil/Makefile b/coremltools/converters/mil/Makefile index e7234e48d..ee605e450 100644 --- a/coremltools/converters/mil/Makefile +++ b/coremltools/converters/mil/Makefile @@ -8,7 +8,6 @@ SRC_PACKAGES=. TF_IOS13_TEST=../tensorflow/test MIL_TEST="." -MIL_TEST_INTERNAL="../../../../coremltools-internal/coremltools_internal/converters/mil" .PHONY: all lint test style checkstyle @@ -26,7 +25,7 @@ lint: ${PYTHON} -m pylint -j 0 ${SRC_PACKAGES} test: - ${PYTHON} -m pytest -W ignore::DeprecationWarning ${MIL_TEST} ${MIL_TEST_INTERNAL} + ${PYTHON} -m pytest -W ignore::DeprecationWarning ${MIL_TEST} test_ref: ${PYTHON} -m pytest -W ignore::DeprecationWarning ${TF_IOS13_TEST} diff --git a/coremltools/converters/mil/__init__.py b/coremltools/converters/mil/__init__.py index 64a17d126..91337b788 100644 --- a/coremltools/converters/mil/__init__.py +++ b/coremltools/converters/mil/__init__.py @@ -11,6 +11,6 @@ get_existing_symbol, get_new_symbol, get_new_variadic_symbol, mil_list, register_op) from .input_types import (ClassifierConfig, ColorLayout, EnumeratedShapes, - ImageType, InputType, RangeDim, Shape, TensorType) + ImageType, InputType, RangeDim, Shape, TensorType, StateType) from .frontend.tensorflow.tf_op_registry import register_tf_op from .frontend.torch import register_torch_op diff --git a/coremltools/converters/mil/_deployment_compatibility.py b/coremltools/converters/mil/_deployment_compatibility.py index db0122111..450a752c2 100644 --- a/coremltools/converters/mil/_deployment_compatibility.py +++ b/coremltools/converters/mil/_deployment_compatibility.py @@ -11,6 +11,7 @@ _SPECIFICATION_VERSION_IOS_15, _SPECIFICATION_VERSION_IOS_16, _SPECIFICATION_VERSION_IOS_17, + _SPECIFICATION_VERSION_IOS_18, ) @@ -21,6 +22,7 @@ class AvailableTarget(IntEnum): iOS15 = _SPECIFICATION_VERSION_IOS_15 iOS16 = _SPECIFICATION_VERSION_IOS_16 iOS17 = _SPECIFICATION_VERSION_IOS_17 + iOS18 = _SPECIFICATION_VERSION_IOS_18 # macOS versions (aliases of iOS versions) macOS10_15 = _SPECIFICATION_VERSION_IOS_13 @@ -29,6 +31,7 @@ class AvailableTarget(IntEnum): macOS12 = _SPECIFICATION_VERSION_IOS_15 macOS13 = _SPECIFICATION_VERSION_IOS_16 macOS14 = _SPECIFICATION_VERSION_IOS_17 + macOS15 = _SPECIFICATION_VERSION_IOS_18 # watchOS versions (aliases of iOS versions) watchOS6 = _SPECIFICATION_VERSION_IOS_13 @@ -36,6 +39,7 @@ class AvailableTarget(IntEnum): watchOS8 = _SPECIFICATION_VERSION_IOS_15 watchOS9 = _SPECIFICATION_VERSION_IOS_16 watchOS10 = _SPECIFICATION_VERSION_IOS_17 + watchOS11 = _SPECIFICATION_VERSION_IOS_18 # tvOS versions (aliases of iOS versions) tvOS13 = _SPECIFICATION_VERSION_IOS_13 @@ -43,6 +47,7 @@ class AvailableTarget(IntEnum): tvOS15 = _SPECIFICATION_VERSION_IOS_15 tvOS16 = _SPECIFICATION_VERSION_IOS_16 tvOS17 = _SPECIFICATION_VERSION_IOS_17 + tvOS18 = _SPECIFICATION_VERSION_IOS_18 # customized __str__ def __str__(self): diff --git a/coremltools/converters/mil/backend/mil/helper.py b/coremltools/converters/mil/backend/mil/helper.py index c123e0ece..481dae351 100644 --- a/coremltools/converters/mil/backend/mil/helper.py +++ b/coremltools/converters/mil/backend/mil/helper.py @@ -10,7 +10,18 @@ from coremltools.converters.mil.mil import types # For immediate values, those types are stored in bytes (MIL parser reads those types from bytes). -IMMEDIATE_VALUE_TYPES_IN_BYTES = (types.fp16, types.int8, types.uint8, types.uint32) +IMMEDIATE_VALUE_TYPES_IN_BYTES = ( + types.fp16, + types.int4, + types.int8, + types.uint1, + types.uint2, + types.uint3, + types.uint4, + types.uint6, + types.uint8, + types.uint32, +) def create_valuetype_scalar(data_type): @@ -251,8 +262,21 @@ def types_to_proto_primitive(valuetype): ) return types.BUILTIN_TO_PROTO_TYPES[valuetype] + def _get_offset_by_writing_data(output_var, blob_writer): - if output_var.val.dtype.kind == 'f' and output_var.val.dtype.itemsize == 4: + if output_var.dtype == types.int4: + offset = blob_writer.write_int4_data(np.ascontiguousarray(output_var.val.flatten())) + elif output_var.dtype == types.uint1: + offset = blob_writer.write_uint1_data(np.ascontiguousarray(output_var.val.flatten())) + elif output_var.dtype == types.uint2: + offset = blob_writer.write_uint2_data(np.ascontiguousarray(output_var.val.flatten())) + elif output_var.dtype == types.uint3: + offset = blob_writer.write_uint3_data(np.ascontiguousarray(output_var.val.flatten())) + elif output_var.dtype == types.uint4: + offset = blob_writer.write_uint4_data(np.ascontiguousarray(output_var.val.flatten())) + elif output_var.dtype == types.uint6: + offset = blob_writer.write_uint6_data(np.ascontiguousarray(output_var.val.flatten())) + elif output_var.val.dtype.kind == "f" and output_var.val.dtype.itemsize == 4: offset = blob_writer.write_float_data(np.ascontiguousarray(output_var.val.flatten())) elif output_var.val.dtype.kind == "f" and output_var.val.dtype.itemsize == 2: output_var_fp16_to_bytes_to_uint16 = np.frombuffer( @@ -269,6 +293,10 @@ def _get_offset_by_writing_data(output_var, blob_writer): offset = blob_writer.write_uint16_data(np.ascontiguousarray(output_var.val.flatten())) elif output_var.val.dtype.kind == "i" and output_var.val.dtype.itemsize == 2: offset = blob_writer.write_int16_data(np.ascontiguousarray(output_var.val.flatten())) + elif output_var.val.dtype.kind == "i" and output_var.val.dtype.itemsize == 4: + offset = blob_writer.write_int32_data(np.ascontiguousarray(output_var.val.flatten())) + elif output_var.val.dtype.kind == "u" and output_var.val.dtype.itemsize == 4: + offset = blob_writer.write_uint32_data(np.ascontiguousarray(output_var.val.flatten())) else: raise TypeError("Unsupported type, {}, for net buffer serialization.".format(output_var.val.dtype)) diff --git a/coremltools/converters/mil/backend/mil/load.py b/coremltools/converters/mil/backend/mil/load.py index 2d4742a2a..d59f64f7e 100644 --- a/coremltools/converters/mil/backend/mil/load.py +++ b/coremltools/converters/mil/backend/mil/load.py @@ -6,11 +6,16 @@ import os import warnings from collections import OrderedDict -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import numpy as np -from coremltools import _OPSET, _SPECIFICATION_VERSION_IOS_15, _SPECIFICATION_VERSION_IOS_17 +from coremltools import ( + _OPSET, + _SPECIFICATION_VERSION_IOS_15, + _SPECIFICATION_VERSION_IOS_17, + _SPECIFICATION_VERSION_IOS_18, +) from coremltools import _logger as logger from coremltools import proto from coremltools.converters.mil import mil @@ -65,6 +70,7 @@ def should_use_weight_file(val): and val.dtype in ['float16', 'float32', 'uint8', 'int8'] ) + class MILProtoExporter: """ An utility class to export a pymil program to milproto. @@ -78,6 +84,7 @@ def __init__( self.prog = prog self.weights_dir = weights_dir self.blob_writers = {} + self.weight_id_to_file_value = {} # mapping from weight_id to file value self.prog.validate(check_essential_scope=True) def translate_program_attributes(self) -> Dict[str, Any]: @@ -107,20 +114,44 @@ def get_blob_writer(self, weight_path: str) -> BlobWriter: def create_file_value(self, var: Var) -> proto.MIL_pb2.Value: """ Returns the mil proto file value of a var. + If weight_id is in self.weight_id_to_file_value, we return the value. """ - weight_path = self.get_weight_path(var.op) - blob_writer = self.get_blob_writer(weight_path) - offset = helper._get_offset_by_writing_data(var, blob_writer) - weight_file_name = os.path.basename(weight_path) - - return create_file_value_tensor( - file_name=os.path.join( - os.path.join("@model_path", _WEIGHTS_DIR_NAME), weight_file_name - ), - offset=offset, - dim=var.val.shape, - data_type=types_to_proto_primitive(var.sym_type.get_primitive()), - ) + + def create_file_value_helper(): + weight_path = self.get_weight_path(var.op) + blob_writer = self.get_blob_writer(weight_path) + offset = helper._get_offset_by_writing_data(var, blob_writer) + weight_file_name = os.path.basename(weight_path) + + # Get proto type for the primitive + if hasattr(var.sym_type, "get_primitive"): # tensor + primitive = var.sym_type.get_primitive() + else: # scalar + primitive = var.sym_type + proto_primitive = types_to_proto_primitive(primitive) + + return create_file_value_tensor( + file_name=os.path.join( + os.path.join("@model_path", _WEIGHTS_DIR_NAME), weight_file_name + ), + offset=offset, + dim=var.val.shape, + data_type=proto_primitive, + ) + + # use the cached file value + weight_id = var.op.weight_id + if weight_id is None: + return create_file_value_helper() + + if weight_id in self.weight_id_to_file_value: + assert weight_id is not None, "invalid weight_id" + return self.weight_id_to_file_value[weight_id] + + file_value = create_file_value_helper() + self.weight_id_to_file_value[weight_id] = file_value + + return file_value def get_milproto_value(self, var: Var) -> proto.MIL_pb2.Value: """ @@ -241,9 +272,65 @@ def types_to_proto(self, valuetype: type) -> proto.MIL_pb2.ValueType: return create_valuetype_list(length=length, elem_shape=elem_shape, dtype=dtype) elif types.is_dict(valuetype): return self.create_valuetype_dict(valuetype.T[0], valuetype.T[1]) + elif types.is_state(valuetype): + wrapped_type = valuetype.wrapped_type() + v_type = proto.MIL_pb2.ValueType() + v_type.stateType.wrappedType.CopyFrom(self.types_to_proto(wrapped_type)) + return v_type else: return create_valuetype_scalar(types_to_proto_primitive(valuetype)) + def translate_coreml_update_state_op(self, op: Operation) -> List[proto.MIL_pb2.Operation]: + """ + ``coreml_update_state`` is decomposed into ``write_state`` and ``read_state``. + """ + + def get_input_binding(param_name: str) -> proto.MIL_pb2.Argument: + arguments = [proto.MIL_pb2.Argument.Binding(name=op.inputs[param_name].name)] + args = proto.MIL_pb2.Argument() + args.arguments.extend(arguments) + return args + + res = [] + + # write_state + write_state_attrs = {"name": create_scalar_value(op.name + "_write_state")} + write_state_inputs = { + "input": get_input_binding("state"), + "data": get_input_binding("value"), + } + res.append( + proto.MIL_pb2.Operation( + type="write_state", + inputs=write_state_inputs, + attributes=write_state_attrs, + ) + ) + + # If the coreml_update_state is not feed into any ops or is not block outputs, + # we don't need the read_state op + if len(op.outputs[0].child_ops) == 0 and len(op.outputs[0].consuming_blocks) == 0: + return res + + # read_state + read_state_attrs = {"name": create_scalar_value(op.name)} + read_state_inputs = { + "input": get_input_binding("state"), + } + outputs = [ + proto.MIL_pb2.NamedValueType(name=v.name, type=self.types_to_proto(v.sym_type)) + for v in op.outputs + ] + res.append( + proto.MIL_pb2.Operation( + type="read_state", + inputs=read_state_inputs, + attributes=read_state_attrs, + outputs=outputs, + ) + ) + return res + def translate_generic_op( self, op: Operation, literal_params: Optional[List[str]] = None ) -> proto.MIL_pb2.Operation: @@ -358,12 +445,10 @@ def feeds_to_only_constexprs(op: Operation) -> bool: # rdar://98689808 (Reshape_like should also accept const value from non literal input) literal_params = ["begins", "ends", "end_masks"] proto_ops.append(self.translate_generic_op(op, literal_params)) + elif op_cls_name == "coreml_update_state": + proto_ops.extend(self.translate_coreml_update_state_op(op)) else: - # A single pymil op might be decomposed into multiple ops - ops = self.translate_generic_op(op) - if not isinstance(ops, list): - ops = [ops] - proto_ops.extend(ops) + proto_ops.append(self.translate_generic_op(op)) inputs = [] if not isinstance(block, Function): @@ -491,8 +576,6 @@ class CoreMLProtoExporter: An utility class to export a pymil program to coreml model. """ - _DEFAULT_FUNCTION_NAME = "main" - def __init__( self, prog: mil.Program, @@ -502,6 +585,7 @@ def __init__( classifier_config: ClassifierConfig, convert_to: str, convert_from: str, + export_multi_functions: bool, ): self.prog = prog self.mil_proto = mil_proto @@ -510,23 +594,26 @@ def __init__( self.classifier_config = classifier_config self.convert_to = convert_to self.convert_from = convert_from + self.export_multi_functions = export_multi_functions self.prog.validate(check_essential_scope=True) @staticmethod - def get_additional_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: + def _decouple_state_and_input( + input_features: List[proto.Model_pb2.FeatureDescription], + ) -> Tuple[List[proto.Model_pb2.FeatureDescription], List[proto.Model_pb2.FeatureDescription]]: """ - Get additional coreml proto related kwargs. + Utils seperates state input from non-state input features. """ - return {} + state_features = [] + non_state_input_features = [] - @staticmethod - def _try_convert_other_input_type( - input_var: Var, input_features: List[proto.Model_pb2.FeatureDescription] - ) -> bool: - """ - Try to convert an input var with additional type. - """ - return False + for input in input_features: + if input.type.WhichOneof("Type") == "stateType": + state_features.append(input) + else: + non_state_input_features.append(input) + + return state_features, non_state_input_features def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDescription]: """ @@ -615,7 +702,36 @@ def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDesc input_features.append( proto.Model_pb2.FeatureDescription(name=var.name, type=input_feature_type) ) - elif not self._try_convert_other_input_type(var, input_features): + elif types.is_state(var.sym_type): + # shape for state input cannot be symbolic + shape = var.sym_type.wrapped_type().get_shape() + if any_variadic(shape): + raise ValueError("Variable rank model states are not supported!") + if any_symbolic(shape): + raise ValueError("Flexible shape model states are not supported!") + + # Core ML only support fp16 for state + if not var.dtype == types.fp16: + raise ValueError( + f"State only support fp16 dtype. Got input var {var.name} with dtype {types.builtin_to_string(var.dtype)}." + ) + + # create the input feature type + array_type = proto.FeatureTypes_pb2.ArrayFeatureType( + shape=shape, dataType=cast_to_framework_io_dtype(var, False) + ) + + state_feature_type = proto.FeatureTypes_pb2.StateFeatureType() + state_feature_type.arrayType.CopyFrom(array_type) + + input_feature_type = proto.FeatureTypes_pb2.FeatureType() + input_feature_type.stateType.CopyFrom(state_feature_type) + + # append feature to the input features list + input_features.append( + proto.Model_pb2.FeatureDescription(name=var.name, type=input_feature_type) + ) + else: raise NotImplementedError(f"Unsupported input type {var.sym_type}.") if not is_input_shape_symbolic: @@ -713,7 +829,6 @@ def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDesc 'There is "None" dim in TF input placeholder. Please consider specifying ' 'input shapes by using the "inputs" param in ct.convert().' ) - return input_features def get_func_output(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDescription]: @@ -807,16 +922,6 @@ def get_func_output(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDes return output_features - def create_model_description( - self, - input_features: List[proto.Model_pb2.FeatureDescription], - output_features: List[proto.Model_pb2.FeatureDescription], - ) -> proto.Model_pb2.ModelDescription: - """ - Create model description from input and output features - """ - return proto.Model_pb2.ModelDescription(input=input_features, output=output_features) - def get_coreml_model( self, input: Dict[str, List[proto.Model_pb2.FeatureDescription]], @@ -825,25 +930,64 @@ def get_coreml_model( ) -> proto.Model_pb2.Model: """ Utils to get a coreml model description. + For the multifunction export, we utilize the FunctionDescription proto message. """ - # Model description - input_features = input[self._DEFAULT_FUNCTION_NAME] - output_features = output[self._DEFAULT_FUNCTION_NAME] - desc = self.create_model_description(input_features, output_features) - - if self.classifier_config is not None: - desc.predictedFeatureName = self.predicted_feature_name - desc.predictedProbabilitiesName = self.predicted_probabilities_name - - # Manually edit output type of predictedFeatureName. - # It doesn't use MLMultiArray and really uses a "primitive" type. - for output in desc.output: - if output.name == self.predicted_feature_name: - if type(self.classifier_config.class_labels[0]) == int: - output.type.int64Type.MergeFromString(b"") - else: - output.type.stringType.MergeFromString(b"") - break + if self.export_multi_functions: + # For multifunction export, we use the FunctionDescription + if specification_version < _SPECIFICATION_VERSION_IOS_18: + raise ValueError( + "minimum_deployment_target for multi-functions export should be iOS18+." + ) + + if self.classifier_config is not None: + # TODO: This should be fixed in rdar://123660416 ([New Feature][Multi-functions] Enable classifier for multi-functions CoreML model) + raise NotImplementedError("classifier model not supported in multi-functions export.") + + function_desc = [] + for func_name in input.keys(): + state_features, non_state_input_features = self._decouple_state_and_input( + input[func_name] + ) + desc = proto.Model_pb2.FunctionDescription( + name=func_name, + input=non_state_input_features, + output=output[func_name], + state=state_features, + ) + function_desc.append(desc) + + desc = proto.Model_pb2.ModelDescription( + functions=function_desc, + defaultFunctionName=self.prog.default_function_name, + ) + + else: + # single function export + input_features = input[self.prog.default_function_name] + output_features = output[self.prog.default_function_name] + state_features, non_state_input_features = self._decouple_state_and_input( + input_features + ) + + desc = proto.Model_pb2.ModelDescription( + input=non_state_input_features, + output=output_features, + state=state_features, + ) + + if self.classifier_config is not None: + desc.predictedFeatureName = self.predicted_feature_name + desc.predictedProbabilitiesName = self.predicted_probabilities_name + + # Manually edit output type of predictedFeatureName. + # It doesn't use MLMultiArray and really uses a "primitive" type. + for output in desc.output: + if output.name == self.predicted_feature_name: + if type(self.classifier_config.class_labels[0]) == int: + output.type.int64Type.MergeFromString(b"") + else: + output.type.stringType.MergeFromString(b"") + break # Create ML Model model = proto.Model_pb2.Model(description=desc, specificationVersion=specification_version) @@ -871,7 +1015,8 @@ def export( ) # Set optional inputs for main function - _set_optional_inputs(model, self.prog.functions["main"].input_types) + if "main" in self.prog.functions: + _set_optional_inputs(model, self.prog.functions["main"].input_types) return model @@ -883,8 +1028,8 @@ def load( specification_version: Optional[int] = _SPECIFICATION_VERSION_IOS_15, **kwargs, ) -> proto.Model_pb2.Model: - if "main" not in prog.functions: - raise ValueError("main function not found in program") + if prog.default_function_name not in prog.functions: + raise ValueError(f"Default function {prog.default_function_name} not found in program") # if user has specified "ClassifierConfig", then add the "classify" op to the prog classifier_config = kwargs.get("classifier_config", None) @@ -914,7 +1059,6 @@ def load( return model # create a CoreML model protobuf - exporter_kwargs = CoreMLProtoExporter.get_additional_kwargs(kwargs) coreml_proto_exporter = CoreMLProtoExporter( prog, mil_proto, @@ -923,6 +1067,6 @@ def load( classifier_config=kwargs.get("classifier_config", None), convert_to=kwargs.get("convert_to", None), convert_from=kwargs.get("convert_from", None), - **exporter_kwargs, + export_multi_functions=kwargs.get("export_multi_functions", False), ) return coreml_proto_exporter.export(specification_version) diff --git a/coremltools/converters/mil/backend/mil/passes/adjust_io_to_supported_types.py b/coremltools/converters/mil/backend/mil/passes/adjust_io_to_supported_types.py index acbc729d8..8a107b518 100644 --- a/coremltools/converters/mil/backend/mil/passes/adjust_io_to_supported_types.py +++ b/coremltools/converters/mil/backend/mil/passes/adjust_io_to_supported_types.py @@ -165,9 +165,12 @@ def _adjust_main_outputs(func): new_outputs = [] for output_var in func.outputs: output_type = output_var.sym_type + # classify outputs contains type int64 output variables, which should not be casted. if ( - types.is_tensor(output_type) or types.is_scalar(output_type) - ) and output_var.dtype not in _IO_SUPPORTED_TYPES: + (types.is_tensor(output_type) or types.is_scalar(output_type)) + and output_var.dtype not in _IO_SUPPORTED_TYPES + and output_var.op.op_type != "classify" + ): output_dtype_str = types.builtin_to_string(output_var.dtype) target_dtype = "int32" if types.is_int(output_var.dtype) else "fp32" logger.warning( diff --git a/coremltools/converters/mil/backend/mil/passes/sanitize_name_strings.py b/coremltools/converters/mil/backend/mil/passes/sanitize_name_strings.py index b5704ceab..50f595029 100644 --- a/coremltools/converters/mil/backend/mil/passes/sanitize_name_strings.py +++ b/coremltools/converters/mil/backend/mil/passes/sanitize_name_strings.py @@ -19,6 +19,11 @@ def apply(self, prog): for f in prog.functions.values(): sanitizer_vars = NameSanitizer(prefix="var_") sanitizer_ops = NameSanitizer(prefix="op_") - NameSanitizer.sanitize_block( - f, sanitizer_vars, sanitizer_ops, prog.functions["main"].input_types - ) + # TODO: rdar://126498947 ([Infra] Investigate the name sanitizer on multifunction model) + if "main" in prog.functions: + NameSanitizer.sanitize_block( + f, + sanitizer_vars, + sanitizer_ops, + prog.functions["main"].input_types, + ) diff --git a/coremltools/converters/mil/backend/mil/passes/test_passes.py b/coremltools/converters/mil/backend/mil/passes/test_passes.py index 84b7cf5ca..41c80c355 100644 --- a/coremltools/converters/mil/backend/mil/passes/test_passes.py +++ b/coremltools/converters/mil/backend/mil/passes/test_passes.py @@ -10,13 +10,15 @@ import pytest import coremltools as ct -from coremltools.converters.mil._deployment_compatibility import \ - AvailableTarget as target +from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target from coremltools.converters.mil.mil import Builder as mb -from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil import mil_list, types from coremltools.converters.mil.mil.passes.pass_registry import PASS_REGISTRY from coremltools.converters.mil.testing_utils import ( - apply_pass_and_basic_check, assert_model_is_valid, get_op_types_in_program) + apply_pass_and_basic_check, + assert_model_is_valid, + get_op_types_in_program, +) class TestAdjustToSupportedTypes: @@ -454,6 +456,22 @@ def prog(x): assert get_op_types_in_program(prev_prog) == ['cast'] assert get_op_types_in_program(prog) == [] + @staticmethod + def test_classify_no_affected(): + """ + If the outputs are from a classify op, it should not be affected by this graph pass. + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(3,))]) + def prog(x): + classes = [np.int64(x) for x in range(3)] + classes_var = mb.const(val=mil_list(classes)) + return mb.classify(probabilities=x, classes=classes_var) + + apply_pass_and_basic_check(prog, "mil_backend::adjust_io_to_supported_types") + assert get_op_types_in_program(prog) == ["classify"] + + class TestImagePreprocessingPass: def test_program_grayscale(self): diff --git a/coremltools/converters/mil/backend/mil/test_load.py b/coremltools/converters/mil/backend/mil/test_load.py new file mode 100644 index 000000000..35d7db812 --- /dev/null +++ b/coremltools/converters/mil/backend/mil/test_load.py @@ -0,0 +1,1098 @@ +# Copyright (c) 2021, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import math +import platform +import shutil +import tempfile +from typing import List, Union + +import numpy as np +import pytest + +import coremltools as ct +from coremltools import _SPECIFICATION_VERSION_IOS_18, proto +from coremltools.converters.mil import mil +from coremltools.converters.mil.converter import mil_convert as _mil_convert +from coremltools.converters.mil.mil import get_new_symbol, types +from coremltools.converters.mil.mil.builder import Builder as mb +from coremltools.converters.mil.mil.ops.tests.iOS18.test_compression import ( + TestConstexprLut as _TestConstexprLut, +) +from coremltools.converters.mil.mil.program import Symbol +from coremltools.models.utils import _macos_version + + +class TestMILFlexibleShapes: + @mb.program(input_specs=[mb.TensorSpec(shape=[1, 3, Symbol("H"), Symbol("W")])]) + def basic_network(x): + return mb.relu(x=x) + + def test_mil_enumerated_multiarray(self): + enumerated_shapes = tuple([(1, 3, 10, 10), (1, 3, 10, 20), (1, 3, 10, 30)]) + input_shape = [ct.TensorType(name="x", shape=ct.EnumeratedShapes(shapes=enumerated_shapes))] + mlmodel = ct.convert( + self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape + ) + input_spec = mlmodel.get_spec().description.input + assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) + assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format( + input_spec[0].name + ) + assert ( + input_spec[0].type.WhichOneof("Type") == "multiArrayType" + ), "Expected multiArrayType, got {}".format(input_spec[0].type.WhichOneof("Type")) + assert ( + input_spec[0].type.multiArrayType.WhichOneof("ShapeFlexibility") == "enumeratedShapes" + ), "Expected enumeratedShapes in ShapeFlexibility" + + spec_default_shape = [s for s in input_spec[0].type.multiArrayType.shape] + spec_enumerated_shapes = set() + for enumerated in input_spec[0].type.multiArrayType.enumeratedShapes.shapes: + spec_enumerated_shapes.add(tuple([s for s in enumerated.shape])) + assert spec_default_shape == [ + 1, + 3, + 10, + 10, + ], "Expected default shape to be [1, 3, 10, 10], got {} instead".format( + str(spec_default_shape) + ) + assert spec_enumerated_shapes == set(enumerated_shapes), "Enumerated shape mismatch" + + def test_mil_enumerated_multiarray_with_default(self): + enumerated_shapes = tuple([(1, 3, 10, 10), (1, 3, 10, 20), (1, 3, 10, 30)]) + input_shape = [ + ct.TensorType( + name="x", + shape=ct.EnumeratedShapes(shapes=enumerated_shapes, default=(1, 3, 10, 30)), + ) + ] + mlmodel = ct.convert( + self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape + ) + input_spec = mlmodel.get_spec().description.input + assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) + assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format( + input_spec[0].name + ) + assert ( + input_spec[0].type.WhichOneof("Type") == "multiArrayType" + ), "Expected multiArrayType, got {}".format(input_spec[0].type.WhichOneof("Type")) + assert ( + input_spec[0].type.multiArrayType.WhichOneof("ShapeFlexibility") == "enumeratedShapes" + ), "Expected enumeratedShapes in ShapeFlexibility" + + spec_default_shape = [s for s in input_spec[0].type.multiArrayType.shape] + spec_enumerated_shapes = set() + for enumerated in input_spec[0].type.multiArrayType.enumeratedShapes.shapes: + spec_enumerated_shapes.add(tuple([s for s in enumerated.shape])) + assert spec_default_shape == [ + 1, + 3, + 10, + 30, + ], "Expected default shape to be [1, 3, 10, 10], got {} instead".format( + str(spec_default_shape) + ) + assert spec_enumerated_shapes == set(enumerated_shapes), "Enumerated shape mismatch" + + def test_mil_enumerated_image(self): + enumerated_shapes = tuple([(1, 3, 10, 10), (1, 3, 10, 20), (1, 3, 10, 30)]) + input_shape = [ct.ImageType(name="x", shape=ct.EnumeratedShapes(shapes=enumerated_shapes))] + mlmodel = ct.convert( + self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape + ) + input_spec = mlmodel.get_spec().description.input + assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) + assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format( + input_spec[0].name + ) + assert ( + input_spec[0].type.WhichOneof("Type") == "imageType" + ), "Expected imageType, got {}".format(input_spec[0].type.WhichOneof("Type")) + assert ( + input_spec[0].type.imageType.WhichOneof("SizeFlexibility") == "enumeratedSizes" + ), "Expected enumeratedShapes in ShapeFlexibility" + + spec_H = input_spec[0].type.imageType.height + spec_W = input_spec[0].type.imageType.width + assert ( + spec_H == 10 and spec_W == 10 + ), "expected [H, W] == [10, 10], got [{}, {}] instead".format(spec_H, spec_W) + + spec_enumerated_shapes = set() + for enumerated in input_spec[0].type.imageType.enumeratedSizes.sizes: + spec_enumerated_shapes.add(tuple([1, 3, enumerated.height, enumerated.width])) + assert spec_enumerated_shapes == set(enumerated_shapes), "Enumerated shape mismatch" + + def test_mil_enumerated_image_with_default(self): + enumerated_shapes = tuple([(1, 3, 10, 10), (1, 3, 10, 20), (1, 3, 10, 30)]) + input_shape = [ + ct.ImageType( + name="x", + shape=ct.EnumeratedShapes(shapes=enumerated_shapes, default=(1, 3, 10, 30)), + ) + ] + mlmodel = ct.convert( + self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape + ) + input_spec = mlmodel.get_spec().description.input + assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) + assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format( + input_spec[0].name + ) + assert ( + input_spec[0].type.WhichOneof("Type") == "imageType" + ), "Expected imageType, got {}".format(input_spec[0].type.WhichOneof("Type")) + assert ( + input_spec[0].type.imageType.WhichOneof("SizeFlexibility") == "enumeratedSizes" + ), "Expected enumeratedShapes in ShapeFlexibility" + + spec_H = input_spec[0].type.imageType.height + spec_W = input_spec[0].type.imageType.width + assert ( + spec_H == 10 and spec_W == 30 + ), "expected [H, W] == [10, 30], got [{}, {}] instead".format(spec_H, spec_W) + + spec_enumerated_shapes = set() + for enumerated in input_spec[0].type.imageType.enumeratedSizes.sizes: + spec_enumerated_shapes.add(tuple([1, 3, enumerated.height, enumerated.width])) + assert spec_enumerated_shapes == set(enumerated_shapes), "Enumerated shape mismatch" + + def test_mil_ranged_multiarray(self): + input_shape = [ct.TensorType(name="x", shape=(1, 3, 10, ct.RangeDim(10, 30)))] + mlmodel = ct.convert( + self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape + ) + input_spec = mlmodel.get_spec().description.input + assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) + assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format( + input_spec[0].name + ) + assert ( + input_spec[0].type.WhichOneof("Type") == "multiArrayType" + ), "Expected multiArrayType, got {}".format(input_spec[0].type.WhichOneof("Type")) + assert ( + input_spec[0].type.multiArrayType.WhichOneof("ShapeFlexibility") == "shapeRange" + ), "Expected shapeRange in ShapeFlexibility" + + spec_default_shape = [s for s in input_spec[0].type.multiArrayType.shape] + ranged_shapes = [(1, 1), (3, 3), (10, 10), (10, 30)] + spec_ranged_shapes = [] + for range_dim in input_spec[0].type.multiArrayType.shapeRange.sizeRanges: + spec_ranged_shapes.append(tuple([range_dim.lowerBound, range_dim.upperBound])) + assert spec_default_shape == [ + 1, + 3, + 10, + 10, + ], "Expected default shape to be [1, 3, 10, 10], got {} instead".format( + str(spec_default_shape) + ) + assert spec_ranged_shapes == ranged_shapes, "Enumerated shape mismatch" + + def test_mil_ranged_multiarray_with_default(self): + input_shape = [ct.TensorType(name="x", shape=(1, 3, 10, ct.RangeDim(10, 30, default=20)))] + mlmodel = ct.convert( + self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape + ) + input_spec = mlmodel.get_spec().description.input + assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) + assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format( + input_spec[0].name + ) + assert ( + input_spec[0].type.WhichOneof("Type") == "multiArrayType" + ), "Expected multiArrayType, got {}".format(input_spec[0].type.WhichOneof("Type")) + assert ( + input_spec[0].type.multiArrayType.WhichOneof("ShapeFlexibility") == "shapeRange" + ), "Expected shapeRange in ShapeFlexibility" + + spec_default_shape = [s for s in input_spec[0].type.multiArrayType.shape] + ranged_shapes = [(1, 1), (3, 3), (10, 10), (10, 30)] + spec_ranged_shapes = [] + for range_dim in input_spec[0].type.multiArrayType.shapeRange.sizeRanges: + spec_ranged_shapes.append(tuple([range_dim.lowerBound, range_dim.upperBound])) + assert spec_default_shape == [ + 1, + 3, + 10, + 20, + ], "Expected default shape to be [1, 3, 10, 20], got {} instead".format( + str(spec_default_shape) + ) + assert spec_ranged_shapes == ranged_shapes, "Enumerated shape mismatch" + + def test_mil_ranged_image(self): + input_shape = [ct.ImageType(name="x", shape=(1, 3, 10, ct.RangeDim(10, 30)))] + mlmodel = ct.convert( + self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape + ) + input_spec = mlmodel.get_spec().description.input + assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) + assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format( + input_spec[0].name + ) + assert ( + input_spec[0].type.WhichOneof("Type") == "imageType" + ), "Expected imageType, got {}".format(input_spec[0].type.WhichOneof("Type")) + assert ( + input_spec[0].type.imageType.WhichOneof("SizeFlexibility") == "imageSizeRange" + ), "Expected imageSizeRange in ShapeFlexibility" + + spec_H = input_spec[0].type.imageType.height + spec_W = input_spec[0].type.imageType.width + assert ( + spec_H == 10 and spec_W == 10 + ), "expected [H, W] == [10, 10], got [{}, {}] instead".format(spec_H, spec_W) + + spec_H_range = [ + input_spec[0].type.imageType.imageSizeRange.heightRange.lowerBound, + input_spec[0].type.imageType.imageSizeRange.heightRange.upperBound, + ] + spec_W_range = [ + input_spec[0].type.imageType.imageSizeRange.widthRange.lowerBound, + input_spec[0].type.imageType.imageSizeRange.widthRange.upperBound, + ] + assert spec_H_range == [10, 10], "Ranged height mismatch" + assert spec_W_range == [10, 30], "Ranged width mismatch" + + def test_mil_ranged_image_with_default(self): + input_shape = [ct.ImageType(name="x", shape=(1, 3, 10, ct.RangeDim(10, 30, default=20)))] + mlmodel = ct.convert( + self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape + ) + input_spec = mlmodel.get_spec().description.input + assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) + assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format( + input_spec[0].name + ) + assert ( + input_spec[0].type.WhichOneof("Type") == "imageType" + ), "Expected imageType, got {}".format(input_spec[0].type.WhichOneof("Type")) + assert ( + input_spec[0].type.imageType.WhichOneof("SizeFlexibility") == "imageSizeRange" + ), "Expected imageSizeRange in ShapeFlexibility" + + spec_H = input_spec[0].type.imageType.height + spec_W = input_spec[0].type.imageType.width + assert ( + spec_H == 10 and spec_W == 20 + ), "expected [H, W] == [10, 20], got [{}, {}] instead".format(spec_H, spec_W) + + spec_H_range = [ + input_spec[0].type.imageType.imageSizeRange.heightRange.lowerBound, + input_spec[0].type.imageType.imageSizeRange.heightRange.upperBound, + ] + spec_W_range = [ + input_spec[0].type.imageType.imageSizeRange.widthRange.lowerBound, + input_spec[0].type.imageType.imageSizeRange.widthRange.upperBound, + ] + assert spec_H_range == [10, 10], "Ranged height mismatch" + assert spec_W_range == [10, 30], "Ranged width mismatch" + + +class TestMILDefaultValues: + @mb.program(input_specs=[mb.TensorSpec(shape=[1]), mb.TensorSpec(shape=[1])]) + def basic_network(x, y): + return mb.add(x=x, y=y, name="output") + + def test_mil_default_value_to_proto(self): + program_input_spec = [ + ct.TensorType(name="x", shape=[1], default_value=np.array([1.0]).astype(np.float32)), + ct.TensorType(name="y", shape=[1]), + ] + mlmodel = ct.convert(self.basic_network, convert_to="mlprogram", inputs=program_input_spec) + input_spec = mlmodel.get_spec().description.input + assert len(input_spec) == 2, "2 input expected, got {} instead".format(len(input_spec)) + assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format( + input_spec[0].name + ) + assert ( + input_spec[0].type.WhichOneof("Type") == "multiArrayType" + ), "Expected multiArrayType, got {}".format(input_spec[0].type.WhichOneof("Type")) + assert ( + input_spec[0].type.multiArrayType.WhichOneof("defaultOptionalValue") + == "floatDefaultValue" + ), "Expected floatDefaultValue, got {} instead".format( + input_spec[0].type.multiArrayType.WhichOneof("defaultOptionalValue") + ) + assert input_spec[0].type.multiArrayType.floatDefaultValue == 1.0 + + def test_mil_default_value_runtime(self): + program_input_spec = [ + ct.TensorType(name="x", shape=[1], default_value=np.array([1.0]).astype(np.float32)), + ct.TensorType(name="y", shape=[1]), + ] + mlmodel = ct.convert(self.basic_network, convert_to="mlprogram", inputs=program_input_spec) + + if _macos_version() < (12, 0): + # Can only get predictions for ml program on macOS 12+ + return + + res = mlmodel.predict({"x": np.array([3.0]), "y": np.array([2.0])}) + assert res["output"][0] == 5.0 + + res = mlmodel.predict({"y": np.array([2.0])}) + assert res["output"][0] == 3.0 + + +class TestMILProtoLoad: + """Verify that the MIL Proto in mlmodel is correctly loaded in iOS18+.""" + + @staticmethod + @pytest.mark.parametrize("opset_version", [ct.target.iOS17, ct.target.iOS18]) + def test_constexpr_use_inputs_instead_of_attributes(opset_version): + """Test the constexpr uses inputs instead of attributes starting from iOS18.""" + + @mb.program(input_specs=[], opset_version=ct.target.iOS17) + def prog_ios17(): + return mb.constexpr_lut_to_dense( + lut=np.array([1.0, 2.0, 3.0, 4.0]), + indices=np.array([10, 4]).astype(np.uint8), + shape=np.array([5]).astype(np.uint32), + ) + + @mb.program(input_specs=[], opset_version=ct.target.iOS18) + def prog_ios18(): + return mb.constexpr_lut_to_dense( + indices=np.array([4, 8, 10, 13, 24, 5, 6, 9, 13, 31, 17, 7, 2, 8, 3, 1]) + .reshape((2, 4, 2)) + .astype(np.uint8), + lut=_TestConstexprLut._generate_lut(shape=(1, 2, 1, 256, 3)), + vector_axis=1, + ) + + mlmodel = ct.convert( + prog_ios17 if opset_version == ct.target.iOS17 else prog_ios18, + convert_to="mlprogram", + minimum_deployment_target=opset_version, + ) + + # Iterates the milproto in mlmodel to make sure lut op uses inputs instead of attributes. + mil = mlmodel.get_spec().mlProgram + for function in mil.functions.values(): + for block in function.block_specializations.values(): + for op in block.operations: + if op.type == "constexpr_lut_to_dense": + # The "attributes" field has at least one value for "name". + expected_attributes_num = 1 + expected_inputs_num = 0 + if opset_version >= ct.target.iOS18: + # Since iOS18, constexpr ops use inputs instead of attributes in milproto. + expected_inputs_num += 3 + else: + expected_attributes_num += 3 + + assert len(op.attributes.values()) == expected_attributes_num + assert len(op.inputs.values()) == expected_inputs_num + + @staticmethod + def test_constexpr_multiple_outputs(): + """Starting from iOS18 there are constexpr ops that have multiple outputs.""" + + @mb.program(input_specs=[], opset_version=ct.target.iOS18) + def prog(): + return mb.constexpr_sparse_blockwise_shift_scale( + data_mask=np.array([[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 0]]).astype( + types.np_uint1_dtype + ), + nonzero_data=np.array([10, 11, 3, 4, 5, 6, 7, 8, 9]).astype(np.int8), + scale=np.array([[0.1, 0.2, 0.3, 0.4]]), + offset=np.array([[1, 2, 3, 4]]).astype(np.int8), + )[1] + + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + + mil = mlmodel.get_spec().mlProgram + for function in mil.functions.values(): + for block in function.block_specializations.values(): + for op in block.operations: + if op.type == "constexpr_sparse_blockwise_shift_scale": + assert len(op.outputs) == 2 + + @staticmethod + def test_sub_byte_immediate_value(): + """ + Test the sub-byte immediate value tensor is exported as packed bytes. + + The sub-byte file value is tested in `coremltools/test/blob/test_weights.py` which + is not in the scope of this test. + """ + + @mb.program(input_specs=[], opset_version=ct.target.iOS18) + def prog(): + return mb.constexpr_blockwise_shift_scale( + data=np.array([-8, 7]).reshape((1, 2, 1)).astype(types.np_int4_dtype), + scale=np.array([4]).reshape((1, 1, 1)).astype(np.float16), + offset=np.array([4]).reshape((1, 1, 1)).astype(types.np_int4_dtype), + ) + + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + + mil = mlmodel.get_spec().mlProgram + for function in mil.functions.values(): + for block in function.block_specializations.values(): + for op in block.operations: + if op.type == "constexpr_blockwise_shift_scale": + bytes_val = ( + op.inputs["data"].arguments[0].value.immediateValue.tensor.bytes.values + ) + # The two 4-bit values should be packed into a single byte. + assert len(bytes_val) == 1 + + @staticmethod + def check_functions_description( + mlmodel: ct.models.MLModel, + expect_function_names: List[str], + expected_default_function_name: str, + ) -> None: + spec = mlmodel.get_spec() + desc = spec.description + assert len(desc.functions) == len(expect_function_names) + for i in range(len(expect_function_names)): + assert desc.functions[i].name == expect_function_names[i] + assert desc.defaultFunctionName == expected_default_function_name + + @staticmethod + def convert_and_save(prog: mil.Program) -> str: + mlmodel = _mil_convert( + prog, + convert_to="mlprogram", + convert_from="milinternal", + specification_version=_SPECIFICATION_VERSION_IOS_18, + compute_units=ct.ComputeUnit.CPU_ONLY, + export_multi_functions=True, + skip_model_load=True, + ) + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + return package_path + + @staticmethod + def check_relu(model: Union[ct.models.MLModel, ct.models.CompiledMLModel]) -> None: + x = np.array([-1.0, 0.0, 1.0], dtype=np.float32) + y_relu = [0, 0, 1] + y = model.predict({"x": x}) + assert all(y["relu_0"] == y_relu) + + @staticmethod + def check_sin(model: Union[ct.models.MLModel, ct.models.CompiledMLModel]) -> None: + x = np.array([-1.0, 0.0, 1.0], dtype=np.float32) + y_sin = list(map(math.sin, x)) + y = model.predict({"x": x}) + np.testing.assert_allclose(y["sin_0"], y_sin, rtol=5e-04, atol=5e-04) + + @staticmethod + def check_cos(model: Union[ct.models.MLModel, ct.models.CompiledMLModel]) -> None: + x = np.array([-1.0, 0.0, 1.0], dtype=np.float32) + y_sin = list(map(math.cos, x)) + y = model.predict({"x": x}) + np.testing.assert_allclose(y["cos_0"], y_sin, rtol=5e-04, atol=5e-04) + + + @pytest.mark.skipif(ct.utils._macos_version() < (15, 0), + reason="Multi-function only supported on macOS 15+") + def test_multi_functions(self): + """ + Test multi-functions program can be exported into multi-functions Core ML proto. + """ + + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func(x): + return mb.relu(x=x) + + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func_1(x): + return mb.sin(x=x) + + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func_2(x): + return mb.cos(x=x) + + prog = mil.Program() + prog.add_function("main", func) + prog.add_function("sin", func_1) + prog.add_function("cos", func_2) + + package_path = self.convert_and_save(prog) + + # Test the proto can be loaded back and validate the spec + mlmodel = ct.models.MLModel(package_path, function_name="main") + self.check_functions_description( + mlmodel, + expect_function_names=["main", "sin", "cos"], + expected_default_function_name="main", + ) + + # Validate MLModel predictions for all three functions + self.check_relu(mlmodel) + self.check_sin( + ct.models.MLModel( + package_path, function_name="sin", compute_units=ct.ComputeUnit.CPU_ONLY + ) + ) + self.check_cos( + ct.models.MLModel( + package_path, function_name="cos", compute_units=ct.ComputeUnit.CPU_ONLY + ) + ) + + # Validate MLModel function_name property + assert mlmodel.function_name == "main" + assert ct.models.MLModel(package_path, function_name="sin").function_name == "sin" + assert ct.models.MLModel(package_path, function_name="cos").function_name == "cos" + + # Invalid function_name + with pytest.raises(ValueError, match="function_name invalid not found in the model"): + mlmodel = ct.models.MLModel(package_path, function_name="invalid") + + # Validate CompiledMLModel predictions for all three functions + compiled_path = mlmodel.get_compiled_model_path() + self.check_relu(ct.models.CompiledMLModel(compiled_path, function_name="main")) + self.check_sin(ct.models.CompiledMLModel(compiled_path, function_name="sin")) + self.check_cos(ct.models.CompiledMLModel(compiled_path, function_name="cos")) + + # clean up + shutil.rmtree(package_path) + + + @pytest.mark.skipif(ct.utils._macos_version() < (15, 0), + reason="Multi-function only supported on macOS 15+") + def test_multi_functions_default_function(self): + """ + Test if no function_name passes to MLModel, default function name will be picked up. + """ + + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func(x): + return mb.relu(x=x) + + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func_1(x): + return mb.sin(x=x) + + prog = mil.Program() + prog.add_function("main_1", func) + prog.add_function("sin", func_1) + prog.default_function_name = "main_1" + + package_path = self.convert_and_save(prog) + + # With no function_name passed, mlmodel.function_name defaults to defaultFunctionName + mlmodel = ct.models.MLModel(package_path) + self.check_functions_description( + mlmodel, + expect_function_names=["main_1", "sin"], + expected_default_function_name="main_1", + ) + assert mlmodel.function_name == "main_1" + + # Validate the prediction runs on default function + self.check_relu(mlmodel) + + # Validate CompiledMLModel predictions for default function + compiled_path = mlmodel.get_compiled_model_path() + self.check_relu(ct.models.CompiledMLModel(compiled_path)) + + # clean up + shutil.rmtree(package_path) + + + @pytest.mark.skipif(ct.utils._macos_version() < (15, 0), + reason="Multi-function only supported on macOS 15+") + def test_single_function_in_multifunction_format(self): + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func(x): + return mb.relu(x=x) + + prog = mil.Program() + prog.add_function("main_1", func) + prog.default_function_name = "main_1" + + package_path = self.convert_and_save(prog) + + # No function_name is passed, default function name is picked up + mlmodel = ct.models.MLModel(package_path) + self.check_functions_description( + mlmodel, + expect_function_names=["main_1"], + expected_default_function_name="main_1", + ) + + # Validate MLModel predictions + self.check_relu(mlmodel) + self.check_relu(ct.models.MLModel(package_path, function_name="main_1")) + + # Validate CompiledMLModel predictions + compiled_path = mlmodel.get_compiled_model_path() + self.check_relu(ct.models.CompiledMLModel(compiled_path)) + self.check_relu(ct.models.CompiledMLModel(compiled_path, function_name="main_1")) + + # clean up + shutil.rmtree(package_path) + + + @pytest.mark.skipif(ct.utils._macos_version() < (15, 0), + reason="Multi-function only supported on macOS 15+") + def test_multi_functions_backward_compatibility(self): + # Test the new MLModel class can load pre-iOS17 single function model + @mb.program(input_specs=[mb.TensorSpec((3,))], opset_version=ct.target.iOS16) + def prog(x): + return mb.relu(x=x) + + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS16, + ) + + # Test the proto can be saved and loaded back + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + + # Validate the MLModel predictions + self.check_relu(ct.models.MLModel(package_path)) + self.check_relu(ct.models.MLModel(package_path, function_name="main")) + + # Validate the MLModel function_name property + assert ct.models.MLModel(package_path).function_name is None + assert ct.models.MLModel(package_path, function_name="main").function_name == "main" + + # Other function_name will error out + with pytest.raises( + ValueError, match='function_name must be "main" for non multifunction model' + ): + mlmodel = ct.models.MLModel(package_path, function_name="invalid") + + # Validate the CompiledMLModel predictions + compiled_path = mlmodel.get_compiled_model_path() + self.check_relu(ct.models.CompiledMLModel(compiled_path)) + self.check_relu(ct.models.CompiledMLModel(compiled_path, function_name="main")) + + # invalid function error at runtime + with pytest.raises(RuntimeError): + compiled_model = ct.models.CompiledMLModel(compiled_path, function_name="invalid") + + # clean up + shutil.rmtree(package_path) + + +@pytest.mark.skipif( + ct.utils._macos_version() < (15, 0), reason="Tests are for deployment target iOS18/macos15" +) +class TestStateModelLoad: + """ + Verify stateful model can be loaded via milproto. + """ + + @staticmethod + def verify_stateful_model(mlmodel, expected_output, input=None): + def verify_numerical(mlmodel, state, expected_output, input=None): + if input is None: + input_dict = {} + else: + input_dict = {"y": input} + output = mlmodel.predict(input_dict, state=state)["output"] + np.testing.assert_allclose(expected_output, output, rtol=5e-04, atol=5e-04) + + # verify the model can be ran + state_1 = mlmodel.make_state() + verify_numerical(mlmodel, state_1, expected_output, input) + verify_numerical(mlmodel, state_1, expected_output, input) + + # create a new state, and make sure the model can run prediction on both old and new state + state_2 = mlmodel.make_state() + verify_numerical(mlmodel, state_2, expected_output, input) + verify_numerical(mlmodel, state_1, expected_output, input) + + def test_export_state_input_feature(self): + """ + Test milproto can export model with state type. + """ + + @mb.program( + input_specs=[ + mb.StateTensorSpec((2, 3), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def prog(x): + return mb.read_state(input=x, name="output") + + # verify the model can be converted + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + compute_units=ct.ComputeUnit.CPU_ONLY, + ) + + # verify the state feature + spec = mlmodel.get_spec() + state = spec.description.state + assert len(state) == 1 + assert state[0].name == "x" + assert state[0].type.WhichOneof("Type") == "stateType" + assert state[0].type.stateType.WhichOneof("Type") == "arrayType" + + array_type = state[0].type.stateType.arrayType + assert array_type.shape == [2, 3] + assert array_type.dataType == proto.FeatureTypes_pb2.ArrayFeatureType.FLOAT16 + + # verify the model + expected_output = np.zeros((2, 3)) + self.verify_stateful_model(mlmodel, expected_output) + + def test_export_mixed_state_input_features(self): + """ + Test milproto can export model with states and inputs. + """ + + @mb.program( + input_specs=[ + mb.StateTensorSpec((2, 3), dtype=types.fp16), + mb.TensorSpec((2, 3), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def prog(x, y): + x = mb.read_state(input=x) + return mb.add(x=x, y=y, name="output") + + # verify the model can be converted + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + compute_units=ct.ComputeUnit.CPU_ONLY, + ) + + # verify the state feature + spec = mlmodel.get_spec() + state = spec.description.state + assert len(state) == 1 + assert state[0].name == "x" + assert state[0].type.WhichOneof("Type") == "stateType" + assert state[0].type.stateType.WhichOneof("Type") == "arrayType" + + array_type = state[0].type.stateType.arrayType + assert array_type.shape == [2, 3] + assert array_type.dataType == proto.FeatureTypes_pb2.ArrayFeatureType.FLOAT16 + + # verify the input + input = spec.description.input + assert len(input) == 1 + assert input[0].name == "y" + assert input[0].type.WhichOneof("Type") == "multiArrayType" + + array_type = input[0].type.multiArrayType + assert array_type.shape == [2, 3] + assert array_type.dataType == proto.FeatureTypes_pb2.ArrayFeatureType.FLOAT16 + + # verify the model + input = np.random.rand(2, 3) + self.verify_stateful_model(mlmodel, input, input) + + + def test_multi_functions_state_model(self): + """ + Make sure multi-functions Core ML models support state. + """ + + @mb.function( + input_specs=[mb.StateTensorSpec((3,), dtype=types.fp16)], + opset_version=ct.target.iOS18, + ) + def func(x): + return mb.read_state(input=x, name="output") + + @mb.function( + input_specs=[mb.StateTensorSpec((2,), dtype=types.fp16)], + opset_version=ct.target.iOS18, + ) + def func_1(y): + return mb.read_state(input=y, name="output") + + prog = mil.Program() + prog.add_function("main", func) + prog.add_function("func_1", func_1) + + mlmodel = _mil_convert( + prog, + convert_to="mlprogram", + convert_from="milinternal", + specification_version=_SPECIFICATION_VERSION_IOS_18, + compute_units=ct.ComputeUnit.CPU_ONLY, + export_multi_functions=True, + ) + + spec = mlmodel.get_spec() + desc = spec.description + assert len(desc.functions) == 2 + assert desc.functions[0].name == "main" + assert len(desc.functions[0].state) == 1 + assert desc.functions[0].state[0].name == "x" + assert desc.functions[1].name == "func_1" + assert len(desc.functions[1].state) == 1 + assert desc.functions[1].state[0].name == "y" + + # main function is the default function + self.verify_stateful_model(mlmodel, np.zeros((3,))) + + # save the mlmodel on disk, and load "main" and "func_1" seperately + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + + # test "main" function + mlmodel_main = ct.models.MLModel( + package_path, compute_units=ct.ComputeUnit.CPU_ONLY, function_name="main" + ) + self.verify_stateful_model(mlmodel_main, np.zeros((3,))) + + # test "func_1" function + mlmodel_func_1 = ct.models.MLModel( + package_path, compute_units=ct.ComputeUnit.CPU_ONLY, function_name="func_1" + ) + self.verify_stateful_model(mlmodel_func_1, np.zeros((2,))) + + # cleanup mlpackage + shutil.rmtree(package_path) + + def test_export_coreml_update_state(self): + """ + The ``coreml_update_state`` dialect op is decomposed into: + write_state -> read_state + """ + + @mb.program( + input_specs=[ + mb.StateTensorSpec((2, 3), dtype=types.fp16), + mb.TensorSpec((2, 3), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def prog(x, y): + return mb.coreml_update_state(state=x, value=y, name="output") + + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + compute_units=ct.ComputeUnit.CPU_ONLY, + ) + + mil = mlmodel.get_spec().mlProgram + for function in mil.functions.values(): + for block in function.block_specializations.values(): + ops = list(block.operations) + assert ops[0].type == "write_state" + assert len(ops[0].outputs) == 0 + assert ops[1].type == "read_state" + + # verify the model + input = np.random.rand(2, 3) + self.verify_stateful_model(mlmodel, input, input) + + + @staticmethod + def test_invalid_state_input(): + """ + Test unsupported input state modes. + """ + # state only supports fp16 + @mb.program( + input_specs=[ + mb.StateTensorSpec((2, 3), dtype=types.fp32), + ], + opset_version=ct.target.iOS18, + ) + def prog(x): + return mb.read_state(input=x) + + with pytest.raises( + ValueError, + match="State only support fp16 dtype. Got input var x with dtype fp32.", + ): + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + + # state doesn't support flexible shape + @mb.program( + input_specs=[ + mb.StateTensorSpec((2, get_new_symbol()), dtype=types.fp32), + ], + opset_version=ct.target.iOS18, + ) + def prog(x): + return mb.read_state(input=x) + + with pytest.raises(ValueError, match="Flexible shape model states are not supported!"): + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + + @staticmethod + def test_coreml_update_state_lowering(): + """ + If the output of coreml_update_state is not a block output and + it is not fed into any other ops, the op should be translated into + a single write_state. + """ + + @mb.program( + input_specs=[ + mb.StateTensorSpec((1,), dtype=types.fp16), + mb.TensorSpec((1,), dtype=types.fp16), + mb.TensorSpec((1,), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def prog(state, x, y): + mb.coreml_update_state(state=state, value=x) + mb.coreml_update_state(state=state, value=y) + return x, mb.coreml_update_state(state=state, value=y) + + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + + mil = mlmodel.get_spec().mlProgram + for function in mil.functions.values(): + for block in function.block_specializations.values(): + ops = list(block.operations) + expected_ops = [ + "write_state", + "write_state", + "write_state", + "read_state", + ] + assert [val.type for val in ops] == expected_ops + + @staticmethod + @pytest.mark.skipif(ct.utils._macos_version() < (15, 0), + reason="State only supported on macOS 15+") + def test_prediction_state(): + """ + Test prediction from a stateful model + """ + + def extract_value(y): + return list(y.values())[0][0] + + def test_state_model(mlmodel, multiplier): + # Using first state + state1 = mlmodel.make_state() + for i in range(1, 5): + y = mlmodel.predict({}, state=state1) + assert extract_value(y) == multiplier * i + + # Use a new state + state2 = mlmodel.make_state() + for i in range(1, 5): + y = mlmodel.predict({}, state=state2) + assert extract_value(y) == multiplier * i + + # Go back to using the first state + for i in range(5, 10): + y = mlmodel.predict({}, state=state1) + assert extract_value(y) == multiplier * i + + @mb.program( + input_specs=[ + mb.StateTensorSpec((1,), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def increment(x): + # Read + y = mb.read_state(input=x) + # Update + y = mb.add(x=y, y=np.array([1.0]).astype("float16")) + # Write + y = mb.coreml_update_state(state=x, value=y) + # Return + return y + + @mb.program( + input_specs=[ + mb.StateTensorSpec((1,), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def increment_by_2(x): + # Read + y = mb.read_state(input=x) + # Update + y = mb.add(x=y, y=np.array([1.0]).astype("float16")) + # Write + y = mb.coreml_update_state(state=x, value=y) + # Update + y = mb.add(x=y, y=np.array([1.0]).astype("float16")) + # Write + mb.coreml_update_state(state=x, value=y) + # Return + return y + + for model, multiplier in [(increment, 1), (increment_by_2, 2)]: + mlmodel = ct.convert( + model, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + + # The test is failing on x86_64 machines + # rdar://126957030 ([State][Bug][Intel] Stateful model prediction is wrong on Intel laptop) + if platform.machine() == "arm64": + test_state_model(mlmodel, multiplier) + + # save the model and load it back + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + + # Load with CPU + test_state_model( + ct.models.MLModel(package_path, compute_units=ct.ComputeUnit.CPU_ONLY), multiplier + ) + + # Load with ALL + if platform.machine() == "arm64": + test_state_model(ct.models.MLModel(package_path), multiplier) + + shutil.rmtree(package_path) diff --git a/coremltools/converters/mil/backend/mil/test_model_input_params.py b/coremltools/converters/mil/backend/mil/test_model_input_params.py deleted file mode 100644 index 5847e1726..000000000 --- a/coremltools/converters/mil/backend/mil/test_model_input_params.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright (c) 2021, Apple Inc. All rights reserved. -# -# Use of this source code is governed by a BSD-3-clause license that can be -# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -import numpy as np - -import coremltools as ct -from coremltools.converters.mil.mil.builder import Builder as mb -from coremltools.converters.mil.mil.program import Symbol -from coremltools.models.utils import _macos_version - - -class TestMILFlexibleShapes: - - @mb.program( - input_specs = [ - mb.TensorSpec(shape=[1, 3, Symbol("H"), Symbol("W")]) - ]) - def basic_network(x): - return mb.relu(x=x) - - def test_mil_enumerated_multiarray(self): - enumerated_shapes = tuple([(1, 3, 10, 10), (1, 3, 10, 20), (1, 3, 10, 30)]) - input_shape = [ct.TensorType(name="x", shape=ct.EnumeratedShapes(shapes=enumerated_shapes))] - mlmodel = ct.convert(self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape) - input_spec = mlmodel.get_spec().description.input - assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) - assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format(input_spec[0].name) - assert input_spec[0].type.WhichOneof("Type") == "multiArrayType", "Expected multiArrayType, got {}".format(input_spec[0].type.WhichOneof("Type")) - assert input_spec[0].type.multiArrayType.WhichOneof("ShapeFlexibility") == "enumeratedShapes", "Expected enumeratedShapes in ShapeFlexibility" - - spec_default_shape = [s for s in input_spec[0].type.multiArrayType.shape] - spec_enumerated_shapes = set() - for enumerated in input_spec[0].type.multiArrayType.enumeratedShapes.shapes: - spec_enumerated_shapes.add(tuple([s for s in enumerated.shape])) - assert spec_default_shape == [1, 3, 10, 10], "Expected default shape to be [1, 3, 10, 10], got {} instead".format(str(spec_default_shape)) - assert spec_enumerated_shapes == set(enumerated_shapes), "Enumerated shape mismatch" - - def test_mil_enumerated_multiarray_with_default(self): - enumerated_shapes = tuple([(1, 3, 10, 10), (1, 3, 10, 20), (1, 3, 10, 30)]) - input_shape = [ct.TensorType(name="x", shape=ct.EnumeratedShapes(shapes=enumerated_shapes, default=(1, 3, 10, 30)))] - mlmodel = ct.convert(self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape) - input_spec = mlmodel.get_spec().description.input - assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) - assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format(input_spec[0].name) - assert input_spec[0].type.WhichOneof("Type") == "multiArrayType", "Expected multiArrayType, got {}".format(input_spec[0].type.WhichOneof("Type")) - assert input_spec[0].type.multiArrayType.WhichOneof("ShapeFlexibility") == "enumeratedShapes", "Expected enumeratedShapes in ShapeFlexibility" - - spec_default_shape = [s for s in input_spec[0].type.multiArrayType.shape] - spec_enumerated_shapes = set() - for enumerated in input_spec[0].type.multiArrayType.enumeratedShapes.shapes: - spec_enumerated_shapes.add(tuple([s for s in enumerated.shape])) - assert spec_default_shape == [1, 3, 10, 30], "Expected default shape to be [1, 3, 10, 10], got {} instead".format(str(spec_default_shape)) - assert spec_enumerated_shapes == set(enumerated_shapes), "Enumerated shape mismatch" - - def test_mil_enumerated_image(self): - enumerated_shapes = tuple([(1, 3, 10, 10), (1, 3, 10, 20), (1, 3, 10, 30)]) - input_shape = [ct.ImageType(name="x", shape=ct.EnumeratedShapes(shapes=enumerated_shapes))] - mlmodel = ct.convert(self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape) - input_spec = mlmodel.get_spec().description.input - assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) - assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format(input_spec[0].name) - assert input_spec[0].type.WhichOneof("Type") == "imageType", "Expected imageType, got {}".format(input_spec[0].type.WhichOneof("Type")) - assert input_spec[0].type.imageType.WhichOneof("SizeFlexibility") == "enumeratedSizes", "Expected enumeratedShapes in ShapeFlexibility" - - spec_H = input_spec[0].type.imageType.height - spec_W = input_spec[0].type.imageType.width - assert spec_H == 10 and spec_W == 10, "expected [H, W] == [10, 10], got [{}, {}] instead".format(spec_H, spec_W) - - spec_enumerated_shapes = set() - for enumerated in input_spec[0].type.imageType.enumeratedSizes.sizes: - spec_enumerated_shapes.add(tuple([1, 3, enumerated.height, enumerated.width])) - assert spec_enumerated_shapes == set(enumerated_shapes), "Enumerated shape mismatch" - - def test_mil_enumerated_image_with_default(self): - enumerated_shapes = tuple([(1, 3, 10, 10), (1, 3, 10, 20), (1, 3, 10, 30)]) - input_shape = [ct.ImageType(name="x", shape=ct.EnumeratedShapes(shapes=enumerated_shapes, default=(1, 3, 10, 30)))] - mlmodel = ct.convert(self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape) - input_spec = mlmodel.get_spec().description.input - assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) - assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format(input_spec[0].name) - assert input_spec[0].type.WhichOneof("Type") == "imageType", "Expected imageType, got {}".format(input_spec[0].type.WhichOneof("Type")) - assert input_spec[0].type.imageType.WhichOneof("SizeFlexibility") == "enumeratedSizes", "Expected enumeratedShapes in ShapeFlexibility" - - spec_H = input_spec[0].type.imageType.height - spec_W = input_spec[0].type.imageType.width - assert spec_H == 10 and spec_W == 30, "expected [H, W] == [10, 30], got [{}, {}] instead".format(spec_H, spec_W) - - spec_enumerated_shapes = set() - for enumerated in input_spec[0].type.imageType.enumeratedSizes.sizes: - spec_enumerated_shapes.add(tuple([1, 3, enumerated.height, enumerated.width])) - assert spec_enumerated_shapes == set(enumerated_shapes), "Enumerated shape mismatch" - - def test_mil_ranged_multiarray(self): - input_shape = [ct.TensorType(name="x", shape=(1, 3, 10, ct.RangeDim(10, 30)))] - mlmodel = ct.convert(self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape) - input_spec = mlmodel.get_spec().description.input - assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) - assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format(input_spec[0].name) - assert input_spec[0].type.WhichOneof("Type") == "multiArrayType", "Expected multiArrayType, got {}".format(input_spec[0].type.WhichOneof("Type")) - assert input_spec[0].type.multiArrayType.WhichOneof("ShapeFlexibility") == "shapeRange", "Expected shapeRange in ShapeFlexibility" - - spec_default_shape = [s for s in input_spec[0].type.multiArrayType.shape] - ranged_shapes = [(1, 1), (3, 3), (10, 10), (10, 30)] - spec_ranged_shapes = [] - for range_dim in input_spec[0].type.multiArrayType.shapeRange.sizeRanges: - spec_ranged_shapes.append(tuple([range_dim.lowerBound, range_dim.upperBound])) - assert spec_default_shape == [1, 3, 10, 10], "Expected default shape to be [1, 3, 10, 10], got {} instead".format(str(spec_default_shape)) - assert spec_ranged_shapes == ranged_shapes, "Enumerated shape mismatch" - - def test_mil_ranged_multiarray_with_default(self): - input_shape = [ct.TensorType(name="x", shape=(1, 3, 10, ct.RangeDim(10, 30, default=20)))] - mlmodel = ct.convert(self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape) - input_spec = mlmodel.get_spec().description.input - assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) - assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format(input_spec[0].name) - assert input_spec[0].type.WhichOneof("Type") == "multiArrayType", "Expected multiArrayType, got {}".format(input_spec[0].type.WhichOneof("Type")) - assert input_spec[0].type.multiArrayType.WhichOneof("ShapeFlexibility") == "shapeRange", "Expected shapeRange in ShapeFlexibility" - - spec_default_shape = [s for s in input_spec[0].type.multiArrayType.shape] - ranged_shapes = [(1, 1), (3, 3), (10, 10), (10, 30)] - spec_ranged_shapes = [] - for range_dim in input_spec[0].type.multiArrayType.shapeRange.sizeRanges: - spec_ranged_shapes.append(tuple([range_dim.lowerBound, range_dim.upperBound])) - assert spec_default_shape == [1, 3, 10, 20], "Expected default shape to be [1, 3, 10, 20], got {} instead".format(str(spec_default_shape)) - assert spec_ranged_shapes == ranged_shapes, "Enumerated shape mismatch" - - def test_mil_ranged_image(self): - input_shape = [ct.ImageType(name="x", shape=(1, 3, 10, ct.RangeDim(10, 30)))] - mlmodel = ct.convert(self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape) - input_spec = mlmodel.get_spec().description.input - assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) - assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format(input_spec[0].name) - assert input_spec[0].type.WhichOneof("Type") == "imageType", "Expected imageType, got {}".format(input_spec[0].type.WhichOneof("Type")) - assert input_spec[0].type.imageType.WhichOneof("SizeFlexibility") == "imageSizeRange", "Expected imageSizeRange in ShapeFlexibility" - - spec_H = input_spec[0].type.imageType.height - spec_W = input_spec[0].type.imageType.width - assert spec_H == 10 and spec_W == 10, "expected [H, W] == [10, 10], got [{}, {}] instead".format(spec_H, spec_W) - - spec_H_range = [input_spec[0].type.imageType.imageSizeRange.heightRange.lowerBound, input_spec[0].type.imageType.imageSizeRange.heightRange.upperBound] - spec_W_range = [input_spec[0].type.imageType.imageSizeRange.widthRange.lowerBound, input_spec[0].type.imageType.imageSizeRange.widthRange.upperBound] - assert spec_H_range == [10, 10], "Ranged height mismatch" - assert spec_W_range == [10, 30], "Ranged width mismatch" - - def test_mil_ranged_image_with_default(self): - input_shape = [ct.ImageType(name="x", shape=(1, 3, 10, ct.RangeDim(10, 30, default=20)))] - mlmodel = ct.convert(self.basic_network, source="milinternal", convert_to="mlprogram", inputs=input_shape) - input_spec = mlmodel.get_spec().description.input - assert len(input_spec) == 1, "1 input expected, got {} instead".format(len(input_spec)) - assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format(input_spec[0].name) - assert input_spec[0].type.WhichOneof("Type") == "imageType", "Expected imageType, got {}".format(input_spec[0].type.WhichOneof("Type")) - assert input_spec[0].type.imageType.WhichOneof("SizeFlexibility") == "imageSizeRange", "Expected imageSizeRange in ShapeFlexibility" - - spec_H = input_spec[0].type.imageType.height - spec_W = input_spec[0].type.imageType.width - assert spec_H == 10 and spec_W == 20, "expected [H, W] == [10, 20], got [{}, {}] instead".format(spec_H, spec_W) - - spec_H_range = [input_spec[0].type.imageType.imageSizeRange.heightRange.lowerBound, input_spec[0].type.imageType.imageSizeRange.heightRange.upperBound] - spec_W_range = [input_spec[0].type.imageType.imageSizeRange.widthRange.lowerBound, input_spec[0].type.imageType.imageSizeRange.widthRange.upperBound] - assert spec_H_range == [10, 10], "Ranged height mismatch" - assert spec_W_range == [10, 30], "Ranged width mismatch" - -class TestMILDefaultValues: - @mb.program( - input_specs = [ - mb.TensorSpec(shape=[1]), - mb.TensorSpec(shape=[1]) - ]) - def basic_network(x, y): - return mb.add(x=x, y=y, name="output") - - def test_mil_default_value_to_proto(self): - program_input_spec = [ct.TensorType(name="x", shape=[1], default_value=np.array([1.0]).astype(np.float32)), ct.TensorType(name="y", shape=[1])] - mlmodel = ct.convert(self.basic_network, convert_to="mlprogram", inputs=program_input_spec) - input_spec = mlmodel.get_spec().description.input - assert len(input_spec) == 2, "2 input expected, got {} instead".format(len(input_spec)) - assert input_spec[0].name == "x", "input name in MLModel is {}, 'x' is expected".format(input_spec[0].name) - assert input_spec[0].type.WhichOneof("Type") == "multiArrayType", "Expected multiArrayType, got {}".format(input_spec[0].type.WhichOneof("Type")) - assert input_spec[0].type.multiArrayType.WhichOneof("defaultOptionalValue") == "floatDefaultValue", "Expected floatDefaultValue, got {} instead".format(input_spec[0].type.multiArrayType.WhichOneof("defaultOptionalValue")) - assert input_spec[0].type.multiArrayType.floatDefaultValue == 1.0 - - def test_mil_default_value_runtime(self): - program_input_spec = [ct.TensorType(name="x", shape=[1], default_value=np.array([1.0]).astype(np.float32)), ct.TensorType(name="y", shape=[1])] - mlmodel = ct.convert(self.basic_network, convert_to="mlprogram", inputs=program_input_spec) - - if _macos_version() < (12, 0): - # Can only get predictions for ml program on macOS 12+ - return - - res = mlmodel.predict({"x": np.array([3.]), "y": np.array([2.])}) - assert res["output"][0] == 5.0 - - res = mlmodel.predict({"y": np.array([2.])}) - assert res["output"][0] == 3.0 diff --git a/coremltools/converters/mil/backend/nn/load.py b/coremltools/converters/mil/backend/nn/load.py index a4c449e73..590e38531 100644 --- a/coremltools/converters/mil/backend/nn/load.py +++ b/coremltools/converters/mil/backend/nn/load.py @@ -16,8 +16,7 @@ ) from coremltools.converters.mil.mil import types from coremltools.converters.mil.mil.types.symbolic import any_symbolic, any_variadic, is_symbolic -from coremltools.models import MLModel -from coremltools.models import neural_network as neural_network +from coremltools.models import model, neural_network from coremltools.models.datatypes import Array from coremltools.models.neural_network import flexible_shape_utils from coremltools.models.neural_network.flexible_shape_utils import ( @@ -31,7 +30,7 @@ def _convert_to_image_input(proto, inputs, skip_model_load=False): - tmp_model = MLModel(proto, skip_model_load=skip_model_load) + tmp_model = model.MLModel(proto, skip_model_load=skip_model_load) for input_type in inputs: if isinstance(input_type, ImageType): if input_type.color_layout in (ColorLayout.GRAYSCALE, ColorLayout.GRAYSCALE_FLOAT16): @@ -58,7 +57,7 @@ def _convert_to_image_input(proto, inputs, skip_model_load=False): def _convert_to_classifier(proto, classifier_config, skip_model_load=False): - tmp_model = MLModel(proto, skip_model_load=skip_model_load) + tmp_model = model.MLModel(proto, skip_model_load=skip_model_load) tmp_model = neural_network.utils.make_nn_classifier( tmp_model, classifier_config.class_labels, diff --git a/coremltools/converters/mil/converter.py b/coremltools/converters/mil/converter.py index f9421b1ad..02f453883 100644 --- a/coremltools/converters/mil/converter.py +++ b/coremltools/converters/mil/converter.py @@ -8,8 +8,8 @@ from typing import Optional, Text, Tuple from coremltools.converters._profile_utils import _profile -from coremltools.converters.mil import Program from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import Program from coremltools.converters.mil.mil.types.symbolic import k_num_internal_syms, k_used_symbols from coremltools.models import MLModel from coremltools.models.model import _create_mlpackage diff --git a/coremltools/converters/mil/debugging_utils.py b/coremltools/converters/mil/debugging_utils.py index 30282659b..01792f0da 100644 --- a/coremltools/converters/mil/debugging_utils.py +++ b/coremltools/converters/mil/debugging_utils.py @@ -24,8 +24,8 @@ def extract_submodel( """ This utility function lets you extract a submodel from a Core ML model. - For a NeuralNetwork model, the function extracts only in-memory Core ML models. - You should always call this function to a model directly from ``ct.convert``. It is not + For a neural network model, the function extracts only in-memory Core ML models. + You should always call this function for a model directly from :py:class:`~coremltools.converters._converters_entry.convert`. It is not allowed to load the model from disk and then call this API. For an ML program model, both cases (in-memory and from disk) are supported. @@ -43,19 +43,19 @@ def extract_submodel( If not provided, the inputs from the original model are used. function_name: str (Optional) - Name of the function where the subgraph is extracted. Default ``main``. + Name of the function where the subgraph is extracted. Default is ``main``. Examples -------- - NeuralNetwork: + Neural network: >>> from coremltools.converters.mil.debugging_utils import extract_submodel >>> mlmodel = ct.convert(model, convert_to="neuralnetwork") >>> outputs = ["output_0", "output_1"] >>> submodel = extract_submodel(mlmodel, outputs) - ML Program: + ML program: >>> from coremltools.converters.mil.debugging_utils import extract_submodel >>> mlmodel = ct.convert(model, convert_to="mlprogram") diff --git a/coremltools/converters/mil/frontend/_utils.py b/coremltools/converters/mil/frontend/_utils.py index 1a4a11a69..4da82cee3 100644 --- a/coremltools/converters/mil/frontend/_utils.py +++ b/coremltools/converters/mil/frontend/_utils.py @@ -7,7 +7,7 @@ import math as math from typing import List, Optional, Union -import numpy as _np +import numpy as np from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target from coremltools.converters.mil.input_types import InputType @@ -512,7 +512,7 @@ def _concat_dims(dims, none_if_empty=False): return ab -def _lower_scaled_dot_product_attention( +def _decompose_scaled_dot_product_attention( q: Var, k: Var, v: Var, mask: Var, name: str, before_op: Optional[Operation] = None ) -> Var: # scale the query input @@ -526,7 +526,7 @@ def _lower_scaled_dot_product_attention( q, k, v = promote_input_dtypes([q, k, v]) multiplicative_scale_factor = 1 / math.sqrt(embed_size) if types.builtin_to_string(q.dtype) == "fp16": - multiplicative_scale_factor = _np.float16(multiplicative_scale_factor) + multiplicative_scale_factor = np.float16(multiplicative_scale_factor) q = mb.mul(x=q, y=multiplicative_scale_factor, before_op=before_op) # multiply query and key input tensors @@ -545,48 +545,96 @@ def _lower_scaled_dot_product_attention( return res -def _construct_constexpr_affine_op( - quantized_weights: _np.ndarray, - zero_point: Optional[Union[Var, _np.ndarray, _np.generic]], - scale: Union[Var, _np.ndarray, _np.generic], +def _construct_constexpr_dequant_op( + quantized_weights: np.ndarray, + zero_point: Optional[Union[Var, np.ndarray, np.generic]], + scale: Union[Var, np.ndarray, np.generic], axis: Optional[Union[Var, int]] = None, name: Optional[str] = None, before_op: Optional[Operation] = None, ) -> Var: - """Constructs the constexpr op to represent the dequantized weight from PyTorch's data.""" - # The constexpr_affine_dequantize op requires axis. - if axis is None: - # Infer the axis based on scale's shape. - non_single_dim = [dim for dim, dim_size in enumerate(scale.shape) if dim_size > 1] - if len(non_single_dim) > 2: + """ + Constructs the constexpr op to represent the quantized weight. + + Use constexpr_affine_dequantize for pre-iOS18 and constexpr_blockwise_shift_scale for others. + """ + if not is_current_opset_version_compatible_with(target.iOS18): + # The constexpr_affine_dequantize op requires axis. + if axis is None: + # Infer the axis based on scale's shape. + non_single_dim = [dim for dim, dim_size in enumerate(scale.shape) if dim_size > 1] + if len(non_single_dim) > 2: + raise ValueError( + "The constexpr_affine_dequantize op doesn't support scale which " + "have more than one non-single dimensions. Got scale with shape " + f"{scale.shape}" + ) + # Empty non_single_dim means per-tensor quantization, just use a dummy axis. + axis = 0 if len(non_single_dim) == 0 else non_single_dim[0] + if isinstance(axis, int): + axis = np.int32(axis) + + # The constexpr_affine_dequantize op requires zero_point. + if zero_point is None: + zero_point = np.zeros_like(scale).astype(quantized_weights.dtype) + + # The constexpr_affine_dequantize op requires scale and zero_point to have rank 0 or 1. + if isinstance(scale, (np.ndarray, np.generic)): + scale = np.squeeze(scale) + if isinstance(zero_point, (np.ndarray, np.generic)): + zero_point = np.squeeze(zero_point) + + kwargs = { + "quantized_data": quantized_weights, + "zero_point": zero_point, + "scale": scale, + "axis": axis, + } + if name is not None: + kwargs["name"] = name + if before_op is not None: + kwargs["before_op"] = before_op + return mb.constexpr_affine_dequantize(**kwargs) + + # For iOS18 constexpr_blockwise_shift_scale op, the data/scale/offset need to have same rank. + if len(quantized_weights.shape) != len(scale.shape): + if axis is not None: + target_shape = [1] * len(quantized_weights.shape) + target_shape[axis] = quantized_weights.shape[axis] + else: + target_shape = list(scale.shape) + [1] * ( + len(quantized_weights.shape) - len(scale.shape) + ) + if np.prod(scale.shape) != np.prod(target_shape): raise ValueError( - "The constexpr_affine_dequantize op doesn't support scale which " - "have more than one non-single dimensions. Got scale with shape " - f"{scale.shape}" + "Unable to infer scale's shape. Please provide a scale that has the " + "same rank as the weight." ) - # If non_single_dim is empty, it means it's per-tensor quantization, just use a dummy axis. - axis = 0 if len(non_single_dim) == 0 else non_single_dim[0] - if isinstance(axis, int): - axis = _np.int32(axis) - - # The constexpr_affine_dequantize op requires zero_point. - if zero_point is None: - zero_point = _np.zeros_like(scale).astype(quantized_weights.dtype) + scale = scale.reshape(target_shape) - # The constexpr_affine_dequantize op requires scale and zero_point to have rank 0 or 1. - if isinstance(scale, (_np.ndarray, _np.generic)): - scale = _np.squeeze(scale) - if isinstance(zero_point, (_np.ndarray, _np.generic)): - zero_point = _np.squeeze(zero_point) + # Check the value range to determine the true data type (such as int4/uint4). + sub_byte_type = ( + types.uint4 + if types.numpy_type_to_builtin_type(quantized_weights.dtype).is_unsigned() + else types.int4 + ) + sub_byte_range = types.type_mapping._TYPES_TO_RANGE[sub_byte_type] + if ( + np.max(quantized_weights) <= sub_byte_range.high + and np.min(quantized_weights) >= sub_byte_range.low + ): + quantized_weights = quantized_weights.astype(types.nptype_from_builtin(sub_byte_type)) kwargs = { - "quantized_data": quantized_weights, - "zero_point": zero_point, + "data": quantized_weights, "scale": scale, - "axis": axis, } + if zero_point is not None and np.any(zero_point): + # Only pass the offset parameter when not all elements in `zero_point` are zeroes. + zero_point = zero_point.reshape(scale.shape).astype(quantized_weights.dtype) + kwargs["offset"] = zero_point if name is not None: kwargs["name"] = name if before_op is not None: kwargs["before_op"] = before_op - return mb.constexpr_affine_dequantize(**kwargs) + return mb.constexpr_blockwise_shift_scale(**kwargs) diff --git a/coremltools/converters/mil/frontend/milproto/helper.py b/coremltools/converters/mil/frontend/milproto/helper.py index 6d7bed661..e26004e72 100644 --- a/coremltools/converters/mil/frontend/milproto/helper.py +++ b/coremltools/converters/mil/frontend/milproto/helper.py @@ -61,5 +61,10 @@ def proto_to_types(valuetype): valuetype = proto_to_types(dicttype.valueType) return types.dict(keytype, valuetype) + + elif valuetype.WhichOneof("type") == "stateType": + wrapped_type = proto_to_types(valuetype.stateType.wrappedType) + + return types.state(wrapped_type) else: raise NotImplementedError("Types {} not yet implemented".format(valuetype.WhichOneof("type"))) diff --git a/coremltools/converters/mil/frontend/milproto/load.py b/coremltools/converters/mil/frontend/milproto/load.py index b6e39e406..037646360 100644 --- a/coremltools/converters/mil/frontend/milproto/load.py +++ b/coremltools/converters/mil/frontend/milproto/load.py @@ -26,6 +26,7 @@ ) from coremltools.converters.mil.mil.block import curr_block from coremltools.converters.mil.mil.ops.registry import SSAOpRegistry as _SSAOpRegistry +from coremltools.converters.mil.mil.program import StateTensorPlaceholder from .helper import proto_to_types @@ -113,8 +114,20 @@ def _load_file_value(context, filevalue_spec, dtype): blob_reader = BlobReader(filename) context.blob_reader_from_filename[filename] = blob_reader - if dtype == types.uint8: + if dtype == types.uint1: + np_value = blob_reader.read_uint1_data(offset) + elif dtype == types.uint2: + np_value = blob_reader.read_uint2_data(offset) + elif dtype == types.uint3: + np_value = blob_reader.read_uint3_data(offset) + elif dtype == types.uint4: + np_value = blob_reader.read_uint4_data(offset) + elif dtype == types.uint6: + np_value = blob_reader.read_uint6_data(offset) + elif dtype == types.uint8: np_value = blob_reader.read_uint8_data(offset) + elif dtype == types.int4: + np_value = blob_reader.read_int4_data(offset) elif dtype == types.int8: np_value = blob_reader.read_int8_data(offset) elif dtype == types.uint16: @@ -126,6 +139,10 @@ def _load_file_value(context, filevalue_spec, dtype): np_value = np.frombuffer(np_value_uint16.tobytes(), np.float16) elif dtype == types.fp32: np_value = blob_reader.read_float_data(offset) + elif dtype == types.int32: + np_value = blob_reader.read_int32_data(offset) + elif dtype == types.uint32: + np_value = blob_reader.read_uint32_data(offset) else: raise ValueError("Invalid dtype for blob file value type") @@ -133,6 +150,19 @@ def _load_file_value(context, filevalue_spec, dtype): def _restore_np_from_bytes_value(value: bytes, dtype: types, shape: Tuple[int]) -> np.ndarray: + # Import _utils here to avoid circular import. + from coremltools.optimize.coreml import _utils as optimize_utils + + if types.is_sub_byte(dtype) and isinstance(value, bytes): + result = np.frombuffer(value, types.nptype_from_builtin(dtype)) + # For sub-byte data, the np array restored from bytes is packed, so we need to unpack it. + nbits = dtype.get_bitwidth() + element_num = np.prod(shape) + are_packed_values_signed = not dtype.is_unsigned() + return optimize_utils.restore_elements_from_packed_bits( + result, nbits, element_num, are_packed_values_signed + ).reshape(shape) + return np.frombuffer(value, types.nptype_from_builtin(dtype)).reshape(shape) @@ -359,18 +389,32 @@ def _load_operation(context: TranscriptionContext, op_spec: proto.MIL_pb2.Operat vars.append(var) else: raise NotImplementedError("Binding {} not yet implemented".format(binding_type)) - op_cls = _SSAOpRegistry._get_core_op_cls(op_type) - if len(vars) == 1 and not isinstance( - op_cls.input_spec.input_types[param_name], TupleInputType - ): + + if op_type == "write_state": inputs[param_name] = vars[0] else: - inputs[param_name] = vars + op_cls = _SSAOpRegistry._get_core_op_cls(op_type) + if len(vars) == 1 and not isinstance( + op_cls.input_spec.input_types[param_name], TupleInputType + ): + inputs[param_name] = vars[0] + else: + inputs[param_name] = vars blocks = _create_nested_blocks(context, op_spec) _set_inputs_for_control_flow_op(inputs, blocks, op_type) - output_var = getattr(mb, op_type)(**inputs) + # write_state is translated into coreml_update_state + if op_type == "write_state": + new_inputs = { + "state": inputs["input"], + "value": inputs["data"], + } + getattr(mb, "coreml_update_state")(**new_inputs) + return + else: + output_var = getattr(mb, op_type)(**inputs) + if not isinstance(output_var, (tuple, list)): output_var = [output_var] @@ -397,6 +441,33 @@ def _load_operation(context: TranscriptionContext, op_spec: proto.MIL_pb2.Operat def _load_block(context, block_spec): + def _try_to_merge_state_ops(): + """ + We detect the pattern of: + + %1 = coreml_update_state(state=%state, value=%value) + %2 = read_state(input=%state) + + and transform it into: + + %2 = coreml_update_state(state=%state, value=%value) + """ + block = curr_block() + + if len(block.operations) < 2: + return + + op_1, op_2 = block.operations.end.prev.op, block.operations.end.op + if op_1.op_type != "coreml_update_state" or op_2.op_type != "read_state": + return + if op_1.state != op_2.input: + return + + var_1, var_2 = op_1.outputs[0], op_2.outputs[0] + var_1.name = var_2.name + context.register_var_with_name(var_1.name, var_1) + block.remove_ops([op_2]) + if not isinstance(block_spec, proto.MIL_pb2.Block): raise TypeError("Invalid Block spec object") @@ -407,6 +478,7 @@ def _load_block(context, block_spec): output_vars = [] for op_spec in block_spec.operations: _load_operation(context, op_spec) + _try_to_merge_state_ops() for proto_output_name in block_outputs: output_vars.append(context.get_var_from_name(proto_output_name)) @@ -428,11 +500,19 @@ def _load_function(context, func_spec, spec_version): name = named_value_type.name valuetype = proto_to_types(named_value_type.type) - if not types.is_tensor(valuetype): - raise ValueError("Functions inputs can only be tensors") - func_inputs[name] = Placeholder( - sym_shape=valuetype.get_shape(), dtype=valuetype.get_primitive(), name=name - ) + if types.is_tensor(valuetype): + func_inputs[name] = Placeholder( + sym_shape=valuetype.get_shape(), dtype=valuetype.get_primitive(), name=name + ) + elif types.is_state(valuetype): + func_inputs[name] = StateTensorPlaceholder( + sym_shape=valuetype.wrapped_type().get_shape(), + dtype=valuetype.wrapped_type().get_primitive(), + name=name, + ) + else: + raise ValueError(f"Functions input of type {valuetype} not supported.") + context.register_var_with_name(name, func_inputs[name].outputs[0]) opset = func_spec.opset diff --git a/coremltools/converters/mil/frontend/milproto/test_load.py b/coremltools/converters/mil/frontend/milproto/test_load.py index 9e3e10c1b..acd38560b 100644 --- a/coremltools/converters/mil/frontend/milproto/test_load.py +++ b/coremltools/converters/mil/frontend/milproto/test_load.py @@ -3,32 +3,31 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import itertools + import numpy as np import pytest import coremltools as ct -from coremltools import ComputeUnit +from coremltools import _SPECIFICATION_VERSION_IOS_18, ComputeUnit from coremltools._deps import _HAS_TF_2, _HAS_TORCH from coremltools.converters._converters_entry import _get_metadata_from_mlmodel from coremltools.converters.mil import Builder as mb from coremltools.converters.mil.converter import mil_convert -from coremltools.converters.mil.frontend.milproto.load import \ - load as milproto_to_pymil -from coremltools.converters.mil.frontend.tensorflow.test.test_ops import \ - TestTensorArray -from coremltools.converters.mil.frontend.tensorflow.test.testing_utils import \ - run_compare_tf -from coremltools.converters.mil.mil.ops.tests.testing_utils import \ - compare_backend +from coremltools.converters.mil.frontend.milproto.load import load as milproto_to_pymil +from coremltools.converters.mil.frontend.tensorflow.test.test_ops import TestTensorArray +from coremltools.converters.mil.frontend.tensorflow.test.testing_utils import run_compare_tf +from coremltools.converters.mil.mil import Program, types +from coremltools.converters.mil.mil.ops.tests.testing_utils import compare_backend from coremltools.converters.mil.testing_utils import ( get_op_names_in_program, - get_op_types_in_program + get_op_types_in_program, ) if _HAS_TORCH: import torch - from coremltools.converters.mil.frontend.torch.test.test_torch_ops import \ - TestScriptedModels + + from coremltools.converters.mil.frontend.torch.test.test_torch_ops import TestScriptedModels def get_pymil_prog_from_mlmodel(mlmodel): @@ -169,6 +168,62 @@ def prog(x): loaded_pymil_prog = get_pymil_prog_from_mlmodel(mlmodel) assert get_op_types_in_program(loaded_pymil_prog) == get_op_types_in_program(prog) + @pytest.mark.parametrize( + "immediate_value, dtype", + itertools.product( + (True, False), + (types.int4, types.uint4, types.int8, types.uint8), + ), + ) + def test_milproto_load_to_pymil_sub_byte(self, immediate_value: bool, dtype: types): + """Test if value in milproto (especially sub-byte) could be corrected loaded into pymil.""" + dtype_range = types.type_mapping.builtin_to_range(dtype) + data_val = [dtype_range.low, dtype_range.high] + if immediate_value: + # Tensors with less than 10 elements will be stored as immediate values. + data = np.array(data_val).reshape((1, 2, 1)) + else: + data = np.array(data_val * 20).reshape((1, 40, 1)) + + offset_val = dtype_range.high if dtype.is_unsigned() else -1 + offset = np.array([offset_val]).reshape((1, 1, 1)) + + np_dtype = types.nptype_from_builtin(dtype) + + @mb.program(input_specs=[], opset_version=ct.target.iOS18) + def prog(): + return mb.constexpr_blockwise_shift_scale( + data=data.astype(np_dtype), + scale=np.array([4]).reshape((1, 1, 1)).astype(np.float16), + offset=offset.astype(np_dtype), + ) + + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + compute_units=ct.ComputeUnit.CPU_ONLY, + minimum_deployment_target=ct.target.iOS18, + ) + pymil_prog: Program = milproto_to_pymil( + model_spec=mlmodel.get_spec(), + specification_version=ct.target.iOS18, + file_weights_dir=mlmodel.weights_dir, + ) + assert get_op_types_in_program(pymil_prog) == get_op_types_in_program(prog) + + original_ops = mlmodel._mil_program.functions["main"].find_ops( + op_type="constexpr_blockwise_shift_scale" + ) + load_back_ops = pymil_prog.functions["main"].find_ops( + op_type="constexpr_blockwise_shift_scale" + ) + for (original_op, load_back_op) in zip(original_ops, load_back_ops): + assert original_op.data.dtype == load_back_op.data.dtype + assert original_op.offset.dtype == load_back_op.offset.dtype + np.testing.assert_array_equal(original_op.data.val, load_back_op.data.val) + np.testing.assert_array_equal(original_op.offset.val, load_back_op.offset.val) + + @pytest.mark.skipif(ct.utils._macos_version() < (12, 0), reason="mlprogram predict available only on macOS12+") class TestE2ENumericalCorrectness: @@ -252,3 +307,132 @@ def test_list(self): backend=("mlprogram", "fp16") ) roundtrip_and_compare_mlmodel(mlmodel, {"Placeholder": input_values[0]}) + + +class TestStatefulModelLoad: + @staticmethod + def convert_and_load_back(prog): + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + + return milproto_to_pymil( + mlmodel.get_spec(), + specification_version=_SPECIFICATION_VERSION_IOS_18, + file_weights_dir=mlmodel.weights_dir, + ) + + @staticmethod + def check_update_prog(prog, output_name): + # check i/o types + assert len(prog.functions) == 1 + func = prog.functions["main"] + + assert len(func.inputs) == 2 + in_var = func.inputs["state_workaround"] + assert types.is_state(in_var.sym_type) + assert in_var.name == "state_workaround" + assert in_var.shape == (2, 3) + assert in_var.dtype == types.fp16 + + in_var_2 = func.inputs["x"] + assert in_var_2.name == "x" + assert in_var_2.shape == (2, 3) + assert in_var_2.dtype == types.fp16 + + assert len(func.outputs) == 1 + out_var = func.outputs[0] + assert out_var.name == output_name + assert out_var.shape == (2, 3) + assert out_var.dtype == types.fp16 + + # check op + get_op_types_in_program(prog) == ["coreml_update_state"] + + def test_load_read_state(self): + @mb.program( + input_specs=[ + mb.StateTensorSpec((2, 3), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def prog(x): + return mb.read_state(input=x, name="out") + + new_prog = self.convert_and_load_back(prog) + + # check i/o types + assert len(new_prog.functions) == 1 + func = new_prog.functions["main"] + + assert len(func.inputs) == 1 + in_var = func.inputs["x"] + assert types.is_state(in_var.sym_type) + assert in_var.name == "x" + assert in_var.shape == (2, 3) + assert in_var.dtype == types.fp16 + + assert len(func.outputs) == 1 + out_var = func.outputs[0] + assert out_var.name == "out" + assert out_var.shape == (2, 3) + assert out_var.dtype == types.fp16 + + # check op + get_op_types_in_program(new_prog) == ["read_state"] + + def test_load_coreml_update_state(self): + @mb.program( + input_specs=[ + mb.StateTensorSpec((2, 3), dtype=types.fp16), + mb.TensorSpec((2, 3), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def prog(state, x): + return mb.coreml_update_state(state=state, value=x, name="out") + + new_prog = self.convert_and_load_back(prog) + self.check_update_prog(new_prog, "out") + + def test_load_coreml_update_state_singular(self): + @mb.program( + input_specs=[ + mb.StateTensorSpec((2, 3), dtype=types.fp16), + mb.TensorSpec((2, 3), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def prog(state, x): + mb.coreml_update_state(state=state, value=x) + return x + + new_prog = self.convert_and_load_back(prog) + self.check_update_prog(new_prog, "x") + + def test_load_state_complex(self): + @mb.program( + input_specs=[ + mb.StateTensorSpec((2, 3), dtype=types.fp16), + mb.TensorSpec((2, 3), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def prog(state, x): + read_state = mb.read_state(input=state) + add = mb.add(x=read_state, y=np.float16([0.1])) + value = mb.coreml_update_state(state=state, value=add) + add = mb.add(x=value, y=x) + mb.coreml_update_state(state=state, value=add) + return add + + new_prog = self.convert_and_load_back(prog) + assert get_op_types_in_program(new_prog) == [ + "read_state", + "add", + "coreml_update_state", + "add", + "coreml_update_state", + ] diff --git a/coremltools/converters/mil/frontend/tensorflow/ops.py b/coremltools/converters/mil/frontend/tensorflow/ops.py index 632b4b43a..056ce33b5 100644 --- a/coremltools/converters/mil/frontend/tensorflow/ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/ops.py @@ -869,6 +869,7 @@ def Neg(context, node): def NotEqual(context, node): x = context[node.inputs[0]] y = context[node.inputs[1]] + x, y = promote_input_dtypes([x, y]) x = mb.not_equal(x=x, y=y, name=node.name) context.add(node.name, x) diff --git a/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py b/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py index 8a9eff1cc..c879df64b 100644 --- a/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py @@ -5616,12 +5616,6 @@ def test_top_k(self, compute_unit, backend, rank, k, sort): pytest.skip("iOS16 version topk needed for sort = False") if not sort and _macos_version() < (13, 0): pytest.skip("New functionality in macOS13/iOS16") - if rank == 5 and k is None and sort and ( - backend[0] == "neuralnetwork" or ( - platform.machine() == "x86_64" and _macos_version() < (15, 0) - ) - ): - pytest.xfail("rdar://120891130: TopK failing randomly") # TensorFlow only supports last dimension (axis = -1). shape = np.random.randint(low=3, high=4, size=rank) diff --git a/coremltools/converters/mil/frontend/tensorflow2/test/test_tf2_conversion_api.py b/coremltools/converters/mil/frontend/tensorflow2/test/test_tf2_conversion_api.py index a867bbe31..966457597 100644 --- a/coremltools/converters/mil/frontend/tensorflow2/test/test_tf2_conversion_api.py +++ b/coremltools/converters/mil/frontend/tensorflow2/test/test_tf2_conversion_api.py @@ -209,9 +209,6 @@ def teardown_class(self): backends, ) def test_convert_tf_keras_h5_file(backend): - if platform.machine() == "arm64": - pytest.xfail("rdar://101162740 ([CI] [TF] The tf_keras_h5_file API testing is failing on M1 with new OS)") - for file_extension in ("h5", "hdf5"): x = tf.keras.Input(shape=(32,), name="input") y = tf.keras.layers.Dense(16, activation="softmax")(x) @@ -225,7 +222,11 @@ def test_convert_tf_keras_h5_file(backend): test_input = np.random.rand(2, 32) expected_val = keras_model(test_input) results = mlmodel.predict({"input": test_input}) - np.testing.assert_allclose(results["Identity"], expected_val, rtol=1e-2, atol=1e-2) + + # We should check the numerical on Rosetta after the radar is fixed: + # rdar://126185417 ([CI][TF] Two TF2 API testing is failing on Rosetta with numerical issues) + if platform.machine() == "arm64": + np.testing.assert_allclose(results["Identity"], expected_val, rtol=1e-2, atol=1e-2) @staticmethod @pytest.mark.parametrize( @@ -242,7 +243,11 @@ def test_convert_tf_keras_model(backend): test_input = np.random.rand(2, 32) expected_val = keras_model(test_input) results = mlmodel.predict({"input": test_input}) - np.testing.assert_allclose(results["Identity"], expected_val, rtol=0.005) + + # We should check the numerical on Rosetta after the radar is fixed: + # rdar://126185417 ([CI][TF] Two TF2 API testing is failing on Rosetta with numerical issues) + if platform.machine() == "arm64": + np.testing.assert_allclose(results["Identity"], expected_val, rtol=0.005) @staticmethod @pytest.mark.parametrize( diff --git a/coremltools/converters/mil/frontend/torch/converter.py b/coremltools/converters/mil/frontend/torch/converter.py index a58401131..5acdca2de 100644 --- a/coremltools/converters/mil/frontend/torch/converter.py +++ b/coremltools/converters/mil/frontend/torch/converter.py @@ -3,9 +3,12 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import math from collections import OrderedDict -from typing import List, Optional, Union +from enum import Enum +from typing import Dict, List, Optional, Union +import attrs import numpy as np import torch as torch from torch.jit._script import RecursiveScriptModule @@ -14,15 +17,17 @@ from coremltools._deps import _HAS_TORCH_EXPORT_API from coremltools.converters.mil import mil from coremltools.converters.mil._deployment_compatibility import AvailableTarget as _target -from coremltools.converters.mil.input_types import ImageType, InputType, TensorType +from coremltools.converters.mil.frontend import _utils as frontend_utils +from coremltools.converters.mil.input_types import ImageType, InputType, StateType, TensorType from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import Function, Placeholder, Program, types from coremltools.converters.mil.mil.block import is_current_opset_version_compatible_with from coremltools.converters.mil.mil.scope import ScopeInfo, ScopeSource -from coremltools.converters.mil.mil.types import is_float +from coremltools.converters.mil.mil.types import builtin_to_string, is_float from coremltools.converters.mil.mil.var import Var +from coremltools.optimize.coreml import _utils as optimize_utils +from coremltools.optimize.coreml._quantization_passes import prune_weights -from .._utils import get_output_names from .internal_graph import InternalTorchIRGraph, InternalTorchIRNode from .ops import convert_nodes from .quantization_ops import _dequantized_weight @@ -42,6 +47,68 @@ from torch.export import ExportedProgram +# The compression info is stored in state_dict with the prefix, e.g. "dense2._COREML_n_bits". +_COMPRESSION_INFO_PREFIX = "_COREML_" + + +# TODO: Share the enum between cto.coreml and cto.torch (rdar://124409664). +class CompressionType(Enum): + PRUNING = 1 + PALETTIZATION = 2 + QUANTIZATION = 3 + + +@attrs.define(kw_only=True) +class CompressionInfo: + """ + This class stores the compression info carried by the traced torch model. + """ + + # Quantization related fields. + quantization_n_bits: Optional[int] = attrs.field( + default=None, + validator=attrs.validators.optional([attrs.validators.instance_of(int)]), + converter=attrs.converters.optional(int), + ) + quantization_scale: Optional[torch.Tensor] = attrs.field( + default=None, + validator=attrs.validators.optional([attrs.validators.instance_of(torch.Tensor)]), + ) + zero_point: Optional[torch.Tensor] = attrs.field( + default=None, + validator=attrs.validators.optional([attrs.validators.instance_of(torch.Tensor)]), + ) + + # Palettization related fields. + lut: Optional[torch.Tensor] = attrs.field( + default=None, + validator=attrs.validators.optional([attrs.validators.instance_of(torch.Tensor)]), + ) + palettization_scale: Optional[torch.Tensor] = attrs.field( + default=None, + validator=attrs.validators.optional([attrs.validators.instance_of(torch.Tensor)]), + ) + + # Compression type indication fields. + compression_type: Optional[List[int]] = attrs.field( + default=None, + converter=attrs.converters.optional(lambda tensor: tensor.tolist()), + ) + + @quantization_n_bits.validator + def check_n_bits(self, attribute, n_bits): + if n_bits is not None and not 1 <= n_bits <= 8: + raise ValueError(f"Only support quantization_n_bits between 1 and 8, but got {n_bits}") + + @compression_type.validator + def check_compression_type(self, attribute, compression_type): + if compression_type is not None: + if not all(isinstance(type_val, int) for type_val in compression_type): + raise ValueError( + f"Only support int compression_type, but got {type(compression_type)}" + ) + + def _convert_to_torch_inputtype(inputs: List[TensorType]) -> List[TensorType]: input_type = [] for _input in inputs: @@ -215,6 +282,8 @@ def __init__( self._torch_graph = None if frontend == TorchFrontend.TORCHSCRIPT: self._quant_context = QuantizationContext(self) + # Dict to map a var's name into its corresponding source state var. + self.name_to_source_state = dict() @property def torch_graph(self): @@ -232,13 +301,139 @@ def torch_graph(self, graph: InternalTorchIRGraph): def prepare_for_conversion(self, node: InternalTorchIRNode) -> None: """ - Perform any preparation necessary before node-specific frontend conversion - is invoked. + Perform any preparation necessary before node-specific frontend conversion is invoked. + + This utility check if the input is a function state input, and + convert it into a tensor type. + + For instance, given the following torchscript graph: + + %x(state, fp16), %y(tensor, fp32) -> { + %1 = add(%x, %y) + } + + The graph is translated into: + + %x(state, fp16), %y(tensor, fp32) -> { + %read_x = read_state(%x) + %read_x_cast = cast(%read_x, "fp32") + %1 = add(%read_x_cast, %y) + } + + ``%read_x_cast`` is cached in ``name_to_source_state``, to make sure one + state feeds into only one ``read_state`` op. """ + for val in node.inputs: + if val is None: + continue + if val not in self: + continue + in_node = self[val] + if in_node is None or not isinstance(in_node, Var): + continue + if types.is_state(in_node.sym_type): + self.name_to_source_state[val] = self[val] + assert ( + in_node.op is None + ), f"A state type var must come from a placeholder. Got parent op {in_node.op.op_type} instead." + read_state = mb.read_state(input=in_node) + read_state_fp32 = mb.cast(x=read_state, dtype="fp32") + self.add(read_state_fp32, torch_name=val, override=True) return - def process_inplace_op(self, node: InternalTorchIRNode): - return + def process_inplace_op(self, node: InternalTorchIRNode) -> None: + """ + This utility: + + 1. adds ``mb.coreml_update_state`` after each torch inplace ops. + 2. adjusts the dtype across state / tensor. + + In torch, inplaces ops have the following properties: + + 1. op type has the suffix of ``_``. For instance, ``add_``, ``mul_``, etc. + 2. The op does an inplace update for the first input tensor. + + For instance, the following syntax of a TorchScript: + + %3 = add_(%1, %2) + + denotes an inplace ``add`` operation on the ``%1`` tensor. The memory buffer + of ``%1`` is updated and returned as a reference ``%3``. + + Here are the steps what this utility does, lets use the above + simple torch script as an example, after adding the ``add_`` in the context, + we currently have a MIL graph as ``%3 = add(x=%1, y=%2)``: + + 1. Validate the first input (``%1``) comes from a state source by checking if the tensor's name ``1`` is in ``name_to_source_state``. If not, this utility does nothing. + 2. Say ``name_to_source_state["1"] = %state``. ``%state, %3`` can potentially has different dtype. + For instance, the user could specify ``%state`` in fp16, while + the MIL program in the front end conversion stage is + still in fp32. Hence we cast ``%3`` into ``%state``'s dtype: + + (%state: fp16) -> { + ... + %3_ = add(x=%1, y=%2) + %3_cast = cast(x=%3_, dtype="fp16") + } + 3. Insert a ``coreml_update_state`` and cast the output back to ``%3``'s original dtype: + + (%state: fp16) -> { + ... + %3_ = add(x=%1, y=%2) + %3_cast = cast(x=%3_, dtype="fp16") + %3_update = coreml_update_state(state=%state, value=%3_cast) + %3 = cast(x=%3_update, dtype="fp32") + } + 4. Set ``name_to_source_state["3"] = %state``, so the state chain can be used in the downstream. + + The below Torch Script model, + + (%state: fp16) -> { + ... + %3 = add_(%1, %2) + %out = sub_(%3, %4) + } + + will result in: + + (%state: fp16) -> { + %1_ = read_state(%state) + %1 = cast(x=%1_, dtype="fp32") + %3_ = add(x=%1, y=%2) + %3_cast = cast(x=%3_, dtype="fp16") + %3_update = coreml_update_state(state=%state, value=%3_cast) + %3 = cast(x=%3_update, dtype="fp32") + %out_ = sub(x=%3, y=%4) + %out_cast = cast(x=%out_, dtype="fp16") + %out_update = coreml_update_state(state=%state, value=%out_cast) + %out = cast(x=%out_update, dtype="fp32") + } + + Please note that, the intermediate ``cast`` ops would be removed + by the ``add_fp16_cast`` + ``cast_optimization`` graph passes: + + (%state: fp16) -> { + %1 = read_state(%state) + %3_ = add(x=%1, y=%2) + %3 = coreml_update_state(state=%state, value=%3_) + %out_ = sub(x=%3, y=%4) + %out = coreml_update_state(state=%state, value=%out_) + } + + """ + if len(node.inputs) == 0: + return + + if node.inputs[0] not in self.name_to_source_state: + return + + source_state = self.name_to_source_state[node.inputs[0]] + self.name_to_source_state[node.name] = source_state + value_node = self[node.name] + cast_value = mb.cast(x=value_node, dtype=builtin_to_string(source_state.dtype)) + update = mb.coreml_update_state(state=source_state, value=cast_value) + cast_update = mb.cast(x=update, dtype=builtin_to_string(value_node.dtype), name=node.name) + self.add(cast_update, torch_name=node.name, override=True) def add(self, ssa_var: Var, torch_name: Optional[str] = None, override=False) -> None: """ @@ -324,6 +519,7 @@ def __init__( cut_at_symbols: Optional[List[str]] = None, opset_version: Optional[int] = None, use_default_fp16_io: bool = False, + states: Optional[List[StateType]] = None, ) -> None: """ Arguments: @@ -343,20 +539,26 @@ def __init__( """ self.use_default_fp16_io = use_default_fp16_io - if inputs is not None: - inputs = _convert_to_torch_inputtype(inputs) - self.inputs = inputs - for idx, inp in enumerate(self.inputs): - if isinstance(inp, ImageType) and self.inputs[idx].channel_first is None: - self.inputs[idx].channel_first = True + # process inputs + if inputs is None: + inputs = [] + self.inputs = _convert_to_torch_inputtype(inputs) + for idx, inp in enumerate(self.inputs): + if isinstance(inp, ImageType) and self.inputs[idx].channel_first is None: + self.inputs[idx].channel_first = True - if self.use_default_fp16_io: - # If the input type is not specified by the user and use_default_fp16_io - # is True. Make the default input type to fp16 - self._adjust_default_input_to_fp16() + # process states + if states is None: + states = [] + self.states = states + + if self.use_default_fp16_io: + # If the input type is not specified by the user and use_default_fp16_io + # is True. Make the default input type to fp16 + self._adjust_default_input_to_fp16() self.outputs = outputs - self.output_names = get_output_names(self.outputs) + self.output_names = frontend_utils.get_output_names(self.outputs) self.opset_version = _target(opset_version) if opset_version is not None else None self._prog = mil.Program() @@ -388,18 +590,64 @@ def __init__( ) self.context.torch_graph = self.graph - self.inputs = list(self.graph.inputs.values()) + self._validate_states() + + # Store the mapping from parameter name (such as "dense1.weight") to the compression info. + self.param_to_compression_info: Dict[str, CompressionInfo] = dict() + if self.opset_version is not None and self.opset_version >= _target.iOS16: + # Notice that even the compression info in registered buffer is kept in self.graph, + # we still want to explicitly construct it here, to make it useful for both TorchScript + # and ExportedProgram. + state_dict = loaded_model.state_dict + self.param_to_compression_info = self._construct_compression_info( + state_dict() if callable(state_dict) else state_dict + ) - def _adjust_default_input_to_fp16(self): + def _validate_states(self) -> None: + """ + Validate that the user provided states is consistent with the + registered buffer in the torchscript model. + """ + if len(self.states) > 0: + for state in self.states: + if state.name is None or state.name not in self.graph.buffers: + raise ValueError( + f"StateType named {state.name} not provided or " + "not found in the source torch model. " + "Please make sure the name in " + "'ct.StateType(name=..., wrapped_type=ct.TensorType(...))' " + f"match the 'named_buffers()' in the source torch model: {list(self.graph.buffers.keys())}" + ) + + state_shape = state.shape.shape + buffer_shape = tuple(self.graph.buffers[state.name].size()) + if state_shape != buffer_shape: + raise ValueError( + f"StateType shape {state_shape} must matched the torch buffer shape {buffer_shape}." + ) + + if self.opset_version is None or self.opset_version < _target.iOS18: + raise ValueError( + "State model is supported only >= iOS18. " + "Please update the minimum_deployment_target to at least coremltools.target.iOS18" + ) + self.inputs.extend(self.states) + + def _adjust_default_input_to_fp16(self) -> None: """ An utility function that sets the default input dtype to fp16 """ - assert isinstance(self.inputs, list), "inputs must be type of list" - # Adjust inputs dtype to fp16 - for val in self.inputs: - if isinstance(val, TensorType) and val.dtype is None: - val.dtype = types.fp16 + + def _adjust_default_input_to_fp16_helper(inputs: InputType): + assert isinstance(inputs, list), "inputs must be type of list" + # Adjust inputs dtype to fp16 + for val in inputs: + if isinstance(val, (StateType, TensorType)) and val.dtype is None: + val.dtype = types.fp16 + + _adjust_default_input_to_fp16_helper(self.inputs) + _adjust_default_input_to_fp16_helper(self.states) def _adjust_default_output_to_fp16(self, graph_outputs): """ @@ -444,13 +692,12 @@ def _check_ops(graph): return implemented_ops, missing_ops @staticmethod - def _create_placeholder( - _input: TensorType, - ) -> Placeholder: + def _create_placeholder(_input: InputType) -> Placeholder: """ Converts an InputType into a Placeholder. - _input: TensorType + 1. ``StateType`` into ``mb.state_tensor_placeholder``. + 2. ``TensorType`` and ``ImageType`` into ``mb.placeholder``. """ shape = _input.shape.symbolic_shape dtype = _input.dtype @@ -459,10 +706,364 @@ def _create_placeholder( dtype = types.int32 elif dtype == types.fp64: dtype = types.fp32 + + if isinstance(_input, StateType): + return mb.state_tensor_placeholder(shape, dtype=dtype) + return mb.placeholder(shape, dtype=dtype) + @staticmethod + def _construct_compression_info( + state_dict: Dict[str, torch.Tensor], + ) -> Dict[str, CompressionInfo]: + """ + Construct compression info from the traced model's state_dict. + + The state_dict of the traced model is something like + { + 'dense1.weight': xxx, 'dense1.bias': xxx, + 'dense1._COREML_/weight/quantization_n_bits': tensor(4), + 'dense1._COREML_/weight/quantization_scale': xxx, + 'dense1._COREML_/weight/zero_point': xxx, + 'dense1._COREML_/weight/compression_type': tensor([3]), + 'dense2.weight': xxx, + ... + } + + We extract the compression info and store it as a dict + { + 'dense1.weight': CompressionInfo(quantization_n_bits=4, quantization_scale=xxx, + zero_point=xxx, compression_type=[QUANTIZATION]), + 'dense2.weight': ... + } + """ + compression_info = dict() + for torch_key_name in state_dict.keys(): + if torch_key_name == f"{_COMPRESSION_INFO_PREFIX}/metadata_version": + # TODO: rdar://124707382 ([Compression] Support versioning in CompressionInfo) + continue + + if _COMPRESSION_INFO_PREFIX in torch_key_name: + module_name = None + buffer_name = torch_key_name + if not torch_key_name.startswith(_COMPRESSION_INFO_PREFIX): + module_name, buffer_name = torch_key_name.rsplit(".", 1) + _, param_name, compression_key = buffer_name.rsplit("/", 2) + if module_name: + param_name = f"{module_name}.{param_name}" + + if param_name not in compression_info: + compression_info[param_name] = CompressionInfo() + setattr( + compression_info[param_name], + compression_key, + state_dict[torch_key_name], + ) + + return compression_info + + def _has_compression_info(self, param_name: str) -> bool: + """Check if the parameter carries compression info.""" + return param_name in self.param_to_compression_info + + def _construct_quantization_op( + self, + weight: np.ndarray, + compression_info: CompressionInfo, + name: str, + compressed_var: Optional[Var] = None, + ) -> Var: + """ + The weight is constructed by `weight = scale * (quantized_data - zero_point)`. + We need to restore the quantized_data to construct the quantization op. + + If compressed_var is not None, it's the var constructed by a previous compression function, + which means this is a joint compression. For example, if the compression_info.compression_type + is [CompressionType.PRUNING, CompressionType.QUANTIZATION], the compressed_var is the var + produced by the pruning. + """ + if compression_info.quantization_n_bits is None: + raise ValueError("quantization_n_bits must be specified in quantization.") + if compression_info.quantization_scale is None: + raise ValueError("quantization_scale must be specified in quantization.") + + scale = compression_info.quantization_scale.detach().numpy() + zero_point: Optional[np.ndarray] = None + if compression_info.zero_point is not None: + zero_point = compression_info.zero_point.detach().numpy() + # For conv/conv_transpose, the weight has rank=4, so we auto-expand scale and zero-point if + # it only has two elements. + if len(weight.shape) == 4 and len(scale.shape) == 2: + scale = np.expand_dims(np.expand_dims(scale, axis=-1), axis=-1) + if zero_point is not None: + zero_point = np.expand_dims(np.expand_dims(zero_point, axis=-1), axis=-1) + + if len(weight.shape) != len(scale.shape): + raise ValueError( + f"In {name}, the `weight` should have same rank as `scale`, but got {weight.shape} vs {scale.shape}" + ) + if zero_point is not None: + if len(weight.shape) != len(zero_point.shape): + raise ValueError( + f"In {name}, the `weight` should have same rank as `zero_point`, but got {weight.shape} vs {zero_point.shape}" + ) + + # The scale has shape [.., block_num, ..], which means each scale is for one block. As + # weight has shape [.., block_num*block_size, ..], we need to interleave repeat it. + scale_repeated = scale + zero_point_repeated = zero_point + for axis, weight_dim_size in enumerate(weight.shape): + scale_dim_size = scale.shape[axis] + if weight_dim_size != scale_dim_size and scale_dim_size != 1: + # Only repeat axis where dim size is not 1, because 1 will be auto-broadcast by np. + block_size = weight_dim_size // scale.shape[axis] + scale_repeated = np.repeat(scale_repeated, block_size, axis=axis) + if zero_point_repeated is not None: + zero_point_repeated = np.repeat(zero_point_repeated, block_size, axis=axis) + + quantized_data = np.round(weight / scale_repeated) + if zero_point_repeated is not None: + quantized_data += zero_point_repeated + + # Adjust dtype based on nbits. + dtype_str_prefix = "int" + if quantized_data.min() >= 0 and (zero_point is None or zero_point.min() >= 0): + dtype_str_prefix = "uint" + dtype_str = dtype_str_prefix + str(compression_info.quantization_n_bits) + builtin_dtype = types.string_to_builtin(dtype_str) + np_dtype = types.nptype_from_builtin(builtin_dtype) + + builtin_range = types.type_mapping.builtin_to_range(builtin_dtype) + quantized_data = np.clip(quantized_data, builtin_range.low, builtin_range.high).astype( + np_dtype + ) + if zero_point is not None: + zero_point = zero_point.astype(np_dtype) + + if compressed_var is None: + return frontend_utils._construct_constexpr_dequant_op( + quantized_data, zero_point, scale, name=name + ) + else: + # Specially handles joint compression, such as using sparse op if joint with pruning. + if compressed_var.op.op_type == "constexpr_sparse_to_dense": + mask, nonzero_data = mb.constexpr_sparse_blockwise_shift_scale( + data_mask=compressed_var.op.mask, + nonzero_data=quantized_data[compressed_var.op.mask.val != 0].flatten(), + scale=scale, + offset=zero_point, + before_op=compressed_var.op, + name=compressed_var.op.name + "_quantized", + ) + return mb.constexpr_sparse_to_dense(nonzero_data=nonzero_data, mask=mask, name=name) + else: + raise ValueError( + "Unsupported joint compression combination. The quantization can only be joint " + f"with pruning, but got {compressed_var.op.op_type}. Please check the value of " + "'compression_type' in your registered buffers." + ) + + @staticmethod + def _construct_palettization_op( + weight: np.ndarray, + compression_info: CompressionInfo, + name: str, + compressed_var: Optional[Var] = None, + ) -> Var: + """ + The weight is constructed by 2**nbits unique values in each group. + + When `palettization_scale` is provided, it means the weight has scales before got palettized. + More specifically, the diagram is: + + lut(fp16) \ + -> constexpr_lut_to_dense -> dense(fp16) -> constexpr_blockwise_shift_scale -> dense(fp16) + indices / + + If compressed_var is not None, it's the var constructed by a previous compression function, + which means this is a joint compression. For example, if the compression_info.compression_type + is [CompressionType.PRUNING, CompressionType.PALETTIZATION], the compressed_var is the var + produced by the pruning. + """ + if compression_info.lut is None: + raise ValueError("Missing lut in compression info. Please register a buffer for lut.") + + lut = compression_info.lut.detach().numpy() + if len(lut.shape) == len(weight.shape) + 2: + if lut.shape[-1] > 1: + raise NotImplementedError( + "Doesn't support Vector Palettization (last dim in lut > 1). " + "Implementation is tracked in rdar://124474258" + ) + elif len(lut.shape) == len(weight.shape) + 1: + # The last dim to indicate vector size is by default 1 for scalar palettization. + lut = np.expand_dims(lut, axis=-1) + else: + raise ValueError( + "The rank of lut is invalid. It should match the weight dimension. " + f"Got {len(lut.shape)} vs {len(weight.shape)}" + ) + + assert len(lut.shape) == len(weight.shape) + 2 + num_palettes = lut.shape[-2] + nbits = int(math.ceil(math.log2(num_palettes))) + if 2**nbits != num_palettes: + # Padding lut to make it has 2**nbits dim size on -2 axis. + padding_shape = list(lut.shape) + padding_shape[-2] = 2**nbits - num_palettes + lut = np.concatenate([lut, np.zeros(padding_shape, dtype=lut.dtype)], axis=-2) + num_palettes = lut.shape[-2] + + if compression_info.palettization_scale is not None: + # The weight has scales, which means the palettization is on the pre-scale data. + scale = compression_info.palettization_scale.detach().numpy() + # For conv/conv_transpose, the weight has rank=4, so we auto-expand scale and zero-point if + # it only has two elements. + if len(weight.shape) == 4 and len(scale.shape) == 2: + scale = np.expand_dims(np.expand_dims(scale, axis=-1), axis=-1) + if len(scale.shape) != len(weight.shape): + raise ValueError( + f"In {name}, the scale should have the same rank as weight, but got " + f"{scale.shape} vs {weight.shape}." + ) + weight = weight / scale + + indices = optimize_utils.find_indices_for_lut(weight, lut) + + if compressed_var is None: + if is_current_opset_version_compatible_with(_target.iOS18): + result = mb.constexpr_lut_to_dense(indices=indices, lut=lut, name=name) + else: + if np.prod(lut.shape[:-2]) > 1: + raise ValueError( + "More than one look-up-table (lut) per tensor is only supported in iOS18+. " + "Please set the minimum_deployment_target to iOS18 or later." + ) + # Convert iOS18 lut params to pre-iOS18 compatible format. + lut = lut.reshape([num_palettes]) + result = mb.constexpr_lut_to_dense( + indices=optimize_utils.pack_elements_into_bits(indices, nbits), + lut=lut, + shape=np.uint32(indices.shape), + name=name, + ) + else: + # Specially handles joint compression, such as using sparse op if joint with pruning. + if compressed_var.op.op_type == "constexpr_sparse_to_dense": + mask, nonzero_data = mb.constexpr_lut_to_sparse( + indices_mask=compressed_var.op.mask, + indices_nonzero_data=indices[compressed_var.op.mask.val != 0].flatten(), + lut=lut, + before_op=compressed_var.op, + name=compressed_var.op.name + "_palettized", + ) + result = mb.constexpr_sparse_to_dense( + nonzero_data=nonzero_data, mask=mask, name=name + ) + else: + raise ValueError( + "Unsupported joint compression combination. The palettization can only be joint " + f"with pruning, but got {compressed_var.op.op_type}. Please check the value of " + "'compression_type' in your registered buffers." + ) + + if compression_info.palettization_scale is not None: + if not is_current_opset_version_compatible_with(_target.iOS18): + raise ValueError( + "The palettization with per-channel-scale is only supported in iOS18+. Please " + "set the minimum_deployment_target to iOS18 or later." + ) + result = mb.constexpr_blockwise_shift_scale( + data=result, scale=scale, offset=None, name=name + ) + return result + + @staticmethod + def _construct_sparsification_op( + weight: np.ndarray, + compression_info: CompressionInfo, + name: str, + compressed_var: Optional[Var] = None, + ) -> Var: + sparse_params = prune_weights.compress_by_threshold( + weight, threshold=np.finfo(np.float16).eps, minimum_sparsity_percentile=0 + ) + if sparse_params is None: + raise ValueError( + f"Unable to construct sparsified op. Please check if the weight {name} " + "is sparse." + ) + if is_current_opset_version_compatible_with(_target.iOS18): + sparse_params_ios18 = optimize_utils.ios16_sparse_params_to_ios18(sparse_params) + return mb.constexpr_sparse_to_dense( + nonzero_data=sparse_params_ios18.nonzero_data, + mask=sparse_params_ios18.mask, + name=name, + ) + else: + return mb.constexpr_sparse_to_dense( + nonzero_data=sparse_params.nonzero_data, + mask=sparse_params.mask, + shape=np.uint32(sparse_params.shape), + name=name, + ) + + def _construct_compression_op(self, val: np.ndarray, param_name: str) -> Var: + """Construct the compression op based on the compression info.""" + compression_info: CompressionInfo = self.param_to_compression_info[param_name] + + shared_msg = ( + "There are coreml compression related buffers registered in the torch " + f"model (with {_COMPRESSION_INFO_PREFIX} in the buffer's name) for {param_name}" + ) + if not compression_info.compression_type: + raise ValueError( + shared_msg + ", but the 'compression_type' is not set. Please set it to indicate " + "the type of compression used on the weight." + ) + if len(compression_info.compression_type) > 3: + raise ValueError( + shared_msg + ", but the 'compression_type' has too many values. Support at most 3 " + "values." + ) + + if len(compression_info.compression_type) > 1: + if not is_current_opset_version_compatible_with(_target.iOS18): + raise ValueError( + "The joint compression (more than one values in 'compression_type') is only " + "supported in iOS18+. Please set minimum_deployment_target to iOS18 or later." + ) + + result: Optional[Var] = None + for idx, type_val in enumerate(compression_info.compression_type): + if CompressionType(type_val) == CompressionType.QUANTIZATION: + result = self._construct_quantization_op(val, compression_info, param_name, result) + elif CompressionType(type_val) == CompressionType.PALETTIZATION: + result = self._construct_palettization_op(val, compression_info, param_name, result) + else: + assert CompressionType(type_val) == CompressionType.PRUNING + result = self._construct_sparsification_op( + val, compression_info, param_name, result + ) + + if result is None: + raise AssertionError(shared_msg + f", but unable to compress weight {param_name}") + return result + def _add_const(self, name: str, val: Union[torch.Tensor, torch._C.ScriptObject]) -> None: """Create a const op and add it to the graph.""" + if isinstance(val, torch.Tensor) and self._has_compression_info(name): + try: + compression_op = self._construct_compression_op(val.detach().numpy(), name) + self.context.add(compression_op) + return + except NotImplementedError as e: + logger.warning( + "Failed to create a compression op based on the compression info " + f"carried by {name} in the torch model. Ignored the compression info " + f"and constructed a normal const. Detailed error message:\n{e}" + ) + if isinstance(val, torch._C.ScriptObject): logger.info(f"Encountered constant {name} of type _torch._C.ScriptObject") return @@ -505,6 +1106,7 @@ def convert_const(self) -> None: # since inputs/constants will not contribute to debugging/profiling # TODO (rdar://125572392): Support torch.export IO metadata with mb.scope( + ScopeInfo(source=ScopeSource.EXIR_STACK_TRACE, data=[None]), ScopeInfo(source=ScopeSource.EXIR_DEBUG_HANDLE, data=[None]), ): self._add_const(name, val) @@ -578,7 +1180,8 @@ def convert(self) -> Program: self.convert_const() # Add the rest of the operations - convert_nodes(self.context, self.graph) + has_states = len(getattr(self, "states", [])) > 0 + convert_nodes(self.context, self.graph, early_exit=not has_states) graph_outputs = [self.context[name] for name in self.graph.outputs] @@ -620,7 +1223,10 @@ def convert(self) -> Program: ScopeSource.TORCHSCRIPT_MODULE_TYPE, ] elif self.context.frontend == TorchFrontend.EXIR: - essential_scope_sources = [ScopeSource.EXIR_DEBUG_HANDLE] + essential_scope_sources = [ + ScopeSource.EXIR_STACK_TRACE, + ScopeSource.EXIR_DEBUG_HANDLE, + ] else: raise ValueError(f"Invalid PyTorch frontend {self.context.frontend}") prog._add_essential_scope_source(essential_scope_sources) diff --git a/coremltools/converters/mil/frontend/torch/internal_graph.py b/coremltools/converters/mil/frontend/torch/internal_graph.py index 06f48a20a..8803f1908 100644 --- a/coremltools/converters/mil/frontend/torch/internal_graph.py +++ b/coremltools/converters/mil/frontend/torch/internal_graph.py @@ -257,10 +257,10 @@ def get_arguments(alist): raise AssertionError(f"Unhandled type of the node: {type(i)}") return tuple(args) + # TODO (rdar://128768037) handle kwargs inputs = get_arguments(node.args) - outputs = [ - node.name - ] # TODO: rdar://115846125 ([Executorch] Handle Models/Layers with Multiple outputs) + # TODO: rdar://115846125 ([Executorch] Handle Models/Layers with Multiple outputs) + outputs = [node.name] try: kind = node.target.name() diff --git a/coremltools/converters/mil/frontend/torch/load.py b/coremltools/converters/mil/frontend/torch/load.py index a2a2c5008..3593103cc 100644 --- a/coremltools/converters/mil/frontend/torch/load.py +++ b/coremltools/converters/mil/frontend/torch/load.py @@ -12,7 +12,7 @@ from coremltools import _logger as logger from coremltools._deps import _HAS_TORCH_EXPORT_API from coremltools.converters.mil.frontend.torch.converter import TorchConverter -from coremltools.converters.mil.input_types import TensorType +from coremltools.converters.mil.input_types import StateType, TensorType from coremltools.converters.mil.mil.program import Program from .converter import TorchConverter @@ -29,6 +29,7 @@ def load( outputs: Optional[List[TensorType]] = None, cut_at_symbols: Optional[List[str]] = None, use_default_fp16_io: bool = False, + states: Optional[List[StateType]] = None, **kwargs ) -> Program: """ @@ -77,6 +78,7 @@ def load( cut_at_symbols, specification_version, use_default_fp16_io, + states, ) return _perform_torch_convert(converter, debug) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index ed415e5da..36f69d386 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -18,6 +18,7 @@ from coremltools import _logger as logger from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target from coremltools.converters.mil.frontend import _utils +from coremltools.converters.mil.frontend.milproto.load import TranscriptionContext from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import Symbol, types from coremltools.converters.mil.mil.block import is_current_opset_version_compatible_with @@ -33,6 +34,7 @@ from coremltools.converters.mil.mil.var import ListVar, Var from .._utils import build_einsum_mil, value_at +from .internal_graph import InternalTorchIRGraph from .torch_op_registry import _TORCH_OPS_REGISTRY, register_torch_op from .utils import ( NUM_TO_DTYPE_STRING, @@ -68,7 +70,11 @@ def _all_outputs_present(context, graph): return True -def convert_nodes(context, graph): +def convert_nodes( + context: TranscriptionContext, + graph: InternalTorchIRGraph, + early_exit: Optional[bool] = True, +) -> None: """ Iterate over the nodes of a graph or block and convert to MIL. @@ -86,7 +92,7 @@ def convert_nodes(context, graph): logger.error(f"\n\nERROR - converting '{node.kind}' op (located at: '{op_location}'):\n") raise e # re-raise exception - if _all_outputs_present(context, graph): + if early_exit and _all_outputs_present(context, graph): # We've generated all the outputs the graph needs, terminate conversion. break @@ -124,7 +130,10 @@ def convert_single_node(context, node): ScopeInfo(source=ScopeSource.TORCHSCRIPT_MODULE_NAME, data=scope_name), ] elif context.frontend == TorchFrontend.EXIR: - scopes = [ScopeInfo(source=ScopeSource.EXIR_DEBUG_HANDLE, data=[node.meta["debug_handle"]])] + scopes = [ + ScopeInfo(source=ScopeSource.EXIR_STACK_TRACE, data=[node.meta.get("stack_trace")]), + ScopeInfo(source=ScopeSource.EXIR_DEBUG_HANDLE, data=[node.meta.get("debug_handle")]), + ] else: raise ValueError(f"Invalid PyTorch frontend {context.frontend}") @@ -545,20 +554,28 @@ def norm(context, node): def _vector_norm(x, order, dim, keep_dims, name): + # 0 norm is special if order.val == 0: # sum(x!=0) x = mb.cast(x=x, dtype="fp32") temp = mb.not_equal(x=x, y=0.) temp = mb.cast(x=temp, dtype='int32') temp = mb.reduce_sum(x=temp, axes=dim, keep_dims=keep_dims, name=name) + # infinity norm is special elif order.val > VALUE_CLOSE_TO_INFINITY: # max(abs(x)) temp = mb.abs(x=x) temp = mb.reduce_max(x=temp, axes=dim, keep_dims=keep_dims, name=name) + # -infinity norm is special elif order.val < -VALUE_CLOSE_TO_INFINITY: # min(abs(x)) temp = mb.abs(x=x) temp = mb.reduce_min(x=temp, axes=dim, keep_dims=keep_dims, name=name) + # Although 2 norm can fit in the general formula, + # since it is very common, we have tailored kernel for it + elif order.val == 2: + temp = mb.reduce_l2_norm(x=x, axes=dim, keep_dims=keep_dims, name=name) + # use general formula to compute all other norms else: # sum(abs(x)^{order})^{(1 / order)} temp = mb.abs(x=x) @@ -568,6 +585,7 @@ def _vector_norm(x, order, dim, keep_dims, name): temp = mb.pow(x=temp, y=1.0 / order.val, name=name) return temp + @register_torch_op def _weight_norm(context, node): v, g, dim = _get_inputs(context, node, expected=3) @@ -594,7 +612,6 @@ def _weight_norm(context, node): context.add(result) - def _matrix_norm(x, order, dim, keep_dims, name): if order.val == 1: # min(sum(abs(x), dim=0)) @@ -810,7 +827,7 @@ def gt(context, node): context.add(greater) -@register_torch_op(torch_alias=["t", "numpy_t"]) +@register_torch_op(torch_alias=["t", "numpy_t", "transpose.int"]) def transpose(context, node): assert len(node.outputs) == 1 inputs = _get_inputs(context, node) @@ -1624,7 +1641,7 @@ def mean(context, node): context.add(res) -@register_torch_op(torch_alias=["squeeze_copy.dim", "squeeze_copy.dims"]) +@register_torch_op(torch_alias=["squeeze.dim", "squeeze_copy.dim", "squeeze_copy.dims"]) def squeeze(context, node): inputs = _get_inputs(context, node) if len(inputs) == 1: @@ -3512,7 +3529,7 @@ def _false_path(): context.add(output_var, torch_name=output_name) -@register_torch_op(torch_alias=["select_copy.int"]) +@register_torch_op(torch_alias=["select.int", "select_copy.int"]) def select(context, node): inputs = _get_inputs(context, node, expected=3) _input = inputs[0] @@ -3566,9 +3583,6 @@ def select(context, node): def getitem(context, node): inputs = _get_inputs(context, node, expected=2) - if not isinstance(inputs[0], (list, tuple)): - raise AssertionError("Item selection is supported only on python list/tuple objects") - if inputs[1].val is None: raise AssertionError("Only static item selection supported") @@ -3579,6 +3593,15 @@ def getitem(context, node): f"Index into python list/tuple needs to be integer. Provided value: {inputs[1].val}" ) + if not isinstance(inputs[0], (list, tuple)): + # For single object with index 0, return this object + if index == 0: + context.add(inputs[0], torch_name=node.name) + return + # Otherwise undefined + else: + raise AssertionError("Item selection is supported only on python list/tuple objects") + out = inputs[0][index] if out is None: @@ -3691,17 +3714,73 @@ def _translate_torch_tensor_assign( squeeze_mask, name, ): - return mb.torch_tensor_assign( - x=x, - updates=updates, - begin=begin, - end=end, - stride=stride, - begin_mask=begin_mask, - end_mask=end_mask, - squeeze_mask=squeeze_mask, - name=name, - ) + + def torch_tensor_assign_implementation() -> Var: + return mb.torch_tensor_assign( + x=x, + updates=updates, + begin=begin, + end=end, + stride=stride, + begin_mask=begin_mask, + end_mask=end_mask, + squeeze_mask=squeeze_mask, + name=name, + ) + + if is_current_opset_version_compatible_with(target.iOS18): + # slice_update is not supporting scalar update at runtime. + # Until this radar is fixed: rdar://128221986 ([Feature][Slice_update] The backend is not supporting scalar update for the slice_update op), + # we have a workaround to expand scalar update to a 1-D tensor. + if updates.rank == 0: + # Since the workaround uses the compile-time value of begin and end, + # so we do the validation first. + is_begin_or_end_dynamic = False + for var in [begin, end]: + if isinstance(var, Var) and var.val is None: + is_begin_or_end_dynamic = True + if is_begin_or_end_dynamic or any_symbolic(x.shape): + return torch_tensor_assign_implementation() + + # First pick up the ``dim`` in which ``squeeze_mask[dim] = True``, + # and do the following transformation: + # 1. set ``squeeze_mask[dim] = False`` + # 2. set both ``begin_mask`` and ``end_mask`` to ``False`` + # 3. make ``end = begin + 1`` + dim = None + for i, val in enumerate(squeeze_mask): + if val is True: + dim = i + break + squeeze_mask[dim] = False + begin_mask = [False] * x.rank + end_mask = [False] * x.rank + + if isinstance(begin, Var): + begin = begin.val + if isinstance(end, Var): + end = end.val + + # convert negative indexes to positive indexes + begin = [val if val >= 0 else val + x.shape[i] for i, val in enumerate(begin)] + end = mb.add(x=begin, y=1) + + # expand updates to 1D tensor + updates = mb.expand_dims(x=updates, axes=[0]) + + return mb.slice_update( + x=x, + update=updates, + begin=begin, + end=end, + stride=stride, + begin_mask=begin_mask, + end_mask=end_mask, + squeeze_mask=squeeze_mask, + name=name, + ) + + return torch_tensor_assign_implementation() @register_torch_op @@ -3807,15 +3886,32 @@ def select_scatter(context, node): def slice_scatter(context, node): inputs = _get_inputs(context, node, min_expected=2) x, updates = promote_input_dtypes(inputs[0:2]) + + # sanitize and validate dim dim = 0 if len(inputs) <= 2 else inputs[2].val if dim is None: raise ValueError("Only compile time known dim supported yet") + if dim < 0: + dim = dim + x.rank + assert 0 <= dim and dim < x.rank, f"invalid dim: {dim}" + + # sanitize start start = 0 if len(inputs) <= 3 else inputs[3] - end = x.shape[dim] if len(inputs) <= 4 else mb.minimum(x=inputs[4], y=x.shape[dim]) - step = 1 if len(inputs) <= 5 else inputs[5] + if start is None: + start = 0 + + # sanitize end + if len(inputs) <= 4: + end = x.shape[dim] + else: + end = inputs[4] + if end is not None: + end = mb.minimum(x=inputs[4], y=x.shape[dim]) + else: + end = x.shape[dim] - assert dim is not None, "slice dim must be known at compile time" - assert 0 <= dim and dim < x.rank + # get step given different number of inputs + step = 1 if len(inputs) <= 5 else inputs[5] # mb.torch_tensor_assign handles multi-dim slicing # so we need to pad start, end, step from scalar to x.rank @@ -4160,7 +4256,12 @@ def index(context, node): @register_torch_op def ones(context, node): - inputs = _get_inputs(context, node, expected=[5, 6]) + inputs = _get_inputs( + context, + node, + expected={TorchFrontend.TORCHSCRIPT: [5, 6]}, + min_expected={TorchFrontend.EXIR: 1} + ) size = inputs[0] # dtype = NUM_TO_TORCH_DTYPE[inputs[1].val] unused # layout = inputs[2] unused @@ -4206,11 +4307,16 @@ def full(context, node): size = inputs[0] - dtype = ( - np.float32 - if len(inputs) < 3 or inputs[2] is None - else NUM_TO_NUMPY_DTYPE[TORCH_DTYPE_TO_NUM[inputs[2].val]] - ) + # dtype could be torch.dtype or an integer that maps to a numpy.dtype + dtype = None + if len(inputs) < 3 or inputs[2] is None: + dtype = np.float32 + elif isinstance(inputs[2].val, torch.dtype): + dtype = NUM_TO_NUMPY_DTYPE[TORCH_DTYPE_TO_NUM[inputs[2].val]] + elif isinstance(inputs[2].val, (int, np.generic)): + dtype = NUM_TO_NUMPY_DTYPE[inputs[2].val] + else: + raise ValueError(f"unsupported type {type(inputs[2].val)}.") val = dtype(inputs[1].val) @@ -4595,11 +4701,21 @@ def split(context, node): context.add(res, torch_name=node.name) -@register_torch_op +@register_torch_op(torch_alias=["unbind.int"]) def unbind(context, node): - inputs = _get_inputs(context, node, expected=2) + inputs = _get_inputs( + context, + node, + expected={ + TorchFrontend.TORCHSCRIPT: 2, + TorchFrontend.EXIR: [1, 2], + }, + ) x = inputs[0] - dim = inputs[1].val + if len(inputs) == 1: + dim = 0 + else: + dim = inputs[1].val split_sizes = [1] * x.shape[dim] if len(split_sizes) == 1: res = [mb.squeeze(x=x, axes=[dim])] @@ -4938,10 +5054,15 @@ def argmax(context, node): @register_torch_op(torch_alias=["empty_like"]) def zeros_like(context, node): - inputs = _get_inputs(context, node, expected=6) + inputs = _get_inputs( + context, + node, + expected={TorchFrontend.TORCHSCRIPT: 6}, + min_expected={TorchFrontend.EXIR: 1}, + ) x = inputs[0] shape = mb.shape(x=x) - if inputs[1] and inputs[1].val: + if len(inputs) > 1 and inputs[1] and inputs[1].val: dtype = inputs[1].val np_type = NUM_TO_NUMPY_DTYPE[dtype] else: @@ -5401,7 +5522,10 @@ def triu(context, node): inputs = _get_inputs(context, node, expected=2) x = inputs[0] diagonal = inputs[1] - diagonal = 0 if diagonal is None else diagonal.val + if diagonal is not None and diagonal.val is not None: + diagonal = diagonal.val + else: + diagonal = 0 if diagonal <= 0: res = mb.band_part(x=x, lower=-diagonal, upper=-1, name=node.name) else: @@ -5415,7 +5539,10 @@ def tril(context, node): inputs = _get_inputs(context, node, expected=2) x = inputs[0] diagonal = inputs[1] - diagonal = 0 if diagonal is None else diagonal.val + if diagonal is not None and diagonal.val is not None: + diagonal = diagonal.val + else: + diagonal = 0 if diagonal >= 0: res = mb.band_part(x=x, lower=-1, upper=diagonal, name=node.name) else: @@ -6034,22 +6161,21 @@ def replication_pad2d(context, node): pad = _np.pad(pad_flipped, (len(x.shape) * 2 - len(pad_flipped), 0)) context.add(mb.pad(x=x, pad=pad, mode='replicate'), node.name) +def _solve_broadcast_shape(shapes: List[List[int]]) -> List[np.ndarray]: + rank = _np.max([len(shape) for shape in shapes]) + shapes = [[1] * (rank - len(shape)) + shape for shape in shapes] + result_shape = [] + for i in range(rank): + dims = [shapes[j][i] for j in range(len(shapes))] + if any_symbolic(dims): + # rdar://85559497 (Handle dynamic shapes inputs broadcast for pytorch) + raise NotImplementedError( + "Only static shaped inputs are supported for torch.broadcast_tensors conversion." + ) + result_shape.append(_np.max(dims)) + return result_shape def _broadcast_tensors(tensors): - def _solve_broadcast_shape(shapes): - rank = _np.max([len(shape) for shape in shapes]) - shapes = [[1] * (rank - len(shape)) + shape for shape in shapes] - result_shape = [] - for i in range(rank): - dims = [shapes[j][i] for j in range(len(tensors))] - if any_symbolic(dims): - # rdar://85559497 (Handle dynamic shapes inputs broadcast for pytorch) - raise NotImplementedError( - "Only static shaped inputs are supported for torch.broadcast_tensors conversion." - ) - result_shape.append(_np.max(dims)) - return result_shape - if len(tensors) == 1: return tensors @@ -6729,7 +6855,7 @@ def _cast_bool_attn_mask(attn_mask: Var, query_var: Var) -> Var: ) return mb.mul(x=-3e4, y=compliment_of_mask) -@register_torch_op +@register_torch_op(torch_alias=["_scaled_dot_product_flash_attention_for_cpu"]) def scaled_dot_product_attention(context, node): """ Input shapes/types: @@ -6750,6 +6876,14 @@ def scaled_dot_product_attention(context, node): See details at: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html """ + + def _get_batch_dims(x: Var) -> List[int]: + return list(x.shape)[:-2] + + def _broadcast_tensor_to_same_batch_dims(x: Var, batch_dims: List[int]) -> Var: + broadcast_shape = batch_dims + list(x.shape[-2:]) + return _broadcast(x.name + "_broadcast_same_batch_dims", x, broadcast_shape) + inputs = _get_inputs(context, node, min_expected=3) q, k, v = inputs[:3] attn_mask = None if len(inputs) < 4 else inputs[3] @@ -6767,8 +6901,20 @@ def scaled_dot_product_attention(context, node): "scaled_dot_product_attention op: attn_mask cannot be provided when is_causal is set to True." ) - if dropout is not None and (dropout.val is None or dropout.val != 0.0): - raise ValueError("scaled_dot_product_attention op: dropout is not supported yet") + if dropout is not None: + if isinstance(dropout, Var): + if dropout.val is None: + raise NotImplementedError( + "A variable dropout probability is specified. Since Core ML " + "does not support dropout yet, we cowardly refuse to convert it" + ) + else: + dropout = dropout.val + if dropout != 0.0: + raise ValueError( + "A non-zero dropout probability is specified. Since Core ML " + "does not support dropout yet, we cannot convert it" + ) # check that ranks of q, k, v and attn_mask match if k.rank != q.rank: @@ -6784,12 +6930,46 @@ def scaled_dot_product_attention(context, node): if is_causal: mask = _get_causal_attn_mask(is_causal, q, k) elif attn_mask is not None: - if is_bool(attn_mask.dtype): + # For ios18-, bool attention mask has to be cast to equivalent floating point attention mask + if is_bool(attn_mask.dtype) and not is_current_opset_version_compatible_with(target.iOS18): mask = _cast_bool_attn_mask(attn_mask, q) else: mask = attn_mask - res = _utils._lower_scaled_dot_product_attention(q, k, v, mask, node.name) + # Since ios18, Core ML supports scaled_dot_product_attention op + if is_current_opset_version_compatible_with(target.iOS18): + # ios18 scaled_dot_product_attention only supports rank >= 3 + is_rank_2 = q.rank == 2 + + if is_rank_2: + q = mb.expand_dims(x=q, axes=[0]) + k = mb.expand_dims(x=k, axes=[0]) + v = mb.expand_dims(x=v, axes=[0]) + + # broadcast the batch_dims to the same shape + # note that, we only support the broadcast if the batch_dim is static + q_batch = _get_batch_dims(q) + k_batch = _get_batch_dims(k) + v_batch = _get_batch_dims(v) + + if not any_symbolic(q_batch + k_batch + v_batch): + b_dims = _solve_broadcast_shape([q_batch, k_batch, v_batch]) + q = _broadcast_tensor_to_same_batch_dims(q, b_dims) + k = _broadcast_tensor_to_same_batch_dims(k, b_dims) + v = _broadcast_tensor_to_same_batch_dims(v, b_dims) + + # directly translated into iOS18 sdpa op + res = mb.scaled_dot_product_attention( + query=q, key=k, value=v, attn_mask=mask, name=node.name + ) + + if is_rank_2: + res = mb.squeeze(x=res, axes=[0], name=node.name) + + # For ios18-, scaled_dot_product_attention has to be decomposed + else: + res = _utils._decompose_scaled_dot_product_attention(q, k, v, mask, node.name) + context.add(res) @@ -6816,5 +6996,6 @@ def multinomial(context, node): raise ValueError("In torch.multinomial op, num_samples must be const") if num_samples > 1 and not replacement: raise ValueError("When num_samples is larger than 1, only replacement=True is supported.") - x = mb.random_categorical(x=x, size=num_samples, name=node.name) + # Based on PyTorch documentations, the input to `torch.multinomial` is probability, not logit. + x = mb.random_categorical(x=x, size=num_samples, mode="probs", name=node.name) context.add(x) diff --git a/coremltools/converters/mil/frontend/torch/quantization_ops.py b/coremltools/converters/mil/frontend/torch/quantization_ops.py index 236f428e4..bf082d041 100644 --- a/coremltools/converters/mil/frontend/torch/quantization_ops.py +++ b/coremltools/converters/mil/frontend/torch/quantization_ops.py @@ -197,7 +197,7 @@ def _dequantized_weight(qweight, name: str = None): scale = _np.float32(qweight.q_scale()) zero_point = quant_dtype_np(qweight.q_zero_point()) quantized_weights = _torch.int_repr(qweight).numpy() - dequant_weights = _utils._construct_constexpr_affine_op( + dequant_weights = _utils._construct_constexpr_dequant_op( quantized_weights, zero_point, scale, axis=None, name=name ) # per_channel_affine_float_qparams is same as per_channel_affine except that it @@ -223,7 +223,7 @@ def _dequantized_weight(qweight, name: str = None): zero_point = quant_dtype_np(val) quantized_weights = _torch.int_repr(qweight).numpy() axis = _np.int32(qweight.q_per_channel_axis()) - dequant_weights = _utils._construct_constexpr_affine_op( + dequant_weights = _utils._construct_constexpr_dequant_op( quantized_weights, zero_point, scale, axis=axis, name=name ) else: diff --git a/coremltools/converters/mil/frontend/torch/test/test_executorch_e2e.py b/coremltools/converters/mil/frontend/torch/test/test_executorch_e2e.py index 143876948..13ecd301d 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_executorch_e2e.py +++ b/coremltools/converters/mil/frontend/torch/test/test_executorch_e2e.py @@ -6,17 +6,17 @@ import itertools import pytest -from coremltools._deps import _HAS_EXECUTORCH, _HAS_TORCH_VISION +from coremltools._deps import _HAS_EXECUTORCH -if not (_HAS_EXECUTORCH and _HAS_TORCH_VISION): - pytest.skip(allow_module_level=True, reason="executorch and torchvision are required") +if not _HAS_EXECUTORCH: + pytest.skip(allow_module_level=True, reason="executorch is required") -import torch -import torchvision -import torchaudio -import torchsr +torch = pytest.importorskip("torch") +torchvision = pytest.importorskip("torchvision") +torchaudio = pytest.importorskip("torchaudio") +torchsr = pytest.importorskip("torchsr") +timm = pytest.importorskip("timm") -import timm import transformers from coremltools.converters.mil import testing_reqs @@ -29,8 +29,11 @@ class TestExecutorchExampleModels(TorchBaseTest): - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_mul(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_mul(self, compute_unit, backend, use_edge_dialect): class MulModule(torch.nn.Module): def forward(self, input, other): return input * other @@ -41,50 +44,58 @@ def forward(self, input, other): compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) mil_program = coreml_model._mil_program mul = mil_program.functions["main"].find_ops(op_type="mul")[0] - debug_handle = mul.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] - assert isinstance(debug_handle, int) + stack_trace = mul.scopes[ScopeSource.EXIR_STACK_TRACE][0] + assert stack_trace.split("\n")[-2].strip() == "return input * other" - debug_handle_to_ops_mapping = mil_program.construct_debug_handle_to_ops_mapping() - assert debug_handle_to_ops_mapping.keys() == {debug_handle} + if use_edge_dialect: + debug_handle = mul.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] + assert isinstance(debug_handle, int) + + debug_handle_to_ops_mapping = mil_program.construct_debug_handle_to_ops_mapping() + assert debug_handle_to_ops_mapping.keys() == {debug_handle} - ops = debug_handle_to_ops_mapping[debug_handle] - index_mul = 0 - indices_const = () - indices_cast = () - if backend[1] == "fp32": - assert len(ops) == 1 + ops = debug_handle_to_ops_mapping[debug_handle] index_mul = 0 - else: - # fp16 introduces additional io casts - # each cast introduces 1 const to store destination dtype - assert len(ops) == 7 - index_mul = 4 - indices_const = (0, 1, 5) - indices_cast = (2, 3, 6) - assert ops[index_mul] == [ - {"Type": "Program"}, - {"Type": "Function", "Name": "main"}, - {"Type": "Block"}, - {"Type": "Operation", "Operator": "mul", "Output": mul.outputs[0].name}, - ] - for index_const_cast in indices_const + indices_cast: - assert ops[index_const_cast][:-1] == [ + indices_const = () + indices_cast = () + if backend[1] == "fp32": + assert len(ops) == 1 + index_mul = 0 + else: + # fp16 introduces additional io casts + # each cast introduces 1 const to store destination dtype + assert len(ops) == 7 + index_mul = 4 + indices_const = (0, 1, 5) + indices_cast = (2, 3, 6) + assert ops[index_mul] == [ {"Type": "Program"}, {"Type": "Function", "Name": "main"}, {"Type": "Block"}, + {"Type": "Operation", "Operator": "mul", "Output": mul.outputs[0].name}, ] - for index_const in indices_const: - assert ops[index_const][-1]["Operator"] == "const" - for index_cast in indices_cast: - assert ops[index_cast][-1]["Operator"] == "cast" + for index_const_cast in indices_const + indices_cast: + assert ops[index_const_cast][:-1] == [ + {"Type": "Program"}, + {"Type": "Function", "Name": "main"}, + {"Type": "Block"}, + ] + for index_const in indices_const: + assert ops[index_const][-1]["Operator"] == "const" + for index_cast in indices_cast: + assert ops[index_cast][-1]["Operator"] == "cast" - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_linear(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_linear(self, compute_unit, backend, use_edge_dialect): class LinearModule(torch.nn.Module): def __init__(self): super().__init__() @@ -99,51 +110,59 @@ def forward(self, arg): compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) mil_program = coreml_model._mil_program linear = mil_program.functions["main"].find_ops(op_type="linear")[0] - debug_handle = linear.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] - assert isinstance(debug_handle, int) - - debug_handle_to_ops_mapping = mil_program.construct_debug_handle_to_ops_mapping() - assert debug_handle_to_ops_mapping.keys() == {debug_handle} - - ops = debug_handle_to_ops_mapping[debug_handle] - index_linear = 0 - indices_const = () - indices_cast = () - if backend[1] == "fp32": - assert len(ops) == 3 - index_linear = 2 - indices_const = (0, 1) - else: - # fp16 introduces additional io casts - # each cast introduces 1 const to store destination dtype - assert len(ops) == 7 - index_linear = 4 - indices_const = (0, 1, 2, 5) - indices_cast = (3, 6) - assert ops[index_linear] == [ - {"Type": "Program"}, - {"Type": "Function", "Name": "main"}, - {"Type": "Block"}, - {"Type": "Operation", "Operator": "linear", "Output": linear.outputs[0].name}, - ] - for index_const_cast in indices_const + indices_cast: - assert ops[index_const_cast][:-1] == [ + stack_trace = linear.scopes[ScopeSource.EXIR_STACK_TRACE][0] + assert stack_trace.split("\n")[-2].strip() == "return self.linear(arg)" + + if use_edge_dialect: + debug_handle = linear.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] + assert isinstance(debug_handle, int) + + debug_handle_to_ops_mapping = mil_program.construct_debug_handle_to_ops_mapping() + assert debug_handle_to_ops_mapping.keys() == {debug_handle} + + ops = debug_handle_to_ops_mapping[debug_handle] + index_linear = 0 + indices_const = () + indices_cast = () + if backend[1] == "fp32": + assert len(ops) == 3 + index_linear = 2 + indices_const = (0, 1) + else: + # fp16 introduces additional io casts + # each cast introduces 1 const to store destination dtype + assert len(ops) == 7 + index_linear = 4 + indices_const = (0, 1, 2, 5) + indices_cast = (3, 6) + assert ops[index_linear] == [ {"Type": "Program"}, {"Type": "Function", "Name": "main"}, {"Type": "Block"}, + {"Type": "Operation", "Operator": "linear", "Output": linear.outputs[0].name}, ] - for index_const in indices_const: - assert ops[index_const][-1]["Operator"] == "const" - for index_cast in indices_cast: - assert ops[index_cast][-1]["Operator"] == "cast" + for index_const_cast in indices_const + indices_cast: + assert ops[index_const_cast][:-1] == [ + {"Type": "Program"}, + {"Type": "Function", "Name": "main"}, + {"Type": "Block"}, + ] + for index_const in indices_const: + assert ops[index_const][-1]["Operator"] == "const" + for index_cast in indices_cast: + assert ops[index_cast][-1]["Operator"] == "cast" - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_add(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_add(self, compute_unit, backend, use_edge_dialect): class AddModule(torch.nn.Module): def forward(self, x, y): z = x + y @@ -158,59 +177,74 @@ def forward(self, x, y): compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) mil_program = coreml_model._mil_program adds = mil_program.functions["main"].find_ops(op_type="add") - debug_handles = [add.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] for add in adds] - for debug_handle in debug_handles: - assert isinstance(debug_handle, int) + stack_traces = [add.scopes[ScopeSource.EXIR_STACK_TRACE][0] for add in adds] + source_codes = [ + "z = x + y", + "z = z + x", + "z = z + x", + "z = z + z", + ] + for i, stack_trace in enumerate(stack_traces): + assert stack_trace.split("\n")[-2].strip() == source_codes[i] - debug_handle_to_ops_mapping = mil_program.construct_debug_handle_to_ops_mapping() - assert debug_handle_to_ops_mapping.keys() == set(debug_handles) + if use_edge_dialect: + debug_handles = [add.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] for add in adds] + for debug_handle in debug_handles: + assert isinstance(debug_handle, int) - for add_index, debug_handle in enumerate(debug_handles): - add = adds[add_index] - ops = debug_handle_to_ops_mapping[debug_handle] - index_add = 0 - indices_const = () - indices_cast = () - if backend[1] == "fp32": - assert len(ops) == 1 + debug_handle_to_ops_mapping = mil_program.construct_debug_handle_to_ops_mapping() + assert debug_handle_to_ops_mapping.keys() == set(debug_handles) + + for add_index, debug_handle in enumerate(debug_handles): + add = adds[add_index] + ops = debug_handle_to_ops_mapping[debug_handle] index_add = 0 - else: - # fp16 introduces additional io casts - # each cast introduces 1 const to store destination dtype - ADD_INDEX_TO_NUM_OPS = {0: 5, 1: 1, 2: 1, 3: 3} - ADD_INDEX_TO_OP_INDEX = {0: -1, 1: 0, 2: 0, 3: 0} - assert len(ops) == ADD_INDEX_TO_NUM_OPS[add_index] - index_add = ADD_INDEX_TO_OP_INDEX[add_index] - if add_index == 0: - indices_const = (0, 1) - indices_cast = (2, 3) - elif add_index == 3: - indices_const = (1,) - indices_cast = (2,) - assert ops[index_add] == [ - {"Type": "Program"}, - {"Type": "Function", "Name": "main"}, - {"Type": "Block"}, - {"Type": "Operation", "Operator": "add", "Output": add.outputs[0].name}, - ] - for index_const_cast in indices_const + indices_cast: - assert ops[index_const_cast][:-1] == [ + indices_const = () + indices_cast = () + if backend[1] == "fp32": + assert len(ops) == 1 + index_add = 0 + else: + # fp16 introduces additional io casts + # each cast introduces 1 const to store destination dtype + ADD_INDEX_TO_NUM_OPS = {0: 5, 1: 1, 2: 1, 3: 3} + ADD_INDEX_TO_OP_INDEX = {0: -1, 1: 0, 2: 0, 3: 0} + assert len(ops) == ADD_INDEX_TO_NUM_OPS[add_index] + index_add = ADD_INDEX_TO_OP_INDEX[add_index] + if add_index == 0: + indices_const = (0, 1) + indices_cast = (2, 3) + elif add_index == 3: + indices_const = (1,) + indices_cast = (2,) + assert ops[index_add] == [ {"Type": "Program"}, {"Type": "Function", "Name": "main"}, {"Type": "Block"}, + {"Type": "Operation", "Operator": "add", "Output": add.outputs[0].name}, ] - for index_const in indices_const: - assert ops[index_const][-1]["Operator"] == "const" - for index_cast in indices_cast: - assert ops[index_cast][-1]["Operator"] == "cast" - - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_add_mul(self, compute_unit, backend): + for index_const_cast in indices_const + indices_cast: + assert ops[index_const_cast][:-1] == [ + {"Type": "Program"}, + {"Type": "Function", "Name": "main"}, + {"Type": "Block"}, + ] + for index_const in indices_const: + assert ops[index_const][-1]["Operator"] == "const" + for index_cast in indices_cast: + assert ops[index_cast][-1]["Operator"] == "cast" + + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_add_mul(self, compute_unit, backend, use_edge_dialect): class AddMulModule(torch.nn.Module): def forward(self, a, x, b): y = torch.mm(a, x) @@ -223,6 +257,7 @@ def forward(self, a, x, b): compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) mil_program = coreml_model._mil_program @@ -230,56 +265,72 @@ def forward(self, a, x, b): for op_type in ("matmul", "add"): matmul_or_add[op_type] = mil_program.functions["main"].find_ops(op_type=op_type)[0] - debug_handle = { - k: v.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] for k, v in matmul_or_add.items() + stack_traces = { + k: v.scopes[ScopeSource.EXIR_STACK_TRACE][0] for k, v in matmul_or_add.items() + } + source_codes = { + "matmul": "y = torch.mm(a, x)", + "add": "z = torch.add(y, b)", } - for v in debug_handle.values(): - assert isinstance(v, int) - - debug_handle_to_ops_mapping = mil_program.construct_debug_handle_to_ops_mapping() - assert debug_handle_to_ops_mapping.keys() == set(debug_handle.values()) - - ops = {} - for op_type in ("matmul", "add"): - ops[op_type] = debug_handle_to_ops_mapping[debug_handle[op_type]] - index = {"matmul": 0, "add": 0} - indices_const = {"matmul": (), "add": ()} - indices_cast = {"matmul": (), "add": ()} - if backend[1] == "fp32": - assert len(ops["matmul"]) == 3 and len(ops["add"]) == 1 - index = {"matmul": 2, "add": 0} - indices_const["matmul"] = (0, 1) - else: - # fp16 introduces additional io casts - # each cast introduces 1 const to store destination dtype - assert len(ops["matmul"]) == 7 and len(ops["add"]) == 5 - index = {"matmul": 6, "add": 2} - indices_const = {"matmul": (0, 1, 2, 3), "add": (0, 3)} - indices_cast = {"matmul": (4, 5), "add": (1, 4)} for op_type in ("matmul", "add"): - assert ops[op_type][index[op_type]] == [ - {"Type": "Program"}, - {"Type": "Function", "Name": "main"}, - {"Type": "Block"}, - { - "Type": "Operation", - "Operator": op_type, - "Output": matmul_or_add[op_type].outputs[0].name, - }, - ] - for index_const_cast in indices_const[op_type] + indices_cast[op_type]: - assert ops[op_type][index_const_cast][:-1] == [ + stack_trace = stack_traces[op_type] + source_code = source_codes[op_type] + assert stack_trace.split("\n")[-2].strip() == source_code + + if use_edge_dialect: + debug_handle = { + k: v.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] for k, v in matmul_or_add.items() + } + for v in debug_handle.values(): + assert isinstance(v, int) + + debug_handle_to_ops_mapping = mil_program.construct_debug_handle_to_ops_mapping() + assert debug_handle_to_ops_mapping.keys() == set(debug_handle.values()) + + ops = {} + for op_type in ("matmul", "add"): + ops[op_type] = debug_handle_to_ops_mapping[debug_handle[op_type]] + index = {"matmul": 0, "add": 0} + indices_const = {"matmul": (), "add": ()} + indices_cast = {"matmul": (), "add": ()} + if backend[1] == "fp32": + assert len(ops["matmul"]) == 3 and len(ops["add"]) == 1 + index = {"matmul": 2, "add": 0} + indices_const["matmul"] = (0, 1) + else: + # fp16 introduces additional io casts + # each cast introduces 1 const to store destination dtype + assert len(ops["matmul"]) == 7 and len(ops["add"]) == 5 + index = {"matmul": 6, "add": 2} + indices_const = {"matmul": (0, 1, 2, 3), "add": (0, 3)} + indices_cast = {"matmul": (4, 5), "add": (1, 4)} + for op_type in ("matmul", "add"): + assert ops[op_type][index[op_type]] == [ {"Type": "Program"}, {"Type": "Function", "Name": "main"}, {"Type": "Block"}, + { + "Type": "Operation", + "Operator": op_type, + "Output": matmul_or_add[op_type].outputs[0].name, + }, ] - for index_const in indices_const[op_type]: - assert ops[op_type][index_const][-1]["Operator"] == "const" - for index_cast in indices_cast[op_type]: - assert ops[op_type][index_cast][-1]["Operator"] == "cast" - - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_softmax(self, compute_unit, backend): + for index_const_cast in indices_const[op_type] + indices_cast[op_type]: + assert ops[op_type][index_const_cast][:-1] == [ + {"Type": "Program"}, + {"Type": "Function", "Name": "main"}, + {"Type": "Block"}, + ] + for index_const in indices_const[op_type]: + assert ops[op_type][index_const][-1]["Operator"] == "const" + for index_cast in indices_cast[op_type]: + assert ops[op_type][index_cast][-1]["Operator"] == "cast" + + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_softmax(self, compute_unit, backend, use_edge_dialect): class SoftmaxModule(torch.nn.Module): def __init__(self): super().__init__() @@ -294,74 +345,101 @@ def forward(self, x): compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) mil_program = coreml_model._mil_program softmax = mil_program.functions["main"].find_ops(op_type="softmax")[0] - debug_handle = softmax.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] - assert isinstance(debug_handle, int) - - debug_handle_to_ops_mapping = mil_program.construct_debug_handle_to_ops_mapping() - assert debug_handle_to_ops_mapping.keys() == {debug_handle} - - ops = debug_handle_to_ops_mapping[debug_handle] - index_softmax = 0 - indices_const = () - indices_cast = () - if backend[1] == "fp32": - assert len(ops) == 2 - index_softmax = 1 - indices_const = (0,) - else: - # fp16 introduces additional io casts - # each cast introduces 1 const to store destination dtype - assert len(ops) == 6 - index_softmax = 3 - indices_const = (0, 1, 4) - indices_cast = (2, 5) - assert ops[index_softmax] == [ - {"Type": "Program"}, - {"Type": "Function", "Name": "main"}, - {"Type": "Block"}, - {"Type": "Operation", "Operator": "softmax", "Output": softmax.outputs[0].name}, - ] - for index_const_cast in indices_const + indices_cast: - assert ops[index_const_cast][:-1] == [ + stack_trace = softmax.scopes[ScopeSource.EXIR_STACK_TRACE][0] + assert stack_trace.split("\n")[-2].strip() == "return self.softmax(x)" + + if use_edge_dialect: + debug_handle = softmax.scopes[ScopeSource.EXIR_DEBUG_HANDLE][0] + assert isinstance(debug_handle, int) + + debug_handle_to_ops_mapping = mil_program.construct_debug_handle_to_ops_mapping() + assert debug_handle_to_ops_mapping.keys() == {debug_handle} + + ops = debug_handle_to_ops_mapping[debug_handle] + index_softmax = 0 + indices_const = () + indices_cast = () + if backend[1] == "fp32": + assert len(ops) == 2 + index_softmax = 1 + indices_const = (0,) + else: + # fp16 introduces additional io casts + # each cast introduces 1 const to store destination dtype + assert len(ops) == 6 + index_softmax = 3 + indices_const = (0, 1, 4) + indices_cast = (2, 5) + assert ops[index_softmax] == [ {"Type": "Program"}, {"Type": "Function", "Name": "main"}, {"Type": "Block"}, + {"Type": "Operation", "Operator": "softmax", "Output": softmax.outputs[0].name}, ] - for index_const in indices_const: - assert ops[index_const][-1]["Operator"] == "const" - for index_cast in indices_cast: - assert ops[index_cast][-1]["Operator"] == "cast" + for index_const_cast in indices_const + indices_cast: + assert ops[index_const_cast][:-1] == [ + {"Type": "Program"}, + {"Type": "Function", "Name": "main"}, + {"Type": "Block"}, + ] + for index_const in indices_const: + assert ops[index_const][-1]["Operator"] == "const" + for index_cast in indices_cast: + assert ops[index_cast][-1]["Operator"] == "cast" @pytest.mark.xfail(reason="numerical error") - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_deeplab_v3(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_deeplab_v3(self, compute_unit, backend, use_edge_dialect): + try: + torch_model = torchvision.models.segmentation.deeplabv3_resnet50( + weights=torchvision.models.segmentation.deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT + ) + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(1, 3, 224, 224)], - torchvision.models.segmentation.deeplabv3_resnet50( - weights=torchvision.models.segmentation.deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT - ), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_edsr(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_edsr(self, compute_unit, backend, use_edge_dialect): + try: + torch_model = torchsr.models.edsr_r16f64(2, True) + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(1, 3, 224, 224)], - torchsr.models.edsr_r16f64(2, True), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_emformer_transcribe(self, compute_unit, backend): + @pytest.mark.xfail(reason="rdar://125514139 emformer transcribe is too huge for Lightning") + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_emformer_transcribe(self, compute_unit, backend, use_edge_dialect): class EmformerRnntTranscriberExample(torch.nn.Module): """ This is a wrapper for validating transcriber for the Emformer RNN-T architecture. @@ -377,19 +455,25 @@ def __init__(self) -> None: def forward(self, sources, source_lengths): return self.rnnt.transcribe(sources, source_lengths) - if backend[0] == "neuralnetwork": - pytest.xfail("rdar://125514139 emformer transcribe fails on neuralnetwork") + try: + torch_model = EmformerRnntTranscriberExample() + except: + pytest.xfail("Torch model download may fail due to network fluctuation") self.run_compare_torch( [(1, 128, 80), (128,)], - EmformerRnntTranscriberExample(), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_emformer_predict(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_emformer_predict(self, compute_unit, backend, use_edge_dialect): class EmformerRnntPredictorExample(torch.nn.Module): """ This is a wrapper for validating predictor for the Emformer RNN-T architecture. @@ -405,18 +489,30 @@ def __init__(self) -> None: def forward(self, targets, target_lengths): return self.rnnt.predict(targets, target_lengths, None) + if backend[0] == "neuralnetwork": + pytest.xfail("rdar://125514139 emformer predict is too huge on neuralnetwork") + + try: + torch_model = EmformerRnntPredictorExample() + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [torch.zeros([1, 128], dtype=int), torch.tensor([128], dtype=int)], - EmformerRnntPredictorExample(), + torch_model, input_as_shape=False, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) @pytest.mark.xfail(reason="numerical error") - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_emformer_join(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_emformer_join(self, compute_unit, backend, use_edge_dialect): class EmformerRnntJoinerExample(torch.nn.Module): """ This is a wrapper for validating joiner for the Emformer RNN-T architecture. @@ -432,112 +528,202 @@ def __init__(self) -> None: def forward(self, source_encodings, source_lengths, target_encodings, target_lengths): return self.rnnt.join(source_encodings, source_lengths, target_encodings, target_lengths) + try: + torch_model = EmformerRnntJoinerExample() + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(1, 128, 1024), (128,), (1, 128, 1024), (128,)], - EmformerRnntJoinerExample(), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_mobilebert(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_mobilebert(self, compute_unit, backend, use_edge_dialect): if backend[1] == "fp16": pytest.skip("Mobile Bert overflows fp16") - tokenizer = transformers.AutoTokenizer.from_pretrained("google/mobilebert-uncased") - token = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"] + try: + tokenizer = transformers.AutoTokenizer.from_pretrained("google/mobilebert-uncased") + token = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"] + torch_model = transformers.MobileBertModel.from_pretrained( + "google/mobilebert-uncased", return_dict=False + ) + except: + pytest.xfail("Torch model download may fail due to network fluctuation") self.run_compare_torch( token, - transformers.MobileBertModel.from_pretrained( - "google/mobilebert-uncased", return_dict=False - ), + torch_model, input_as_shape=False, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, rtol=0.005, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_mobilenet_v2(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_mobilenet_v2(self, compute_unit, backend, use_edge_dialect): + try: + torch_model = torchvision.models.mobilenet_v2( + weights=torchvision.models.mobilenetv2.MobileNet_V2_Weights.DEFAULT + ) + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(1, 3, 224, 224)], - torchvision.models.mobilenet_v2( - weights=torchvision.models.mobilenetv2.MobileNet_V2_Weights.DEFAULT - ), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_mobilenet_v3(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_mobilenet_v3(self, compute_unit, backend, use_edge_dialect): + try: + torch_model = torchvision.models.mobilenet_v3_small(pretrained=True) + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(1, 3, 224, 224)], - torchvision.models.mobilenet_v3_small(pretrained=True), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_vit(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_vit(self, compute_unit, backend, use_edge_dialect): + try: + torch_model = torchvision.models.vit_b_16(weights="IMAGENET1K_V1") + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(1, 3, 224, 224)], - torchvision.models.vit_b_16(weights="IMAGENET1K_V1"), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_wav2letter(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_wav2letter(self, compute_unit, backend, use_edge_dialect): + try: + torch_model = torchaudio.models.Wav2Letter(num_classes=4096) + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(10, 1, 700)], - torchaudio.models.Wav2Letter(num_classes=4096), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_inception_v3(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_inception_v3(self, compute_unit, backend, use_edge_dialect): + try: + torch_model = torchvision.models.inception_v3(weights="IMAGENET1K_V1") + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(1, 3, 224, 224)], - torchvision.models.inception_v3(weights="IMAGENET1K_V1"), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_inception_v4(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_inception_v4(self, compute_unit, backend, use_edge_dialect): + try: + torch_model = timm.models.inception_v4(pretrained=True) + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(1, 3, 299, 299)], - timm.models.inception_v4(pretrained=True), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_resnet18(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_resnet18(self, compute_unit, backend, use_edge_dialect): + try: + torch_model = torchvision.models.resnet18( + weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1 + ) + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(1, 3, 224, 224)], - torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_resnet50(self, compute_unit, backend): + @pytest.mark.parametrize( + "compute_unit, backend, use_edge_dialect,", + itertools.product(compute_units, backends, (True, False)), + ) + def test_resnet50(self, compute_unit, backend, use_edge_dialect): + try: + torch_model = torchvision.models.resnet50( + weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1 + ) + except: + pytest.xfail("Torch model download may fail due to network fluctuation") + self.run_compare_torch( [(1, 3, 224, 224)], - torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1), + torch_model, compute_unit=compute_unit, backend=backend, frontend=TorchFrontend.EXIR, + use_edge_dialect=use_edge_dialect, ) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py b/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py index 77a45f0d0..6ae95f11b 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py @@ -454,7 +454,7 @@ def forward(self, x): assert isinstance(model, ct.converters.mil.Program) @staticmethod - def test_torch_classifier(): + def _get_classifier_model(): class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() @@ -475,23 +475,36 @@ def forward(self, img): # convert + flatten traced_model = torch.jit.trace(model, example_input) traced_model.eval() + return traced_model, example_input + + @staticmethod + def _convert_classifier_model(traced_model, example_input, class_type, backend="mlprogram"): + label = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + if class_type == "str": + label = list(map(lambda x: str(x), label)) + classifier_config = ct.ClassifierConfig(label) + return ct.convert( + traced_model, + source="pytorch", + convert_to=backend, + inputs=[ + ct.TensorType( + name="input", + shape=example_input.shape, + dtype=example_input.numpy().dtype, + ) + ], + classifier_config=classifier_config, + ) + + @staticmethod + def test_torch_classifier(): def _test_classifier(traced_model, example_input, class_type, backend): - label = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - if class_type == "str": - label = list(map(lambda x: str(x), label)) - classifier_config = ct.ClassifierConfig(label) - mlmodel = ct.convert( + mlmodel = TestPyTorchConverterExamples._convert_classifier_model( traced_model, - source='pytorch', - convert_to=backend, - inputs=[ - ct.TensorType( - name="input", - shape=example_input.shape, - dtype=example_input.numpy().dtype, - ) - ], - classifier_config=classifier_config + example_input, + class_type, + backend, ) if ct.utils._is_macos(): coreml_out = mlmodel.predict({"input": example_input.detach().numpy()}) @@ -500,6 +513,7 @@ def _test_classifier(traced_model, example_input, class_type, backend): assert isinstance(coreml_out["classLabel"], key_type) for class_type in ("str", "int"): + traced_model, example_input = TestPyTorchConverterExamples._get_classifier_model() _test_classifier(traced_model, example_input, class_type, "neuralnetwork") if ct.utils._macos_version() >= (12, 0): _test_classifier(traced_model, example_input, class_type, "mlprogram") diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index ecb656107..d8d453dbb 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -1570,20 +1570,6 @@ def test_convolution2d( bias, groups=1, ): - if ( - backend == ('neuralnetwork', 'fp32') and - padding == 1 and - stride == 2 and - height == 7 and - width == 5 and - in_channels == 3 and - out_channels == 3 and - kernel_size == 2 and - dilation == 3 and - not bias - ): - pytest.xfail("rdar://121954894: Conv2d starts to fail") - if padding == "same" and stride != 1: return model = nn.Conv2d( @@ -7703,13 +7689,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TestTensorAssign(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend", + "compute_unit, backend, minimum_deployment_target", itertools.product( compute_units, backends, + [None, ct.target.iOS18], ), ) - def test_tensor_assign_case_1(self, compute_unit, backend): + def test_tensor_assign_scalar(self, compute_unit, backend, minimum_deployment_target): # single dimension assignment for a 1D tensor class TensorAssignModel(torch.nn.Module): def forward(self, x): @@ -7721,16 +7708,24 @@ def forward(self, x): shape = (5,) model = TensorAssignModel() - self.run_compare_torch(shape, model, backend=backend, compute_unit=compute_unit) + self.run_compare_torch( + shape, + model, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, + ) @pytest.mark.parametrize( - "compute_unit, backend", - itertools.product( - compute_units, - backends, - ), + "compute_unit, backend, minimum_deployment_target", + itertools.product(compute_units, backends, [None, ct.target.iOS18]), ) - def test_tensor_assign_case_2(self, compute_unit, backend): + def test_tensor_assign_case_scalar_case_2( + self, compute_unit, backend, minimum_deployment_target + ): + """ + A little bit more complicated scalar tensor assignment test. + """ # single dimension assignment for two 1D tensors class TensorAssignModel(torch.nn.Module): def forward(self, x, y): @@ -7746,11 +7741,15 @@ def forward(self, x, y): shape = (5,) model = TensorAssignModel() self.run_compare_torch( - [shape, shape], model, backend=backend, compute_unit=compute_unit + [shape, shape], + model, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, ) @pytest.mark.parametrize( - "compute_unit, backend, shape", + "compute_unit, backend, shape, minimum_deployment_target", itertools.product( compute_units, backends, @@ -7758,10 +7757,18 @@ def forward(self, x, y): (5, 4), (5, 4, 3), ], + [None, ct.target.iOS18], ), ) - def test_tensor_assign_case_3(self, compute_unit, backend, shape): + def test_tensor_assign_case_broadcast( + self, compute_unit, backend, shape, minimum_deployment_target + ): # broadcast assignment for two n-D tensors + if compute_unit != ct.ComputeUnit.CPU_ONLY: + pytest.xfail( + "rdar://128024502 ([Bug][iOS18] slice_update failing test on backends beside CPU_ONLY)" + ) + class TensorAssignModel(torch.nn.Module): def __init__(self): super(TensorAssignModel, self).__init__() @@ -7773,18 +7780,28 @@ def forward(self, x, y): return x model = TensorAssignModel() - self.run_compare_torch( - [shape, shape], model, backend=backend, compute_unit=compute_unit + res = self.run_compare_torch( + [shape, shape], + model, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, ) + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + @pytest.mark.parametrize( - "compute_unit, backend", + "compute_unit, backend, minimum_deployment_target", itertools.product( compute_units, backends, + [None, ct.target.iOS18], ), ) - def test_itensor_assign_case_4(self, compute_unit, backend): + def test_tensor_assign_nd_tensor(self, compute_unit, backend, minimum_deployment_target): # single dimension assignment for two n-D tensors class TensorAssignModel(torch.nn.Module): def forward(self, x, y): @@ -7795,18 +7812,28 @@ def forward(self, x, y): shape = (5, 4) model = TensorAssignModel() - self.run_compare_torch( - [shape, shape], model, backend=backend, compute_unit=compute_unit + res = self.run_compare_torch( + [shape, shape], + model, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, ) + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + @pytest.mark.parametrize( - "compute_unit, backend", + "compute_unit, backend, minimum_deployment_target", itertools.product( compute_units, backends, + [None, ct.target.iOS18], ), ) - def test_tensor_assign_case_5(self, compute_unit, backend): + def test_tensor_assign_slice(self, compute_unit, backend, minimum_deployment_target): # slice dimension assignment class TensorAssignModel(torch.nn.Module): def forward(self, x): @@ -7815,16 +7842,28 @@ def forward(self, x): shape = (2, 10) model = TensorAssignModel() - self.run_compare_torch(shape, model, backend=backend, compute_unit=compute_unit) + res = self.run_compare_torch( + shape, + model, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, + ) + + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) @pytest.mark.parametrize( - "compute_unit, backend", + "compute_unit, backend, minimum_deployment_target", itertools.product( compute_units, backends, + [None, ct.target.iOS18], ), ) - def test_tensor_assign_case_6(self, compute_unit, backend): + def test_tensor_assign_slice_case_2(self, compute_unit, backend, minimum_deployment_target): # a more complicated slice dimension assignment class TensorAssignModel(torch.nn.Module): def forward(self, x): @@ -7833,17 +7872,31 @@ def forward(self, x): shape = (2, 10, 3) model = TensorAssignModel() - self.run_compare_torch(shape, model, backend=backend, compute_unit=compute_unit) + res = self.run_compare_torch( + shape, + model, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, + ) + + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) @pytest.mark.parametrize( - "compute_unit, backend, dynamic", + "compute_unit, backend, dynamic, minimum_deployment_target", itertools.product( compute_units, backends, [True, False], + [None, ct.target.iOS18], ), ) - def test_tensor_assign_case_7(self, compute_unit, backend, dynamic): + def test_tensor_assign_complex_slice( + self, compute_unit, backend, dynamic, minimum_deployment_target + ): # general case class TensorAssignModel(torch.nn.Module): def forward(self, x): @@ -7868,19 +7921,34 @@ def forward(self, x): ] else: converter_input_type = None - self.run_compare_torch( + res = self.run_compare_torch( shape, model, converter_input_type=converter_input_type, backend=backend, - compute_unit=compute_unit + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, ) + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + @pytest.mark.parametrize( - "compute_unit, backend, dynamic, mixed_rank", - itertools.product(compute_units, backends, [True, False], [True, False]), + "compute_unit, backend, dynamic, mixed_rank, minimum_deployment_target", + itertools.product( + compute_units, backends, [True, False], [True, False], [None, ct.target.iOS18] + ), ) - def test_tensor_assign_case_8(self, compute_unit, backend, dynamic, mixed_rank): + def test_tensor_assign_dynamic_slice( + self, compute_unit, backend, dynamic, mixed_rank, minimum_deployment_target + ): + if compute_unit != ct.ComputeUnit.CPU_ONLY: + pytest.xfail( + "rdar://128024502 ([Bug][iOS18] slice_update failing test on backends beside CPU_ONLY)" + ) + # general case with dynamic begin and end class TensorAssignModel(torch.nn.Module): def forward(self, x, begin_0, begin_1, end_1): @@ -7933,7 +8001,8 @@ def forward(self, x, begin_0, begin_1, end_1): input_as_shape=False, converter_input_type=converter_input_type, backend=backend, - compute_unit=compute_unit + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, ) if not mixed_rank: @@ -7943,14 +8012,22 @@ def forward(self, x, begin_0, begin_1, end_1): assert "squeeze" not in get_op_types_in_program(prog) assert "expand_dims" not in get_op_types_in_program(prog) + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + @pytest.mark.parametrize( - "compute_unit, backend", + "compute_unit, backend, minimum_deployment_target", itertools.product( compute_units, backends, + [None, ct.target.iOS18], ), ) - def test_tensor_assign_type_compatibility(self, compute_unit, backend): + def test_tensor_assign_type_compatibility( + self, compute_unit, backend, minimum_deployment_target + ): class TensorAssignModel(torch.nn.Module): def forward(self, x): x[:, 1] = torch.tensor([1, 2], dtype=torch.int32) @@ -7958,7 +8035,138 @@ def forward(self, x): shape = (2, 3) model = TensorAssignModel() - self.run_compare_torch(shape, model, backend=backend, compute_unit=compute_unit) + res = self.run_compare_torch( + shape, + model, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, + ) + + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + + +class TestSelectScatter(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, minimum_deployment_target, input_shape", + itertools.product( + compute_units, + backends, + [None, ct.target.iOS18], + [(1,), (4,), (3, 4), (1, 2, 4)], + ), + ) + def test_select_scatter(self, compute_unit, backend, minimum_deployment_target, input_shape): + rank = len(input_shape) + + if ( + input_shape == (1, 2, 4) + and minimum_deployment_target == ct.target.iOS18 + and compute_unit != ct.ComputeUnit.CPU_ONLY + ): + pytest.xfail( + "rdar://128024502 ([Bug][iOS18] slice_update failing test on backends beside CPU_ONLY)" + ) + + def test_model(src_shape, dim, index): + + class SelectScatterModel(torch.nn.Module): + def forward(self, x, y): + return torch.select_scatter( + input=x, + src=y, + dim=dim, + index=index, + ) + + class Rank0SelectScatterModel(torch.nn.Module): + def forward(self, x, y): + y = y[0] + return torch.select_scatter( + input=x, + src=y, + dim=dim, + index=index, + ) + + if len(src_shape) == 0: + src_shape = [1] + model = Rank0SelectScatterModel() + else: + model = SelectScatterModel() + + res = self.run_compare_torch( + [input_shape, src_shape], + model, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, + ) + + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + + for dim in range(-rank, rank): + for index in range(-input_shape[dim], input_shape[dim]): + dim_val = dim + rank if dim < 0 else dim + src_shape = list(input_shape) + src_shape = src_shape[:dim_val] + src_shape[dim_val + 1 :] + test_model(src_shape, dim, index) + + +class TestSliceScatter(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, minimum_deployment_target, input_shape", + itertools.product( + compute_units, + backends, + [None, ct.target.iOS18], + [(1,), (4,), (3, 4), (1, 2, 4)], + ), + ) + def test_slice_scatter(self, compute_unit, backend, minimum_deployment_target, input_shape): + rank = len(input_shape) + + def test_model(src_shape, dim, start, end, step): + class SliceScatterModel(torch.nn.Module): + def forward(self, x, y): + return torch.slice_scatter( + input=x, + src=y, + dim=dim, + start=start, + end=end, + step=step, + ) + + res = self.run_compare_torch( + [input_shape, src_shape], + SliceScatterModel(), + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, + ) + + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + + for dim in range(-rank, rank): + for start in list(range(0, input_shape[dim])) + [None]: + start_val = start if start is not None else 0 + for end in list(range(start_val + 1, input_shape[dim] + 1)) + [None]: + end_val = end if end is not None else input_shape[dim] + for step in range(1, end_val - start_val + 1): + src_shape = list(input_shape) + src_shape[dim] = 1 + (end_val - start_val - 1) // step + src_shape = tuple(src_shape) + test_model(src_shape, dim, start, end, step) class TestIndexPut(TorchBaseTest): @@ -8126,10 +8334,12 @@ def forward(self, x, indices, values): ) @pytest.mark.parametrize( - "compute_unit, backend, frontend", - itertools.product(compute_units, backends, frontends), + "compute_unit, backend, frontend, minimum_deployment_target", + itertools.product(compute_units, backends, frontends, [None, ct.target.iOS18]), ) - def test_index_put_int_index_case_2(self, compute_unit, backend, frontend): + def test_index_put_int_index_case_2( + self, compute_unit, backend, frontend, minimum_deployment_target + ): class IndexPutModel(torch.nn.Module): def forward(self, x): box_corner = x.new(x.shape) @@ -8137,38 +8347,56 @@ def forward(self, x): box_corner[:, :, 1] = x[:, :, 1] return box_corner[:, :, :2] - self.run_compare_torch( + res = self.run_compare_torch( (2, 3, 4), IndexPutModel(), frontend=frontend, backend=backend, compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, ) + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + @pytest.mark.parametrize( - "compute_unit, backend, frontend", - itertools.product(compute_units, backends, frontends), + "compute_unit, backend, frontend, minimum_deployment_target", + itertools.product(compute_units, backends, frontends, [None, ct.target.iOS18]), ) - def test_index_put_int_index_case_3(self, compute_unit, backend, frontend): + def test_index_put_int_index_case_3( + self, compute_unit, backend, frontend, minimum_deployment_target + ): class IndexPutModel(torch.nn.Module): def forward(self, x): y = x.clone() y[:, 0] = 1.0 return y - self.run_compare_torch( + res = self.run_compare_torch( (2, 3), IndexPutModel(), frontend=frontend, backend=backend, compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, ) + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + @pytest.mark.parametrize( - "compute_unit, backend, frontend, val_shape", - itertools.product(compute_units, backends, frontends, ((2, 1), (1,))), + "compute_unit, backend, frontend, val_shape, minimum_deployment_target", + itertools.product( + compute_units, backends, frontends, ((2, 1), (1,)), [None, ct.target.iOS18] + ), ) - def test_index_put_dynamic_int_index_case_1(self, compute_unit, backend, frontend, val_shape): + def test_index_put_dynamic_int_index_case_1( + self, compute_unit, backend, frontend, val_shape, minimum_deployment_target + ): if frontend == TorchFrontend.TORCHSCRIPT: pytest.xfail( "https://github.com/apple/coremltools/issues/2188: " @@ -8181,7 +8409,7 @@ def forward(self, x, position, val): y[:, position] = val return y - self.run_compare_torch( + res = self.run_compare_torch( [(2, 3), (1,), val_shape], IndexPutModel(), input_dtype=np.int32, @@ -8189,13 +8417,21 @@ def forward(self, x, position, val): frontend=frontend, backend=backend, compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, ) + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + @pytest.mark.parametrize( - "compute_unit, backend, frontend", - itertools.product(compute_units, backends, frontends), + "compute_unit, backend, frontend, minimum_deployment_target", + itertools.product(compute_units, backends, frontends, [None, ct.target.iOS18]), ) - def test_index_put_dynamic_int_index_case_2(self, compute_unit, backend, frontend): + def test_index_put_dynamic_int_index_case_2( + self, compute_unit, backend, frontend, minimum_deployment_target + ): if frontend == TorchFrontend.TORCHSCRIPT: pytest.xfail( "https://github.com/apple/coremltools/issues/2188: " @@ -8208,7 +8444,7 @@ def forward(self, x, position, val): y[position, 1:4] = val return y - self.run_compare_torch( + res = self.run_compare_torch( [(2, 4), (1,), (1,)], IndexPutModel(), input_dtype=np.int32, @@ -8216,8 +8452,14 @@ def forward(self, x, position, val): frontend=frontend, backend=backend, compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, ) + # check slice_update is used + if minimum_deployment_target == ct.target.iOS18: + prog = res[1]._mil_program + assert "slice_update" in get_op_types_in_program(prog) + @pytest.mark.parametrize( "compute_unit, backend, frontend, accumulate, minimum_deployment_target", itertools.product( @@ -10799,17 +11041,72 @@ class TestScaledDotProductAttention(TorchBaseTest): """ @pytest.mark.parametrize( - "compute_unit, backend, frontend, rank, dynamic", + "compute_unit, backend, frontend, minimum_deployment_target", + itertools.product( + compute_units, + backends, + frontends, + [None, ct.target.iOS18], + ), + ) + def test_different_batch_dims(self, compute_unit, backend, frontend, minimum_deployment_target): + """ + The query/key/value inputs can have different batch_dims. + """ + q_shape = [1, 2, 10, 3] + k_shape = [2, 1, 10, 3] + v_shape = [2, 2, 10, 3] + input_shape = [ + q_shape, + k_shape, + v_shape, + ] + + model = ModuleWrapper( + function=nn.functional.scaled_dot_product_attention, + kwargs={ + "attn_mask": None, + "dropout_p": 0.0, + "is_causal": False, + }, + ) + + res = self.run_compare_torch( + input_shape, + model, + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + minimum_deployment_target=minimum_deployment_target, + ) + + # Only iOS 18 with torch script can have mb.sdpa, because + # 1. mb.sdpa is introduced in iOS 18, so before iOS 18 we would decompose sdpa + # 2. torch.sdpa is not a core aten op, so EXIR would decompose sdpa + if minimum_deployment_target == ct.target.iOS18 and frontend == TorchFrontend.TORCHSCRIPT: + if backend == ("mlprogram", "fp16"): + assert get_op_types_in_program(res[1]._mil_program) == [ + "cast", + "tile", + "cast", + "tile", + "cast", + "scaled_dot_product_attention", + ] + + @pytest.mark.parametrize( + "compute_unit, backend, frontend, minimum_deployment_target, rank, dynamic", itertools.product( compute_units, backends, frontends, + [None, ct.target.iOS18], [2, 3, 4, 5], [True, False], ), ) def test_different_input_ranks_no_mask( - self, compute_unit, backend, frontend, rank, dynamic, minimum_deployment_target=None + self, compute_unit, backend, frontend, minimum_deployment_target, rank, dynamic ): """ The query/key/value inputs can be any rank 2 or greater. @@ -10845,7 +11142,7 @@ def test_different_input_ranks_no_mask( else: converter_input_type = None - return self.run_compare_torch( + _, coreml_model, _, _, _, _ = self.run_compare_torch( [input_shape] * 3, model, frontend=frontend, @@ -10853,14 +11150,56 @@ def test_different_input_ranks_no_mask( converter_input_type=converter_input_type, compute_unit=compute_unit, minimum_deployment_target=minimum_deployment_target, - )[1] + ) + + # Only iOS 18 with torch script can have mb.sdpa, because + # 1. mb.sdpa is introduced in iOS 18, so before iOS 18 we would decompose sdpa + # 2. torch.sdpa is not a core aten op, so EXIR would decompose sdpa + if minimum_deployment_target == ct.target.iOS18 and frontend == TorchFrontend.TORCHSCRIPT: + if backend == ("mlprogram", "fp16"): + if rank == 2: + if dynamic: + expected_ops = [ + "expand_dims", + "expand_dims", + "expand_dims", + "scaled_dot_product_attention", + "squeeze", + ] + else: + expected_ops = [ + "cast", + "expand_dims", + "cast", + "expand_dims", + "cast", + "expand_dims", + "scaled_dot_product_attention", + "squeeze", + ] + assert get_op_types_in_program(coreml_model._mil_program) == expected_ops + + else: + if dynamic: + expected_ops = [ + "scaled_dot_product_attention", + ] + else: + expected_ops = [ + "cast", + "cast", + "cast", + "scaled_dot_product_attention", + ] + assert get_op_types_in_program(coreml_model._mil_program) == expected_ops @pytest.mark.parametrize( - "compute_unit, backend, frontend, seq_lengths, include_heads, dynamic", + "compute_unit, backend, frontend, minimum_deployment_target, seq_lengths, include_heads, dynamic", itertools.product( compute_units, backends, frontends, + [None, ct.target.iOS18], [(5, 5), (5, 7), (6, 4)], [False, True], [True, False], @@ -10871,10 +11210,10 @@ def test_is_causal_flag( compute_unit, backend, frontend, + minimum_deployment_target, seq_lengths, include_heads, dynamic, - minimum_deployment_target=None, ): if frontend == TorchFrontend.EXIR: pytest.xfail( @@ -10920,11 +11259,12 @@ def test_is_causal_flag( assert len(mil_prog.find_ops(op_type="band_part")) == 0 @pytest.mark.parametrize( - "compute_unit, backend, frontend, seq_lengths, bool_mask, dynamic", + "compute_unit, backend, frontend, minimum_deployment_target, seq_lengths, bool_mask, dynamic", itertools.product( compute_units, backends, frontends, + [None, ct.target.iOS18], [(5, 5), (7, 5)], [False, True], [False, True], @@ -10935,10 +11275,10 @@ def test_attn_mask( compute_unit, backend, frontend, + minimum_deployment_target, seq_lengths, bool_mask, dynamic, - minimum_deployment_target=None, ): if frontend == TorchFrontend.TORCHSCRIPT and bool_mask: pytest.xfail( @@ -10985,11 +11325,12 @@ def test_attn_mask( ) @pytest.mark.parametrize( - "compute_unit, backend, frontend, mask_as_input, dynamic", + "compute_unit, backend, frontend, minimum_deployment_target, mask_as_input, dynamic", itertools.product( compute_units, backends, frontends, + [None, ct.target.iOS18], [True, False], [True, False], ), @@ -10999,9 +11340,9 @@ def test_toy_xformer_with_sdpa( compute_unit, backend, frontend, + minimum_deployment_target, mask_as_input, dynamic, - minimum_deployment_target=None, ): if frontend == TorchFrontend.EXIR and not mask_as_input: pytest.xfail( @@ -11139,7 +11480,10 @@ def test_dropout_early_error_out(self): with pytest.raises( ValueError, - match=r"scaled_dot_product_attention op: dropout is not supported yet", + match=( + r"A non-zero dropout probability is specified. Since Core ML " + r"does not support dropout yet, we cannot convert it" + ), ): model = ModuleWrapper( function=nn.functional.scaled_dot_product_attention, @@ -11204,7 +11548,7 @@ def forward(self, x): # As sampling is random, we make one element significantly larger than others to make # outputs consistent. - input_data = torch.tensor([0, 1e5, 0, 0, 1, 1, 1], dtype=torch.float) + input_data = torch.tensor([0, 5e4, 0, 0, 1, 1, 1], dtype=torch.float) self.run_compare_torch( input_data, TestModel(), @@ -11213,6 +11557,42 @@ def forward(self, x): input_as_shape=False, ) + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product(compute_units, backends), + ) + def test_multinomial_probs_instead_of_logits(self, compute_unit, backend): + """ + Verify the input to multinomial is probs instead of logits. + + When the number of drawing is large, the drawing results could tell us if the input is probs + or logits. In this test we use only 2 classes, so we can compare the number of `1` in results + to verify if the input is taken a logarithm or not. + """ + + class TestModel(nn.Module): + def forward(self, x): + return torch.multinomial(x, 1000, replacement=True) + + input_data = torch.tensor([0.01, 0.1], dtype=torch.float) + torch_model = TestModel() + torch_model.eval() + traced_model = torch.jit.trace(torch_model, input_data) + mlmodel = ct.convert( + traced_model, + inputs=[ct.TensorType(name="input", shape=input_data.shape, dtype=np.float16)], + outputs=[ct.TensorType(name="output", dtype=np.float16)], + convert_to="mlprogram", + compute_units=ct.ComputeUnit.CPU_ONLY, + minimum_deployment_target=ct.target.iOS16, + ) + + if ct.utils._is_macos(): + mlmodel_out = mlmodel.predict({"input": input_data.numpy()})["output"] + torch_out = torch_model(input_data).numpy() + # The counting of 1 in PyTorch and CoreML output should be similar. + assert np.abs(np.sum(mlmodel_out) - np.sum(torch_out)) / mlmodel_out.size < 0.05 + @pytest.mark.parametrize( "compute_unit, backend", itertools.product(compute_units, backends), diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py index d41de1e77..dcd297ba9 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py @@ -12,6 +12,7 @@ import torchvision import coremltools as ct +import coremltools.optimize as cto from coremltools import TensorType from coremltools._deps import ( _HAS_TORCH, @@ -19,7 +20,16 @@ MSG_TORCH_NOT_FOUND, MSG_TORCH_VISION_NOT_FOUND, ) +from coremltools.converters.mil import testing_reqs +from coremltools.converters.mil.mil import types from coremltools.converters.mil.testing_utils import get_op_types_in_program +from coremltools.optimize.coreml import _quantization_passes +from coremltools.test.ml_program.test_compression import get_test_model_and_data +from coremltools.test.optimize.coreml.test_post_training_quantization import ( + create_quantize_friendly_weight, + create_sparse_weight, + create_unique_weight, +) from .testing_utils import TorchBaseTest @@ -28,6 +38,7 @@ torch.manual_seed(30) np.random.seed(30) torch.backends.quantized.engine = "qnnpack" +compute_units = testing_reqs.compute_units def _force_quantize_model( @@ -384,10 +395,17 @@ def forward(self, x): self.run_compare_torch(input_shape, model) @pytest.mark.parametrize( - "quant_dtype, channel_axis", - itertools.product([torch.quint8, torch.qint8], [0, 1, None]), + "compute_unit, quant_dtype, channel_axis, minimum_deployment_target", + itertools.product( + compute_units, + [torch.quint8, torch.qint8], + [0, 1, None], + [ct.target.iOS16, ct.target.iOS17, ct.target.iOS18], + ), ) - def test_quantized_params(self, quant_dtype, channel_axis): + def test_quantized_params( + self, compute_unit, quant_dtype, channel_axis, minimum_deployment_target + ): class Model(torch.nn.Module): def __init__(self): super().__init__() @@ -400,9 +418,17 @@ def forward(self, x): model = Model() model = _force_quantize_model(model, q_dtype=quant_dtype, channel_axis=channel_axis) input_shape = [(3, 5)] - res = self.run_compare_torch(input_shape, model) + res = self.run_compare_torch( + input_shape, + model, + minimum_deployment_target=minimum_deployment_target, + compute_unit=compute_unit, + ) prog = res[1]._mil_program - assert get_op_types_in_program(prog) == ["constexpr_affine_dequantize", "linear"] + if minimum_deployment_target < ct.target.iOS18: + assert get_op_types_in_program(prog) == ["constexpr_affine_dequantize", "linear"] + else: + assert get_op_types_in_program(prog) == ["constexpr_blockwise_shift_scale", "matmul"] @pytest.mark.skipif(not _HAS_TORCH_VISION, reason=MSG_TORCH_VISION_NOT_FOUND) @@ -418,3 +444,485 @@ class TestTorchvisionQuantizedModels(TorchQuantizationBaseTest): def test_quantized_mobilenetv2(self): model = torchvision.models.quantization.mobilenet_v2(pretrained=True, quantize=True) self.run_compare_torch([(1, 3, 224, 224)], model, atol=1.0) + + +class TestPytorchCarryCompressionInfo(TorchQuantizationBaseTest): + """Test compressed PyTorch models which use register_buffer to carry compression info.""" + + @pytest.mark.parametrize( + "compute_unit, n_bits, signed, minimum_deployment_target", + itertools.product( + compute_units, + [4, 8], + [True, False], + [ct.target.iOS16, ct.target.iOS18], + ), + ) + def test_quantization(self, compute_unit, n_bits, signed, minimum_deployment_target): + if n_bits == 4 and minimum_deployment_target < ct.target.iOS18: + pytest.skip("Sub-byte quantization is only supported since iOS18.") + + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data( + quantize_config=cto.coreml.OpLinearQuantizerConfig( + mode="linear_symmetric", + dtype=types.get_nbits_int_builtin_type(n_bits, signed), + granularity="per_tensor", + ) + ) + + scale = np.array([2.0], dtype=np.float32).reshape(1, 1, 1, 1) + zero_point = np.array( + [0 if signed else 2 ** (n_bits - 1)], dtype=np.int8 if signed else np.uint8 + ).reshape(1, 1, 1, 1) + + model.register_buffer("_COREML_/metadata_version", torch.tensor(2)) + model.register_buffer("_COREML_/weight/compression_type", torch.tensor([3])) + model.register_buffer("_COREML_/weight/quantization_n_bits", torch.tensor(n_bits)) + model.register_buffer("_COREML_/weight/quantization_scale", torch.from_numpy(scale)) + model.register_buffer("_COREML_/weight/zero_point", torch.from_numpy(zero_point)) + + traced_model = torch.jit.trace(model, torch_input_values) + input_shape = [input.shape.to_list() for input in inputs] + res = self.run_compare_torch( + input_shape, + traced_model, + minimum_deployment_target=minimum_deployment_target, + compute_unit=compute_unit, + converter=ct.convert, + rtol=1e-04, + atol=1e-03, + ) + main_func = res[1]._mil_program.functions["main"] + + target_dtype_str = ("int" if signed else "uint") + str(n_bits) + if minimum_deployment_target >= ct.target.iOS18: + quantize_ops = main_func.find_ops(op_type="constexpr_blockwise_shift_scale") + assert len(quantize_ops) > 0 + for quantize_op in quantize_ops: + assert types.builtin_to_string(quantize_op.data.dtype) == target_dtype_str + if not signed: + assert types.builtin_to_string(quantize_op.offset.dtype) == target_dtype_str + else: + quantize_ops = main_func.find_ops(op_type="constexpr_affine_dequantize") + assert len(quantize_ops) > 0 + for quantize_op in quantize_ops: + assert types.builtin_to_string(quantize_op.quantized_data.dtype) == target_dtype_str + assert types.builtin_to_string(quantize_op.zero_point.dtype) == target_dtype_str + + @pytest.mark.parametrize( + "compute_unit, n_bits, minimum_deployment_target", + itertools.product(compute_units, [4, 8], [ct.target.iOS16, ct.target.iOS18]), + ) + def test_multiple_parameters_in_same_layer( + self, compute_unit, n_bits, minimum_deployment_target + ): + """Test one layer has multiple parameters (such as weight and bias in a linear layer)""" + if n_bits == 4 and minimum_deployment_target < ct.target.iOS18: + pytest.skip("Sub-byte quantization is only supported since iOS18.") + + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.linear_1 = torch.nn.Linear(16, 32) + self.linear_2 = torch.nn.Linear(32, 64) + + def forward(self, x): + return self.linear_2(self.linear_1(x)) + + model = Model().eval() + with torch.no_grad(): + fake_weight_scale = 2 if n_bits == 4 else 40 + model.linear_2.weight = torch.nn.Parameter( + torch.from_numpy( + np.ones_like(model.linear_2.weight.detach().numpy()) * fake_weight_scale + ).float() + ) + model.linear_2.bias = torch.nn.Parameter( + torch.from_numpy( + np.ones_like(model.linear_2.bias.detach().numpy()) * fake_weight_scale + ).float() + ) + + # Register buffers for both weight and bias for linear_2 layer. + weight_scale = np.array([2.0], dtype=np.float32).reshape(1, 1) + bias_scale = np.array([2.0], dtype=np.float32) + model.linear_2.register_buffer("_COREML_/weight/compression_type", torch.tensor([3])) + model.linear_2.register_buffer("_COREML_/weight/quantization_n_bits", torch.tensor(n_bits)) + model.linear_2.register_buffer( + "_COREML_/weight/quantization_scale", torch.from_numpy(weight_scale) + ) + model.linear_2.register_buffer("_COREML_/bias/compression_type", torch.tensor([3])) + model.linear_2.register_buffer("_COREML_/bias/quantization_n_bits", torch.tensor(n_bits)) + model.linear_2.register_buffer( + "_COREML_/bias/quantization_scale", torch.from_numpy(bias_scale) + ) + model.register_buffer("_COREML_/metadata_version", torch.tensor(2)) + + torch_input_values = torch.rand((8, 16)) + traced_model = torch.jit.trace(model, torch_input_values) + res = self.run_compare_torch( + [(8, 16)], + traced_model, + minimum_deployment_target=minimum_deployment_target, + compute_unit=compute_unit, + converter=ct.convert, + ) + main_func = res[1]._mil_program.functions["main"] + + quantize_op_type = ( + "constexpr_blockwise_shift_scale" + if minimum_deployment_target >= ct.target.iOS18 + else "constexpr_affine_dequantize" + ) + # Only the linear_2 layer got quantized based on registered buffers. + linear_ops = main_func.find_ops(op_type="linear") + assert linear_ops[0].weight.op.op_type == "const" + assert linear_ops[0].bias.op.op_type == "const" + assert linear_ops[1].weight.op.op_type == quantize_op_type + assert linear_ops[1].bias.op.op_type == quantize_op_type + + quantize_ops = main_func.find_ops(op_type=quantize_op_type) + assert len(quantize_ops) == 2 + for quantize_op in quantize_ops: + if minimum_deployment_target >= ct.target.iOS18: + assert types.builtin_to_string(quantize_op.data.dtype) == f"uint{n_bits}" + else: + assert types.builtin_to_string(quantize_op.quantized_data.dtype) == f"uint{n_bits}" + + def test_invalid_compression_info(self): + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() + + # Invalid key combination (didn't specify compression schema) + model.register_buffer("_COREML_/weight/quantization_n_bits", torch.tensor(4)) + with pytest.raises( + ValueError, + match="There are coreml compression related buffers registered in the torch .* but " + "the 'compression_type' is not set", + ): + self.run_compare_torch( + [input.shape.to_list() for input in inputs], + torch.jit.trace(model, torch_input_values), + minimum_deployment_target=ct.target.iOS18, + compute_unit=ct.ComputeUnit.CPU_ONLY, + converter=ct.convert, + ) + + # Invalid key names. + model.register_buffer("_COREML_/weight/invalid_key", torch.tensor(4)) + with pytest.raises(AttributeError, match="has no attribute 'invalid_key'"): + self.run_compare_torch( + [input.shape.to_list() for input in inputs], + torch.jit.trace(model, torch_input_values), + minimum_deployment_target=ct.target.iOS18, + compute_unit=ct.ComputeUnit.CPU_ONLY, + converter=ct.convert, + ) + + # The lut must be specified for palettization. + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() + model.register_buffer("_COREML_/weight/compression_type", torch.tensor([2])) + with pytest.raises( + ValueError, match="Missing lut in compression info. Please register a buffer for lut." + ): + self.run_compare_torch( + [input.shape.to_list() for input in inputs], + torch.jit.trace(model, torch_input_values), + minimum_deployment_target=ct.target.iOS18, + compute_unit=ct.ComputeUnit.CPU_ONLY, + converter=ct.convert, + ) + + @pytest.mark.parametrize( + "compute_unit, n_bits, group_size, channel_axis, minimum_deployment_target", + itertools.product( + compute_units, + [4, 8], + [0, 1, 2], + [0, 1], + [ct.target.iOS16, ct.target.iOS18], + ), + ) + def test_palettization( + self, compute_unit, n_bits, group_size, channel_axis, minimum_deployment_target + ): + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data( + multi_layer=True + ) + + # per-channel scales for the [32, 64, 2, 2] and [64, 32, 2, 2] weight. + scale_1 = np.array([2.0] * 32, dtype=np.float32).reshape(32, 1, 1, 1) + scale_2 = np.array([3.0] * 64, dtype=np.float32).reshape(64, 1, 1, 1) + + unique_weight_1 = create_unique_weight(model.conv_1.weight, nbits=n_bits) + unique_weight_2 = create_unique_weight(model.conv_2.weight, nbits=n_bits) + + # Use grouped-channel-wise lut for conv1 for iOS18+. + block_sizes = [0] * len(unique_weight_1.shape) + if minimum_deployment_target >= ct.target.iOS18: + block_sizes[channel_axis] = group_size + lut_1_params = _quantization_passes.palettize_weights.blockwise_compress( + unique_weight_1, + "UNIQUE", + nbits=n_bits, + block_sizes=block_sizes, + ) + + # Use per-tensor lut for conv2. + lut_2_params = _quantization_passes.palettize_weights.blockwise_compress( + unique_weight_2, "UNIQUE", nbits=n_bits, block_sizes=[0] * len(unique_weight_2.shape) + ) + + if minimum_deployment_target >= ct.target.iOS18: + # Only do per-channel-scale for iOS18+. + unique_weight_1 *= scale_1 + unique_weight_2 *= scale_2 + + with torch.no_grad(): + model.conv_1.weight = torch.nn.Parameter(torch.Tensor(unique_weight_1)) + model.conv_2.weight = torch.nn.Parameter(torch.Tensor(unique_weight_2)) + + model.register_buffer("_COREML_/metadata_version", torch.tensor(1)) + if minimum_deployment_target >= ct.target.iOS18: + model.conv_1.register_buffer("_COREML_/weight/compression_type", torch.tensor([2])) + model.conv_1.register_buffer("_COREML_/weight/lut", torch.tensor(lut_1_params.lut)) + model.conv_1.register_buffer( + "_COREML_/weight/palettization_scale", torch.from_numpy(scale_1) + ) + model.conv_2.register_buffer("_COREML_/weight/compression_type", torch.tensor([2])) + model.conv_2.register_buffer("_COREML_/weight/lut", torch.tensor(lut_2_params.lut)) + if minimum_deployment_target >= ct.target.iOS18: + model.conv_2.register_buffer( + "_COREML_/weight/palettization_scale", torch.from_numpy(scale_2) + ) + + traced_model = torch.jit.trace(model, torch_input_values) + input_shape = [input.shape.to_list() for input in inputs] + res = self.run_compare_torch( + input_shape, + traced_model, + minimum_deployment_target=minimum_deployment_target, + compute_unit=compute_unit, + converter=ct.convert, + ) + main_func = res[1]._mil_program.functions["main"] + + if minimum_deployment_target >= ct.target.iOS18: + expected_dtype = f"uint{n_bits}" + expected_quantize_ops_num = 0 # The scale is moved to post-conv. + expected_palettize_ops_num = 2 + # The lut op is directly fed into conv because the quant scale is no longer there. + palettize_op_child_op_type = "conv" + else: + expected_dtype = "uint8" + expected_quantize_ops_num = 0 + expected_palettize_ops_num = 1 + # The iOS16 doesn't have per-channel-scale, so lut output is directly fed into conv. + palettize_op_child_op_type = "conv" + + quantize_ops = main_func.find_ops(op_type="constexpr_blockwise_shift_scale") + assert len(quantize_ops) == expected_quantize_ops_num + palettize_ops = main_func.find_ops(op_type="constexpr_lut_to_dense") + assert len(palettize_ops) == expected_palettize_ops_num + for palettize_op in palettize_ops: + assert types.builtin_to_string(palettize_op.indices.dtype) == expected_dtype + assert palettize_op.outputs[0].child_ops[0].op_type == palettize_op_child_op_type + + @pytest.mark.parametrize( + "compute_unit, sparse_ratio, minimum_deployment_target", + itertools.product( + compute_units, + [0.01, 0.5, 0.99], + [ct.target.iOS16, ct.target.iOS18], + ), + ) + def test_pruning(self, compute_unit, sparse_ratio, minimum_deployment_target): + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data( + multi_layer=True + ) + with torch.no_grad(): + model.conv_1.weight = torch.nn.Parameter( + torch.Tensor( + create_sparse_weight(model.conv_1.weight, target_sparsity=sparse_ratio) + ) + ) + model.conv_2.weight = torch.nn.Parameter( + torch.Tensor( + create_sparse_weight(model.conv_2.weight, target_sparsity=sparse_ratio) + ) + ) + + model.register_buffer("_COREML_/metadata_version", torch.tensor(1)) + model.conv_1.register_buffer("_COREML_/weight/compression_type", torch.tensor([1])) + model.conv_2.register_buffer("_COREML_/weight/compression_type", torch.tensor([1])) + + traced_model = torch.jit.trace(model, torch_input_values) + input_shape = [input.shape.to_list() for input in inputs] + res = self.run_compare_torch( + input_shape, + traced_model, + minimum_deployment_target=minimum_deployment_target, + compute_unit=compute_unit, + converter=ct.convert, + ) + main_func = res[1]._mil_program.functions["main"] + sparse_ops = main_func.find_ops(op_type="constexpr_sparse_to_dense") + assert len(sparse_ops) == 2 + + for sparse_op in sparse_ops: + assert sparse_op.outputs[0].child_ops[0].op_type == "conv" + assert types.builtin_to_string(sparse_op.nonzero_data.dtype) == "fp32" + if minimum_deployment_target >= ct.target.iOS18: + assert types.builtin_to_string(sparse_op.mask.dtype) == "uint1" + else: + assert types.builtin_to_string(sparse_op.mask.dtype) == "uint8" + assert types.builtin_to_string(sparse_op.shape.dtype) == "uint32" + + @pytest.mark.parametrize( + "compute_unit, n_bits, signed", + itertools.product( + compute_units, + [4, 8], + [True, False], + ), + ) + def test_joint_pruning_quantization(self, compute_unit, n_bits, signed): + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data( + multi_layer=True, + ) + + # Make the weight sparse and also quantization-friendly. + weight_1, scale_1, zero_point_1 = create_quantize_friendly_weight( + model.conv_1.weight.detach().numpy(), nbits=n_bits, signed=signed + ) + weight_1 *= np.random.randint(low=0, high=2, size=model.conv_1.weight.shape) + weight_2, scale_2, zero_point_2 = create_quantize_friendly_weight( + model.conv_2.weight.detach().numpy(), nbits=n_bits, signed=signed + ) + weight_2 *= np.random.randint(low=0, high=2, size=model.conv_2.weight.shape) + with torch.no_grad(): + model.conv_1.weight = torch.nn.Parameter(torch.Tensor(weight_1)) + model.conv_2.weight = torch.nn.Parameter(torch.Tensor(weight_2)) + + model.register_buffer("_COREML_/metadata_version", torch.tensor(2)) + model.conv_1.register_buffer("_COREML_/weight/compression_type", torch.tensor([1, 3])) + model.conv_1.register_buffer("_COREML_/weight/quantization_n_bits", torch.tensor(n_bits)) + model.conv_1.register_buffer( + "_COREML_/weight/quantization_scale", torch.from_numpy(scale_1) + ) + model.conv_1.register_buffer("_COREML_/weight/zero_point", torch.from_numpy(zero_point_1)) + model.conv_2.register_buffer("_COREML_/weight/compression_type", torch.tensor([1, 3])) + model.conv_2.register_buffer("_COREML_/weight/quantization_n_bits", torch.tensor(n_bits)) + model.conv_2.register_buffer( + "_COREML_/weight/quantization_scale", torch.from_numpy(scale_2) + ) + model.conv_2.register_buffer("_COREML_/weight/zero_point", torch.from_numpy(zero_point_2)) + + traced_model = torch.jit.trace(model, torch_input_values) + input_shape = [input.shape.to_list() for input in inputs] + res = self.run_compare_torch( + input_shape, + traced_model, + minimum_deployment_target=ct.target.iOS18, + compute_unit=compute_unit, + converter=ct.convert, + atol=1e-2, + ) + main_func = res[1]._mil_program.functions["main"] + + sparse_quantize_ops = main_func.find_ops(op_type="constexpr_sparse_blockwise_shift_scale") + assert len(sparse_quantize_ops) == 2 + for sparse_quantize_op in sparse_quantize_ops: + expected_dtype = f"int{n_bits}" if signed else f"uint{n_bits}" + assert types.builtin_to_string(sparse_quantize_op.nonzero_data.dtype) == expected_dtype + assert types.builtin_to_string(sparse_quantize_op.data_mask.dtype) == "uint1" + assert types.builtin_to_string(sparse_quantize_op.scale.dtype) == "fp32" + assert sparse_quantize_op.outputs[1].child_ops[0].op_type == "constexpr_sparse_to_dense" + + sparse_ops = main_func.find_ops(op_type="constexpr_sparse_to_dense") + assert len(sparse_ops) == 2 + for sparse_op in sparse_ops: + assert types.builtin_to_string(sparse_op.mask.dtype) == "uint1" + assert types.builtin_to_string(sparse_op.nonzero_data.dtype) == "fp32" + assert sparse_op.outputs[0].child_ops[0].op_type == "conv" + + @pytest.mark.parametrize( + "compute_unit, n_bits, group_size", + itertools.product( + compute_units, + [4, 8], + [0, 1, 2], + ), + ) + def test_joint_pruning_palettization(self, compute_unit, n_bits, group_size): + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data( + multi_layer=True + ) + + # Make the weight sparse and also can be represented by lut. + weight_1 = create_unique_weight(model.conv_1.weight, nbits=n_bits) * np.random.randint( + low=0, high=2, size=model.conv_1.weight.shape + ) + weight_2 = create_unique_weight(model.conv_2.weight, nbits=n_bits) * np.random.randint( + low=0, high=2, size=model.conv_2.weight.shape + ) + + with torch.no_grad(): + model.conv_1.weight = torch.nn.Parameter(torch.Tensor(weight_1)) + model.conv_2.weight = torch.nn.Parameter(torch.Tensor(weight_2)) + + lut_1_params = _quantization_passes.palettize_weights.blockwise_compress( + weight_1, + "UNIQUE", + nbits=n_bits, + block_sizes=[group_size] + [0] * (len(weight_1.shape) - 1), + ) + lut_2_params = _quantization_passes.palettize_weights.blockwise_compress( + weight_2, + "UNIQUE", + nbits=n_bits, + block_sizes=[group_size] + [0] * (len(weight_2.shape) - 1), + ) + + model.register_buffer("_COREML_/metadata_version", torch.tensor(1)) + model.conv_1.register_buffer("_COREML_/weight/compression_type", torch.tensor([1, 2])) + model.conv_1.register_buffer("_COREML_/weight/lut", torch.tensor(lut_1_params.lut)) + model.conv_2.register_buffer("_COREML_/weight/compression_type", torch.tensor([1, 2])) + model.conv_2.register_buffer("_COREML_/weight/lut", torch.tensor(lut_2_params.lut)) + + traced_model = torch.jit.trace(model, torch_input_values) + input_shape = [input.shape.to_list() for input in inputs] + res = self.run_compare_torch( + input_shape, + traced_model, + minimum_deployment_target=ct.target.iOS18, + compute_unit=compute_unit, + converter=ct.convert, + ) + main_func = res[1]._mil_program.functions["main"] + + sparse_palettize_ops = main_func.find_ops(op_type="constexpr_lut_to_sparse") + assert len(sparse_palettize_ops) == 2 + for sparse_palettize_op in sparse_palettize_ops: + assert ( + types.builtin_to_string(sparse_palettize_op.indices_nonzero_data.dtype) + == f"uint{n_bits}" + ) + assert types.builtin_to_string(sparse_palettize_op.indices_mask.dtype) == "uint1" + assert types.builtin_to_string(sparse_palettize_op.lut.dtype) == "fp32" + assert ( + sparse_palettize_op.outputs[1].child_ops[0].op_type == "constexpr_sparse_to_dense" + ) + # As both palettization and pruning is on the original weight, the shape of lut should + # match the original weight's shape except on the output channel. + weight_shape = sparse_palettize_op.outputs[1].child_ops[0].outputs[0].shape + expected_lut_shape = [1] * len(weight_shape) + [2**n_bits] + [1] + if group_size > 0: + expected_lut_shape[0] = weight_shape[0] // group_size + assert sparse_palettize_op.lut.shape == tuple(expected_lut_shape) + + sparse_ops = main_func.find_ops(op_type="constexpr_sparse_to_dense") + assert len(sparse_ops) == 2 + for sparse_op in sparse_ops: + assert types.builtin_to_string(sparse_op.mask.dtype) == "uint1" + assert types.builtin_to_string(sparse_op.nonzero_data.dtype) == "fp32" + assert sparse_op.outputs[0].child_ops[0].op_type == "conv" diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_stateful_model.py b/coremltools/converters/mil/frontend/torch/test/test_torch_stateful_model.py new file mode 100644 index 000000000..721643ea8 --- /dev/null +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_stateful_model.py @@ -0,0 +1,1181 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools + +import numpy as np +import pytest + +import coremltools as ct +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.types.symbolic import any_symbolic +from coremltools.converters.mil.testing_reqs import compute_units +from coremltools.converters.mil.testing_utils import ( + assert_output_dtype, + assert_prog_output_type, + assert_spec_input_image_type, + assert_spec_output_image_type, + get_op_types_in_program, + verify_prediction, +) +from coremltools.proto import FeatureTypes_pb2 as ft + +torch = pytest.importorskip("torch") + + +@pytest.fixture +def float16_buffer_model(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.tensor(np.array([7, 5, 6], dtype=np.float16))) + + def forward(self, x): + x = x.type(torch.float16) + self.state.mul_(x) + self.state.add_(torch.tensor(np.array([1, 2, 3], dtype=np.float16))) + return self.state * 9 + + example_input = torch.randint(0, 100, (3,), dtype=torch.int32) + return torch.jit.trace(Model().eval(), example_input) + + +@pytest.fixture +def float32_buffer_model(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.tensor(np.array([7, 5, 6], dtype=np.float32))) + + def forward(self, x): + self.state.add_(x) + return self.state * 5 + + example_input = torch.randint(0, 100, (3,), dtype=torch.int32) + return torch.jit.trace(Model().eval(), example_input) + + +@pytest.fixture +def float32_non_persistent_buffer_model(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "state", torch.tensor(np.array([7, 5, 6], dtype=np.float32)), persistent=False + ) + + def forward(self, x): + self.state.add_(x) + return self.state * 5 + + example_input = torch.randint(0, 100, (3,), dtype=torch.int32) + return torch.jit.trace(Model().eval(), example_input) + + +@pytest.fixture +def float32_buffer_not_returned_model(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state_1", torch.tensor(np.array([7, 5, 6], dtype=np.float32))) + self.register_buffer("state_2", torch.tensor(np.array([7, 5, 6], dtype=np.float32))) + + def forward(self, x): + self.state_1.add_(x) + self.state_2.add_(x) + return x + + example_input = torch.randint(0, 100, (3,), dtype=torch.int32) + return torch.jit.trace(Model().eval(), example_input) + + +@pytest.fixture +def float32_buffer_not_returned_model_2(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state_1", torch.tensor(np.array([7, 5, 6], dtype=np.float32))) + self.register_buffer("state_2", torch.tensor(np.array([7, 5, 6], dtype=np.float32))) + + def forward(self, x): + self.state_1.add_(x) + self.state_2.add_(x) + self.state_1.add_(x) + return x + + example_input = torch.randint(0, 100, (3,), dtype=torch.int32) + return torch.jit.trace(Model().eval(), example_input) + + +@pytest.fixture +def float32_buffer_model_with_two_inputs(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.tensor(np.array([7, 5, 6], dtype=np.float32))) + + def forward(self, x, y): + self.state.add_(x) + self.state.add_(y) + return self.state * 5 + + example_input = [ + torch.randint(0, 100, (3,), dtype=torch.int32), + torch.randint(0, 100, (3,), dtype=torch.int32), + ] + return torch.jit.trace(Model().eval(), example_input) + + +@pytest.fixture +def float32_buffer_model_two_inputs_two_states(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state_1", torch.tensor(np.array([1, 2, 3], dtype=np.float32))) + self.register_buffer("state_2", torch.tensor(np.array([4, 5, 6], dtype=np.float32))) + + def forward(self, x, y): + self.state_1.add_(x) + self.state_2.add_(y) + return self.state_1 * self.state_2 + + example_input = [ + torch.randint(0, 100, (3,), dtype=torch.int32), + torch.randint(0, 100, (3,), dtype=torch.int32), + ] + return torch.jit.trace(Model().eval(), example_input) + + +def float32_buffer_sequantial_model(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.tensor(np.array([7, 5, 6], dtype=np.float32))) + + def forward(self, x): + res = self.state + 8 + self.state[0] = 9.0 + x = self.state * x + self.state.mul_(self.state) + self.state.sub_(x) + return torch.relu(self.state) + + example_input = torch.randint(0, 100, (3,), dtype=torch.int32) + return torch.jit.trace(Model().eval(), example_input) + + +@pytest.fixture +def float32_two_buffers_model(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state_1", torch.tensor(np.array([1, 2, 3], dtype=np.float32))) + self.register_buffer("state_2", torch.tensor(np.array([4, 5, 6], dtype=np.float32))) + + def forward(self, x): + v1 = self.state_2 - x + self.state_2.mul_(self.state_1) + self.state_1.mul_(v1) + self.state_1.add_(self.state_2) + return self.state_1 + x + + example_input = torch.randint(0, 100, (3,), dtype=torch.int32) + return torch.jit.trace(Model().eval(), example_input) + + +@pytest.fixture +def rank4_input_model_with_buffer(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "state_1", torch.tensor(np.zeros((1, 3, 10, 20), dtype=np.float32)) + ) + + def forward(self, x): + x = x + 5.5 + self.state_1.add_(x) + self.state_1[0, 0, 0, 0:1] = torch.tensor([1.0]) + return x + + example_input = torch.randint(0, 100, (1, 3, 10, 20), dtype=torch.float32) + return torch.jit.trace(Model().eval(), example_input) + + +@pytest.fixture +def rank4_grayscale_input_model_with_buffer(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "state_1", torch.tensor(np.zeros((1, 1, 10, 20), dtype=np.float32)) + ) + + def forward(self, x): + x = x + 5 + self.state_1.add_(x) + self.state_1[0, 0, 0, 0:1] = torch.tensor([1.0]) + return x + + example_input = torch.randint(0, 100, (1, 1, 10, 20), dtype=torch.float32) + return torch.jit.trace(Model().eval(), example_input) + + +@pytest.mark.skipif( + ct.utils._macos_version() < (15, 0), reason="Tests are for deployment target iOS18/macos15" +) +class TestStateConversionAPI: + @pytest.mark.parametrize( + "compute_unit", + compute_units, + ) + def test_state_model_api_example(self, compute_unit): + """ + Test the public API example. + """ + + class UpdateBufferModel(torch.nn.Module): + def __init__(self): + super(UpdateBufferModel, self).__init__() + self.register_buffer("state_1", torch.tensor(np.array([0, 0, 0], dtype=np.float32))) + + def forward(self, x): + # In place update of the model state + self.state_1.add_(x) + return self.state_1 + + model = UpdateBufferModel() + traced_model = torch.jit.trace(model, torch.tensor([1, 2, 3], dtype=torch.float32)) + + inputs = [ + ct.TensorType(shape=(3,)), + ] + states = [ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state_1", + ), + ] + mlmodel = ct.convert( + traced_model, + inputs=inputs, + states=states, + minimum_deployment_target=ct.target.iOS18, + convert_to="mlprogram", + compute_units=compute_unit, + ) + + verify_prediction(mlmodel) + + @pytest.mark.parametrize( + "compute_unit", + compute_units, + ) + def test_single_state_single_input( + self, float32_buffer_model, float32_non_persistent_buffer_model, compute_unit + ): + """ + Tests for different combination of input dtypes. + """ + + def test_valid_prog(prog, expected_ops=None): + block = prog.functions["main"] + assert types.is_tensor(block.inputs["x"].sym_type) + assert types.is_state(block.inputs["state_workaround"].sym_type) + assert len(block.outputs) == 1 + assert types.is_tensor(block.outputs[0].sym_type) + if expected_ops is None: + expected_ops = [ + "read_state", + "add", + "coreml_update_state", + "mul", + ] + assert get_op_types_in_program(prog) == expected_ops + + """ + fp32 state / input (default with compute_precision=fp32), + with both persistent and non-persistent buffer. + fp32 state is not supported through runtime. + + (%x: Tensor(fp32), %state: State(fp32)) -> { + %read_state(fp32) = read_state(%state) + %add(fp32) = add(%read_state, %x) + %update(fp32) = coreml_update_state(%state, %add) + %mul(fp32) = mul(%update, 5) + } -> (%mul) + """ + for model in [float32_buffer_model, float32_non_persistent_buffer_model]: + prog = ct.convert( + model, + inputs=[ + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state", + ), + ], + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT32, + convert_to="milinternal", + ) + test_valid_prog(prog) + block = prog.functions["main"] + assert block.inputs["x"].sym_type.get_primitive() == types.fp32 + assert ( + block.inputs["state_workaround"].sym_type.wrapped_type().get_primitive() + == types.fp32 + ) + assert block.outputs[0].dtype == types.fp32 + + """ + fp16 state / input (user specify) + + (%x: Tensor(fp16), %state: State(fp16)) -> { + %read_state(fp16) = read_state(%state) + %add(fp16) = add(%read_state, %x) + %update(fp16) = coreml_update_state(%state, %add) + %mul(fp16) = mul(%update, 5) + } -> (%mul) + """ + mlmodel = ct.convert( + float32_buffer_model, + inputs=[ + ct.TensorType(shape=(3,), dtype=np.float16), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + dtype=np.float16, + ), + name="state", + ), + ], + minimum_deployment_target=ct.target.iOS18, + convert_to="mlprogram", + compute_units=compute_unit, + ) + # check the pymil program + prog = mlmodel._mil_program + test_valid_prog(prog) + block = prog.functions["main"] + assert block.inputs["x"].sym_type.get_primitive() == types.fp16 + assert ( + block.inputs["state_workaround"].sym_type.wrapped_type().get_primitive() == types.fp16 + ) + assert block.outputs[0].dtype == types.fp16 + + # check the mil proto + mil = mlmodel.get_spec().mlProgram + for function in mil.functions.values(): + for block in function.block_specializations.values(): + ops = list(block.operations) + expected_ops = [ + "read_state", + "add", + "write_state", + "read_state", + "const", + "mul", + ] + assert [val.type for val in ops] == expected_ops + assert len(ops[2].outputs) == 0 + + verify_prediction(mlmodel) + + """ + fp16 state / input (default with compute_precision=fp16) + + (%x: Tensor(fp16), %state: State(fp16)) -> { + %read_state(fp16) = read_state(%state) + %add(fp16) = add(%read_state, %x) + %update(fp16) = coreml_update_state(%state, %add) + %mul(fp16) = mul(%update, 5) + } -> (%mul) + """ + mlmodel = ct.convert( + float32_buffer_model, + inputs=[ + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state", + ), + ], + minimum_deployment_target=ct.target.iOS18, + compute_units=compute_unit, + ) + prog = mlmodel._mil_program + test_valid_prog(prog) + block = prog.functions["main"] + assert block.inputs["x"].sym_type.get_primitive() == types.fp16 + assert ( + block.inputs["state_workaround"].sym_type.wrapped_type().get_primitive() == types.fp16 + ) + assert block.outputs[0].dtype == types.fp16 + verify_prediction(mlmodel) + + + """ + fp16 state and fp32 input + + (%x: Tensor(fp32), %state: State(fp16)) -> { + %read_state(fp16) = read_state(%state) + %x_cast(fp16) = cast(%x) + %add(fp16) = add(%read_state, %x_cast) + %update(fp16) = coreml_update_state(%state, %add) + %mul(fp16) = mul(%update, 5) + } -> (%mul) + """ + mlmodel = ct.convert( + float32_buffer_model, + inputs=[ + ct.TensorType(shape=(3,), dtype=np.float32), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType(shape=(3,), dtype=np.float16), name="state" + ), + ], + minimum_deployment_target=ct.target.iOS18, + compute_units=compute_unit, + ) + prog = mlmodel._mil_program + expected_ops = [ + "read_state", + "cast", + "add", + "coreml_update_state", + "mul", + ] + test_valid_prog(prog, expected_ops) + block = prog.functions["main"] + assert block.inputs["x"].sym_type.get_primitive() == types.fp32 + assert ( + block.inputs["state_workaround"].sym_type.wrapped_type().get_primitive() == types.fp16 + ) + assert prog.find_ops("cast")[0].x.op is None + assert block.outputs[0].dtype == types.fp16 + # This model is failing due to a bug in Espresso: + # rdar://128478924 ([Bug][Stateful model][E5] Stateful model triggers a State Operation dependencies error in E5) + # After the above radar is fixed, we should be able to run the mlmodel + # verify_prediction(mlmodel) + + """ + fp32 state and fp16 input. This is a rare corner case that shouldn't + happend often. + fp32 state is not supported through runtime. + + (%x: Tensor(fp16), %state: State(fp32)) -> { + %read_state(fp32) = read_state(%state) + %read_state_cast(fp16) = cast(read_state) + %add(fp16) = add(%read_state_casr, %x) + %add_cast(fp32) = cast(%add) + %update(fp32) = coreml_update_state(%state, %add_cast) + %update_cast(fp16) = cast(%update) + %mul(fp16) = mul(%update_cast, 5) + } -> (%mul) + """ + prog = ct.convert( + float32_buffer_model, + inputs=[ + ct.TensorType(shape=(3,), dtype=np.float16), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType(shape=(3,), dtype=np.float32), name="state" + ), + ], + minimum_deployment_target=ct.target.iOS18, + convert_to="milinternal", + ) + expected_ops = [ + "read_state", + "cast", + "add", + "cast", + "coreml_update_state", + "cast", + "mul", + ] + test_valid_prog(prog, expected_ops) + block = prog.functions["main"] + assert block.inputs["x"].sym_type.get_primitive() == types.fp16 + assert ( + block.inputs["state_workaround"].sym_type.wrapped_type().get_primitive() == types.fp32 + ) + assert prog.find_ops("cast")[0].x.op.op_type == "read_state" + assert prog.find_ops("cast")[1].x.op.op_type == "add" + assert prog.find_ops("cast")[2].x.op.op_type == "coreml_update_state" + assert block.outputs[0].dtype == types.fp16 + + @pytest.mark.parametrize( + "compute_unit", + compute_units, + ) + def test_single_state_single_input_model_fp16(self, float16_buffer_model, compute_unit): + """ + Tests conversion of a stateful torch model defined in fp16. + This will be common in model with large size. + """ + # fp16 state / input + mlmodel = ct.convert( + float16_buffer_model, + inputs=[ + ct.TensorType(shape=(3,), dtype=np.float16), + ], + states=[ + ct.StateType(wrapped_type=ct.TensorType(shape=(3,), dtype=np.float16), name="state") + ], + minimum_deployment_target=ct.target.iOS18, + convert_to="mlprogram", + compute_units=compute_unit, + ) + prog = mlmodel._mil_program + assert get_op_types_in_program(prog) == [ + "read_state", + "mul", + "coreml_update_state", + "add", + "coreml_update_state", + "mul", + ] + if compute_unit in (ct.ComputeUnit.CPU_ONLY, ct.ComputeUnit.CPU_AND_GPU): + # rdar://128446982 ([Bug][Stateful model][ANE] Stateful model fails to run on ANE) + verify_prediction(mlmodel) + + # force state / input to be fp32 (intented stress test) + prog = ct.convert( + float16_buffer_model, + inputs=[ + ct.TensorType(shape=(3,), dtype=np.float32), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType(shape=(3,), dtype=np.float32), name="state" + ), + ], + minimum_deployment_target=ct.target.iOS18, + convert_to="milinternal", + ) + assert get_op_types_in_program(prog) == [ + "read_state", + "cast", + "cast", + "mul", + "cast", + "coreml_update_state", + "cast", + "add", + "cast", + "coreml_update_state", + "cast", + "mul", + ] + + + @pytest.mark.parametrize( + "compute_unit", + compute_units, + ) + def test_multiple_states_model(self, float32_two_buffers_model, compute_unit): + """ + Tests for a model with multiple buffers. + """ + mlmodel = ct.convert( + float32_two_buffers_model, + inputs=[ + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state_1", + ), + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state_2", + ), + ], + minimum_deployment_target=ct.target.iOS18, + convert_to="mlprogram", + compute_units=compute_unit, + ) + prog = mlmodel._mil_program + assert get_op_types_in_program(prog) == [ + "read_state", + "sub", + "read_state", + "mul", + "coreml_update_state", + "mul", + "coreml_update_state", + "add", + "coreml_update_state", + "add", + ] + if compute_unit in (ct.ComputeUnit.CPU_ONLY, ct.ComputeUnit.CPU_AND_GPU): + # rdar://128446982 ([Bug][Stateful model][ANE] Stateful model fails to run on ANE) + verify_prediction(mlmodel) + + def test_convert_buffer_model_without_state_type(self, float32_buffer_model): + """ + If the users don't specify StateType for buffer states, + they will be treated as const tensors. + We should modify this unittest after we fix this radar: + rdar://116489054 ([Infra] Have a more sophisticated handling for torch buffer state when not declared as StateType) + """ + prog = ct.convert( + float32_buffer_model, + inputs=[ + ct.TensorType(shape=(3,)), + ], + minimum_deployment_target=ct.target.iOS17, + convert_to="milinternal", + ) + assert get_op_types_in_program(prog) == [ + "add", + "mul", + ] + + @pytest.mark.parametrize( + "compute_unit", + compute_units, + ) + def test_tensor_state_inputs_interleave( + self, float32_buffer_model_two_inputs_two_states, compute_unit + ): + """ + We allow the user to interleave tensor / state input types. + """ + mlmodel = ct.convert( + float32_buffer_model_two_inputs_two_states, + inputs=[ + ct.TensorType(shape=(3,)), + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state_1", + ), + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state_2", + ), + ], + minimum_deployment_target=ct.target.iOS18, + convert_to="mlprogram", + compute_units=compute_unit, + ) + prog = mlmodel._mil_program + assert get_op_types_in_program(prog) == [ + "read_state", + "add", + "coreml_update_state", + "read_state", + "add", + "coreml_update_state", + "mul", + ] + verify_prediction(mlmodel) + + def test_invalid_deployment_target_error_out(self, float32_buffer_model): + """ + The conversion should error out if the user tries to convert it + into deployment target < ioS18. + """ + with pytest.raises( + ValueError, + match="State model is supported only >= iOS18. Please update the minimum_deployment_target to at least coremltools.target.iOS18", + ): + prog = ct.convert( + float32_buffer_model, + inputs=[ + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state", + ), + ], + minimum_deployment_target=ct.target.iOS17, + ) + + with pytest.raises( + ValueError, + match="State model is supported only >= iOS18. Please update the minimum_deployment_target to at least coremltools.target.iOS18", + ): + prog = ct.convert( + float32_buffer_model, + inputs=[ + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state", + ), + ], + convert_to="neuralnetwork", + ) + + def test_invalid_state_name_error_out(self, float32_buffer_model): + """ + The conversion should error out if the user doesn't provide / + or provides wrong name of the buffer + """ + with pytest.raises( + ValueError, + match="StateType named None not provided or not found in the source torch model. Please make sure the name in 'ct.StateType\(name=..., wrapped_type=ct.TensorType\(...\)\)' match the 'named_buffers\(\)' in the source torch model.", + ): + prog = ct.convert( + float32_buffer_model, + inputs=[ + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ) + ), + ], + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT32, + convert_to="milinternal", + ) + + with pytest.raises( + ValueError, + match="StateType named invalid not provided or not found in the source torch model. Please make sure the name in 'ct.StateType\(name=..., wrapped_type=ct.TensorType\(...\)\)' match the 'named_buffers\(\)' in the source torch model: \['state'\]", + ): + prog = ct.convert( + float32_buffer_model, + inputs=[ + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType(wrapped_type=ct.TensorType(shape=(3,)), name="invalid"), + ], + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT32, + convert_to="milinternal", + ) + + def test_invalid_state_shape_out(self, float32_buffer_model): + """ + The conversion should error out if the provided StateType has + a different shape than the registered buffer. + """ + with pytest.raises( + ValueError, + match="StateType shape \(2,\) must matched the torch buffer shape \(3,\)", + ): + prog = ct.convert( + float32_buffer_model, + inputs=[ + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(2,), + ), + name="state", + ), + ], + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT32, + convert_to="milinternal", + ) + + def test_invalid_input_numbers_error_out(self, float32_buffer_model_with_two_inputs): + """ + The checking for the tensor inputs should not be affected by + the new added StateType inputs + """ + with pytest.raises( + ValueError, + match="Number of TorchScript inputs \(2\) must match the user provided inputs \(1\).", + ): + prog = ct.convert( + float32_buffer_model_with_two_inputs, + inputs=[ + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state", + ), + ], + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT32, + convert_to="milinternal", + ) + + def test_invalid_inputs_contains_states_error_out(self, float32_buffer_model_with_two_inputs): + """ + The checking for the inputs should not contain StateType. + """ + with pytest.raises( + ValueError, + match="'inputs' cannot contain an instance of StateType", + ): + prog = ct.convert( + float32_buffer_model_with_two_inputs, + inputs=[ + ct.TensorType(shape=(3,)), + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state", + ), + ], + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT32, + convert_to="milinternal", + ) + + @staticmethod + def convert_state_model(model, backend, compute_unit=ct.ComputeUnit.CPU_ONLY): + return ct.convert( + model, + inputs=[ + ct.TensorType(shape=(3,)), + ], + states=[ + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state_1", + ), + ct.StateType( + wrapped_type=ct.TensorType( + shape=(3,), + ), + name="state_2", + ), + ], + minimum_deployment_target=ct.target.iOS18, + convert_to=backend, + compute_units=compute_unit, + ) + + @staticmethod + def check_state_model(mlmodel, expected_ops, run_prediction=True): + mil = mlmodel.get_spec().mlProgram + for function in mil.functions.values(): + for block in function.block_specializations.values(): + ops = list(block.operations) + assert [val.type for val in ops] == expected_ops + if run_prediction: + verify_prediction(mlmodel) + + @pytest.mark.parametrize( + "compute_unit", + compute_units, + ) + def test_state_ops_cannot_removed( + self, + float32_buffer_not_returned_model, + float32_buffer_not_returned_model_2, + compute_unit, + ): + """ + Check the coreml_update_state should not be removed by dead_code_elimination pass. + """ + # Test case 1 + prog = self.convert_state_model(float32_buffer_not_returned_model, "milinternal") + assert get_op_types_in_program(prog) == [ + "identity", + "read_state", + "add", + "coreml_update_state", + "read_state", + "add", + "coreml_update_state", + ] + mlmodel = self.convert_state_model( + float32_buffer_not_returned_model, "mlprogram", compute_unit + ) + expected_ops = [ + "identity", + "read_state", + "add", + "write_state", + "read_state", + "add", + "write_state", + ] + # The model is failing besides CPU, we should set the run_prediction = True after the following radar is fixed: + # rdar://128481009 ([Bug][Stateful model][E5] Stateful mdel triggers a No OperationBuilder in this Block produces output x error) + run_prediction = compute_units == ct.ComputeUnit.CPU_ONLY + self.check_state_model(mlmodel, expected_ops, run_prediction) + + # Test case 2 + prog = self.convert_state_model(float32_buffer_not_returned_model_2, "milinternal") + assert get_op_types_in_program(prog) == [ + "identity", + "read_state", + "add", + "coreml_update_state", + "read_state", + "add", + "coreml_update_state", + "add", + "coreml_update_state", + ] + mlmodel = self.convert_state_model( + float32_buffer_not_returned_model_2, "mlprogram", compute_unit + ) + expected_ops = [ + "identity", + "read_state", + "add", + "write_state", + "read_state", + "read_state", + "add", + "write_state", + "add", + "write_state", + ] + # The model is failing besides CPU, we should set the run_prediction = True after the following radar is fixed: + # rdar://128481009 ([Bug][Stateful model][E5] Stateful mdel triggers a No OperationBuilder in this Block produces output x error) + run_prediction = compute_units == ct.ComputeUnit.CPU_ONLY + self.check_state_model(mlmodel, expected_ops, run_prediction) + + @pytest.mark.parametrize( + "compute_unit, dtype", + itertools.product( + compute_units, + [np.float16, np.float32], + ), + ) + def test_single_state_single_input_sequential_model(self, compute_unit, dtype): + """ + Tests for a model with a sequence of inplace ops. + """ + + def get_stateful_model(): + # fp32 state is not supported through runtime + convert_to = "milinternal" if dtype == np.float32 else "mlprogram" + compute_precision_mapping = { + np.float16: ct.precision.FLOAT16, + np.float32: ct.precision.FLOAT32, + } + model = ct.convert( + float32_buffer_sequantial_model(), + inputs=[ + ct.TensorType(shape=(3,), dtype=dtype), + ], + states=[ + ct.StateType(wrapped_type=ct.TensorType(shape=(3,), dtype=dtype), name="state"), + ], + minimum_deployment_target=ct.target.iOS18, + compute_precision=compute_precision_mapping[dtype], + convert_to=convert_to, + compute_units=compute_unit, + ) + + if dtype == np.float32: + return None, model + assert dtype == np.float16 + return model, model._mil_program + + mlmodel, prog = get_stateful_model() + assert get_op_types_in_program(prog) == [ + "read_state", + "slice_update", + "coreml_update_state", + "mul", + "mul", + "coreml_update_state", + "sub", + "coreml_update_state", + "relu", + ] + + if mlmodel is not None: + # rdar://128446982 ([Bug][Stateful model][ANE] Stateful model fails to run on ANE) + if compute_unit in (ct.ComputeUnit.CPU_ONLY, ct.ComputeUnit.CPU_AND_GPU): + verify_prediction(mlmodel) + + @pytest.mark.parametrize( + "compute_unit", + compute_units, + ) + def test_color_input_with_buffer(self, rank4_input_model_with_buffer, compute_unit): + mlmodel = ct.convert( + rank4_input_model_with_buffer, + inputs=[ct.ImageType(shape=(1, 3, 10, 20), color_layout=ct.colorlayout.RGB)], + states=[ct.StateType(wrapped_type=ct.TensorType(shape=(1, 3, 10, 20)), name="state_1")], + outputs=[ct.TensorType(dtype=np.float32)], + minimum_deployment_target=ct.target.iOS18, + compute_units=compute_unit, + ) + assert_spec_input_image_type(mlmodel._spec, expected_feature_type=ft.ImageFeatureType.RGB) + assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp32") + verify_prediction(mlmodel) + + @pytest.mark.parametrize( + "compute_unit", + compute_units, + ) + def test_color_output_with_buffer(self, rank4_input_model_with_buffer, compute_unit): + # image input / image output + mlmodel = ct.convert( + rank4_input_model_with_buffer, + inputs=[ct.ImageType(shape=(1, 3, 10, 20), color_layout=ct.colorlayout.BGR)], + states=[ct.StateType(wrapped_type=ct.TensorType(shape=(1, 3, 10, 20)), name="state_1")], + outputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)], + minimum_deployment_target=ct.target.iOS18, + compute_units=compute_unit, + ) + assert_spec_input_image_type(mlmodel._spec, expected_feature_type=ft.ImageFeatureType.BGR) + assert_spec_output_image_type(mlmodel._spec, expected_feature_type=ft.ImageFeatureType.RGB) + assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp32") + verify_prediction(mlmodel) + + # tensor input / image output + # check mlprogram can have image output, both static and dynamic case are tested + for is_dynamic in [True, False]: + shape = ( + ct.Shape((1, 3, ct.RangeDim(5, 10, default=10), ct.RangeDim(5, 20, default=20))) + if is_dynamic + else ct.Shape((1, 3, 10, 20)) + ) + mlmodel = ct.convert( + rank4_input_model_with_buffer, + inputs=[ct.TensorType(shape=shape, dtype=np.float32)], + states=[ + ct.StateType(wrapped_type=ct.TensorType(shape=(1, 3, 10, 20)), name="state_1") + ], + outputs=[ct.ImageType(name="output_image", color_layout=ct.colorlayout.RGB)], + minimum_deployment_target=ct.target.iOS18, + compute_units=compute_unit, + ) + assert_spec_output_image_type( + mlmodel._spec, expected_feature_type=ft.ImageFeatureType.RGB + ) + assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp32") + if is_dynamic: + assert any_symbolic(mlmodel._mil_program.functions["main"].outputs[0].shape) + if not is_dynamic or compute_unit != ct.ComputeUnit.CPU_ONLY: + # rdar://128491187 ([Bug][Stateful model][Classic CPU] Stateful model is throwing a Failed to reshape error) + verify_prediction(mlmodel) + + @pytest.mark.parametrize( + "compute_unit", + compute_units, + ) + def test_grayscale_input_with_buffer( + self, rank4_grayscale_input_model_with_buffer, compute_unit + ): + # test with GRAYSCALE + mlmodel = ct.convert( + rank4_grayscale_input_model_with_buffer, + inputs=[ct.ImageType(shape=(1, 1, 10, 20), color_layout=ct.colorlayout.GRAYSCALE)], + states=[ct.StateType(wrapped_type=ct.TensorType(shape=(1, 1, 10, 20)), name="state_1")], + outputs=[ct.TensorType(dtype=np.float32)], + minimum_deployment_target=ct.target.iOS18, + compute_units=compute_unit, + ) + assert_spec_input_image_type( + mlmodel._spec, expected_feature_type=ft.ImageFeatureType.GRAYSCALE + ) + assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp32") + verify_prediction(mlmodel) + + # test with GRAYSCALE_FLOAT16 + mlmodel = ct.convert( + rank4_grayscale_input_model_with_buffer, + inputs=[ + ct.ImageType(shape=(1, 1, 10, 20), color_layout=ct.colorlayout.GRAYSCALE_FLOAT16) + ], + states=[ct.StateType(wrapped_type=ct.TensorType(shape=(1, 1, 10, 20)), name="state_1")], + outputs=[ct.TensorType(dtype=np.float16)], + minimum_deployment_target=ct.target.iOS18, + compute_units=compute_unit, + ) + assert_spec_input_image_type( + mlmodel._spec, expected_feature_type=ft.ImageFeatureType.GRAYSCALE_FLOAT16 + ) + assert_output_dtype(mlmodel, expected_type_str="fp16") + verify_prediction(mlmodel) + + @pytest.mark.parametrize( + "compute_unit", + compute_units, + ) + def test_grayscale_output_with_buffer( + self, rank4_grayscale_input_model_with_buffer, compute_unit + ): + # grayscale fp16 input and output + mlmodel = ct.convert( + rank4_grayscale_input_model_with_buffer, + inputs=[ + ct.ImageType(shape=(1, 1, 10, 20), color_layout=ct.colorlayout.GRAYSCALE_FLOAT16) + ], + states=[ct.StateType(wrapped_type=ct.TensorType(shape=(1, 1, 10, 20)), name="state_1")], + outputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)], + minimum_deployment_target=ct.target.iOS18, + compute_units=compute_unit, + ) + assert_spec_input_image_type( + mlmodel._spec, expected_feature_type=ft.ImageFeatureType.GRAYSCALE_FLOAT16 + ) + assert_spec_output_image_type( + mlmodel._spec, expected_feature_type=ft.ImageFeatureType.GRAYSCALE_FLOAT16 + ) + assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp16") + verify_prediction(mlmodel) + + # grayscale input and grayscale fp16 output + mlmodel = ct.convert( + rank4_grayscale_input_model_with_buffer, + inputs=[ct.ImageType(shape=(1, 1, 10, 20), color_layout=ct.colorlayout.GRAYSCALE)], + outputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)], + minimum_deployment_target=ct.target.iOS18, + compute_units=compute_unit, + ) + assert_spec_input_image_type( + mlmodel._spec, expected_feature_type=ft.ImageFeatureType.GRAYSCALE + ) + assert_spec_output_image_type( + mlmodel._spec, expected_feature_type=ft.ImageFeatureType.GRAYSCALE_FLOAT16 + ) + assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp16") + verify_prediction(mlmodel) diff --git a/coremltools/converters/mil/frontend/torch/test/testing_utils.py b/coremltools/converters/mil/frontend/torch/test/testing_utils.py index d62d6a002..2543daad7 100644 --- a/coremltools/converters/mil/frontend/torch/test/testing_utils.py +++ b/coremltools/converters/mil/frontend/torch/test/testing_utils.py @@ -259,6 +259,9 @@ def run_compare_torch( backend=("neuralnetwork", "fp32"), rand_range=(-1.0, 1.0), use_scripting=False, + # TODO (rdar://128768037): Once we fully figure out torch.export converter, + # we may default the tests to ATen dialect + use_edge_dialect=True, converter_input_type=None, compute_unit=ct.ComputeUnit.CPU_ONLY, minimum_deployment_target=None, @@ -298,8 +301,9 @@ def run_compare_torch( input_data_clone = tuple(input_data_clone) elif isinstance(input_data_clone, torch.Tensor): input_data_clone = (input_data_clone,) - exir_program_aten = torch.export.export(model, input_data_clone) - model_spec = executorch.exir.to_edge(exir_program_aten).exported_program() + model_spec = torch.export.export(model, input_data_clone) + if use_edge_dialect: + model_spec = executorch.exir.to_edge(model_spec).exported_program() else: raise ValueError( f"Unknown value of frontend. Needs to be either TorchFrontend.TORCHSCRIPT or TorchFrontend.EXIR. Provided: {frontend}" diff --git a/coremltools/converters/mil/input_types.py b/coremltools/converters/mil/input_types.py index 7f3c66b39..5b8c598f8 100644 --- a/coremltools/converters/mil/input_types.py +++ b/coremltools/converters/mil/input_types.py @@ -10,7 +10,6 @@ from coremltools.converters.mil.mil import types from coremltools.converters.mil.mil.types.symbolic import is_symbolic -from coremltools.converters.mil.mil.types.type_mapping import is_builtin, numpy_type_to_builtin_type class ColorLayout(Enum): @@ -208,7 +207,7 @@ def __init__(self, name=None, shape=None, dtype=None, default_value=None): """ super(TensorType, self).__init__(name, shape) if dtype is not None: - if is_builtin(dtype): + if types.is_builtin(dtype): self.dtype = dtype if dtype not in ( types.int8, @@ -226,7 +225,7 @@ def __init__(self, name=None, shape=None, dtype=None, default_value=None): else: # Assume dtype is numpy type try: - self.dtype = numpy_type_to_builtin_type(dtype) + self.dtype = types.numpy_type_to_builtin_type(dtype) except TypeError: raise TypeError("dtype={} is unsupported".format(dtype)) if dtype not in (np.float16, np.float32, np.float64, float, @@ -247,20 +246,19 @@ def __init__(self, name=None, shape=None, dtype=None, default_value=None): msg = 'TensorType {} default_value can only have ' +\ 'same entries' raise ValueError(msg.format(name)) - if not self.shape.has_symbolic and \ - list(default_value.shape) != list(self.shape.symbolic_shape): - msg = 'TensorType {} default_value shape {} != ' +\ - 'TensorType.shape {}' - raise ValueError(msg.format(name, default_value.shape, - self.shape.to_list())) - if self.dtype is not None and \ - numpy_type_to_builtin_type(default_value.dtype) != self.dtype: - msg = 'TensorType {} default_value dtype {} != ' +\ - 'TensorType.dtype {}' - raise ValueError(msg.format(name, default_value.dtype, - self.dtype.__type_info__())) + if not self.shape.has_symbolic and list(default_value.shape) != list( + self.shape.symbolic_shape + ): + msg = "TensorType {} default_value shape {} != " + "TensorType.shape {}" + raise ValueError(msg.format(name, default_value.shape, self.shape.to_list())) + if ( + self.dtype is not None + and types.numpy_type_to_builtin_type(default_value.dtype) != self.dtype + ): + msg = "TensorType {} default_value dtype {} != " + "TensorType.dtype {}" + raise ValueError(msg.format(name, default_value.dtype, self.dtype.__type_info__())) else: - self.dtype = numpy_type_to_builtin_type(default_value.dtype) + self.dtype = types.numpy_type_to_builtin_type(default_value.dtype) self.default_value = default_value @@ -272,6 +270,58 @@ def __str__(self): self.shape, self.dtype) +class StateType(InputType): + SUPPORTED_WRAPPER_TYPE = ( + TensorType, + ) + + def __init__( + self, + wrapped_type: type, + name: Optional[str] = None, + ): + """ + Specify a model state as a wrapper of a ``TensorType``. + For example, you can use the following code to create a + state type input that wraps a fp16 tensor with shape ``(2, 3)``:: + + ct.StateType( + wrapped_type=ct.TensorType( + shape=(2, 3), + dtype=np.float16 + ), + name="state", + ) + + Parameters + ---------- + wrapped_type: coremltools.converters.mil.input_types.InputType + - The type wrapped in the state. + - Can be ``TensorType``. + Note that the ``name`` and ``default_value`` of the wrapped ``TensorType`` must not be provided. + + name: str + The name of the state. + It must match the key of ``named_buffers()`` in the source TorchScript model. + """ + if not isinstance(wrapped_type, StateType.SUPPORTED_WRAPPER_TYPE): + raise ValueError( + f"StateType only supports {StateType.SUPPORTED_WRAPPER_TYPE}. Got {type(wrapped_type)}." + ) + # name and default_value cannot be set + if wrapped_type.name is not None: + raise ValueError("name cannot be set in the state wrapped_type.") + if wrapped_type.default_value is not None: + raise ValueError("default_value cannot be set in the state wrapped_type.") + + super(StateType, self).__init__(name, wrapped_type.shape, wrapped_type.dtype) + self.wrapped_type = wrapped_type + + def __repr__(self): + return self.__str__() + + def __str__(self): + return f"StateType[{self.wrapped_type}]" class RangeDim: def __init__( @@ -448,12 +498,7 @@ def __init__(self, shapes, default=None): .. sourcecode:: python sample_shape = ct.EnumeratedShapes( - shapes=[ - (2, 4, 64, 64), - (2, 4, 48, 48), - (2, 4, 32, 32) - ], - default=(2, 4, 64, 64) + shapes=[(2, 4, 64, 64), (2, 4, 48, 48), (2, 4, 32, 32)], default=(2, 4, 64, 64) ) my_core_ml_model = ct.convert( diff --git a/coremltools/converters/mil/mil/block.py b/coremltools/converters/mil/mil/block.py index 9c5e88209..54b03e2f0 100644 --- a/coremltools/converters/mil/mil/block.py +++ b/coremltools/converters/mil/mil/block.py @@ -721,6 +721,8 @@ def _copy_scope_info(src: Var, dst: Var) -> None: res = ["__COREML__::TORCHSCRIPT_PLACEHOLDER"] elif val == ScopeSource.TORCHSCRIPT_MODULE_NAME: res = [f"__COREML__::TORCHSCRIPT_PLACEHOLDER_{src.name}"] + elif val == ScopeSource.EXIR_STACK_TRACE: + res = [None] elif val == ScopeSource.EXIR_DEBUG_HANDLE: res = [None] else: diff --git a/coremltools/converters/mil/mil/builder.py b/coremltools/converters/mil/mil/builder.py index 95c74890c..9d55079c8 100644 --- a/coremltools/converters/mil/mil/builder.py +++ b/coremltools/converters/mil/mil/builder.py @@ -16,7 +16,7 @@ from .block import Function, curr_block from .input_type import InternalInputType, ListOrTensorInputType, TensorInputType, TupleInputType -from .program import Placeholder +from .program import Placeholder, StateTensorPlaceholder from .scope import ( SCOPE_STACK, VALID_OPS_TO_COPY_SCOPE_INFO, @@ -217,6 +217,14 @@ def placeholder( def TensorSpec(shape, dtype=None): return Placeholder(shape, dtype) + @staticmethod + def StateTensorSpec(shape, dtype=None): + return StateTensorPlaceholder(shape, dtype) + + @staticmethod + def state_tensor_placeholder(shape, dtype=None): + return StateTensorPlaceholder(shape, dtype) + @staticmethod def _create_function( main_block: Callable, @@ -338,7 +346,7 @@ def scope( Examples -------- - Here is an example of creating a scope for torchscript module heirarchy with type and name information. + The following is an example of creating a scope for torchscript module heirarchy with type and name information. .. sourcecode:: python @@ -351,11 +359,11 @@ def prog(x): return mb.add(x=x, y=4.3, name="add_1") - In the above example, the "add_1" op will have two scope attributes, for torchscipt module type and name: + In the previous example, the "add_1" op will have two scope attributes, for torchscipt module type and name: * TORCHSCRIPT_MODULE_TYPE: ["Module1"] * TORCHSCRIPT_MODULE_NAME: ["module_1"] - Here is an example of creating nested scopes: + The following is an example of creating nested scopes: .. sourcecode:: python @@ -371,7 +379,7 @@ def prog(x): ): return mb.add(x=x, y=3.2, name="add_2") - In the above example, the "add_1" op would have a scope attribute: + In the previous example, the "add_1" op would have a scope attribute: * TORCHSCRIPT_MODULE_TYPE: ["Module1"] while the "add_2" op would have scope attributes: diff --git a/coremltools/converters/mil/mil/input_type.py b/coremltools/converters/mil/mil/input_type.py index 51d5e9351..05a3cd1f7 100644 --- a/coremltools/converters/mil/mil/input_type.py +++ b/coremltools/converters/mil/mil/input_type.py @@ -23,7 +23,7 @@ types.int16, types.int32, types.int64, -] +] + list(types._SUB_BYTE_TYPES) SUPPORT_COMPLEX_TYPES = [ types.complex64, @@ -35,6 +35,7 @@ + SUPPORT_INT_TYPES + SUPPORT_COMPLEX_TYPES + [types.bool, types.str] + + list(types._SUB_BYTE_TYPES) ) @@ -217,7 +218,7 @@ def __str__(self): @property def type_str(self): """Descriptive string describing expected mil types""" - return self.__str__(self) + return self.__str__() class TensorInputType(_InputType): @@ -355,6 +356,13 @@ class InternalInputType(_InputType): def _is_compatible(self, v): return True # skip type check by default for InternalInputType. +class StateInputType(_InputType): + """ + StateInputType allows inputs of type types.state + """ + + def _is_compatible(self, v): + return types.is_state(v.sym_type) class PyFunctionInputType(InternalInputType): """ diff --git a/coremltools/converters/mil/mil/operation.py b/coremltools/converters/mil/mil/operation.py index 8f3536f06..b2cc8d859 100644 --- a/coremltools/converters/mil/mil/operation.py +++ b/coremltools/converters/mil/mil/operation.py @@ -8,7 +8,6 @@ import numpy as np from coremltools.converters.mil.mil import types -from coremltools.converters.mil.mil.types import is_compatible_type from coremltools.converters.mil.mil.types.symbolic import any_symbolic, is_symbolic from . import SPACES @@ -326,7 +325,7 @@ def type_value_inference(self, overwrite_output=False): # Check type inference if overwrite_output: out_var._sym_type = sym_type - elif not is_compatible_type(sym_type, out_var.sym_type): + elif not types.is_compatible_type(sym_type, out_var.sym_type): msg = "Output Var {} in op {} type changes with new input Vars" raise ValueError(msg.format(out_var.name, self.name)) @@ -494,7 +493,10 @@ def _validate_and_set_inputs(self, input_kvs, no_check_var_types=False): def check_and_detach(v_new, v_old, op, no_check_var_types): # Check new var's sym_type is compatible with the # existing's sym_type. - if not is_compatible_type(v_new.sym_type, v_old.sym_type) and not no_check_var_types: + if ( + not types.is_compatible_type(v_new.sym_type, v_old.sym_type) + and not no_check_var_types + ): raise ValueError( f"New var type `{v_new.sym_type}` not a " f"subtype of existing var type `{v_old.sym_type}`." diff --git a/coremltools/converters/mil/mil/ops/defs/__init__.py b/coremltools/converters/mil/mil/ops/defs/__init__.py index edde1bdfe..d0b67a5bf 100644 --- a/coremltools/converters/mil/mil/ops/defs/__init__.py +++ b/coremltools/converters/mil/mil/ops/defs/__init__.py @@ -3,4 +3,4 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -from . import complex_dialect_ops, iOS15, iOS16, iOS17 +from . import complex_dialect_ops, coreml_dialect, iOS15, iOS16, iOS17, iOS18 diff --git a/coremltools/converters/mil/mil/ops/defs/_utils.py b/coremltools/converters/mil/mil/ops/defs/_utils.py index 3f084dc84..953228c02 100644 --- a/coremltools/converters/mil/mil/ops/defs/_utils.py +++ b/coremltools/converters/mil/mil/ops/defs/_utils.py @@ -665,90 +665,3 @@ def solve_slice_by_index_shape(x_shape, begin, end, stride, begin_mask, end_mask ret_shape.append(max(0, num)) return ret_shape - - -def pack_elements_into_bits(elements: np.ndarray, nbits: int) -> np.ndarray: - """ - Pack elements into nbits representation, by starting with the least significant bit (LSB) and - moving upward to the most significant bit (MSB). - - Returns packed elements as np.uint8. - """ - if not np.issubdtype(elements.dtype, np.integer): - raise ValueError(f"Only support packing integers elements, but got {elements.dtype}") - - # Adjust allowed value range based on if the input is signed or unsigned. - if np.issubdtype(elements.dtype, np.signedinteger): - max_val = 2 ** (nbits - 1) - 1 - min_val = -max_val - 1 - else: - max_val = 2**nbits - 1 - min_val = 0 - if np.max(elements) > max_val: - raise ValueError( - f"To pack elements into {nbits}-bit, the max value is {max_val}, but got {np.max(elements)}" - ) - if np.min(elements) < min_val: - raise ValueError( - f"To pack elements into {nbits}-bit, the min value is {min_val}, but got {np.min(elements)}" - ) - - # As np.unpackbits only supports uint8, convert to uint8 first. - # Notice that it will not lose information, because the bits are unchanged when converting int8 - # to uint8. For example, the signed int -6 has bit representation '11111010', and when we unpackbits - # we get [0, 1, 0, 1, 1, 1, 1, 1], where only first 4 elements are needed for 4-bit representation. - elements = elements.astype(np.uint8) - bitarray = np.unpackbits(elements.reshape(-1, 1), bitorder="little", axis=-1)[:, :nbits] - return np.packbits(bitarray.flatten(), bitorder="little") - - -def restore_elements_from_packed_bits( - packed_values: np.ndarray, nbits: int, element_num: int, are_packed_values_signed: bool = False -) -> np.ndarray: - """ - Restore elements from packed bits. Requires values that are packed by starting with the - least significant bit (LSB) and moving upward to the most significant bit (MSB), which is the - method used in `pack_elements_into_bits`. - - are_packed_values_signed: Indicates if the packed_values were packed from signed integers. If - True, the n-bit number unpacked from packed_values will be interpreted as signed integers, - and the returned ndarray will have dtype np.int8. Otherwise, np.uint8 will be used. - """ - if len(packed_values.shape) != 1: - raise NotImplementedError( - f"Only support 1-rank packed_values. But got {len(packed_values.shape)}" - ) - - if packed_values.dtype == np.int8: - # As np.unpackbits only supports uint8, need to convert first. - packed_values = packed_values.astype(np.uint8) - elif packed_values.dtype != np.uint8: - raise NotImplementedError( - f"Only support int8 or uint8 packed_values, but got {packed_values.dtype}" - ) - - bitarray = np.unpackbits(packed_values, bitorder="little") - pad_required = bitarray.size % nbits != 0 - if pad_required: - bitarray = np.concatenate([bitarray, np.zeros(nbits - bitarray.size % nbits)]).astype( - bitarray.dtype - ) - if bitarray.size % nbits != 0: - raise ValueError( - f"The length of bitarray ({bitarray.size}) should be divisible by " - f"nbits ({nbits})." - ) - bitarray = bitarray.reshape(-1, nbits)[:element_num, :] - # The np.packbits doesn't work well for signed int if we feed `bitarray` to it directly. - # For example, the original signed int is -6, which is packed as 1010 for 4-bit representation, - # and here `bitarray` is [[0, 1, 0, 1]], where the value will be interpreted as 10 (b'1010') - # by np.packbits. - # To make np.packbits work correctly, we need to repeat the sign bit. For example, 1010 will - # become 11111010, where np.packbits can correctly handle and after converting to int8 it's -6. - if are_packed_values_signed: - # Repeat the sign bit to make uint8 to int8 works. - bitarray = np.repeat(bitarray, [1] * (nbits - 1) + [8 - nbits + 1], axis=1) - restored_elements = np.packbits(bitarray, bitorder="little", axis=-1).reshape(-1) - if are_packed_values_signed: - restored_elements = restored_elements.astype(np.int8) - return restored_elements diff --git a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py index e19bf1757..96225b301 100644 --- a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py @@ -33,12 +33,8 @@ def fft_fft(context, nodes): import numpy as np -from coremltools.converters.mil.mil import Operation, types -from coremltools.converters.mil.mil.input_type import ( - DefaultInputs, - InputSpec, - TensorInputType, -) +from coremltools.converters.mil.mil import operation, types +from coremltools.converters.mil.mil.input_type import DefaultInputs, InputSpec, TensorInputType from coremltools.converters.mil.mil.ops.registry import SSAOpRegistry from coremltools.converters.mil.mil.types.symbolic import any_symbolic, is_symbolic from coremltools.converters.mil.mil.types.type_mapping import ( @@ -143,7 +139,7 @@ def fft_canonicalize_shapes_dims( @register_op(namespace="complex") -class complex(Operation): +class complex(operation.Operation): """ Dialect op for constructing a complex data from real and imaginary data. """ @@ -170,7 +166,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_real(Operation): +class complex_real(operation.Operation): """Dialect op for extracting real part of complex data.""" input_spec = InputSpec( @@ -188,7 +184,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_imag(Operation): +class complex_imag(operation.Operation): """Dialect op for extracting imaginary part of complex data.""" input_spec = InputSpec( @@ -206,7 +202,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_fft(Operation): +class complex_fft(operation.Operation): """ Dialect op for 1-D FFT. As PyTorch's FFT API has a much more fine-grained control than TensorFlow's, the parameters of this dialect op mainly follows `torch.fft.fft`. @@ -280,7 +276,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_fftn(Operation): +class complex_fftn(operation.Operation): """ Dialect op for N-D FFT. As PyTorch's FFT API has a much more fine-grained control than TensorFlow's, the parameters of this dialect op mainly follows `torch.fft.fftn`. @@ -358,7 +354,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_rfft(Operation): +class complex_rfft(operation.Operation): """ Dialect op for 1-D RFFT. It's similar to 1-D FFT, but the input is real number. The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``, so the output contains only the @@ -416,7 +412,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_rfftn(Operation): +class complex_rfftn(operation.Operation): """ Dialect op for N-D RFFT (rfftn). The FFT of a real signal is Hermitian-symmetric, X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n]) so the full ``complex_fftn`` output contains @@ -472,7 +468,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_ifft(Operation): +class complex_ifft(operation.Operation): """ Dialect op for IFFT. Computes the one dimensional inverse discrete Fourier transform of input. @@ -533,7 +529,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_ifftn(Operation): +class complex_ifftn(operation.Operation): """ Dialect op for N-D IFFT (ifftn). @@ -595,7 +591,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_irfft(Operation): +class complex_irfft(operation.Operation): """ Dialect op for IRFFT. Computes the inverse of RFFT. The input is interpreted as a one-sided Hermitian signal in the Fourier domain, as produced by rfft(). By the Hermitian property, the @@ -648,7 +644,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_irfftn(Operation): +class complex_irfftn(operation.Operation): """ Dialect op for N-D IRFFT (irfftn). @@ -703,7 +699,7 @@ def type_inference(self): @register_op(namespace="complex") -class complex_shape(Operation): +class complex_shape(operation.Operation): """ Returns a 1-dimensional tensor with the shape of the input complex tensor. @@ -729,7 +725,7 @@ class complex_shape(Operation): "T": (types.complex64,), } - # If type_inference or value_inference is invoked when the graph is being constructed, + # If type_inference or value_inference is invoked when the graph is being constructed, # x.real and x.imag may not be set since the complex lowering pass hasn't yet been invoked. # self.x should already have the shape set, so use that instead. @@ -748,14 +744,14 @@ def value_inference(self): return np.array(self.x.shape).astype(np.int32) @register_op(namespace="complex") -class complex_abs(Operation): +class complex_abs(operation.Operation): """ Returns the absolute value of a complex tensor. Parameters ---------- x: tensor<[*d], T> (Required) - + Returns ------- tensor<[*d], fp32> @@ -778,7 +774,7 @@ def type_inference(self): return types.tensor(infer_fp_dtype_from_complex(self.x.dtype), self.x.shape) @register_op(namespace="complex") -class complex_stft(Operation): +class complex_stft(operation.Operation): """ Dialect op for 1-D STFT. @@ -838,7 +834,7 @@ def default_inputs(self): def type_inference(self): output_type = (types.complex64) - + # STFT shape is [B x N x T], where N is the number of frequency bins # and T is the number of windows # B is 1 for a time series or 2 for a batch of time series @@ -858,6 +854,5 @@ def type_inference(self): # add back rank if needed if self.input.rank == 2: output_shape = [self.input.shape[0]] + output_shape - - return types.tensor(output_type, tuple(output_shape)) + return types.tensor(output_type, tuple(output_shape)) diff --git a/coremltools/converters/mil/mil/ops/defs/coreml_dialect/__init__.py b/coremltools/converters/mil/mil/ops/defs/coreml_dialect/__init__.py new file mode 100644 index 000000000..47a437bce --- /dev/null +++ b/coremltools/converters/mil/mil/ops/defs/coreml_dialect/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from .ops import coreml_update_state diff --git a/coremltools/converters/mil/mil/ops/defs/coreml_dialect/ops.py b/coremltools/converters/mil/mil/ops/defs/coreml_dialect/ops.py new file mode 100644 index 000000000..873625046 --- /dev/null +++ b/coremltools/converters/mil/mil/ops/defs/coreml_dialect/ops.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.input_type import InputSpec, StateInputType, TensorInputType +from coremltools.converters.mil.mil.operation import Operation +from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op + + +@register_op(namespace="coreml") +class coreml_update_state(Operation): + """ + Copy the content of a variable into a state and return the copy of the variable. + The type of the variable must match the type that is wrapped inside the state. + This is a coreml dialect op to simplify the program. When + loading into MIL, the following transformation is done: + + .. code-block:: + + %x = coreml_update_state(state=%state, value=%value) + + --> + + write_state(state=%state, value=%value) + %x = read_state(input=%state) + + Parameters + ---------- + state: state (Required) + value: ST (Required) + + Returns + ------- + ST + + Attributes + ---------- + ST: tensor + """ + + input_spec = InputSpec( + state=StateInputType(), + value=TensorInputType(type_domain="T"), + ) + + type_domains = { + "T": ( + types.fp16, + types.fp32, + types.int8, + types.int16, + types.int32, + types.uint8, + types.uint16, + types.bool, + ), + } + + def type_inference(self): + state_wrapped_type = self.state._sym_type.wrapped_type() + if not state_wrapped_type == self.value.sym_type: + raise ValueError( + f"State wrapped type {state_wrapped_type.__type_info__()} not matched with the value's sym_type {self.value.sym_type.__type_info__()}." + ) + return self.value.sym_type diff --git a/coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py b/coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py index 40e5d3126..b220ac00b 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS15/control_flow.py @@ -163,6 +163,10 @@ class Const(Operation): val=InternalInputType(const=True), ) + def __init__(self, **kwargs): + super(Const, self).__init__(**kwargs) + self._weight_id = None + def type_inference(self): builtin_type, _ = self._get_type_val(self.val.val) return builtin_type @@ -217,6 +221,24 @@ def _get_type_val(self, value): _, builtin_type = numpy_val_to_builtin_val(value) return builtin_type, value + @property + def weight_id(self) -> int: + """ + Weight id for the const. It is used for weight sharing across multiple functions. + Constants sharing the same weight_id will use the same blob file value when + lowering to milproto. + """ + return self._weight_id + + @weight_id.setter + def weight_id(self, val: int) -> None: + """ + Set weight id for the const. + """ + assert isinstance(val, int), f"weight_id must be type of int. Got {type(val)}." + assert self._weight_id is None, f"cannot set {self.name} weight_id twice." + self._weight_id = val + @register_op class const(Const): diff --git a/coremltools/converters/mil/mil/ops/defs/iOS15/recurrent.py b/coremltools/converters/mil/mil/ops/defs/iOS15/recurrent.py index b3b5d25a6..c9f194641 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS15/recurrent.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS15/recurrent.py @@ -51,17 +51,14 @@ class gru(Operation): * ``weigh_ih = [W_{ir} | W_{io} | W_{iz}]`` where ``[a|b]`` denotes column concatenation and ``[a, b]`` denotes row concatenation. ``W_{ir}``, ``W_{io}``, and ``W_{iz}`` have shape ``(H, I)``. - * This is used when direction="forward" or "reverse". weight_hh: const<3*H, H, T> (Required) - Weight matrix * ``weight_hh = [W_{hr} | W_{ho} | W_{hz}]``: ``W_{hr}``, ``W_{ho}``, and ``W_{hz}`` have shape ``(H, H)``. - * This is used when direction="forward" or "reverse". bias: const<3*H, T> (Optional) [Default all 0s] * ``bias[0]`` are input-hidden and hidden-hidden bias. * ``3*H`` are biases for ``[b_{ir} | b_{io} | b_{hz}]``. - * This is used when direction="forward" or "reverse". direction: const (Optional) [Default=forward] * Either ``forward`` or ``reverse``. diff --git a/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_operation.py b/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_operation.py index a32eeefb3..5aa4402bc 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_operation.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS15/tensor_operation.py @@ -7,25 +7,15 @@ import numpy as np -from coremltools.converters.mil.mil import ( - get_new_symbol, - get_new_variadic_symbol, - types, -) +from coremltools.converters.mil.mil import get_new_symbol, get_new_variadic_symbol, types from coremltools.converters.mil.mil.input_type import ( DefaultInputs, InputSpec, - ListOrTensorInputType, + InternalInputType, TensorInputType, TupleInputType, ) -from coremltools.converters.mil.mil.operation import ( - NONE, - SYMBOL, - VALUE, - Operation, - precondition, -) +from coremltools.converters.mil.mil.operation import NONE, SYMBOL, VALUE, Operation, precondition from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op from coremltools.converters.mil.mil.ops.defs._utils import MAX_SIZE_CONSTANT_FOLDING from coremltools.converters.mil.mil.types.symbolic import ( @@ -692,8 +682,13 @@ def type_inference(self): if rep <= 0: raise ValueError("All entries of reps parameter must be greater than 0") - if is_symbolic(rep) or is_symbolic(x_shape[i]): + if is_symbolic(rep): out_shape.append(get_new_symbol()) + elif is_symbolic(x_shape[i]): + if rep == 1: + out_shape.append(x_shape[i]) + else: + out_shape.append(get_new_symbol()) else: out_shape.append(rep * x_shape[i]) @@ -1339,9 +1334,7 @@ class identity(Operation): T: fp16, fp32, i32, bool """ - input_spec = InputSpec( - x=ListOrTensorInputType() - ) + input_spec = InputSpec(x=InternalInputType()) def type_inference(self): return self.x.sym_type diff --git a/coremltools/converters/mil/mil/ops/defs/iOS16/constexpr_ops.py b/coremltools/converters/mil/mil/ops/defs/iOS16/constexpr_ops.py index 925ec149d..1c8ab1774 100644 --- a/coremltools/converters/mil/mil/ops/defs/iOS16/constexpr_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/iOS16/constexpr_ops.py @@ -9,7 +9,6 @@ from coremltools.converters.mil.mil.input_type import InputSpec, TensorInputType from coremltools.converters.mil.mil.operation import Operation from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op -from coremltools.converters.mil.mil.ops.defs._utils import restore_elements_from_packed_bits from coremltools.converters.mil.mil.ops.defs.iOS16 import _IOS16_TARGET @@ -289,8 +288,11 @@ def materialized_val_inference(self): @staticmethod def decompress(lut, indices, shape): + # Import here to avoid circular import. + from coremltools.optimize.coreml import _utils as optimize_utils + nbits = np.log2(lut.size).astype(np.int32) - indices = restore_elements_from_packed_bits(indices, nbits, np.prod(shape)) + indices = optimize_utils.restore_elements_from_packed_bits(indices, nbits, np.prod(shape)) flatten_val = lut[indices] return flatten_val.reshape(shape) diff --git a/coremltools/converters/mil/mil/ops/defs/iOS18/__init__.py b/coremltools/converters/mil/mil/ops/defs/iOS18/__init__.py new file mode 100644 index 000000000..25d994ebb --- /dev/null +++ b/coremltools/converters/mil/mil/ops/defs/iOS18/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target +from coremltools.converters.mil.mil.ops.registry import SSAOpRegistry + +# Ensure op registrations recognize the new opset. +_IOS18_TARGET = target.iOS18 + +from .compression import ( + constexpr_blockwise_shift_scale, + constexpr_lut_to_dense, + constexpr_lut_to_sparse, + constexpr_sparse_blockwise_shift_scale, + constexpr_sparse_to_dense, +) +from .recurrent import gru +from .states import read_state +from .tensor_transformation import slice_update +from .transformers import scaled_dot_product_attention diff --git a/coremltools/converters/mil/mil/ops/defs/iOS18/compression.py b/coremltools/converters/mil/mil/ops/defs/iOS18/compression.py new file mode 100644 index 000000000..663fce2d4 --- /dev/null +++ b/coremltools/converters/mil/mil/ops/defs/iOS18/compression.py @@ -0,0 +1,791 @@ +# Copyright (c) 2023, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from typing import List, Optional + +import numpy as np + +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.input_type import InputSpec, TensorInputType +from coremltools.converters.mil.mil.operation import Operation +from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op +from coremltools.converters.mil.mil.ops.defs.iOS16.constexpr_ops import ( + constexpr_cast as _constexpr_cast_iOS16, +) +from coremltools.converters.mil.mil.ops.defs.iOS18 import _IOS18_TARGET +from coremltools.converters.mil.mil.var import Var + + +@register_op(opset_version=_IOS18_TARGET) +class constexpr_blockwise_shift_scale(Operation): + """ + A compile-time operation that returns a constant output value upon dequantizing its constant inputs. + + It's similar to iOS 16 :py:class:`~.iOS16.constexpr_ops.constexpr_affine_dequantize`, but supports + block-wise quantization for int4 and int8. + + Although all parameters of this op are constants, this op is not constant-folded to a single + const op at the time of model serialization. The unquantized output will be decompressed later, + based on the implementation detail (either at model load time or runtime). + + Generic expression: output = scale * (data - offset) + + Algorithm: + Assuming Rank 3 scenario: + output_data[i, j, k] = scale[i0, j0, k0] * (data[i, j, k] - offset[i0, j0, k0]) + where + i0 = floor(i/block_size[0]), + j0 = floor(j/block_size[1]), + k0 = floor(k/block_size[2]) + The block size is implied by block_size[m] = data.shape[m] / scale.shape[m] + + Constraints: + - All tensors: scale, data, offset and output have same rank. + - Inputs: scale and offset (if provided) have same shape. + - Output shape is same as the shape of input argument: `data`. + - Number of scales along each dimension should be a factor of corresponding dimension size of + `data`. That is, block_size[i] should be an integer where block_size[i] = data.shape[i] / scale.shape[i] + + Parameters + ---------- + data: const tensor (Required) + + scale: const tensor (Required) + + offset: const tensor (Optional) + * If provided, must have the same shape as the ``scale``. + * If dtype is not fp16 or fp32, it must be the same as SrcT. + + Returns + ------- + const tensor + + Attributes + ---------- + SrcT: int4, uint4, int8, uint8, fp16, fp32 + DstT: fp16, fp32 + OffsetT: int4, uint4, int8, uint8, fp16, fp32 + """ + + input_spec = InputSpec( + data=TensorInputType(const=True, type_domain="SrcT"), + scale=TensorInputType(const=True, type_domain="DstT"), + offset=TensorInputType(const=True, optional=True, type_domain="OffsetT"), + ) + + type_domains = { + "SrcT": (types.int4, types.uint4, types.int8, types.uint8, types.fp16, types.fp32), + "DstT": (types.fp16, types.fp32), + "OffsetT": (types.int4, types.uint4, types.int8, types.uint8, types.fp16, types.fp32), + } + + @staticmethod + def _validate_shift_scale_inputs( + data_shape: List[int], data_dtype: types, scale: Var, offset: Var + ): + data_rank = len(data_shape) + if data_rank != scale.rank: + raise ValueError( + f"Parameter 'data' and 'scale' need to have the same rank, but got {data_rank} vs {scale.rank}." + ) + if data_rank < 1: + raise ValueError("Parameter 'data' needs to have at least rank 1, but got scalar.") + for rank_idx in range(data_rank): + data_dim = data_shape[rank_idx] + scale_dim = scale.shape[rank_idx] + if data_dim % scale_dim != 0: + raise ValueError( + f"Number of scales along each dimension should be a factor of " + f"corresponding dimension size of 'data'. However, at dim " + f"{rank_idx}, the 'data' has {data_dim} while 'scale' has {scale_dim}." + ) + + if offset is not None: + if offset.shape != scale.shape: + raise ValueError( + "Invalid parameter 'offset'; the shape of 'offset' should match the shape of " + f"'scale', but got ({offset.shape}) vs ({scale.shape})." + ) + if not types.is_float(offset.dtype) and offset.dtype != data_dtype: + raise ValueError( + "Invalid parameter 'offset'; the dtype of 'offset' should match the dtype of " + f"'data', but got ({types.builtin_to_string(offset.dtype)}) vs " + f"({types.builtin_to_string(data_dtype)})." + ) + + def _validate_inputs(self): + self._validate_shift_scale_inputs(self.data.shape, self.data.dtype, self.scale, self.offset) + + def type_inference(self): + self._validate_inputs() + return types.tensor(self.scale.dtype, self.data.shape) + + def materialized_val_inference(self): + return self.decompress( + self.data.val, + self.scale.val, + None if self.offset is None else self.offset.val, + ) + + @staticmethod + def decompress( + data: np.ndarray, + scale: np.ndarray, + offset: Optional[np.ndarray], + ): + # Adjust dtype to avoid overflow in the quantized dtype. + data = data.astype(scale.dtype) + + # Interleaved repeat scale and offset to make it match the shape of data. + block_sizes = [ + data_shape // scale_shape for (data_shape, scale_shape) in zip(data.shape, scale.shape) + ] + for axis, block_size in enumerate(block_sizes): + if block_size > 1: + scale = np.repeat(scale, block_size, axis) + if offset is not None: + offset = np.repeat(offset, block_size, axis) + + if offset is not None: + data = data - offset + data = scale * data + + return data + + +@register_op(opset_version=_IOS18_TARGET) +class constexpr_lut_to_dense(Operation): + """ + A compile-time operation that returns a constant output value upon dequantizing its constant inputs. + + This operator is used to store constant weights in lookup tables format (aka palettized weights). + It's similar to iOS 16 :py:class:`~.iOS16.constexpr_ops.constexpr_lut_to_dense`, but supports + block-wise / vector palettization. + + LUT's rank is K + 2, where K is the rank of indices. + Each dimension of LUT's first K dimensions should be divisible by each corresponding dimension + of the decompressed tensor. + e.g., when indices_shape = [2, 3, 4], lut_shape[:3] = [1, 1, 2], it means that there are two + lookup tables over the last axis. And each of them have their own LUT values. + See Case 1 below for details. + + VECTOR_SIZE is added to support vector palettization. + - When VECTOR_SIZE is 1, it is scalar palettization. + - When VECTOR_SIZE is larger than 1, it retrieves a vector instead of a single value from the + lookup table, and fill the result continuously. + The vector_axis is used to define which axis the vectored elements in the lookup table be filled + across the output tensor. vector_axis is only optional if VECTOR_SIZE is 1. + As a result: + output_shape[i] = indices_shape[i] , i != vector_axis + output_shape[i] = indices_shape[i] * VECTOR_SIZE, i == vector_axis + See Case 2 below for details. + + Examples: + + Case 1: per-group scalar palettization: + e.g.: + - indices = tensor>([2, 3, 3, 0, 1, 0, 3, 0, 2, 1, 0, 3]) + - lut = tensor([1.0, 5.0, 9.0, 13.0, 2.0, 10.0, 18.0, 26.0]) + + It is effectively a 2-group 2-bit scalar palettization. + The output shape would be [6, 2], which is the same as the indices shape. + The output tensor values are: + [[lut0[2]->9.0, lut0[3]->13.0], + [lut0[3]->13.0, lut0[0]->1.0], + [lut0[1]->5.0, lut0[0]->1.0], + [lut1[3]->26.0, lut1[0]->2.0], + [lut1[2]->18.0, lut1[1]->10.0], + [lut1[0]->2.0, lut1[3]->26.0]] + where lut0 is the first lookup table (lut[0, :, :, :]) and lut1 is the second lookup table. + + Case 2: per-tensor vector palettization: + e.g.: + - indices = tensor>. + The indices values are: + [ + [ + [0, 0], + [1, 0] + ], + [ + [1, 1], + [0, 0] + ] + ] + - lut = tensor([a0, a1, a2, + b0, b1, b2]) + which means the two centroids are [a1, a2, a3] and [b1, b2, b3]. + + Case 2.1: vector_axis = 1 + It is effectively a 1-bit vector palettization. + The output shape would be [2, 2*3, 2], where each index in the indices would be effectively replaced with + the 3 elements in the vector over the 1st dimension to construct the output tensor. + The output values are: + [ + [ + [a0, a0], + [a1, a1], + [a2, a2], + [b0, a0], + [b1, a1], + [b2, a2], + ], + [ + [b0, b0], + [b1, b1], + [b2, b2], + [a0, a0], + [a1, a1], + [a2, a2], + ] + ] + + Case 2.2: vector_axis = 2 + The output shape would be [2, 2, 2*3], where each index in the indices would be effectively replaced with + the 3 elements in the vector over the last dimension to construct the output tensor. + The output values are: + [ + [ + [a0, a1, a2, a0, a1, a2], + [b0, b1, b2, a0, a1, a2], + ], + [ + [b0, b1, b2, b0, b1, b2], + [a0, a1, a2, a0, a1, a2], + ] + ] + + Parameters + ---------- + indices: const tensor (Required) + + lut: const tensor (Required) + * NUM_PALETTES needs to be 2^nbits where nbits is indicated by IndicesT. + + vector_axis: const tensor (Optional) + * vector_axis can be optional if VECTOR_SIZE is 1. + + Returns + ------- + const tensor + * output_shape = indices_shape * [1..1, VECTOR_SIZE, 1..1] (all 1 but VECTOR_SIZE at vector_axis dimension). + + Attributes + ---------- + IndicesT: uint1, uint2, uint3, uint4, uint6, uint8 + T: uint8, int8, fp16, fp32 + """ + + input_spec = InputSpec( + indices=TensorInputType(const=True, type_domain="IndicesT"), + lut=TensorInputType(const=True, type_domain="T"), + vector_axis=TensorInputType(const=True, optional=True, type_domain=types.int32), + ) + + type_domains = { + "IndicesT": (types.uint1, types.uint2, types.uint3, types.uint4, types.uint6, types.uint8), + "T": (types.int8, types.uint8, types.fp16, types.fp32), + } + + @staticmethod + def _validate_lut_inputs( + indices_shape: List[int], indices_dtype: types, lut_shape: List[int], vector_axis: Var + ): + indices_rank = len(indices_shape) + lut_rank = len(lut_shape) + + if indices_rank < 1: + raise ValueError("Parameter 'indices' needs to have at least rank 1, but got scalar.") + + if lut_rank != indices_rank + 2: + raise ValueError( + f"Parameter 'lut' need to have 2 more dim than 'indices', but got " + f"{lut_rank}-rank 'lut' and {indices_rank}-rank 'indices'." + ) + + for rank_idx in range(indices_rank): + indices_dim = indices_shape[rank_idx] + lut_dim = lut_shape[rank_idx] + if indices_dim % lut_dim != 0: + raise ValueError( + f"Each dimension of 'indices' should be divisible by each corresponding " + f"dimension of the 'lut'. However, at dim {rank_idx}, the 'indices' has " + f"{indices_dim} while 'lut' has {lut_dim}." + ) + + nbits = indices_dtype.get_bitwidth() + if lut_shape[-2] != 2**nbits: + raise ValueError( + "Invalid parameter 'lut'; the second last dim should have size " + f"2^nbits, where nbits is {nbits}, but got {lut_shape[-2]}." + ) + + if vector_axis is not None: + if vector_axis.rank > 0: + raise ValueError( + "Invalid parameter 'vector_axis'; It should be a scalar, but got " "a tensor." + ) + if not -indices_rank <= vector_axis.val < indices_rank: + raise ValueError( + f"Invalid parameter 'vector_axis'; The valid range is between " + f"{-indices_rank} and {indices_rank}, but got {vector_axis.val}." + ) + else: + if lut_shape[-1] > 1: + raise ValueError( + "When lut's last dim (VECTOR_SIZE) > 1, the parameter " + "'vector_axis' need to be provided." + ) + + def _validate_inputs(self): + self._validate_lut_inputs( + self.indices.shape, self.indices.dtype, self.lut.shape, self.vector_axis + ) + + def type_inference(self): + self._validate_inputs() + output_shape = self.indices.shape + vector_size = self.lut.shape[-1] + if vector_size > 1: + output_shape = list(output_shape) + output_shape[self.vector_axis.val] *= vector_size + output_shape = tuple(output_shape) + return types.tensor(self.lut.dtype, output_shape) + + def materialized_val_inference(self): + return self.decompress( + self.indices.val, + self.lut.val, + None if self.vector_axis is None else self.vector_axis.val, + ) + + @staticmethod + def decompress( + indices: np.ndarray, + lut: np.ndarray, + vector_axis: Optional[np.generic], + ): + num_palettes = lut.shape[-2] + vector_size = lut.shape[-1] + original_lut_shape = lut.shape + block_size = [indices.shape[idx] // lut.shape[idx] for idx in range(len(indices.shape))] + + if vector_axis is not None and vector_axis < 0: + vector_axis += len(indices.shape) + + lut = lut.reshape(-1, num_palettes, vector_size) + decompressed_res = indices.astype(lut.dtype) + if vector_size > 1: + # Tile the vector_axis to make room for the vector retrieved from lut. + decompressed_res = np.repeat(decompressed_res, vector_size, axis=vector_axis) + else: + lut = np.squeeze(lut, axis=-1) + + # TODO (rdar://115061946): Vectorize the computation. + for table_idx in range(lut.shape[0]): + # Get the corresponding idx in indices for the current table. + # For example, if table coord is (1, 3), the corresponding indices should be + # [1*block_size[0] : 2*block_size[0], 3*block_size[1], 4*block_size[1]]. + original_table_coord = np.unravel_index(table_idx, original_lut_shape[:-2]) + slice_idxes = tuple( + slice(coord * block_size[idx], (coord + 1) * block_size[idx]) + for idx, coord in enumerate(original_table_coord) + ) + unquantized_values = lut[table_idx][indices[slice_idxes]] + if vector_size > 1: + if vector_axis is None: + raise ValueError("vector_axis must be provided for vector lut.") + # Merge the vector dim into the decompressed values (flatten the vector). + unquantized_values = np.swapaxes(unquantized_values, vector_axis, -2) + unquantized_values = unquantized_values.reshape( + unquantized_values.shape[:-2] + (-1,) + ) + unquantized_values = np.swapaxes(unquantized_values, vector_axis, -1) + # Resize the slice to make room for the merged vector dequantized values. + slice_idxes = list(slice_idxes) + resized_slice = slice( + slice_idxes[vector_axis].start * vector_size, + slice_idxes[vector_axis].stop * vector_size, + slice_idxes[vector_axis].step, + ) + slice_idxes[vector_axis] = resized_slice + decompressed_res[tuple(slice_idxes)] = unquantized_values + + return decompressed_res + + +@register_op(opset_version=_IOS18_TARGET) +class constexpr_sparse_to_dense(Operation): + """ + A compile-time operation that returns a constant output value upon de-sparsification of its constant inputs. + + The differences from iOS16 :py:class:`~.iOS16.constexpr_ops.constexpr_sparse_to_dense` are: + - In iOS16, the mask parameter is 'const tensor', which is a flat tensor with length + M, so it requires a parameter `shape` to determine the output shape. + In iOS18, we use uint1 (0 or 1) to represent bitmask, which packs the bitmask data and costs + the same memory as the uint8 mask in iOS16, but can explicitly tell the tensor shape. We use + uint1 instead of bool because bool in MIL uses uint8 as the storage dtype, which costs 8x + memory compared to uint1. + - Support more dtypes (int4 and uint4) for the input/output data. + + Parameters + ---------- + nonzero_data: const tensor (Required) + + mask: const tensor (Required) + + Returns + ------- + const tensor + + Attributes + ---------- + T: int4, uint4, int8, uint8, fp16, fp32 + """ + + input_spec = InputSpec( + nonzero_data=TensorInputType(const=True, type_domain="T"), + mask=TensorInputType(const=True, type_domain=types.uint1), + ) + + type_domains = {"T": (types.int4, types.uint4, types.int8, types.uint8, types.fp16, types.fp32)} + + @staticmethod + def decompress(nonzero_data: np.ndarray, mask: np.ndarray) -> np.ndarray: + decompressed_val = np.zeros_like(mask, dtype=nonzero_data.dtype) + decompressed_val[mask != 0] = nonzero_data + return decompressed_val + + @staticmethod + def _validate_sparse_inputs(nonzero_data: Var, mask: Var): + if nonzero_data.rank != 1: + raise ValueError( + f"Parameter nonzero_data needs to have rank 1, but got {nonzero_data.rank}" + ) + if mask.val is not None and np.count_nonzero(mask.val) != nonzero_data.shape[0]: + raise AssertionError( + "Number of 1s in mask not match number of elements in parameter nonzero_data" + ) + + def type_inference(self): + self._validate_sparse_inputs(self.nonzero_data, self.mask) + return types.tensor(self.nonzero_data.dtype, self.mask.shape) + + def materialized_val_inference(self): + nonzero_data = self.nonzero_data.val + mask = self.mask.val + if nonzero_data is None and self.nonzero_data.op.op_type.startswith("constexpr_"): + nonzero_data = self.nonzero_data.op.materialized_val_inference() + if isinstance(nonzero_data, tuple) and len(nonzero_data) > 0: + # For sparse constexpr ops they have two outputs, one for mask and one for val. + nonzero_data = nonzero_data[1] + if mask is None and self.mask.op.op_type.startswith("constexpr_"): + mask = self.mask.op.materialized_val_inference() + if isinstance(mask, tuple) and len(mask) > 0: + mask = mask[0] + return self.decompress(nonzero_data, mask) + + +@register_op(opset_version=_IOS18_TARGET) +class constexpr_lut_to_sparse(Operation): + """ + A compile-time operation that returns a constant output value upon de-palettizing its constant inputs. + + This op is a sparse-to-sparse op to support `constexpr_lut_to_dense` on sparse data, where the + de-palettization is only applied on the nonzero data. Usually it would be followed by a + `constexpr_sparse_to_dense` op to get the dense tensor. So, parameters of this op are similar to + `constexpr_sparse_to_dense` and `constexpr_lut_to_dense`. For detailed descriptions + about its parameters, please refer to iOS 18 :py:class:`~.iOS18.constexpr_ops.constexpr_sparse_to_dense` + and :py:class:`~.iOS18.constexpr_ops.constexpr_lut_to_dense`. + + This op has two outputs: + 1. the mask of the de-palettized nonzero_data. + 2. the de-palettized nonzero_data. + + Parameters + ---------- + indices_mask: const tensor (Required) + + indices_nonzero_data: const tensor (Required) + + lut: const tensor (Required) + * NUM_PALETTES needs to be 2^nbits where nbits is indicated by IndicesT. + + vector_axis: const tensor (Optional) + * vector_axis can be optional if VECTOR_SIZE is 1. + + Returns + ------- + const tensor + * the mask of the de-palettized nonzero_data. + For scalar palettization, it's the same as the input indices_mask. + For vector palettization, it's expanded of the indices_mask over axis=vector_axis. + const tensor + * the de-palettized nonzero_data. + For scalar palettization, VD=D (same size as indices_nonzero_data). + For vector palettization, VD=VECTOR_SIZE * D (each entry is expanded by a vector). + + Attributes + ---------- + IndicesT: uint1, uint2, uint3, uint4, uint6, uint8 + T: uint8, int8, fp16, fp32 + + Examples + ---------- + Assume we have the following inputs: + indices_mask = [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 1], + [0, 1, 1, 0, 1, 0], + [0, 0, 0, 1, 0, 0]] + indices_nonzero_data = [0, 1, 1, 0, 1, 1, 0, 0, 1] + + Notice that: + - The uint1 in `indices_mask` and `indices_nonzero_data` has different meanings. For + `indices_mask` the dtype is always uint1 to represent bit mask. For `indices_nonzero_data` + the uint1 means the LUT only has two entries, so only 1 bit is needed to represent indices. + - The 0 in `indices_mask` and `indices_nonzero_data` has different meanings. For + `indices_mask` the 0 means empty entry in sparse representation. For `indices_nonzero_data` + the 0 means index 0 in LUT. + + With the given indices_mask and indices_nonzero_data, an example for "Scalar Palettization": + lut = [2.0, 3.0] (indices-to-values mapping is {0: 2.0, 1: 3.0}) + + The sparse indices in the dense layout would look like: + 0 1 . . . . + 1 0 . . . 1 + . 1 0 . 0 . + . . . 1 . . + (here "." means spare elements in sparse representation) + + When we apply per-tensor de-palettization with this sparse indices, the `indices_nonzero_data` + is used to read the values from the LUT as in the dense layout. The output sparse tensor in + the dense layout would be: + 2.0 3.0 . . . . + 3.0 2.0 . . . 3.0 + . 3.0 2.0 . 2.0 . + . . . 3.0 . . + The first output would be the same as the indices_mask. + The second output would be [2.0, 3.0, 3.0, 2.0, 3.0, 3.0, 2.0, 2.0, 3.0] + + With the given indices_mask and indices_nonzero_data, an example for "Vector Palettization": + lut = [ + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 1], + [1, 1, 0, 0, 0, 1], + [0, 1, 1, 0, 1, 0], + [0, 1, 1, 0, 1, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0], + ] + The second output in the dense layout would be: + 2.0 3.0 . . . . + 2.0 3.0 . . . . + 3.0 2.0 . . . 3.0 + 3.0 2.0 . . . 3.0 + . 3.0 2.0 . 2.0 . + . 3.0 2.0 2.0 . + . . . 3.0 . . + . . . 3.0 . . + It is created by fetching the vector entry from the lut for every bit 1 in the data_mask, + and filling the vector over axis=0. + + Those two outputs of this op could be passed as inputs to a following `sparse_to_dense` op + in order to recover the dense weights. + """ + + input_spec = InputSpec( + indices_mask=TensorInputType(const=True, type_domain=types.uint1), + indices_nonzero_data=TensorInputType(const=True, type_domain="IndicesT"), + lut=TensorInputType(const=True, type_domain="T"), + vector_axis=TensorInputType(const=True, optional=True, type_domain=types.int32), + ) + + type_domains = { + "IndicesT": (types.uint1, types.uint2, types.uint3, types.uint4, types.uint6, types.uint8), + "T": (types.int8, types.uint8, types.fp16, types.fp32), + } + + def _validate_inputs(self): + constexpr_sparse_to_dense._validate_sparse_inputs( + self.indices_nonzero_data, self.indices_mask + ) + constexpr_lut_to_dense._validate_lut_inputs( + self.indices_mask.shape, + self.indices_nonzero_data.dtype, + self.lut.shape, + self.vector_axis, + ) + + def type_inference(self): + self._validate_inputs() + output_mask_shape = self.indices_mask.shape + output_nonzero_data_shape = self.indices_nonzero_data.shape + vector_size = self.lut.shape[-1] + if vector_size > 1: + output_mask_shape = list(output_mask_shape) + output_mask_shape[self.vector_axis.val] *= vector_size + output_mask_shape = tuple(output_mask_shape) + output_nonzero_data_shape = tuple( + [dim * vector_size for dim in output_nonzero_data_shape] + ) + + output_mask_type = types.tensor(self.indices_mask.dtype, output_mask_shape) + output_nonzero_data_type = types.tensor(self.lut.dtype, output_nonzero_data_shape) + return output_mask_type, output_nonzero_data_type + + @staticmethod + def decompress( + indices_mask: np.ndarray, + indices_nonzero_data: np.ndarray, + lut: np.ndarray, + vector_axis: Optional[np.generic], + ): + indices = constexpr_sparse_to_dense.decompress(indices_nonzero_data, indices_mask) + output_nonzero_data = constexpr_lut_to_dense.decompress(indices, lut, vector_axis) + output_mask = indices_mask + if vector_axis is not None: + vector_size = lut.shape[-1] + output_mask = np.repeat(output_mask, vector_size, axis=vector_axis) + output_nonzero_data = output_nonzero_data[output_mask != 0].flatten() + + return output_mask, output_nonzero_data + + def materialized_val_inference(self): + vector_axis = self.vector_axis.val if self.vector_axis is not None else None + return self.decompress( + self.indices_mask.val, self.indices_nonzero_data.val, self.lut.val, vector_axis + ) + + +@register_op(opset_version=_IOS18_TARGET) +class constexpr_sparse_blockwise_shift_scale(Operation): + """ + A compile-time operation that returns a constant output value upon de-quantize (shift-scale) its + constant inputs. + This op is a sparse-to-sparse op to support `constexpr_blockwise_shift_scale` on sparse data, + where the de-quantization is only applied on the nonzero data. Usually it would be followed by a + `constexpr_sparse_to_dense` op to get the dense tensor. So, parameters of this op are similar to + `constexpr_sparse_to_dense` and `constexpr_blockwise_shift_scale`. For detailed descriptions + about its parameters, please refer to iOS 18 :py:class:`~.iOS18.constexpr_ops.constexpr_sparse_to_dense` + and :py:class:`~.iOS18.constexpr_ops.constexpr_blockwise_shift_scale`. + + This op has two outputs: + 1. the mask of the de-quantized nonzero_data. + 2. the de-quantized nonzero_data. + + Parameters + ------- + data_mask: const tensor (Required) + + nonzero_data: const tensor (Required) + + scale: const tensor (Required) + + offset: const tensor (Optional) + * If provided, must have the same shape as the ``scale``. + + Returns + ------- + const tensor + * the mask of the shift-scaled nonzero_data. + const tensor + * the shift-scaled nonzero_data. + + Attributes + ------- + SrcT: int4, uint4, int8, uint8, fp16, fp32 + DstT: fp16, fp32 + OffsetT: int4, uint4, int8, uint8, fp16, fp32 + + Examples + ------- + For example: + data_mask = [[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 0]] + nonzero_data = [10, 11, 3, 4, 5, 6, 7, 8, 9] + The sparse tensor in the dense layout would look like: + 10 11 . . + 3 4 5 . + . . 6 7 + 8 9 . . + + When we apply per-channel de-quantization on this sparse tensor, where: + scale = [[0.1, 0.2, 0.3, 0.4]] + offset = [[1, 2, 3, 4]] + The input `nonzero_data` would be dequantized per-column as in the dense layout, and the + output sparse tensor in the dense layout would be: + (10-1)*0.1 (11-2)*0.2 . . + (10-1)*0.1 (11-2)*0.2 . . + (3-1)*0.1 (4-2)*0.2 (5-3)*0.3 . + . . (6-3)*0.3 (7-4)*0.4 + (8-1)*0.1 (9-2)*0.2 . . + + The first output would be the same as the `data_mask`, + The second output would be [0.9, 1.8, 0.2, 0.4, 0.6, 0.9, 1.2, 0.7, 1.4]. + The two outputs could be passed as inputs to the following `sparse_to_dense` op in order to + get the dense weights. + """ + + input_spec = InputSpec( + data_mask=TensorInputType(const=True, type_domain=types.uint1), + nonzero_data=TensorInputType(const=True, type_domain="SrcT"), + scale=TensorInputType(const=True, type_domain="DstT"), + offset=TensorInputType(const=True, optional=True, type_domain="OffsetT"), + ) + + type_domains = { + "SrcT": (types.int4, types.uint4, types.int8, types.uint8, types.fp16, types.fp32), + "DstT": (types.fp16, types.fp32), + "OffsetT": (types.int4, types.uint4, types.int8, types.uint8, types.fp16, types.fp32), + } + + def _validate_inputs(self): + constexpr_sparse_to_dense._validate_sparse_inputs(self.nonzero_data, self.data_mask) + constexpr_blockwise_shift_scale._validate_shift_scale_inputs( + self.data_mask.shape, self.nonzero_data.dtype, self.scale, self.offset + ) + + def type_inference(self): + self._validate_inputs() + output_mask_shape = self.data_mask.shape + output_nonzero_data_shape = self.nonzero_data.shape + output_mask_type = types.tensor(self.data_mask.dtype, output_mask_shape) + output_nonzero_data_type = types.tensor(self.scale.dtype, output_nonzero_data_shape) + return output_mask_type, output_nonzero_data_type + + @staticmethod + def decompress( + data_mask: np.ndarray, + nonzero_data: np.ndarray, + scale: np.ndarray, + offset: Optional[np.ndarray], + ): + data = constexpr_sparse_to_dense.decompress(nonzero_data, data_mask) + dequantized_data = constexpr_blockwise_shift_scale.decompress(data, scale, offset) + output_nonzero_data = dequantized_data[data_mask != 0].flatten() + return data_mask, output_nonzero_data + + def materialized_val_inference(self): + offset = self.offset.val if self.offset is not None else None + return self.decompress(self.data_mask.val, self.nonzero_data.val, self.scale.val, offset) + + +@register_op(opset_version=_IOS18_TARGET) +class constexpr_cast(_constexpr_cast_iOS16): + """ + A compile-time operation that returns a constant output value upon casting its constant input. + + The only difference between this version and the iOS 16 :py:class:`~.iOS16.constexpr_ops.constexpr_cast` is + the parameters are treated as inputs, instead of attributes in the MIL backend framework. + """ + + input_spec = InputSpec( + source_val=TensorInputType(const=True, type_domain=types.fp16), + output_dtype=TensorInputType(const=True, type_domain=types.str), + ) diff --git a/coremltools/converters/mil/mil/ops/defs/iOS18/recurrent.py b/coremltools/converters/mil/mil/ops/defs/iOS18/recurrent.py new file mode 100644 index 000000000..0aecfe369 --- /dev/null +++ b/coremltools/converters/mil/mil/ops/defs/iOS18/recurrent.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + + +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.input_type import InputSpec, TensorInputType +from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op +from coremltools.converters.mil.mil.ops.defs.iOS17.recurrent import gru as _gru_iOS17 +from coremltools.converters.mil.mil.ops.defs.iOS18 import _IOS18_TARGET + + +@register_op(opset_version=_IOS18_TARGET) +class gru(_gru_iOS17): + """ + Gated Recurrent Unit (GRU) + + The only difference between this version and the iOS 17 :py:class:`~.iOS17.recurrent.gru` is + the reset_after parameter. This parameter is optional and defaults to False. When True, the + reset gate is applied before the elementwise matrix multiplication. + """ + input_spec = InputSpec( + x=TensorInputType(type_domain="T"), + initial_h=TensorInputType(type_domain="T"), + weight_ih=TensorInputType(const=True, type_domain="T"), + weight_hh=TensorInputType(const=True, type_domain="T"), + bias=TensorInputType(const=True, optional=True, type_domain="T"), + direction=TensorInputType(const=True, optional=True, type_domain=types.str), + output_sequence=TensorInputType(const=True, optional=True, type_domain=types.bool), + recurrent_activation=TensorInputType(const=True, optional=True, type_domain=types.str), + activation=TensorInputType(const=True, optional=True, type_domain=types.str), + reset_after=TensorInputType(const=True, optional=True, type_domain=types.bool), + input_bias=TensorInputType(const=True, optional=True, type_domain="T"), + ) diff --git a/coremltools/converters/mil/mil/ops/defs/iOS18/states.py b/coremltools/converters/mil/mil/ops/defs/iOS18/states.py new file mode 100644 index 000000000..8ec3cc7a6 --- /dev/null +++ b/coremltools/converters/mil/mil/ops/defs/iOS18/states.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.input_type import InputSpec, StateInputType +from coremltools.converters.mil.mil.operation import Operation +from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op +from coremltools.converters.mil.mil.ops.defs.iOS18 import _IOS18_TARGET + + +@register_op(opset_version=_IOS18_TARGET) +class read_state(Operation): + """ + Read a state, copy its content into a new variable, and return the variable. + The type of the output variable depends on the type that is wrapped inside the state, + which could be ``types.tensor``. + + Parameters + ---------- + input: state (Required) + + Returns + ------- + ST + + Attributes + ---------- + ST: tensor + """ + + input_spec = InputSpec( + input=StateInputType(), + ) + + def type_inference(self): + sym_type = self.input.sym_type.wrapped_type() + if not types.is_tensor(sym_type): + raise ValueError( + f"State only supports wrapped type of types.tensor. Got {sym_type.__type_info__()}." + ) + return sym_type diff --git a/coremltools/converters/mil/mil/ops/defs/iOS18/tensor_transformation.py b/coremltools/converters/mil/mil/ops/defs/iOS18/tensor_transformation.py new file mode 100644 index 000000000..96d7d0c0c --- /dev/null +++ b/coremltools/converters/mil/mil/ops/defs/iOS18/tensor_transformation.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import numpy as np + +from coremltools.converters.mil.mil import Operation, types +from coremltools.converters.mil.mil.input_type import DefaultInputs, InputSpec, TensorInputType +from coremltools.converters.mil.mil.operation import Operation +from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op +from coremltools.converters.mil.mil.ops.defs._utils import ( + get_param_val, + solve_slice_by_index_shape, + solve_slice_by_index_slice, +) +from coremltools.converters.mil.mil.ops.defs.iOS18 import _IOS18_TARGET +from coremltools.converters.mil.mil.types.symbolic import is_compatible_symbolic_vector + + +@register_op(opset_version=_IOS18_TARGET) +class slice_update(Operation): + """ + Update a custom slice of a source tensor with another tensor of + the same shape, as dictated by the slice. + + For example, if you have a tensor ``x``, this method produces the following:: + + x[begin[0]: end[0]: stride[0], begin[1]: end[1]: stride[1], ...] = value + + The arguments defining the slice (``begin``, ``end``, ``stride``, ``masks``, and so on) should be + treated the same way as iOS15 :py:class:`~.iOS15.tensor_transformation.slice_by_index`. + + + Parameters + ---------- + x: tensor<*?, T> (Required) + * Input tensor. + update: tensor<\*K, T> (Required) + * Value tensor to be inserted. + * The shape of the update tensor must match the slicing result of the input data. + * rank-0 update is not supported. + begin: tensor<[rank], U> (Required) + * Starting index for the dimension of slicing. + end: tensor<[rank(x)], U> (Required) + * Ending index for the dimension of slicing. + stride: tensor<[rank(x)], U> (Optional) + * Default as all ``1``. + * Stride for the dimension of slicing. + begin_mask: tensor<[rank(x)], bool> (Optional) + * Default to all ``False``. + * If ``begin_mask[i]==True``, neglect ``begin[i]``, and set ``begin[i]`` to ``0``. + end_mask: tensor<[rank(x)], bool> (Optional) + * Default to all ``False``. + * If ``end_mask[i]==True``, neglect ``end[i]``, and set ``end[i]`` to ``x.shape[i]``. + squeeze_mask: tensor<[rank(x)], bool> (Optional) + * Default to all ``False``. + * If ``squeeze_mask[i]==True``, neglect ``end[i]``, and do the pure index at ``begin[i]``. + + Returns + ------- + tensor<\*?, T> + - Scalar or tensor. + + Attributes + ---------- + T: fp16, fp32, int8, int16, int32, uint8, uint16, bool + U: int8, int16, int32 + """ + + input_spec = InputSpec( + x=TensorInputType(type_domain="T"), + update=TensorInputType(type_domain="T"), + begin=TensorInputType(type_domain="U"), + end=TensorInputType(type_domain="U"), + stride=TensorInputType(const=True, optional=True, type_domain="U"), + begin_mask=TensorInputType(const=True, optional=True, type_domain=types.bool), + end_mask=TensorInputType(const=True, optional=True, type_domain=types.bool), + squeeze_mask=TensorInputType(const=True, optional=True, type_domain=types.bool), + ) + + type_domains = { + "T": ( + types.fp16, + types.fp32, + types.int8, + types.int16, + types.int32, + types.uint8, + types.uint16, + types.bool, + ), + "U": (types.int8, types.int16, types.int32), + } + + def default_inputs(self): + return DefaultInputs( + stride=None, + begin_mask=None, + end_mask=None, + squeeze_mask=None, + ) + + def type_inference(self): + # solve shape + ret_shape = solve_slice_by_index_shape( + self.x.shape, + self.begin.val, + self.end.val, + get_param_val(self.stride), + get_param_val(self.begin_mask), + get_param_val(self.end_mask), + get_param_val(self.squeeze_mask), + ) + + if not is_compatible_symbolic_vector(ret_shape, self.update.shape): + raise ValueError( + "The update tensor should have shape {}. Got {}".format( + ret_shape, self.update.shape + ) + ) + + if self.update.rank == 0: + # rdar://128221986 ([Feature][Slice_update] The backends is not supporting scalar update for the slice_update op) + raise ValueError(f"rank-0 'update' is not supported in 'slice_update' op {self.name}.") + + return self.x.sym_type + + def value_inference(self): + if ( + self.x.sym_val is None + or self.update.sym_val is None + or self.begin.val is None + or self.end.val is None + ): + return None + + # solve the data slices + slices = solve_slice_by_index_slice( + self.x.shape, + self.begin.val, + self.end.val, + get_param_val(self.stride), + get_param_val(self.begin_mask), + get_param_val(self.end_mask), + get_param_val(self.squeeze_mask), + ) + + # copy the data and do the inplace update + copy_x_val = np.copy(self.x.sym_val) + copy_x_val[slices] = np.reshape(self.update.sym_val, copy_x_val[slices].shape) + return copy_x_val diff --git a/coremltools/converters/mil/mil/ops/defs/iOS18/transformers.py b/coremltools/converters/mil/mil/ops/defs/iOS18/transformers.py new file mode 100644 index 000000000..1c7783def --- /dev/null +++ b/coremltools/converters/mil/mil/ops/defs/iOS18/transformers.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import numpy as np + +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.input_type import InputSpec, TensorInputType +from coremltools.converters.mil.mil.operation import Operation +from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op +from coremltools.converters.mil.mil.ops.defs._utils import broadcast_shapes +from coremltools.converters.mil.mil.ops.defs.iOS18 import _IOS18_TARGET +from coremltools.converters.mil.mil.types.symbolic import any_symbolic, is_symbolic + + +@register_op(opset_version=_IOS18_TARGET) +class scaled_dot_product_attention(Operation): + """ + Source: `PyTorch scaled dot product attention `_. + Computes the scaled dot product attention on query, key, and value tensors, using an optional attention mask if passed. + In PyTorch, this is equivalent to:: + + attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask + attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_mask, dim=-1) + return attn_weight @ V + + Shape key: + - ``B`` = Batch size + - ``S`` = Source sequence length + - ``L`` = Target sequence length + - ``E`` = Query/Key embedding dimension + - ``EV`` = Value embedding dimension + + Numerical values can differ due to floating point fusion/accumulation between backends. + Note: We currently do not support the ``dropout_p`` and ``is_causal``. + + Mask can either be bool or float matching query, key, or value. For bool, it indicates + whether the element should take part in the attention. Floats are added to the attention score. + Mask shape must be broadcastable to ``[B, \*?, L, S]``. + + Parameters + ---------- + query: tensor<[B, \*?, L, E], T> (Required) + key: tensor<[B, \*?, S, E], T> (Required) + value: tensor<[B, \*?, S, EV], T> (Required) + attn_mask: tensor<[\*?, S], M> (Optional) + + Returns + ------- + tensor<[B, \*?, L, EV], T> + + Attributes + ---------- + T: fp16, fp32 + M: bool, fp16, fp32 + """ + + input_spec = InputSpec( + query=TensorInputType(type_domain="T"), + key=TensorInputType(type_domain="T"), + value=TensorInputType(type_domain="T"), + attn_mask=TensorInputType(optional=True, type_domain="M"), + ) + + type_domains = { + "T": (types.fp16, types.fp32), + "M": (types.bool, types.fp16, types.fp32), + } + + def _validate_inputs(self): + query_rank = self.query.rank + key_rank = self.key.rank + value_rank = self.value.rank + if query_rank != key_rank or query_rank != value_rank: + raise ValueError( + f"query, key, value must have a same rank, got\n" + f"* query rank = {query_rank}\n" + f"* key rank = {key_rank}\n" + f"* value rank = {value_rank}" + ) + if query_rank < 3: + raise ValueError( + f"query, key, value must have at lease rank 3 " + f"for batch, sequence length, embedding, got rank {query_rank}" + ) + + query_shape = self.query.shape + key_shape = self.key.shape + value_shape = self.value.shape + B_query = query_shape[:-2] + E_query = query_shape[-1] + B_key = key_shape[:-2] + S_key = key_shape[-2] + E_key = key_shape[-1] + B_value = value_shape[:-2] + S_value = value_shape[-2] + + batch_dims = [B_query, B_key, B_value] + batch_dims = [batch_dim for batch_dim in batch_dims if not any_symbolic(batch_dims)] + if len(set(batch_dims)) > 1: + raise ValueError( + "query, key, value must have a same batch dimension, got\n" + f"* query batch = {B_query}\n" + f"* key batch = {B_key}\n" + f"* value batch = {B_value}" + ) + if not is_symbolic(E_query) and not is_symbolic(E_key) and E_query != E_key: + raise ValueError( + "query and key must have a same embedding dimension, got\n" + f"* query embedding = {E_query}\n" + f"* key embedding = {E_key}" + ) + if not is_symbolic(S_key) and not is_symbolic(S_value) and S_key != S_value: + raise ValueError( + "key and value must have a same sequence length, got\n" + f"* key sequence = {S_key}\n" + f"* value sequence = {S_value}" + ) + + if self.attn_mask is not None: + mask_shape = self.attn_mask.shape + S_mask = mask_shape[-1] + if not is_symbolic(S_mask) and not is_symbolic(S_key) and S_mask != S_key: + raise ValueError( + "key and mask must have a same sequence length, got\n" + f"* key sequence = {S_key}\n" + f"* mask sequence = {S_mask}" + ) + # If shapes are inconsistent, then `broadcast_shapes` would raise exception + broadcast_shapes(query_shape[:-1], mask_shape[:-1]) + + def type_inference(self): + self._validate_inputs() + + shape = list(self.query.shape[:-1]) + [self.value.shape[-1]] + return types.tensor(self.query.dtype, shape) + + def value_inference(self): + query = self.query.val + key = self.key.val + value = self.value.val + if query is None or key is None or value is None: + return None + + float_mask = None + if self.attn_mask is not None and self.attn_mask.val is not None: + mask = self.attn_mask.val + if mask.dtype == bool: + float_mask = np.zeros(mask.shape) + float_mask[np.where(np.logical_not(mask))] = -np.inf + else: + float_mask = mask + + similarity = np.matmul(query, key.swapaxes(-2, -1)) / np.sqrt(query.shape[-1]) + if float_mask is not None: + similarity += float_mask + attention_weight = self.numpy_softmax_last_dim(similarity) + attention = np.matmul(attention_weight, value) + return attention + + @staticmethod + def numpy_softmax_last_dim(x: np.ndarray) -> np.ndarray: + exps = np.exp(x - np.max(x, axis=-1)[..., None]) + softmax = exps / np.sum(exps, axis=-1)[..., None] + return softmax diff --git a/coremltools/converters/mil/mil/ops/registry.py b/coremltools/converters/mil/mil/ops/registry.py index 244aaf166..a64feb0b7 100644 --- a/coremltools/converters/mil/mil/ops/registry.py +++ b/coremltools/converters/mil/mil/ops/registry.py @@ -58,6 +58,7 @@ class SSAOpRegistry: target.iOS15, target.iOS16, target.iOS17, + target.iOS18, ) core_ops = defaultdict(dict) dialect_ops = {} diff --git a/coremltools/converters/mil/mil/ops/tests/coreml_dialect/__init__.py b/coremltools/converters/mil/mil/ops/tests/coreml_dialect/__init__.py new file mode 100644 index 000000000..a097a9c9d --- /dev/null +++ b/coremltools/converters/mil/mil/ops/tests/coreml_dialect/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/converters/mil/mil/ops/tests/coreml_dialect/test_coreml_dialect.py b/coremltools/converters/mil/mil/ops/tests/coreml_dialect/test_coreml_dialect.py new file mode 100644 index 000000000..ab0b4015a --- /dev/null +++ b/coremltools/converters/mil/mil/ops/tests/coreml_dialect/test_coreml_dialect.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import pytest + +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.builder import Builder as mb +from coremltools.converters.mil.mil.ops.defs.iOS18 import _IOS18_TARGET +from coremltools.converters.mil.testing_utils import get_op_types_in_program + + +class TestCoreMLUpdateState: + @staticmethod + def test_update_tensor_state_builder(): + @mb.program( + input_specs=[mb.StateTensorSpec((2, 3)), mb.TensorSpec((2, 3))], + opset_version=_IOS18_TARGET, + ) + def prog(x, value): + return mb.coreml_update_state(state=x, value=value) + + update_state_op = prog.find_ops("coreml_update_state")[0] + assert types.is_state(update_state_op.state._sym_type) + assert types.is_tensor(update_state_op.outputs[0]._sym_type) + + @staticmethod + def test_update_tensor_state_builder_invalid(): + # Update state with value of different shape + with pytest.raises( + ValueError, + match="State wrapped type tensor\[2,3,fp32\] not matched with the value's sym_type tensor\[3,2,fp32\]", + ): + + @mb.program( + input_specs=[mb.StateTensorSpec((2, 3)), mb.TensorSpec((3, 2))], + opset_version=_IOS18_TARGET, + ) + def prog(x, value): + return mb.coreml_update_state(state=x, value=value) + + # Update state with value of different dtype + with pytest.raises( + ValueError, + match="State wrapped type tensor\[2,3,fp32\] not matched with the value's sym_type tensor\[2,3,fp16\]", + ): + + @mb.program( + input_specs=[mb.StateTensorSpec((2, 3)), mb.TensorSpec((2, 3), dtype=types.fp16)], + opset_version=_IOS18_TARGET, + ) + def prog(x, value): + return mb.coreml_update_state(state=x, value=value) + + @staticmethod + def test_simple_stateful_model_builder(): + @mb.program( + input_specs=[mb.StateTensorSpec((2, 3)), mb.TensorSpec((2, 3))], + opset_version=_IOS18_TARGET, + ) + def prog(x, value): + read_val = mb.read_state(input=x) + add = mb.add(x=read_val, y=value) + return mb.coreml_update_state(state=x, value=add) + + assert get_op_types_in_program(prog) == ["read_state", "add", "coreml_update_state"] diff --git a/coremltools/converters/mil/mil/ops/tests/iOS14/test_control_flow.py b/coremltools/converters/mil/mil/ops/tests/iOS14/test_control_flow.py index bcb8278a0..c3e9ba75d 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS14/test_control_flow.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS14/test_control_flow.py @@ -588,3 +588,28 @@ def prog(): else: target_dtype = dtype assert const_op.outputs[0].dtype == types.numpy_type_to_builtin_type(target_dtype) + + @pytest.mark.parametrize( + "compute_unit, backend, dtype_str", + itertools.product( + compute_units, backends, ("int4", "uint1", "uint2", "uint3", "uint4", "uint6") + ), + ) + def test_const_sub_byte_dtype(self, compute_unit, backend, dtype_str): + builtin_dtype = types.string_to_builtin(dtype_str) + upper_bound = types.type_mapping.builtin_to_range(builtin_dtype).high + original_data = np.random.randint(0, upper_bound + 1, (2, 3)) + np_dtype = types.nptype_from_builtin(builtin_dtype) + + @mb.program(input_specs=[], opset_version=backend.opset_version) + def prog(): + return mb.const(val=original_data.astype(np_dtype)) + + const_op = prog.functions["main"].find_ops(op_type="const")[0] + assert types.builtin_to_string(const_op.outputs[0].dtype) == dtype_str + expected_underlying_dtype = np.int8 if dtype_str.startswith("i") else np.uint8 + assert const_op.outputs[0].val.dtype == expected_underlying_dtype + assert const_op.outputs[0].val.dtype.metadata["true_dtype"] == types.string_to_builtin( + dtype_str + ) + np.testing.assert_equal(const_op.outputs[0].val, original_data) diff --git a/coremltools/converters/mil/mil/ops/tests/iOS14/test_conv.py b/coremltools/converters/mil/mil/ops/tests/iOS14/test_conv.py index 57010f313..c85e9fe3b 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS14/test_conv.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS14/test_conv.py @@ -449,7 +449,9 @@ def test_builder_to_backend_stress( "symbolic": True, } ): - pytest.xfail("rdar://121954894: Conv2d starts to fail") + pytest.xfail( + "rdar://129121584: NN Conv Fail when Run Multiple Faulty Models at Same Time" + ) padding = config["padding"] DHWKdKhKw = config["DHWKdKhKw"] @@ -638,20 +640,6 @@ def test_builder_to_backend_stress_weights_input( conv_dim, config, ): - if ( - conv_dim == "conv2d" and - config == { - 'padding': (1, 1, 1), - 'DHWKdKhKw': (5, 5, 5, 2, 2, 2), - 'stride': (2, 2, 2), - 'dilation': (2, 1, 1), - 'has_bias': True, - 'groups': 1, - 'symbolic': True, - } - ): - pytest.xfail("rdar://121954894: Conv2d starts to fail") - padding = config["padding"] DHWKdKhKw = config["DHWKdKhKw"] stride = config["stride"] diff --git a/coremltools/converters/mil/mil/ops/tests/iOS14/test_image_resizing.py b/coremltools/converters/mil/mil/ops/tests/iOS14/test_image_resizing.py index 20ab1ac0d..04ff5ec65 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS14/test_image_resizing.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS14/test_image_resizing.py @@ -395,12 +395,13 @@ def build(x): class TestCropResize: @mark_api_breaking(breaking_opset_version=ct.target.iOS17) @pytest.mark.parametrize( - "compute_unit, backend, is_symbolic", - itertools.product(compute_units, backends, [True, False]), + "compute_unit, backend, is_symbolic, mode", + itertools.product(compute_units, backends, [True, False], list(range(5))), ) - def test_builder_to_backend_smoke(self, compute_unit, backend, is_symbolic): + def test_builder_to_backend_smoke(self, compute_unit, backend, is_symbolic, mode): if backend.backend == "mlprogram" and compute_unit != ct.ComputeUnit.CPU_ONLY: pytest.xfail("rdar://97398582 (TestCropResize failing on mlprogram + GPU)") + x = np.array( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32, @@ -536,13 +537,12 @@ def build(x, mode=0): np.array([11, 10, 7, 6], dtype=np.float32).reshape(1, 1, 1, 2, 2), ] - for mode in range(5): - run_compare_builder( - functools.partial(build, mode=mode), - input_placeholder_dict, - input_value_dict, - expected_output_type[mode], - expected_output[mode], - compute_unit=compute_unit, - backend=backend, - ) + run_compare_builder( + functools.partial(build, mode=mode), + input_placeholder_dict, + input_value_dict, + expected_output_type[mode], + expected_output[mode], + compute_unit=compute_unit, + backend=backend, + ) diff --git a/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_operation.py b/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_operation.py index a9f43104c..f0c7f189e 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_operation.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_operation.py @@ -1170,6 +1170,19 @@ def test_builder_eval(self): class TestDynamicTile: + @staticmethod + def test_dynamic_shape_tile_type_inference(): + reps = [1, 2] + input_shape = [get_new_symbol(), get_new_symbol()] + + @mb.program(input_specs=[mb.TensorSpec(shape=input_shape)]) + def prog(x): + x = mb.tile(x=x, reps=[1, 2]) + assert x.shape[0] == input_shape[0] + assert is_symbolic(x.shape[1]) + assert x.shape[1] != input_shape[1] + return x + @pytest.mark.parametrize( "compute_unit, backend", itertools.product( diff --git a/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_transformation.py b/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_transformation.py index 6ae83e6aa..bd703c8ba 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_transformation.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS14/test_tensor_transformation.py @@ -1058,7 +1058,7 @@ def prog(x): ) return x - x = np.random.rand(*INPUT_SHAPE) + x = np.float16(np.random.rand(*INPUT_SHAPE)) # slice by index is x[begin[0]: end[0]: stride[0], begin[1]: end[1]: stride[1], ...] y_numpy = x[0:1:1, 0:2:1, 0:8:2, 0:12:2] diff --git a/coremltools/converters/mil/mil/ops/tests/iOS16/test_constexpr_ops.py b/coremltools/converters/mil/mil/ops/tests/iOS16/test_constexpr_ops.py index 24358362d..58e21c3c9 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS16/test_constexpr_ops.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS16/test_constexpr_ops.py @@ -14,7 +14,10 @@ from coremltools.converters.mil.mil import types from coremltools.converters.mil.mil.ops.defs.iOS16 import constexpr_ops from coremltools.converters.mil.mil.ops.tests.iOS16 import backends -from coremltools.converters.mil.mil.ops.tests.testing_utils import run_compare_builder +from coremltools.converters.mil.mil.ops.tests.testing_utils import ( + mark_api_breaking, + run_compare_builder, +) from coremltools.converters.mil.testing_utils import get_op_types_in_program, ssa_fn compute_units = testing_reqs.compute_units @@ -332,6 +335,7 @@ def build(x): assert "constexpr_cast" in get_op_types_in_program(prog) class TestConstexprLutToDense: + @mark_api_breaking(breaking_opset_version=ct.target.iOS18) @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) def test_builder_to_backend_smoke(self, compute_unit, backend): @@ -385,6 +389,7 @@ def build(x): prog = mlmodel._mil_program assert "constexpr_lut_to_dense" in get_op_types_in_program(prog) + @mark_api_breaking(breaking_opset_version=ct.target.iOS18) @pytest.mark.parametrize("backend", backends) def test_shape_of_constexpr_is_replaceable(self, backend): @mb.program(input_specs=[], opset_version=backend.opset_version) @@ -473,6 +478,7 @@ def lut_config_generator(): } yield params + @mark_api_breaking(breaking_opset_version=ct.target.iOS18) @pytest.mark.parametrize( "compute_unit, backend, config", itertools.product(compute_units, backends, lut_config_generator.__func__()), @@ -522,6 +528,7 @@ def build(x): raise AssertionError("Invalidated: Test Failed") class TestConstexprSparseToDense: + @mark_api_breaking(breaking_opset_version=ct.target.iOS18) @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) def test_builder_to_backend_smoke(self, compute_unit, backend): @@ -608,6 +615,7 @@ def sparse_config_generator(): } yield params + @mark_api_breaking(breaking_opset_version=ct.target.iOS18) @pytest.mark.parametrize( "compute_unit, backend, config", itertools.product(compute_units, backends, sparse_config_generator.__func__()), diff --git a/coremltools/converters/mil/mil/ops/tests/iOS17/test_quantization.py b/coremltools/converters/mil/mil/ops/tests/iOS17/test_quantization.py index cae1135c5..78585be93 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS17/test_quantization.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS17/test_quantization.py @@ -24,13 +24,6 @@ np.random.seed(1042) -def _set_backend_precision(backend, precision): - return BackendConfig( - backend=backend.backend, - precision=precision, - opset_version=backend.opset_version, - ) - class TestQuantizationBase: @staticmethod def get_random_quantization_params( @@ -238,14 +231,14 @@ def build(x): @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) @pytest.mark.parametrize( - "compute_unit, backend, float_dtype, quant_dtype, compute_precision, input_rank, is_zp_present", + "compute_unit, backend, float_dtype, quant_dtype, input_rank, axis, is_zp_present", itertools.product( compute_units, backends, (np.float32, np.float16), (np.int8, np.uint8), - ("fp32", "fp16"), - (1, 2, 3, 4, 5), + list(range(1, 6)), + [None] + list(range(-5, 5)), (True, False), ), ) @@ -255,10 +248,13 @@ def test_stress_builder_to_backend_quantize_all_possibilities( backend, float_dtype, quant_dtype, - compute_precision, input_rank, + axis, is_zp_present, ): + if axis is not None and (axis < -input_rank or axis >= input_rank): + pytest.skip("axis should either be None or in [-input_rank, input_rank)") + def build(x): x = mb.cast(x=x, dtype=builtin_to_string(numpy_type_to_builtin_type(float_dtype))) quantized = mb.quantize( @@ -275,34 +271,33 @@ def build(x): ) return dequantized - for axis in [None] + [i for i in range(-input_rank, input_rank)]: - x_fp, scale, zero_point = self.get_random_quantization_params( - float_dtype, quant_dtype, input_rank, is_zp_present, axis - ) + x_fp, scale, zero_point = self.get_random_quantization_params( + float_dtype, quant_dtype, input_rank, is_zp_present, axis + ) - input_placeholders = { - "x": mb.placeholder( - shape=x_fp.shape, - dtype=numpy_type_to_builtin_type(float_dtype), - ), - } - input_values = {"x": x_fp} - - output_torch = self.torch_quantize(x_fp, scale, zero_point, axis, quant_dtype) - output_torch_val = output_torch.int_repr().numpy() - output_type = output_torch_val.shape + (numpy_type_to_builtin_type(np.float32),) - expected_outputs = [output_torch_val] - expected_output_types = [output_type] - - run_compare_builder( - build, - input_placeholders, - input_values, - expected_output_types, - expected_outputs=expected_outputs, - compute_unit=compute_unit, - backend=_set_backend_precision(backend, compute_precision), - ) + input_placeholders = { + "x": mb.placeholder( + shape=x_fp.shape, + dtype=numpy_type_to_builtin_type(float_dtype), + ), + } + input_values = {"x": x_fp} + + output_torch = self.torch_quantize(x_fp, scale, zero_point, axis, quant_dtype) + output_torch_val = output_torch.int_repr().numpy() + output_type = output_torch_val.shape + (numpy_type_to_builtin_type(np.float32),) + expected_outputs = [output_torch_val] + expected_output_types = [output_type] + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs=expected_outputs, + compute_unit=compute_unit, + backend=backend, + ) class TestDequantize(TestQuantizationBase): @@ -374,7 +369,13 @@ def build(x): expected_output_types=[expected_output_type], expected_outputs=[expected_output], compute_unit=compute_unit, - backend=_set_backend_precision(backend, "fp32"), + # Other test cases are mostly testing fp16 precision, + # so this one we explicitly test fp32 precision + backend=BackendConfig( + backend=backend.backend, + precision="fp32", + opset_version=backend.opset_version, + ), atol=1e-3, rtol=1e-3, ) @@ -409,21 +410,27 @@ def build(x): expected_output_types=[expected_output_type], expected_outputs=[expected_output], compute_unit=compute_unit, - backend=_set_backend_precision(backend, "fp32"), + # Other test cases are mostly testing fp16 precision, + # so this one we explicitly test fp32 precision + backend=BackendConfig( + backend=backend.backend, + precision="fp32", + opset_version=backend.opset_version, + ), atol=1e-3, rtol=1e-3, ) @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) @pytest.mark.parametrize( - "compute_unit, backend, float_dtype, quant_dtype, compute_precision, input_rank, is_zp_present", + "compute_unit, backend, float_dtype, quant_dtype, input_rank, axis, is_zp_present", itertools.product( compute_units, backends, (np.float32, np.float16), (np.int8, np.uint8), - ("fp32", "fp16"), - (1, 2, 3, 4, 5), + list(range(1, 6)), + [None] + list(range(-5, 5)), (True, False), ), ) @@ -433,10 +440,13 @@ def test_stress_builder_to_backend_dequantize_all_possibilities( backend, float_dtype, quant_dtype, - compute_precision, input_rank, + axis, is_zp_present, ): + if axis is not None and (axis < -input_rank or axis >= input_rank): + pytest.skip("axis should either be None or in [-input_rank, input_rank)") + def build(x): x = mb.cast(x=x, dtype=builtin_to_string(numpy_type_to_builtin_type(float_dtype))) # TODO(rdar://107430678): Replace scale=1 zero_point=0 quantize/dequantize with cast @@ -453,33 +463,32 @@ def build(x): ) return dequantized - for axis in [None] + [i for i in range(-input_rank, input_rank)]: - x_fp, scale, zero_point = self.get_random_quantization_params( - float_dtype, quant_dtype, input_rank, is_zp_present, axis - ) + x_fp, scale, zero_point = self.get_random_quantization_params( + float_dtype, quant_dtype, input_rank, is_zp_present, axis + ) - x_q = self.torch_quantize(x_fp, scale, zero_point, axis, quant_dtype) - - output_torch_val = torch.dequantize(x_q).numpy() - output_type = output_torch_val.shape + (numpy_type_to_builtin_type(np.float32),) - - input_placeholders = { - "x": mb.placeholder( - shape=x_fp.shape, - dtype=numpy_type_to_builtin_type(float_dtype), - ), - } - input_values = {"x": x_q.int_repr().numpy()} - - expected_outputs = [output_torch_val] - expected_output_types = [output_type] - run_compare_builder( - build, - input_placeholders, - input_values, - expected_output_types, - expected_outputs=expected_outputs, - compute_unit=compute_unit, - backend=_set_backend_precision(backend, compute_precision), - rtol=1e-3, - ) + x_q = self.torch_quantize(x_fp, scale, zero_point, axis, quant_dtype) + + output_torch_val = torch.dequantize(x_q).numpy() + output_type = output_torch_val.shape + (numpy_type_to_builtin_type(np.float32),) + + input_placeholders = { + "x": mb.placeholder( + shape=x_fp.shape, + dtype=numpy_type_to_builtin_type(float_dtype), + ), + } + input_values = {"x": x_q.int_repr().numpy()} + + expected_outputs = [output_torch_val] + expected_output_types = [output_type] + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs=expected_outputs, + compute_unit=compute_unit, + backend=backend, + rtol=1e-3, + ) diff --git a/coremltools/converters/mil/mil/ops/tests/iOS17/test_scatter_gather.py b/coremltools/converters/mil/mil/ops/tests/iOS17/test_scatter_gather.py index b4b4edca5..5207b84cc 100644 --- a/coremltools/converters/mil/mil/ops/tests/iOS17/test_scatter_gather.py +++ b/coremltools/converters/mil/mil/ops/tests/iOS17/test_scatter_gather.py @@ -8,6 +8,7 @@ import numpy as np import pytest +import coremltools as ct from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import types from coremltools.converters.mil.mil.ops.tests.iOS14.test_scatter_gather import ( @@ -41,6 +42,34 @@ class TestScatter: def test_ios17_invalid_indices( self, compute_unit, backend, indices_val, validate_indices, dynamic ): + if ( + indices_val == [10, 0] + and backend.opset_version == ct.target.iOS18 + and not validate_indices + ): + pytest.xfail( + "rdar://128089254 ([Bug][Regression] iOS18 scatter ops has unexpected behavior than iOS17)" + ) + + if ( + indices_val == [-1, 0] + and backend.opset_version == ct.target.iOS18 + and validate_indices + and dynamic + ): + pytest.xfail( + "rdar://128089254 ([Bug][Regression] iOS18 scatter ops has unexpected behavior than iOS17)" + ) + + if ( + indices_val == [-1, 0] + and backend.opset_version == ct.target.iOS18 + and not validate_indices + ): + pytest.xfail( + "rdar://128089254 ([Bug][Regression] iOS18 scatter ops has unexpected behavior than iOS17)" + ) + def build_static(data, updates): return ( mb.scatter( @@ -82,6 +111,7 @@ def build_dynamic(data, indices, updates): expected_error_msg = ( "Error computing NN outputs", "Unable to compute the prediction using a neural network model", + "Unable to compute the prediction using ML Program", ) else: # The negative or out-of-bound indices will error out when validate_indices is set. @@ -98,10 +128,10 @@ def build_dynamic(data, indices, updates): compute_unit=compute_unit, backend=backend, ) - if not isinstance(expected_error_msg, tuple): - expected_error_msg = expected_error_msg - assert any([err in str(excinfo.value) for err in expected_error_msg]) + if not isinstance(expected_error_msg, tuple): + expected_error_msg = expected_error_msg + assert any([err in str(excinfo.value) for err in expected_error_msg]) class TestScatterAlongAxis: @pytest.mark.parametrize( @@ -127,6 +157,15 @@ def test_builder_to_backend_programmatic(self, compute_unit, backend, rank_axis) ), ) def test_ios17_invalid_indices(self, compute_unit, backend, indices_val, dynamic): + if ( + indices_val == [[-1, 0, 1], [1, 1, 0]] + and dynamic + and backend.opset_version == ct.target.iOS18 + ): + pytest.xfail( + "rdar://128089254 ([Bug][Regression] iOS18 scatter ops has unexpected behavior than iOS17)" + ) + def build_static(data, updates): return ( mb.scatter_along_axis( @@ -164,6 +203,7 @@ def build_dynamic(data, indices, updates): expected_error_msg = ( "Error computing NN outputs", "Unable to compute the prediction using a neural network model", + "Unable to compute the prediction using ML Program", ) else: # The negative or out-of-bound indices will error out when validate_indices is set. @@ -181,9 +221,10 @@ def build_dynamic(data, indices, updates): compute_unit=compute_unit, backend=backend, ) - if not isinstance(expected_error_msg, tuple): - expected_error_msg = expected_error_msg - assert any([err in str(excinfo.value) for err in expected_error_msg]) + + if not isinstance(expected_error_msg, tuple): + expected_error_msg = expected_error_msg + assert any([err in str(excinfo.value) for err in expected_error_msg]) class TestScatterNd: @@ -194,6 +235,15 @@ class TestScatterNd: ), ) def test_ios17_invalid_indices(self, compute_unit, backend, indices_val, dynamic): + if ( + indices_val == [[1, 0], [0, -1]] + and dynamic + and backend.opset_version == ct.target.iOS18 + ): + pytest.xfail( + "rdar://128089254 ([Bug][Regression] iOS18 scatter ops has unexpected behavior than iOS17)" + ) + def build_static(data, updates): return ( mb.scatter_nd( @@ -226,6 +276,7 @@ def build_dynamic(data, indices, updates): expected_error_msg = ( "Error computing NN outputs", "Unable to compute the prediction using a neural network model", + "Unable to compute the prediction using ML Program", ) else: # The negative or out-of-bound indices will error out when validate_indices is set. @@ -242,9 +293,9 @@ def build_dynamic(data, indices, updates): compute_unit=compute_unit, backend=backend, ) - if not isinstance(expected_error_msg, tuple): - expected_error_msg = expected_error_msg - assert any([err in str(excinfo.value) for err in expected_error_msg]) + if not isinstance(expected_error_msg, tuple): + expected_error_msg = expected_error_msg + assert any([err in str(excinfo.value) for err in expected_error_msg]) class TestGather(_TestGatherIOS16): diff --git a/coremltools/converters/mil/mil/ops/tests/iOS18/__init__.py b/coremltools/converters/mil/mil/ops/tests/iOS18/__init__.py new file mode 100644 index 000000000..bd4e048a2 --- /dev/null +++ b/coremltools/converters/mil/mil/ops/tests/iOS18/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import coremltools as ct +from coremltools.converters.mil.testing_reqs import backends_internal, clean_up_backends + +backends = clean_up_backends(backends_internal, ct.target.iOS18) diff --git a/coremltools/converters/mil/mil/ops/tests/iOS18/test_compression.py b/coremltools/converters/mil/mil/ops/tests/iOS18/test_compression.py new file mode 100644 index 000000000..bd6ad8f5b --- /dev/null +++ b/coremltools/converters/mil/mil/ops/tests/iOS18/test_compression.py @@ -0,0 +1,1961 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools +import math +import re +from typing import List, Tuple + +import numpy as np +import pytest + +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.ops.defs._utils import promote_input_dtypes +from coremltools.converters.mil.mil.ops.defs.iOS18 import ( + _IOS18_TARGET, + constexpr_blockwise_shift_scale, + constexpr_lut_to_dense, + constexpr_lut_to_sparse, + constexpr_sparse_blockwise_shift_scale, + constexpr_sparse_to_dense, +) +from coremltools.converters.mil.mil.ops.tests.iOS18 import backends +from coremltools.converters.mil.mil.ops.tests.testing_utils import run_compare_builder +from coremltools.converters.mil.testing_reqs import compute_units + + +def _convert_to_sub_byte_dtype(data: np.ndarray, sub_byte_dtype: type) -> np.ndarray: + """Convert data to a specific sub-byte dtype, including shift between signed and unsigned range.""" + if not np.issubdtype(data.dtype, np.integer): + raise ValueError("Input data must be integer.") + if not types.is_sub_byte(sub_byte_dtype): + raise ValueError("Target dtype must be a sub-byte dtype.") + + original_signed = np.issubdtype(data.dtype, np.signedinteger) + target_signed = not sub_byte_dtype.is_unsigned() + if original_signed != target_signed: + shift = 2 ** (sub_byte_dtype.get_bitwidth() - 1) + if original_signed: + data += shift + else: + data -= shift + + dtype_range = types.type_mapping.builtin_to_range(sub_byte_dtype) + if np.max(data) > dtype_range.high: + raise ValueError( + f"Data has element {np.max(data)}, which is larger than the lower-bound {dtype_range.high}" + ) + if np.min(data) < dtype_range.low: + raise ValueError( + f"Data has element {np.min(data)}, which is smaller than the lower-bound {dtype_range.low}" + ) + + return data.astype(types.nptype_from_builtin(sub_byte_dtype)) + + +def _infer_lut_shape( + indices_shape: Tuple[int, ...], block_sizes: Tuple[int, ...], nbits: int, vector_size: int +): + """Infer the shape of look-up-table (LUT).""" + lut_shape = [] + for axis, dim_size in enumerate(indices_shape): + lut_dim_size = 1 if block_sizes[axis] == 0 else dim_size // block_sizes[axis] + lut_shape.append(lut_dim_size) + lut_shape.extend([2**nbits, vector_size]) + return lut_shape + + +class TestConstexprBlockwiseDequantize: + def test_builder_eval_basic_8bit(self): + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_blockwise_shift_scale( + data=np.array([4, 8, 10, 13, 24, 5, 6, 9]).reshape((1, 2, 4)).astype(np.int8), + scale=np.array([4, 8]).reshape((1, 1, 2)).astype(np.float16), + offset=np.array([4, 0]).reshape((1, 1, 2)).astype(np.int8), + ) + + main_func = prog.functions["main"] + constexpr_blockwise_shift_scale_op = main_func.find_ops( + op_type="constexpr_blockwise_shift_scale" + )[0] + decompressed_res = ( + np.array([0, 16, 80, 104, 80, 4, 48, 72]).reshape((1, 2, 4)).astype(np.float16) + ) + np.testing.assert_allclose( + decompressed_res, + constexpr_blockwise_shift_scale_op.outputs[0].op.materialized_val_inference(), + ) + + @pytest.mark.parametrize( + "scale_shape_output, quantized_dtype", + itertools.product( + [ + ((1, 1, 2), [0, -16, -64, 0, -40, -16, -24, 0]), + ((1, 2, 1), [0, -16, -48, -16, -48, 0, -24, 0]), + ], + ["int4", "uint4"], + ), + ) + def test_builder_eval_basic_4bit( + self, scale_shape_output: Tuple[Tuple[int], List[int]], quantized_dtype: str + ): + quantized_dtype = types.string_to_builtin(quantized_dtype) + scale_shape, expected_output = scale_shape_output + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + quantized_data = _convert_to_sub_byte_dtype( + np.array([4, 0, -8, 0, -6, 0, -3, 0]).reshape((1, 2, 4)), quantized_dtype + ) + offset = _convert_to_sub_byte_dtype( + np.array([4, 0]).reshape(scale_shape), quantized_dtype + ) + quantized_data = mb.const(val=quantized_data, name="quantized_data") + offset = mb.const(val=offset, name="offset") + return mb.constexpr_blockwise_shift_scale( + data=quantized_data, + scale=np.array([4, 8]).reshape(scale_shape).astype(np.float32), + offset=offset, + ) + + constexpr_blockwise_shift_scale_op = prog.functions["main"].find_ops( + op_type="constexpr_blockwise_shift_scale" + )[0] + np.testing.assert_allclose( + np.array(expected_output).reshape((1, 2, 4)).astype(np.float32), + constexpr_blockwise_shift_scale_op.outputs[0].op.materialized_val_inference(), + ) + + def test_builder_eval_basic_no_offset(self): + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + quantized_data = mb.const( + val=np.array([4, 0, -8, 0, -6, 0, -3, 0]) + .reshape((1, 2, 4)) + .astype(types.np_int4_dtype), + name="quantized_data", + ) + return mb.constexpr_blockwise_shift_scale( + data=quantized_data, + scale=np.array([4, 8]).reshape((1, 1, 2)).astype(np.float32), + ) + + constexpr_blockwise_shift_scale_op = prog.functions["main"].find_ops( + op_type="constexpr_blockwise_shift_scale" + )[0] + np.testing.assert_allclose( + np.array([16, 0, -64, 0, -24, 0, -24, 0]).reshape((1, 2, 4)).astype(np.float32), + constexpr_blockwise_shift_scale_op.outputs[0].op.materialized_val_inference(), + ) + + @pytest.mark.parametrize( + "nbits, block_size, mode", + itertools.product( + (4, 8), + (1, 2, 4), + ("linear_symmetric", "linear"), + ), + ) + def test_builder_eval_numerical_stress(self, nbits, block_size, mode): + nbits_range_max = 2 ** (nbits - 1) - 1 + nbits_range_min = -nbits_range_max + if mode == "linear": + nbits_range_min -= 1 + + nbits_range = nbits_range_max - nbits_range_min + # As small-bit quantization has a lot of information loss, we use int input to make the + # information loss less critical when comparing the dequantized data with original data. + original_data = ( + np.random.randn(2, 3, 8) + if block_size == 1 + else np.random.randint(nbits_range_min, nbits_range_max, (2, 3, 8)) + ) + + scaled_data = original_data.flatten() + scales = [] + zero_points = [] + for i in range(0, scaled_data.size, block_size): + block_data = scaled_data[i : i + block_size] + offset = 0 + + if mode == "linear_symmetric": + block_range = np.max(np.abs(block_data)) * 2 + else: + assert mode == "linear" + # For the linear mode, we need to make sure the data range contains `0`. + block_max = np.maximum(0.0, np.max(block_data)) + block_min = np.minimum(0.0, np.min(block_data)) + block_range = block_max - block_min + offset = ( + (nbits_range_min * block_max - nbits_range_max * block_min) / block_range + if block_range != 0.0 + else 0.0 + ) + zero_points.append(offset) + + block_scale = block_range / nbits_range + scales.append(block_scale) + scaled_data[i : i + block_size] = np.round(block_data / block_scale + offset) + scaled_data = np.minimum(scaled_data, nbits_range_max) + scaled_data = np.maximum(scaled_data, nbits_range_min) + scaled_data = scaled_data.reshape(original_data.shape).astype(np.int8) + scales_shape = original_data.shape[:-1] + (original_data.shape[-1] // block_size,) + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + quantized_data = scaled_data + if nbits == 4: + quantized_data = mb.const(val=quantized_data.astype(types.np_int4_dtype)) + return mb.constexpr_blockwise_shift_scale( + data=quantized_data, + scale=np.array(scales).reshape(scales_shape).astype(np.float32), + offset=None + if mode == "linear_symmetric" + else np.array(zero_points).reshape(scales_shape).astype(np.float32), + ) + + constexpr_blockwise_shift_scale_op = prog.functions["main"].find_ops( + op_type="constexpr_blockwise_shift_scale" + )[0] + + if block_size == 1: + # With block_size==1, the quantization will not have information loss. + atol, rtol = 1e-06, 1e-06 + elif nbits > 4 and block_size < 3: + # When block size is small and nbits is large, the information loss is limited. + atol, rtol = 1e-04, 1e-04 + else: + atol, rtol = 1e-02, 1e-02 + + dequantized_data = constexpr_blockwise_shift_scale_op.outputs[ + 0 + ].op.materialized_val_inference() + if np.issubdtype(original_data.dtype, np.integer): + dequantized_data = np.round(dequantized_data) + np.testing.assert_allclose( + original_data, + dequantized_data, + atol=atol, + rtol=rtol, + ) + + def test_builder_eval_invalid_parameter(self): + with pytest.raises( + ValueError, match=r"Parameter 'data' needs to have at least rank 1, but got scalar." + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_blockwise_shift_scale( + data=np.int8(10), + scale=np.float32(2.0), + ) + + with pytest.raises( + ValueError, match=r"Parameter 'data' and 'scale' need to have the same rank" + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_blockwise_shift_scale( + data=np.int8(10), + scale=np.array([1, 2]).astype(np.float32), + ) + + with pytest.raises( + ValueError, + match=r"Number of scales along each dimension should be a " + r"factor of corresponding dimension size of 'data'.", + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_blockwise_shift_scale( + data=np.array([1, 2]).reshape((1, 2)).astype(np.int8), + scale=np.array([1, 2]).reshape((2, 1)).astype(np.float16), + ) + + with pytest.raises( + ValueError, + match=r"Invalid parameter 'offset'; the shape of 'offset' " + r"should match the shape of 'scale'", + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_blockwise_shift_scale( + data=np.array([1, 2]).astype(np.int8), + scale=np.array([1, 2]).astype(np.float16), + offset=np.array([1, 2]).reshape((1, 2)).astype(np.int8), + ) + + with pytest.raises( + ValueError, + match=r"Invalid parameter 'offset'; the dtype of 'offset' " + r"should match the dtype of 'data'", + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_blockwise_shift_scale( + data=np.array([1, 2]).astype(types.nptype_from_builtin(types.int4)), + scale=np.array([1, 2]).astype(np.float16), + offset=np.array([1, 2]).astype(np.int8), + ) + + # When the offset is float, it doesn't need to have the same dtype as data. + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_blockwise_shift_scale( + data=np.array([1, 2]).astype(types.nptype_from_builtin(types.int4)), + scale=np.array([1, 2]).astype(np.float16), + offset=np.array([1, 2]).astype(np.float32), + ) + + @pytest.mark.parametrize( + "compute_unit, backend, nbits, has_offset", + itertools.product(compute_units, backends, [4, 8], [True, False]), + ) + def test_builder_to_backend_smoke(self, compute_unit, backend, nbits, has_offset): + x_val = np.ones(1).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + np_dtype = types.nptype_from_builtin(types.string_to_builtin(f"int{nbits}")) + + if nbits == 8: + data_val = [4, 8, 10, 13, 24, 5, 6, 9] + elif nbits == 4: + data_val = [2, 3, 5, 7, 6, 5, 3, 1] + data = np.array(data_val).reshape((1, 2, 4)).astype(np_dtype) + + if has_offset is True: + if nbits == 8: + offset_val = [4, 0] + elif nbits == 4: + offset_val = [1, 0] + else: + offset_val = [0, 0] + offset = np.array(offset_val).reshape((1, 1, 2)).astype(np_dtype) + + scale = np.array([1, 2]).reshape((1, 1, 2)).astype(np.float32) + + # Calculate expected output based on op definition. + expected_output = np.zeros(data.shape) + for n in range(0, 1): + for i in range(0, data.shape[0]): + for j in range(0, data.shape[1]): + for k in range(0, data.shape[2]): + i0 = math.floor(i / (data.shape[0] / scale.shape[0])) + j0 = math.floor(j / (data.shape[1] / scale.shape[1])) + k0 = math.floor(k / (data.shape[2] / scale.shape[2])) + expected_output[i][j][k] = ( + scale[i0][j0][k0] * (data[i][j][k] - offset[i0][j0][k0]) + 1 + ) + + def build(x): + output = mb.constexpr_blockwise_shift_scale( + data=data, + scale=scale, + offset=offset, + ) + return mb.add(x=x, y=output) + + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, dtype, block_sizes, has_offset", + itertools.product( + compute_units, + backends, + ["int4", "uint4", "int8", "uint8", "fp16"], + [(0, 1, 1, 1), (0, 0, 0, 2), (0, 0, 0, 0), (1, 1, 1, 1), (0, 4, 2, 0), (4, 8, 16, 8)], + [True, False], + ), + ) + def test_builder_to_backend_stress(self, compute_unit, backend, dtype, block_sizes, has_offset): + """ + Use constexpr_blockwise_shift_scale op's value inference to check backends outputs. + + Following combinations will fail if enable BNNS (rdar://125854036). + - dtype = 'uint4'/'int4', block_sizes = (1, 1, 1, 1) + - dtype = 'uint4'/'int4', block_sizes = (0, 1, 1, 1) + """ + quantized_data_shape = (4, 8, 16, 8) + builtin_dtype = types.string_to_builtin(dtype) + np_dtype = types.nptype_from_builtin(builtin_dtype) + + if types.is_int(builtin_dtype): + data_range = types.type_mapping.builtin_to_range(builtin_dtype) + quantized_data = np.random.randint( + low=data_range.low, high=data_range.high + 1, size=quantized_data_shape + ).astype(np_dtype) + else: + quantized_data = np.random.rand(*quantized_data_shape).astype(np_dtype) + + scale_shape = [ + 1 if block_sizes[axis] == 0 else dim_size // block_sizes[axis] + for axis, dim_size in enumerate(quantized_data.shape) + ] + scale = np.random.rand(*scale_shape) + offset = None + if has_offset: + if types.is_int(builtin_dtype): + offset = np.random.randint( + low=data_range.low, high=data_range.high + 1, size=scale.shape + ).astype(np_dtype) + else: + offset = np.random.rand(*scale.shape).astype(np_dtype) + + def build(x): + output = mb.constexpr_blockwise_shift_scale( + data=quantized_data, + scale=scale, + offset=offset, + ) + return mb.add(x=x, y=output) + + x_val = np.ones_like(quantized_data).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + expected_output = ( + constexpr_blockwise_shift_scale.decompress(quantized_data, scale, offset) + 1 + ) + + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + +class TestConstexprLut: + @staticmethod + def _pad_lut_for_nbits_requirements(lut: np.ndarray, nbits: int): + """ + Make the number of palettes in lut size (second last dim) meet the 2^nbits requirement. + + This util function is needed before we add all uint sub-byte dtypes. + """ + pad_shape = lut.shape[:-2] + (2**nbits - lut.shape[-2], lut.shape[-1]) + return np.concatenate((lut, np.zeros(pad_shape)), axis=-2) + + @staticmethod + def _generate_lut(shape: Tuple[int, ...]): + """It follows the MIL test cases.""" + total_num = np.prod(shape) + lut = np.arange(min(total_num, 128)) + if total_num > lut.size: + lut = np.concatenate((lut, np.ones(total_num - lut.size) * 127)) + return lut.reshape(shape) + + @pytest.mark.parametrize("nbits", [1, 2, 3, 4, 6, 8]) + def test_builder_eval_channelwise_lut(self, nbits): + """ + Test channel-wise lut with first axis as channel axis (the first dim of lut has size > 1). + + indices = tensor>([2, 3, 3, 0, 1, 0, 3, 0, 2, 1, 0, 3]) + lut = tensor([1, 5, 9, 13, 2, 10, 18, 26]) + + It is effectively a 2-group 2-bit scalar palettization. + The output shape would be [6, 2], which is the same as the indices shape. + The output tensor values are: + [[lut0[2]->9, lut0[3]->13], + [lut0[3]->13, lut0[0]->1], + [lut0[1]->5, lut0[0]->1], + [lut1[3]->26, lut1[0]->2], + [lut1[2]->18, lut1[1]->10], + [lut1[0]->2, lut1[3]->26]] + """ + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + if nbits == 1: + indices = np.array([0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1]).reshape((6, 2)) + lut = np.array([1, 5, 9, 13]).reshape((2, 1, 2, 1)).astype(np.int8) + else: + indices = np.array([2, 3, 3, 0, 1, 0, 3, 0, 2, 1, 0, 3]).reshape((6, 2)) + lut = self._pad_lut_for_nbits_requirements( + np.array([1, 5, 9, 13, 2, 10, 18, 26]).reshape((2, 1, 4, 1)).astype(np.int8), + nbits=nbits, + ) + indices_np_dtype = types.nptype_from_builtin(types.string_to_builtin(f"uint{nbits}")) + indices = indices.astype(indices_np_dtype) + return mb.constexpr_lut_to_dense(indices=indices, lut=lut) + + constexpr_lut_to_dense_op = prog.functions["main"].find_ops( + op_type="constexpr_lut_to_dense" + )[0] + if nbits == 1: + decompressed_res = np.array([1, 5, 5, 1, 5, 1, 13, 9, 13, 13, 9, 13]) + else: + decompressed_res = np.array([9, 13, 13, 1, 5, 1, 26, 2, 18, 10, 2, 26]) + decompressed_res = decompressed_res.reshape((6, 2)).astype(np.int8) + np.testing.assert_allclose( + decompressed_res, constexpr_lut_to_dense_op.outputs[0].op.materialized_val_inference() + ) + + @pytest.mark.parametrize("vector_axis", (0, 1, 2, -1)) + def test_builder_eval_vector_lut(self, vector_axis): + """ + Test vector lut on different axis. + + indices = [ + [ + [4, 8], -> group 0 + [10, 13], -> group 0 + [24, 5], -> group 1 + [6, 9] -> group 1 + ], + [ + [13, 31], -> group 0 + [17, 7], -> group 0 + [2, 8], -> group 1 + [3, 1] -> group 1 + ] + ] + """ + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_lut_to_dense( + indices=np.array([4, 8, 10, 13, 24, 5, 6, 9, 13, 31, 17, 7, 2, 8, 3, 1]) + .reshape((2, 4, 2)) + .astype(np.uint8), + lut=self._generate_lut(shape=(1, 2, 1, 256, 3)), + vector_axis=vector_axis, + ) + + constexpr_lut_to_dense_op = prog.functions["main"].find_ops( + op_type="constexpr_lut_to_dense" + )[0] + if vector_axis == 0: + decompressed_res = ( + np.array( + [ + 12, + 24, + 30, + 39, + 127, + 127, + 127, + 127, + 13, + 25, + 31, + 40, + 127, + 127, + 127, + 127, + 14, + 26, + 32, + 41, + 127, + 127, + 127, + 127, + 39, + 93, + 51, + 21, + 127, + 127, + 127, + 127, + 40, + 94, + 52, + 22, + 127, + 127, + 127, + 127, + 41, + 95, + 53, + 23, + 127, + 127, + 127, + 127, + ] + ) + .reshape((2 * 3, 4, 2)) + .astype(np.int8) + ) + elif vector_axis == 1: + decompressed_res = ( + np.array( + [ + 12, + 24, + 13, + 25, + 14, + 26, + 30, + 39, + 31, + 40, + 32, + 41, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 39, + 93, + 40, + 94, + 41, + 95, + 51, + 21, + 52, + 22, + 53, + 23, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + ] + ) + .reshape((2, 4 * 3, 2)) + .astype(np.int8) + ) + else: + decompressed_res = ( + np.array( + [ + 12, + 13, + 14, + 24, + 25, + 26, + 30, + 31, + 32, + 39, + 40, + 41, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 39, + 40, + 41, + 93, + 94, + 95, + 51, + 52, + 53, + 21, + 22, + 23, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + 127, + ] + ) + .reshape((2, 4, 2 * 3)) + .astype(np.int8) + ) + np.testing.assert_allclose( + decompressed_res, constexpr_lut_to_dense_op.outputs[0].op.materialized_val_inference() + ) + + @pytest.mark.parametrize( + "compute_unit, backend, nbits", itertools.product(compute_units, backends, [2, 3, 4, 6, 8]) + ) + def test_builder_to_backend_smoke(self, compute_unit, backend, nbits): + x_val = np.ones(12).astype(np.float32).reshape(6, 2) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + + def build(x): + indices = np.array([2, 3, 3, 0, 1, 0, 3, 0, 2, 1, 0, 3]).reshape((6, 2)) + lut = self._pad_lut_for_nbits_requirements( + np.array([1, 5, 9, 13, 2, 10, 18, 26]).reshape((2, 1, 4, 1)).astype(np.int8), + nbits=nbits, + ) + indices_np_dtype = types.nptype_from_builtin(types.string_to_builtin(f"uint{nbits}")) + indices = indices.astype(indices_np_dtype) + + output = mb.constexpr_lut_to_dense( + indices=indices, + lut=lut, + ) + return mb.add(x=x, y=output) + + expected_output = np.array([9, 13, 13, 1, 5, 1, 26, 2, 18, 10, 2, 26]).reshape(6, 2) + 1 + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, nbits, block_sizes, vector_size, lut_dtype", + itertools.product( + compute_units, + backends, + [2, 3, 4, 6, 8], + [(0, 2, 0, 0), (2, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1), (4, 2, 0, 0), (4, 8, 16, 8)], + [1, 4], + ["fp16", "fp32"], # TODO (rdar://125859751): Add "int8" and "uint8". + ), + ) + def test_builder_to_backend_stress( + self, compute_unit, backend, nbits, block_sizes, vector_size, lut_dtype + ): + """Use constexpr_lut_to_dense op's value inference to check backends outputs.""" + indices_shape = (4, 8, 16, 8) + builtin_dtype = types.string_to_builtin(f"uint{nbits}") + np_dtype = types.nptype_from_builtin(builtin_dtype) + indices = np.random.randint(low=0, high=2**nbits, size=indices_shape).astype(np_dtype) + + lut_np_dtype = types.nptype_from_builtin(types.string_to_builtin(lut_dtype)) + lut_shape = _infer_lut_shape(indices_shape, block_sizes, nbits, vector_size) + lut = np.random.rand(*lut_shape).astype(lut_np_dtype) + + vector_axis = 0 if vector_size > 1 else None + + def build(x): + output = mb.constexpr_lut_to_dense( + indices=indices, + lut=lut, + vector_axis=vector_axis, + ) + x, output = promote_input_dtypes([x, output]) + return mb.add(x=x, y=output) + + output_shape = list(indices.shape) + if vector_size > 1: + output_shape[vector_axis] *= vector_size + x_val = np.ones(output_shape).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + expected_output = ( + constexpr_lut_to_dense.decompress(indices, lut, vector_axis=vector_axis) + 1 + ) + + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + +class TestConstexprSparseToDense: + def test_builder_eval_basic(self): + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_sparse_to_dense( + nonzero_data=np.array([3.0, 5.0, 4.0]), + mask=np.array([1, 0, 1, 0, 1, 0]).reshape((2, 3)).astype(types.np_uint1_dtype), + ) + + constexpr_sparse_to_dense_op = prog.functions["main"].find_ops( + op_type="constexpr_sparse_to_dense" + )[0] + decompressed_res = np.array([[3.0, 0.0, 5.0], [0.0, 4.0, 0.0]]) + np.testing.assert_allclose( + decompressed_res, + constexpr_sparse_to_dense_op.outputs[0].op.materialized_val_inference(), + ) + + @pytest.mark.parametrize( + "shape, data_dtype", + itertools.product( + ((2, 3, 4), (3, 8), (24,)), + (types.int4, types.uint4, types.int8, types.uint8, types.fp16, types.fp32), + ), + ) + def test_builder_eval_numerical_stress(self, shape, data_dtype): + np_dtype = types.nptype_from_builtin(data_dtype) + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_sparse_to_dense( + nonzero_data=np.array([3.0, 5.0, 4.0]).astype(np_dtype), + mask=np.array([1, 0, 1, 0, 1, 0] + [0] * 18) + .reshape(shape) + .astype(types.np_uint1_dtype), + ) + + constexpr_sparse_to_dense_op = prog.functions["main"].find_ops( + op_type="constexpr_sparse_to_dense" + )[0] + decompressed_res = np.array([3, 0, 5, 0, 4, 0] + [0] * 18).reshape(shape).astype(np_dtype) + np.testing.assert_allclose( + decompressed_res, + constexpr_sparse_to_dense_op.outputs[0].op.materialized_val_inference(), + ) + + def test_builder_eval_invalid_parameter(self): + with pytest.raises( + ValueError, match="Parameter nonzero_data needs to have rank 1, but got 2" + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_sparse_to_dense( + nonzero_data=np.array([1.0, 5.0, 4.0]).reshape((3, 1)), + mask=np.array([1, 1, 1, 0, 0, 0]).reshape((2, 3)).astype(types.np_uint1_dtype), + ) + + with pytest.raises( + AssertionError, + match="Number of 1s in mask not match number of elements in parameter nonzero_data", + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_sparse_to_dense( + nonzero_data=np.array([1.0, 5.0, 4.0]), + mask=np.array([1, 1, 1, 0, 1, 0]).reshape((2, 3)).astype(types.np_uint1_dtype), + ) + + @pytest.mark.parametrize( + "compute_unit, backend, data_dtype", + itertools.product( + compute_units, + backends, + ("fp16", "fp32"), # TODO (rdar://125859751): Add "int8" and "uint8". + ), + ) + def test_builder_to_backend_smoke(self, compute_unit, backend, data_dtype): + builtin_dtype = types.string_to_builtin(data_dtype) + np_dtype = types.nptype_from_builtin(builtin_dtype) + x_val = np.array([1, 1, 1, 1, 1, 1], dtype=np_dtype).reshape((2, 3)) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape, dtype=builtin_dtype)} + + def build(x): + nonzero_data = np.array([3.0, 5.0, 4.0]).astype(np_dtype) + mask = np.array([1, 0, 1, 0, 1, 0]).reshape((2, 3)).astype(types.np_uint1_dtype) + + output = mb.constexpr_sparse_to_dense( + nonzero_data=nonzero_data, + mask=mask, + ) + return mb.add(x=x, y=output) + + expected_output = np.array([[3.0, 0.0, 5.0], [0.0, 4.0, 0.0]]).astype(np_dtype) + 1 + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (builtin_dtype,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, sparse_ratio, data_dtype", + itertools.product( + compute_units, + backends, + [0.01, 0.5, 0.99], + ["fp16", "fp32"], # TODO (rdar://125859751): Add "int8" and "uint8". + ), + ) + def test_builder_to_backend_stress(self, compute_unit, backend, sparse_ratio, data_dtype): + """Use constexpr_sparse_to_dense op's value inference to check backends outputs.""" + dense_data_shape = (4, 8, 16, 8) + mask = np.random.choice( + [0, 1], size=dense_data_shape, p=[sparse_ratio, 1.0 - sparse_ratio] + ).astype(types.np_uint1_dtype) + non_zero_element_num = np.sum(mask) + data_np_dtype = types.nptype_from_builtin(types.string_to_builtin(data_dtype)) + nonzero_data = np.random.rand(non_zero_element_num).astype(data_np_dtype) + + def build(x): + output = mb.constexpr_sparse_to_dense( + nonzero_data=nonzero_data, + mask=mask, + ) + x, output = promote_input_dtypes([x, output]) + return mb.add(x=x, y=output) + + x_val = np.ones_like(mask).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + expected_output = constexpr_sparse_to_dense.decompress(nonzero_data, mask) + 1 + + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + +class TestConstexprLutToSparse: + def test_builder_eval_scalar_lut(self): + """ + indices_mask = + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 1], + [0, 1, 1, 0, 1, 0], + [0, 0, 0, 1, 0, 0]] + indices_nonzero_data = [0, 1, 1, 0, 1, 1, 0, 0, 1] + lut = [2.0, 3.0] + + The output mask is the same as input indices_mask. + The output sparse tensor in the dense layout is: + 2.0 3.0 + 3.0 2.0 3.0 + 3.0 2.0 2.0 + 3.0 + So the output nonzero_data is [2.0, 3.0, 3.0, 2.0, 3.0, 3.0, 2.0, 2.0, 3.0]. + """ + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_lut_to_sparse( + indices_mask=np.array( + [[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 1], [0, 1, 1, 0, 1, 0], [0, 0, 0, 1, 0, 0]] + ).astype(types.np_uint1_dtype), + indices_nonzero_data=np.array([0, 1, 1, 0, 1, 1, 0, 0, 1]).astype( + types.np_uint1_dtype + ), + lut=np.array([2.0, 3.0]).reshape((1, 1, 2, 1)), + ) + + constexpr_lut_to_sparse_op = prog.functions["main"].find_ops( + op_type="constexpr_lut_to_sparse" + )[0] + expected_output_mask = np.array( + [[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 1], [0, 1, 1, 0, 1, 0], [0, 0, 0, 1, 0, 0]] + ) + expected_output_nonzero_data = np.array([2.0, 3.0, 3.0, 2.0, 3.0, 3.0, 2.0, 2.0, 3.0]) + output_mask, output_nonzero_data = constexpr_lut_to_sparse_op.outputs[ + 0 + ].op.materialized_val_inference() + np.testing.assert_allclose(output_mask, expected_output_mask) + np.testing.assert_allclose(output_nonzero_data, expected_output_nonzero_data) + + def test_builder_eval_vector_lut(self): + """ + indices_mask = + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 1], + [0, 1, 1, 0, 1, 0], + [0, 0, 0, 1, 0, 0]] + indices_nonzero_data = [0, 1, 1, 0, 1, 1, 0, 0, 1] + lut = [ + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 1], + [1, 1, 0, 0, 0, 1], + [0, 1, 1, 0, 1, 0], + [0, 1, 1, 0, 1, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0], + ] + The second output in the dense layout would be: + 2.0 3.0 + 2.0 3.0 + 3.0 2.0 3.0 + 3.0 2.0 3.0 + 3.0 2.0 2.0 + 3.0 2.0 2.0 + 3.0 + 3.0 + It is created by fetching the vector entry from the lut for every bit 1 in the data_mask, + and filling the vector over axis=0. + """ + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_lut_to_sparse( + indices_mask=np.array( + [[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 1], [0, 1, 1, 0, 1, 0], [0, 0, 0, 1, 0, 0]] + ).astype(types.np_uint1_dtype), + indices_nonzero_data=np.array([0, 1, 1, 0, 1, 1, 0, 0, 1]).astype( + types.np_uint1_dtype + ), + lut=np.array([[2.0, 2.0], [3.0, 3.0]]).reshape((1, 1, 2, 2)), + vector_axis=0, + ) + + constexpr_lut_to_sparse_op = prog.functions["main"].find_ops( + op_type="constexpr_lut_to_sparse" + )[0] + expected_output_mask = np.array( + [ + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 1], + [1, 1, 0, 0, 0, 1], + [0, 1, 1, 0, 1, 0], + [0, 1, 1, 0, 1, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0], + ] + ) + expected_output_nonzero_data = np.array( + [ + 2.0, + 3.0, + 2.0, + 3.0, + 3.0, + 2.0, + 3.0, + 3.0, + 2.0, + 3.0, + 3.0, + 2.0, + 2.0, + 3.0, + 2.0, + 2.0, + 3.0, + 3.0, + ] + ) + output_mask, output_nonzero_data = constexpr_lut_to_sparse_op.outputs[ + 0 + ].op.materialized_val_inference() + np.testing.assert_allclose(output_mask, expected_output_mask) + np.testing.assert_allclose(output_nonzero_data, expected_output_nonzero_data) + + def test_builder_eval_invalid_parameter(self): + with pytest.raises( + AssertionError, + match="Number of 1s in mask not match number of elements in parameter nonzero_data", + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_lut_to_sparse( + indices_mask=np.array([1, 1, 1, 0, 1, 0]) + .reshape((2, 3)) + .astype(types.np_uint1_dtype), + indices_nonzero_data=np.array([0, 1, 0]).astype(types.np_uint1_dtype), + lut=np.array([2.0, 3.0]).reshape((1, 1, 2, 1)), + ) + + with pytest.raises( + ValueError, + match=re.escape( + "When lut's last dim (VECTOR_SIZE) > 1, the parameter " + "'vector_axis' need to be provided." + ), + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_lut_to_sparse( + indices_mask=np.array([1, 1, 1, 0, 1, 0]) + .reshape((2, 3)) + .astype(types.np_uint1_dtype), + indices_nonzero_data=np.array([0, 1, 0, 1]).astype(types.np_uint1_dtype), + lut=np.array([2.0, 3.0, 2.0, 3.0]).reshape((1, 1, 2, 2)), + ) + + @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) + def test_builder_to_backend_smoke(self, compute_unit, backend): + x_val = np.ones(18).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + + def build(x): + indices_mask = np.array( + [[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 1], [0, 1, 1, 0, 1, 0], [0, 0, 0, 1, 0, 0]] + ).astype(types.np_uint1_dtype) + indices_nonzero_data = np.array([0, 1, 1, 0, 1, 1, 0, 0, 1]).astype( + types.np_uint1_dtype + ) + lut = np.array([[2.0, 2.0], [3.0, 3.0]]).reshape((1, 1, 2, 2)) + vector_axis = 0 + + output_mask, output_nonzero_data = mb.constexpr_lut_to_sparse( + indices_mask=indices_mask, + indices_nonzero_data=indices_nonzero_data, + lut=lut, + vector_axis=vector_axis, + ) + return mb.add(x=x, y=output_nonzero_data) + + expected_output = 1 + np.array( + [ + 2.0, + 3.0, + 2.0, + 3.0, + 3.0, + 2.0, + 3.0, + 3.0, + 2.0, + 3.0, + 3.0, + 2.0, + 2.0, + 3.0, + 2.0, + 2.0, + 3.0, + 3.0, + ] + ) + + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, nbits, block_sizes, vector_size, sparse_ratio, lut_dtype", + itertools.product( + compute_units, + backends, + [2, 3, 4, 6, 8], + [(0, 1, 1, 1), (0, 0, 0, 2), (0, 0, 0, 0), (1, 1, 1, 1), (0, 4, 2, 0), (4, 8, 16, 8)], + [1, 4], + [0.01, 0.5, 0.99], + ["fp16", "fp32"], # TODO (rdar://125859751): Add "int8" and "uint8". + ), + ) + def test_builder_to_backend_stress( + self, compute_unit, backend, nbits, block_sizes, vector_size, sparse_ratio, lut_dtype + ): + """Use constexpr_lut_to_sparse op's value inference to check backends outputs.""" + indices_shape = (4, 8, 16, 8) + indices_mask = np.random.choice( + [0, 1], size=indices_shape, p=[sparse_ratio, 1.0 - sparse_ratio] + ).astype(types.np_uint1_dtype) + indices_nonzero_element_num = np.sum(indices_mask) + indices_np_dtype = types.nptype_from_builtin(types.string_to_builtin(f"uint{nbits}")) + indices_nonzero_data = np.random.randint( + low=0, high=2**nbits, size=indices_nonzero_element_num + ).astype(indices_np_dtype) + + lut_np_dtype = types.nptype_from_builtin(types.string_to_builtin(lut_dtype)) + lut_shape = _infer_lut_shape(indices_shape, block_sizes, nbits, vector_size) + lut = np.random.rand(*lut_shape).astype(lut_np_dtype) + vector_axis = 0 if vector_size > 1 else None + + def build(x): + output_mask, output_nonzero_data = mb.constexpr_lut_to_sparse( + indices_mask=indices_mask, + indices_nonzero_data=indices_nonzero_data, + lut=lut, + vector_axis=vector_axis, + ) + x, output_nonzero_data = promote_input_dtypes([x, output_nonzero_data]) + return mb.add(x=x, y=output_nonzero_data) + + output_shape = int(indices_nonzero_element_num * vector_size) + x_val = np.ones(output_shape).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + expected_output = ( + constexpr_lut_to_sparse.decompress( + indices_mask, indices_nonzero_data, lut, vector_axis + )[1] + + 1 + ) + + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + +class TestConstexprSparseBlockwiseShiftScale: + def test_builder_eval_sparse_per_channel(self): + """ + Test per-channel de-quantization on sparse tensor. + + data_mask = [[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 0]] + nonzero_data = [10, 11, 3, 4, 5, 6, 7, 8, 9] + scale = [[0.1, 0.2, 0.3, 0.4]] + offset = [[1, 2, 3, 4]] + The sparse tensor in the dense layout would look like: + 10 11 + 3 4 5 + 6 7 + 8 9 + + The input `nonzero_data` would be dequantized per-column as in the dense layout, and the + output sparse tensor in the dense layout would be: + (10-1)*0.1 (11-2)*0.2 + (3-1)*0.1 (4-2)*0.2 (5-3)*0.3 + (6-3)*0.3 (7-4)*0.4 + (8-1)*0.1 (9-2)*0.2 + + The first output would be the same as the `data_mask`, + The second output would be [0.9, 1.8, 0.2, 0.4, 0.6, 0.9, 1.2, 0.7, 1.4] + """ + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_sparse_blockwise_shift_scale( + data_mask=np.array([[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 0]]).astype( + types.np_uint1_dtype + ), + nonzero_data=np.array([10, 11, 3, 4, 5, 6, 7, 8, 9]).astype(np.int8), + scale=np.array([[0.1, 0.2, 0.3, 0.4]]), + offset=np.array([[1, 2, 3, 4]]).astype(np.int8), + ) + + constexpr_sparse_blockwise_shift_scale_op = prog.functions["main"].find_ops( + op_type="constexpr_sparse_blockwise_shift_scale" + )[0] + expected_output_mask = np.array([[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 0]]) + expected_output_nonzero_data = np.array([0.9, 1.8, 0.2, 0.4, 0.6, 0.9, 1.2, 0.7, 1.4]) + output_mask, output_nonzero_data = constexpr_sparse_blockwise_shift_scale_op.outputs[ + 0 + ].op.materialized_val_inference() + np.testing.assert_allclose(output_mask, expected_output_mask) + np.testing.assert_allclose(output_nonzero_data, expected_output_nonzero_data) + + def test_builder_eval_sparse_per_block(self): + """ + Test per-block de-quantization on sparse tensor with block size 2. + + data_mask = [[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] # shape [4, 4] + nonzero_data = [10, 11, 3, 4, 5, 6, 7, 8, 9, 2] + scale = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]] # shape [4, 2] because block size is [1, 2] + offset = [[1, 2], [3, 4], [5, 6], [7, 8]] + The sparse tensor in the dense layout would look like: + 10 11 + 3 4 5 + 6 7 + 8 9 2 + + The input `nonzero_data` would be dequantized per-column as in the dense layout, and the + output sparse tensor in the dense layout would be: + (10-1)*0.1 (11-1)*0.1 + (3-3)*0.3 (4-3)*0.3 (5-4)*0.4 + (6-6)*0.6 (7-6)*0.6 + (8-7)*0.7 (9-7)*0.7 (2-8)*0.8 + + The first output would be the same as the `data_mask`, + The second output would be [0.9, 1.0, 0.0, 0.3, 0.4, 0.0, 0.6, 0.7, 1.4, -4.8] + """ + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_sparse_blockwise_shift_scale( + data_mask=np.array([[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]).astype( + types.np_uint1_dtype + ), + nonzero_data=np.array([10, 11, 3, 4, 5, 6, 7, 8, 9, 2]).astype(np.int8), + scale=np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), + offset=np.array([[1, 2], [3, 4], [5, 6], [7, 8]]).astype(np.int8), + ) + + constexpr_sparse_blockwise_shift_scale_op = prog.functions["main"].find_ops( + op_type="constexpr_sparse_blockwise_shift_scale" + )[0] + expected_output_mask = np.array([[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]) + expected_output_nonzero_data = np.array([0.9, 1.0, 0.0, 0.3, 0.4, 0.0, 0.6, 0.7, 1.4, -4.8]) + output_mask, output_nonzero_data = constexpr_sparse_blockwise_shift_scale_op.outputs[ + 0 + ].op.materialized_val_inference() + np.testing.assert_allclose(output_mask, expected_output_mask) + np.testing.assert_allclose(output_nonzero_data, expected_output_nonzero_data) + + def test_builder_eval_invalid_parameter(self): + with pytest.raises( + AssertionError, + match="Number of 1s in mask not match number of elements in parameter nonzero_data", + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_sparse_blockwise_shift_scale( + data_mask=np.array([1, 1, 1, 0, 1, 0]) + .reshape((2, 3)) + .astype(types.np_uint1_dtype), + nonzero_data=np.array([0, 1, 0]).astype(np.int8), + scale=np.array([[0.1, 0.2, 0.3]]), + ) + + with pytest.raises( + ValueError, + match=re.escape("the shape of 'offset' should match the shape of 'scale'"), + ): + + @mb.program(input_specs=[], opset_version=_IOS18_TARGET) + def prog(): + return mb.constexpr_sparse_blockwise_shift_scale( + data_mask=np.array([1, 1, 1, 0, 1, 0]) + .reshape((2, 3)) + .astype(types.np_uint1_dtype), + nonzero_data=np.array([0, 1, 0, 1]).astype(np.int8), + scale=np.array([[0.1, 0.2, 0.3]]), + offset=np.array([[1, 2, 3, 4]]).astype(np.int8), + ) + + @pytest.mark.parametrize( + "compute_unit, backend, per_block, data_dtype", + itertools.product( + compute_units, + backends, + (True, False), + (types.uint4, types.int8, types.uint8, types.fp32), + ), + ) + def test_builder_to_backend_smoke(self, compute_unit, backend, per_block, data_dtype): + x_val = np.ones(10).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + np_dtype = types.nptype_from_builtin(data_dtype) + + def build(x): + data_mask_val = np.array( + [[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ).astype(types.np_uint1_dtype) + nonzero_data_val = np.array([10, 11, 3, 4, 5, 6, 7, 8, 9, 2]).astype(np_dtype) + + if per_block: + scale_val = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]) + else: + scale_val = np.array([[0.1, 0.2, 0.3, 0.4]]) + + if per_block: + offset_val = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]).astype(np_dtype) + else: + offset_val = np.array([[1, 2, 3, 4]]).astype(np_dtype) + + output_mask, output_nonzero_data = mb.constexpr_sparse_blockwise_shift_scale( + data_mask=data_mask_val, + nonzero_data=nonzero_data_val, + scale=scale_val, + offset=offset_val, + ) + return mb.add(x=x, y=output_nonzero_data) + + if per_block: + expected_output = np.array([0.9, 1.0, 0.0, 0.3, 0.4, 0.0, 0.6, 0.7, 1.4, -4.8]) + 1 + else: + expected_output = np.array([0.9, 1.8, 0.2, 0.4, 0.6, 0.9, 1.2, 0.7, 1.4, -0.8]) + 1 + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + atol=1e-3, + rtol=1e-3, + ) + + @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) + def test_builder_to_backend_corner_case(self, compute_unit, backend): + """ + This test case uses the real data from a conv model. + + It's for testing the scale/offset is correctly repeated and the joint ops + materialized_val_inference work as expected. + """ + + def build_weight(): + data_mask = np.array( + [ + [[[0, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 0], [1, 1]]], + [[[1, 1], [1, 1]], [[1, 1], [0, 0]], [[1, 1], [1, 1]], [[1, 0], [0, 0]]], + ] + ).astype(types.np_uint1_dtype) + data_mask, nonzero_data = mb.constexpr_sparse_blockwise_shift_scale( + data_mask=data_mask, + nonzero_data=np.array( + [ + -8, + -2, + 7, + -4, + -7, + -6, + -5, + 2, + -6, + 7, + -5, + 2, + -8, + -6, + -7, + -8, + -5, + -8, + 6, + 7, + 6, + -7, + 7, + 2, + -8, + ] + ).astype(np.int8), + scale=np.array([[[[0.01955]], [[0.02809]]], [[[0.02898]], [[0.02487]]]]), + offset=np.array([[[[3]], [[-1]]], [[[-2]], [[-3]]]]).astype(np.int8), + ) + return mb.constexpr_sparse_to_dense(nonzero_data=nonzero_data, mask=data_mask) + + def build(x): + return mb.add(x=x, y=build_weight()) + + # Get the const expected weight by decompressing val inference from the joint constexpr ops. + weight_prog = mb.program(input_specs=[], opset_version=_IOS18_TARGET)(build_weight) + result_op = weight_prog.functions["main"].find_ops(op_type="constexpr_sparse_to_dense")[0] + expected_weight = result_op.outputs[0].op.materialized_val_inference() + + x_val = np.ones(2 * 4 * 2 * 2).reshape((2, 4, 2, 2)).astype(np.float32) + expected_output = expected_weight + 1 + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + # With joint quant + sparse ops, the backend prediction should match the expected_weight. + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + # Test conv using joint constexpr ops weight matches using the decompressed const weight. + def build_conv_with_joint_constexpr_weight(x): + return mb.conv(x=x, weight=build_weight()) + + def build_conv_with_const_weight(x): + return mb.conv(x=x, weight=expected_weight) + + x_val = np.random.rand(1, 4, 10, 10).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + mlmodel_conv_with_joint_constexpr_weight = run_compare_builder( + build_conv_with_joint_constexpr_weight, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=(1, 2, 9, 9) + (types.fp32,), + frontend_only=True, + compute_unit=compute_unit, + backend=backend, + ) + mlmodel_conv_with_const_weight = run_compare_builder( + build_conv_with_const_weight, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=(1, 2, 9, 9) + (types.fp32,), + frontend_only=True, + compute_unit=compute_unit, + backend=backend, + ) + result_1 = mlmodel_conv_with_joint_constexpr_weight.predict({"x": x_val}) + result_2 = mlmodel_conv_with_const_weight.predict({"x": x_val}) + + np.testing.assert_allclose(result_1["conv_0"], result_2["conv_0"], rtol=3e-3, atol=3e-4) + + @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) + def test_builder_to_backend_no_offset(self, compute_unit, backend): + """ + Test per-channel de-quantization on sparse tensor without offset. + + data_mask = [[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 0]] + nonzero_data = [10, 11, 3, 4, 5, 6, 7, 8, 9] + scale = [[0.1, 0.2, 0.3, 0.4]] + The sparse tensor in the dense layout would look like: + 10 11 + 3 4 5 + 6 7 + 8 9 + + The input `nonzero_data` would be dequantized per-column as in the dense layout, and the + output sparse tensor in the dense layout would be: + (10)*0.1 (11)*0.2 + (3)*0.1 (4)*0.2 (5)*0.3 + (6)*0.3 (7)*0.4 + (8)*0.1 (9)*0.2 + + The first output would be the same as the `data_mask`, + The second output would be [1.0, 1.1, 0.3, 0.8, 1.5, 1.8, 2.8, 0.8, 1.8] + """ + data_dtype = types.int8 + x_val = np.ones(9).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + np_dtype = types.nptype_from_builtin(data_dtype) + + def build(x): + data_mask_val = np.array( + [[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [1, 1, 0, 0]] + ).astype(types.np_uint1_dtype) + nonzero_data_val = np.array([10, 11, 3, 4, 5, 6, 7, 8, 9]).astype(np_dtype) + scale_val = np.array([[0.1, 0.2, 0.3, 0.4]]) + + output_mask, output_nonzero_data = mb.constexpr_sparse_blockwise_shift_scale( + data_mask=data_mask_val, + nonzero_data=nonzero_data_val, + scale=scale_val, + ) + return mb.add(x=x, y=output_nonzero_data) + + expected_output = np.array([1.0, 2.2, 0.3, 0.8, 1.5, 1.8, 2.8, 0.8, 1.8]) + 1 + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, dtype, block_sizes, has_offset, sparse_ratio", + itertools.product( + compute_units, + backends, + ["int4", "uint4", "int8", "uint8", "fp16"], + [(0, 1, 1, 1), (0, 0, 0, 2), (0, 0, 0, 0), (1, 1, 1, 1), (0, 4, 2, 0), (4, 8, 16, 8)], + [True, False], + [0.01, 0.5, 0.99], + ), + ) + def test_builder_to_backend_stress( + self, compute_unit, backend, dtype, block_sizes, has_offset, sparse_ratio + ): + """ + Use constexpr_sparse_blockwise_shift_scale op's value inference to check backends outputs. + """ + quantized_data_shape = (4, 8, 16, 8) + builtin_dtype = types.string_to_builtin(dtype) + np_dtype = types.nptype_from_builtin(builtin_dtype) + + data_mask = np.random.choice( + [0, 1], size=quantized_data_shape, p=[sparse_ratio, 1.0 - sparse_ratio] + ).astype(types.np_uint1_dtype) + data_nonzero_element_num = int(np.sum(data_mask)) + + if types.is_int(builtin_dtype): + data_range = types.type_mapping.builtin_to_range(builtin_dtype) + quantized_data = np.random.randint( + low=data_range.low, high=data_range.high + 1, size=data_nonzero_element_num + ).astype(np_dtype) + else: + quantized_data = np.random.rand(data_nonzero_element_num).astype(np_dtype) + + scale_shape = [ + 1 if block_sizes[axis] == 0 else dim_size // block_sizes[axis] + for axis, dim_size in enumerate(quantized_data_shape) + ] + scale = np.random.rand(*scale_shape) + offset = None + if has_offset: + if types.is_int(builtin_dtype): + offset = np.random.randint( + low=data_range.low, high=data_range.high + 1, size=scale.shape + ).astype(np_dtype) + else: + offset = np.random.rand(*scale.shape).astype(np_dtype) + + def build(x): + output_mask, output_nonzero_data = mb.constexpr_sparse_blockwise_shift_scale( + data_mask=data_mask, + nonzero_data=quantized_data, + scale=scale, + offset=offset, + ) + return mb.add(x=x, y=output_nonzero_data) + + x_val = np.ones_like(quantized_data).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + expected_output = ( + constexpr_sparse_blockwise_shift_scale.decompress( + data_mask, quantized_data, scale, offset + )[1] + + 1 + ) + + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + +class TestJointCompressionOps: + @pytest.mark.parametrize( + "compute_unit, backend, nbits, block_sizes, vector_size, lut_dtype, quant_dtype", + itertools.product( + compute_units, + backends, + [2, 3, 4, 8], + [(0, 2, 0, 0), (2, 0, 0, 0), (4, 2, 0, 0)], + [1, 4], + ["fp16", "fp32"], + ["int4", "uint4", "int8", "uint8"], + ), + ) + def test_quant_lut( + self, compute_unit, backend, nbits, block_sizes, vector_size, lut_dtype, quant_dtype + ): + """ + Test lut with quantized (int8) entries, which is represented as + lut(int8) -> constexpr_blockwise_shift_scale -> lut(fp) \ + constexpr_lut_to_dense -> dense(fp) + indices / + """ + indices_shape = (4, 8, 16, 8) + builtin_dtype = types.string_to_builtin(f"uint{nbits}") + np_dtype = types.nptype_from_builtin(builtin_dtype) + indices = np.random.randint(low=0, high=2**nbits, size=indices_shape).astype(np_dtype) + + lut_np_dtype = types.nptype_from_builtin(types.string_to_builtin(lut_dtype)) + lut_shape = _infer_lut_shape(indices_shape, block_sizes, nbits, vector_size) + vector_axis = 0 if vector_size > 1 else None + + quant_builtin_dtype = types.string_to_builtin(quant_dtype) + quant_np_dtype = types.nptype_from_builtin(quant_builtin_dtype) + quant_data_range = types.type_mapping.builtin_to_range(quant_builtin_dtype) + quantized_data = np.random.randint( + low=quant_data_range.low, high=quant_data_range.high + 1, size=lut_shape + ).astype(quant_np_dtype) + scale_shape = tuple([1] * len(lut_shape)) + scale = np.array([2.0]).reshape(scale_shape).astype(lut_np_dtype) + offset = np.array([3]).reshape(scale_shape).astype(quant_np_dtype) + + def build(x): + lut = mb.constexpr_blockwise_shift_scale( + data=quantized_data, + scale=scale, + offset=offset, + ) + output = mb.constexpr_lut_to_dense( + indices=indices, + lut=lut, + vector_axis=vector_axis, + ) + x, output = promote_input_dtypes([x, output]) + return mb.add(x=x, y=output) + + output_shape = list(indices.shape) + if vector_size > 1: + output_shape[vector_axis] *= vector_size + x_val = np.ones(output_shape).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + lut = constexpr_blockwise_shift_scale.decompress(quantized_data, scale, offset) + expected_output = ( + constexpr_lut_to_dense.decompress(indices, lut, vector_axis=vector_axis) + 1 + ) + + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, nbits, block_sizes, vector_size, sparse_ratio, lut_dtype", + itertools.product( + compute_units, + backends, + [2, 3, 4, 8], + [(0, 2, 0, 0), (2, 0, 0, 0), (1, 1, 1, 1), (4, 2, 0, 0)], + [1, 4], + [0.01, 0.5, 0.99], + ["fp16", "fp32"], # TODO (rdar://125859751): Add "int8" and "uint8". + ), + ) + def test_sparse_lut( + self, compute_unit, backend, nbits, block_sizes, vector_size, sparse_ratio, lut_dtype + ): + """Joint constexpr_lut_to_sparse + constexpr_sparse_to_dense.""" + indices_shape = (4, 8, 16, 8) + indices_mask = np.random.choice( + [0, 1], size=indices_shape, p=[sparse_ratio, 1.0 - sparse_ratio] + ).astype(types.np_uint1_dtype) + indices_nonzero_element_num = np.sum(indices_mask) + indices_np_dtype = types.nptype_from_builtin(types.string_to_builtin(f"uint{nbits}")) + indices_nonzero_data = np.random.randint( + low=0, high=2**nbits, size=indices_nonzero_element_num + ).astype(indices_np_dtype) + + lut_np_dtype = types.nptype_from_builtin(types.string_to_builtin(lut_dtype)) + lut_shape = _infer_lut_shape(indices_shape, block_sizes, nbits, vector_size) + lut = np.random.rand(*lut_shape).astype(lut_np_dtype) + vector_axis = 0 if vector_size > 1 else None + + def build(x): + output_mask, output_nonzero_data = mb.constexpr_lut_to_sparse( + indices_mask=indices_mask, + indices_nonzero_data=indices_nonzero_data, + lut=lut, + vector_axis=vector_axis, + ) + output = mb.constexpr_sparse_to_dense( + nonzero_data=output_nonzero_data, + mask=output_mask, + ) + x, output = promote_input_dtypes([x, output]) + return mb.add(x=x, y=output) + + output_mask, output_nonzero_data = constexpr_lut_to_sparse.decompress( + indices_mask, indices_nonzero_data, lut, vector_axis + ) + expected_output = constexpr_sparse_to_dense.decompress(output_nonzero_data, output_mask) + 1 + + x_val = np.ones(expected_output.shape).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, dtype, block_sizes, has_offset, sparse_ratio", + itertools.product( + compute_units, + backends, + ["int4", "uint4", "int8", "uint8", "fp16"], + [(0, 2, 0, 0), (2, 0, 0, 0), (1, 1, 1, 1), (4, 2, 0, 0)], + [True, False], + [0.01, 0.5, 0.99], + ), + ) + def test_sparse_quant( + self, compute_unit, backend, dtype, block_sizes, has_offset, sparse_ratio + ): + """Joint constexpr_sparse_blockwise_shift_scale + constexpr_sparse_to_dense.""" + quantized_data_shape = (4, 8, 16, 8) + builtin_dtype = types.string_to_builtin(dtype) + np_dtype = types.nptype_from_builtin(builtin_dtype) + + data_mask = np.random.choice( + [0, 1], size=quantized_data_shape, p=[sparse_ratio, 1.0 - sparse_ratio] + ).astype(types.np_uint1_dtype) + data_nonzero_element_num = int(np.sum(data_mask)) + + if types.is_int(builtin_dtype): + data_range = types.type_mapping.builtin_to_range(builtin_dtype) + quantized_data = np.random.randint( + low=data_range.low, high=data_range.high + 1, size=data_nonzero_element_num + ).astype(np_dtype) + else: + quantized_data = np.random.rand(data_nonzero_element_num).astype(np_dtype) + + scale_shape = [ + 1 if block_sizes[axis] == 0 else dim_size // block_sizes[axis] + for axis, dim_size in enumerate(quantized_data_shape) + ] + scale = np.random.rand(*scale_shape) + offset = None + if has_offset: + if types.is_int(builtin_dtype): + offset = np.random.randint( + low=data_range.low, high=data_range.high + 1, size=scale.shape + ).astype(np_dtype) + else: + offset = np.random.rand(*scale.shape).astype(np_dtype) + + def build(x): + output_mask, output_nonzero_data = mb.constexpr_sparse_blockwise_shift_scale( + data_mask=data_mask, + nonzero_data=quantized_data, + scale=scale, + offset=offset, + ) + output = mb.constexpr_sparse_to_dense( + nonzero_data=output_nonzero_data, + mask=output_mask, + ) + return mb.add(x=x, y=output) + + output_mask, output_nonzero_data = constexpr_sparse_blockwise_shift_scale.decompress( + data_mask, quantized_data, scale, offset + ) + expected_output = constexpr_sparse_to_dense.decompress(output_nonzero_data, output_mask) + 1 + + x_val = np.ones(expected_output.shape).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, nbits, block_sizes, vector_size, sparse_ratio, lut_dtype, quant_dtype", + itertools.product( + compute_units, + backends, + [2, 3, 4, 8], + [(0, 2, 0, 0), (2, 0, 0, 0), (4, 2, 0, 0)], + [1, 4], + [0.01, 0.5, 0.99], + ["fp16", "fp32"], + ["int4", "uint4", "int8", "uint8"], + ), + ) + def test_quant_sparse_lut( + self, + compute_unit, + backend, + nbits, + block_sizes, + vector_size, + sparse_ratio, + lut_dtype, + quant_dtype, + ): + """ + Test sparse lut with quantized (int8) entries, which is represented as + constexpr_blockwise_shift_scale + constexpr_lut_to_sparse + constexpr_sparse_to_dense + """ + indices_shape = (4, 8, 16, 8) + indices_mask = np.random.choice( + [0, 1], size=indices_shape, p=[sparse_ratio, 1.0 - sparse_ratio] + ).astype(types.np_uint1_dtype) + indices_nonzero_element_num = np.sum(indices_mask) + indices_np_dtype = types.nptype_from_builtin(types.string_to_builtin(f"uint{nbits}")) + indices_nonzero_data = np.random.randint( + low=0, high=2**nbits, size=indices_nonzero_element_num + ).astype(indices_np_dtype) + + lut_np_dtype = types.nptype_from_builtin(types.string_to_builtin(lut_dtype)) + lut_shape = _infer_lut_shape(indices_shape, block_sizes, nbits, vector_size) + vector_axis = 0 if vector_size > 1 else None + + quant_builtin_dtype = types.string_to_builtin(quant_dtype) + quant_np_dtype = types.nptype_from_builtin(quant_builtin_dtype) + quant_data_range = types.type_mapping.builtin_to_range(quant_builtin_dtype) + quantized_data = np.random.randint( + low=quant_data_range.low, high=quant_data_range.high + 1, size=lut_shape + ).astype(quant_np_dtype) + scale_shape = tuple([1] * len(lut_shape)) + scale = np.array([2.0]).reshape(scale_shape).astype(lut_np_dtype) + offset = np.array([3]).reshape(scale_shape).astype(quant_np_dtype) + + def build(x): + lut = mb.constexpr_blockwise_shift_scale( + data=quantized_data, + scale=scale, + offset=offset, + ) + output_mask, output_nonzero_data = mb.constexpr_lut_to_sparse( + indices_mask=indices_mask, + indices_nonzero_data=indices_nonzero_data, + lut=lut, + vector_axis=vector_axis, + ) + output = mb.constexpr_sparse_to_dense( + nonzero_data=output_nonzero_data, + mask=output_mask, + ) + x, output = promote_input_dtypes([x, output]) + return mb.add(x=x, y=output) + + lut = constexpr_blockwise_shift_scale.decompress(quantized_data, scale, offset) + output_mask, output_nonzero_data = constexpr_lut_to_sparse.decompress( + indices_mask, indices_nonzero_data, lut, vector_axis + ) + expected_output = constexpr_sparse_to_dense.decompress(output_nonzero_data, output_mask) + 1 + + x_val = np.ones(expected_output.shape).astype(np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + run_compare_builder( + build, + input_placeholders, + input_values={"x": x_val}, + expected_output_types=expected_output.shape + (types.fp32,), + expected_outputs=expected_output, + compute_unit=compute_unit, + backend=backend, + ) diff --git a/coremltools/converters/mil/mil/ops/tests/iOS18/test_recurrent.py b/coremltools/converters/mil/mil/ops/tests/iOS18/test_recurrent.py new file mode 100644 index 000000000..299c18486 --- /dev/null +++ b/coremltools/converters/mil/mil/ops/tests/iOS18/test_recurrent.py @@ -0,0 +1,178 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools + +import numpy as np +import pytest +import torch + +import coremltools as ct +from coremltools.converters.mil.mil import Builder as mb, types +from coremltools.converters.mil.mil.ops.tests.iOS17.test_recurrent import TestGRU as _TestGRU_iOS17 +from coremltools.converters.mil.mil.ops.tests.iOS18 import backends +from coremltools.converters.mil.mil.ops.tests.testing_utils import run_compare_builder +from coremltools.converters.mil.testing_reqs import compute_units + + +class TestGRU(_TestGRU_iOS17): + # Test functionality from previous opset version + @pytest.mark.parametrize( + argnames=[ + "compute_unit", + "backend", + "seq_len", + "batch_size", + "input_size", + "hidden_size", + "has_bias", + "output_sequence", + "direction", + "activation_functions", + "symbolic", + "dtype", + ], + argvalues=itertools.product( + compute_units, + backends, + [1, 3], + [1], + [1, 2], + [1, 2], + [True, False], + [True, False], + ["forward", "reverse"], + [ + ["tanh", "sigmoid"], + ["sigmoid", "tanh"], + ], + [True, False], + [np.float16, np.float32], + ), + ) + def test_builder_to_backend_smoke( + self, + compute_unit, + backend, + seq_len, + batch_size, + input_size, + hidden_size, + has_bias, + output_sequence, + direction, + activation_functions, + symbolic, + dtype, + ): + super().test_builder_to_backend_smoke( + compute_unit, + backend, + seq_len, + batch_size, + input_size, + hidden_size, + has_bias, + output_sequence, + direction, + activation_functions, + symbolic, + dtype, + ) + + + @pytest.mark.xfail(reason="rdar://128479517") + @pytest.mark.parametrize( + argnames=[ + "compute_units", + "backend", + "sequence_length", + "num_features", # also called "input_size" + "hidden_size", + "batch_size", + ], + argvalues=itertools.product( + compute_units, + backends, + [1, 3], + [1, 2], + [1], + [1, 2], + ), + ) + def test_pytorch_parity(self, backend, compute_units, sequence_length, num_features, hidden_size, batch_size): + + def get_weight_i_tensor(): + return np.random.rand(hidden_size, num_features).astype('float32') + + def get_weight_h_tensor(): + return np.random.rand(hidden_size, hidden_size).astype('float32') + + def get_bias_tensor(): + return np.random.rand(hidden_size).astype('float32') + + W_ir, W_iz, W_in = get_weight_i_tensor(), get_weight_i_tensor(), get_weight_i_tensor() + W_hr, W_hz, W_hn = get_weight_h_tensor(), get_weight_h_tensor(), get_weight_h_tensor() + + b_ir, b_iz, b_in = get_bias_tensor(), get_bias_tensor(), get_bias_tensor() + b_hr, b_hz, b_hn = get_bias_tensor(), get_bias_tensor(), get_bias_tensor() + + # MIL op only supports single direction and single layer + x = np.random.rand(sequence_length, batch_size, num_features).astype('float16') + initial_h = np.random.rand(1, batch_size, hidden_size).astype('float16') + + # Set up PyTorch model + m_t = torch.nn.GRU(num_features, hidden_size) + t_state = m_t.state_dict() + t_state['weight_ih_l0'] = torch.Tensor(np.concatenate((W_ir, W_iz, W_in))) + t_state['weight_hh_l0'] = torch.Tensor(np.concatenate((W_hr, W_hz, W_hn))) + t_state['bias_ih_l0'] = torch.Tensor(np.concatenate((b_ir, b_iz, b_in))) + t_state['bias_hh_l0'] = torch.Tensor(np.concatenate((b_hr, b_hz, b_hn))) + m_t.load_state_dict(t_state) + + # Get PyTorch results + (out_t, h_t) = m_t(torch.Tensor(x), torch.Tensor(initial_h)) + out_t = out_t.detach().numpy() + h_t = h_t.detach().numpy() + + # MIL op only support num_layers=1 and D=1, so hidden state only has rank 2 + initial_h = initial_h.squeeze(0) + + # MIL program + @mb.program( + [ + mb.TensorSpec(shape=x.shape, dtype=types.fp32), + mb.TensorSpec(shape=initial_h.shape, dtype=types.fp32) + ], + opset_version=backend.opset_version + ) + def prog(x, initial_h): + return mb.gru( + x=x, + initial_h=initial_h, + weight_ih=np.concatenate((W_ir, W_in, W_iz)), + weight_hh=np.concatenate((W_hr, W_hn, W_hz)), + input_bias=np.concatenate((b_ir, b_in, b_iz)), + bias=np.concatenate((b_hr, b_hn, b_hz)), + reset_after=True, + output_sequence=True, + ) + + mlmodel = ct.convert( + prog, + source="milinternal", + convert_to=backend.backend, + minimum_deployment_target=backend.opset_version, + compute_units=compute_units, + pass_pipeline=ct.PassPipeline.EMPTY, + ) + + # Core ML ouput + y_cm = mlmodel.predict({'x': x, 'initial_h': initial_h}) + out_cm, h_cm = y_cm['gru_0_0'], y_cm['gru_0_1'] + + # Check outputs + np.testing.assert_allclose(out_cm, out_t, atol=0.01, rtol=0.1) + np.testing.assert_allclose([h_cm], h_t, atol=0.01, rtol=0.1) diff --git a/coremltools/converters/mil/mil/ops/tests/iOS18/test_states.py b/coremltools/converters/mil/mil/ops/tests/iOS18/test_states.py new file mode 100644 index 000000000..c7204af6a --- /dev/null +++ b/coremltools/converters/mil/mil/ops/tests/iOS18/test_states.py @@ -0,0 +1,342 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools + +import numpy as np +import pytest + +import coremltools as ct +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.builder import Builder as mb +from coremltools.converters.mil.mil.ops.defs.iOS18 import _IOS18_TARGET +from coremltools.converters.mil.mil.ops.tests.iOS18 import backends +from coremltools.converters.mil.mil.ops.tests.testing_utils import run_compare_builder +from coremltools.converters.mil.testing_reqs import compute_units +from coremltools.converters.mil.testing_utils import random_gen + + +class TestCoreMLUpdateState: + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product( + compute_units, + backends, + ), + ) + def test_coreml_update_state_smoke(self, compute_unit, backend): + def build(state, value): + return mb.coreml_update_state( + state=state, + value=value, + ) + + input_placeholders = { + "state": mb.state_tensor_placeholder( + shape=(2,), + dtype=types.fp16, + ), + "value": mb.placeholder(shape=(2,), dtype=types.fp16), + } + value = random_gen((2,)) + input_values = {"value": value} + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types=[(2, types.fp16)], + expected_outputs=[value], + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, shape", + itertools.product( + compute_units, + backends, + [(1,), (2, 3), (4, 5, 6)], + ), + ) + def test_coreml_update_stress(self, compute_unit, backend, shape): + if not compute_unit in (ct.ComputeUnit.CPU_ONLY, ct.ComputeUnit.CPU_AND_GPU): + pytest.xfail( + "rdar://128446982 ([Bug][Stateful model] Stateful model fails to run on ANE)" + ) + + def build(x_in, y_in, z_in): + def increase_val_by_one(state, input): + v = mb.add(x=input, y=np.float16(1)) + return mb.coreml_update_state(state=state, value=v) + + x = mb.read_state(input=x_in) + y = mb.read_state(input=y_in) + z = mb.read_state(input=z_in) + + for i in range(10): + x = increase_val_by_one(x_in, x) + y = increase_val_by_one(y_in, y) + z = increase_val_by_one(z_in, z) + + return mb.read_state(input=x_in), mb.read_state(input=y_in), mb.read_state(input=z_in) + + input_placeholders = { + "x_in": mb.state_tensor_placeholder( + shape=shape, + dtype=types.fp16, + ), + "y_in": mb.state_tensor_placeholder( + shape=shape, + dtype=types.fp16, + ), + "z_in": mb.state_tensor_placeholder( + shape=shape, + dtype=types.fp16, + ), + } + input_values = {} + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types=[ + ( + *shape, + types.fp16, + ) + ] + * 3, + expected_outputs=[ + [ + 10 + * np.ones( + shape, + ) + ] + * 3, + [ + 20 + * np.ones( + shape, + ) + ] + * 3, + ], + compute_unit=compute_unit, + backend=backend, + pred_iters=2, + ) + + +class TestReadState: + @staticmethod + def test_read_tensor_state_builder(): + @mb.program(input_specs=[mb.StateTensorSpec((2, 3))], opset_version=_IOS18_TARGET) + def prog(x): + return mb.read_state(input=x) + + read_state_op = prog.find_ops("read_state")[0] + assert types.is_state(read_state_op.input._sym_type) + assert types.is_tensor(read_state_op.outputs[0]._sym_type) + + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product( + compute_units, + backends, + ), + ) + def test_read_state_smoke(self, compute_unit, backend): + def build(state): + return mb.read_state( + input=state, + ) + + input_placeholders = { + "state": mb.state_tensor_placeholder( + shape=(2,), + dtype=types.fp16, + ), + } + input_values = {} + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types=[(2, types.fp16)], + expected_outputs=[ + np.zeros( + 2, + ) + ], + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, shape", + itertools.product(compute_units, backends, [(1,), (2, 3), (4, 5, 6)]), + ) + def test_read_state_stress(self, compute_unit, backend, shape): + def build(x, y, z): + return ( + mb.read_state( + input=x, + ), + mb.read_state( + input=y, + ), + mb.read_state( + input=z, + ), + ) + + input_placeholders = { + "x": mb.state_tensor_placeholder( + shape=shape, + dtype=types.fp16, + ), + "y": mb.state_tensor_placeholder( + shape=shape, + dtype=types.fp16, + ), + "z": mb.state_tensor_placeholder( + shape=shape, + dtype=types.fp16, + ), + } + input_values = {} + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types=[ + ( + *shape, + types.fp16, + ) + ] + * 3, + expected_outputs=[ + np.zeros( + shape, + ) + ] + * 3, + compute_unit=compute_unit, + backend=backend, + ) + + +class TestStatefulModel: + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product( + compute_units, + backends, + ), + ) + def test_state_model_with_slice_update(self, compute_unit, backend): + def build(x_in, y_in, z_in, update_1, update_2): + def single_slice_update(state, input): + v = mb.slice_update( + x=input, + update=update_1, + begin=[0, 0], + end=[1, 2], + ) + return mb.coreml_update_state(state=state, value=v) + + def double_slice_update(state, input): + v = mb.slice_update( + x=input, + update=update_1, + begin=[0, 0], + end=[1, 2], + ) + v = mb.slice_update( + x=input, + update=update_2, + begin=[1, 1], + end=[3, 3], + ) + return mb.coreml_update_state(state=state, value=v) + + x = mb.read_state(input=x_in) + y = mb.read_state(input=y_in) + z = mb.read_state(input=z_in) + + for i in range(10): + # single slice update + x = single_slice_update(x_in, x) + y = single_slice_update(y_in, y) + z = single_slice_update(z_in, z) + + # double slice update + x = double_slice_update(x_in, x) + y = double_slice_update(y_in, y) + z = double_slice_update(z_in, z) + + return mb.read_state(input=x_in), mb.read_state(input=y_in), mb.read_state(input=z_in) + + shape = (8, 9) + + input_placeholders = { + "x_in": mb.state_tensor_placeholder( + shape=shape, + dtype=types.fp16, + ), + "y_in": mb.state_tensor_placeholder( + shape=shape, + dtype=types.fp16, + ), + "z_in": mb.state_tensor_placeholder( + shape=shape, + dtype=types.fp16, + ), + "update_1": mb.placeholder( + shape=(1, 2), + dtype=types.fp16, + ), + "update_2": mb.placeholder( + shape=(2, 2), + dtype=types.fp16, + ), + } + + update_1_val = np.array([[1, 2]], dtype=np.float16) + update_2_val = np.array([[1, 2], [3, 4]], dtype=np.float16) + input_values = { + "update_1": update_1_val, + "update_2": update_2_val, + } + + output = np.zeros(shape, dtype=np.float16) + output[:1, :2] = update_1_val + output[1:3, 1:3] = update_2_val + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types=[ + ( + *shape, + types.fp16, + ) + ] + * 3, + expected_outputs=[ + [output] * 3, + [output] * 3, + ], + compute_unit=compute_unit, + backend=backend, + pred_iters=2, + ) diff --git a/coremltools/converters/mil/mil/ops/tests/iOS18/test_tensor_transformation.py b/coremltools/converters/mil/mil/ops/tests/iOS18/test_tensor_transformation.py new file mode 100644 index 000000000..8c363dbc5 --- /dev/null +++ b/coremltools/converters/mil/mil/ops/tests/iOS18/test_tensor_transformation.py @@ -0,0 +1,545 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools + +import numpy as np +import pytest + +from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.ops.tests.iOS18 import backends +from coremltools.converters.mil.mil.ops.tests.testing_utils import run_compare_builder +from coremltools.converters.mil.testing_reqs import compute_units + + +def _test_eval( + x, + update, + begin, + end, + stride=None, + begin_mask=None, + end_mask=None, + squeeze_mask=None, + ans=None, + compute_unit=None, + backend=None, + x_builtin_dtype=None, + run_conversion_test=True, +): + # Test the value inference in pymil + @mb.program(input_specs=[], opset_version=target.iOS18) + def prog(): + res = mb.slice_update( + x=x, + update=update, + begin=begin, + end=end, + stride=stride, + begin_mask=begin_mask, + end_mask=end_mask, + squeeze_mask=squeeze_mask, + ) + assert res.shape == ans.shape + np.testing.assert_allclose(ans, res.val, atol=1e-04, rtol=1e-05) + return res + + if not run_conversion_test: + return + + # pymil to backend test + x_val = np.array(x, dtype=np.float32) + update_val = np.array(update, dtype=np.float32) + begin_val = np.array(begin, dtype=np.int32) + end_val = np.array(end, dtype=np.int32) + + input_placeholders = { + "x": mb.placeholder(shape=x_val.shape, dtype=x_builtin_dtype), + "update": mb.placeholder(shape=update_val.shape, dtype=x_builtin_dtype), + "begin": mb.placeholder(shape=begin_val.shape, dtype=types.int32), + "end": mb.placeholder(shape=end_val.shape, dtype=types.int32), + } + + input_values = {"x": x_val, "update": update_val, "begin": begin_val, "end": end_val} + + expected_output_shape = list(ans.shape) + expected_output_types = [expected_output_shape + [types.fp32]] + expected_outputs = [ans] + + def build(x, update, begin, end): + return mb.slice_update( + x=x, + update=update, + begin=begin, + end=end, + begin_mask=begin_mask, + end_mask=end_mask, + squeeze_mask=squeeze_mask, + stride=stride, + ) + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + compute_unit=compute_unit, + backend=backend, + ) + + +class TestSliceUpdate: + @pytest.mark.parametrize( + "compute_unit, backend, x_dtype, idx_dtype", + itertools.product( + compute_units, + backends, + (np.float16, np.float32, np.int32), + (np.int16, np.int32, np.int8), + ), + ) + def test_builder_to_backend_smoke(self, compute_unit, backend, x_dtype, idx_dtype): + x_builtin_dtype = types.numpy_type_to_builtin_type(x_dtype) + idx_builtin_dtype = types.numpy_type_to_builtin_type(idx_dtype) + + x_val = np.array(list(range(24))).reshape((2, 3, 4)).astype(x_dtype) + update_val = np.array([[[-1, -2], [-3, -4]]]).astype(x_dtype) + begin_val = np.array([1, 1, 1], dtype=idx_dtype) + end_val = np.array([2, 3, 3], dtype=idx_dtype) + + input_placeholders = { + "x": mb.placeholder(shape=x_val.shape, dtype=x_builtin_dtype), + "update": mb.placeholder(shape=update_val.shape, dtype=x_builtin_dtype), + "begin": mb.placeholder(shape=begin_val.shape, dtype=idx_builtin_dtype), + "end": mb.placeholder(shape=end_val.shape, dtype=idx_builtin_dtype), + } + + input_values = {"x": x_val, "update": update_val, "begin": begin_val, "end": end_val} + + expected_output_types = [(2, 3, 4, x_builtin_dtype)] * 2 + copy_x_val = np.array(x_val, dtype=x_dtype) + copy_x_val[1:2, 1:3, 1:3] = update_val + expected_outputs = [copy_x_val, copy_x_val] + + def build(x, update, begin, end): + begin_c = mb.const(val=begin_val) + end_c = mb.const(val=end_val) + update_c = mb.const(val=update_val) + return [ + mb.slice_update(x=x, update=update, begin=begin, end=end), + mb.slice_update(x=x, update=update_c, begin=begin_c, end=end_c), + ] + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + compute_unit=compute_unit, + backend=backend, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, x_dtype, idx_dtype", + itertools.product( + compute_units, + backends, + (np.float16, np.float32, np.int32), + (np.int16, np.int32, np.int8), + ), + ) + def test_stress(self, compute_unit, backend, x_dtype, idx_dtype): + x_val = np.array(list(range(24))).reshape((2, 3, 4)).astype(x_dtype) + x_builtin_dtype = types.numpy_type_to_builtin_type(x_dtype) + + update = np.random.rand(1, 1, 1).astype(x_dtype) + ans = np.copy(x_val) + ans[1:2, 1:2, 1:2] = update + _test_eval( + x=x_val, + begin=[1, 1, 1], + end=[2, 2, 2], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(1, 2, 2).astype(x_dtype) + ans = np.copy(x_val) + ans[1:2, 1:3, 1:4:2] = update + _test_eval( + x=x_val, + begin=[1, 1, 1], + end=[2, 3, 4], + stride=[1, 1, 2], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(1, 2, 2).astype(x_dtype) + ans = np.copy(x_val) + ans[-3:-1, -3:-1, -3:-1] = update + _test_eval( + x=x_val, + begin=[-3, -3, -3], + end=[-1, -1, -1], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + # rdar://128037672 ([Bug][iOS18][Classic CPU] slice_update fails on classic CPU on an unittest) + run_conversion_test=False, + ) + + update = np.random.rand(1, 1, 1).astype(x_dtype) + ans = np.copy(x_val) + ans[0:-1, 0:-2, -3:-2] = update + _test_eval( + x=x_val, + begin=[0, 0, -3], + end=[-1, -2, -2], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(1, 1, 1).astype(x_dtype) + ans = np.copy(x_val) + ans[-1:0:-2, -1:1:-1, -1:-3:-3] = update + _test_eval( + x=x_val, + begin=[-1, -1, -1], + end=[0, 1, -3], + stride=[-2, -1, -3], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(2, 2, 2).astype(x_dtype) + ans = np.copy(x_val) + ans[:2, 1:3, :4:2] = update + _test_eval( + x=x_val, + begin=[1, 1, 1], + end=[2, 3, 4], + stride=[1, 1, 2], + begin_mask=[True, False, True], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(2, 2, 2).astype(x_dtype) + ans = np.copy(x_val) + ans[:, 1:, :4:2] = update + _test_eval( + x=x_val, + begin=[1, 1, 1], + end=[2, 3, 4], + stride=[1, 1, 2], + begin_mask=[True, False, True], + end_mask=[True, True, False], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(1, 2).astype(x_dtype) + ans = np.copy(x_val) + ans[1::1, 1, :3:2] = update + _test_eval( + x=x_val, + begin=[1, 1, 1], + end=[2, 3, 3], + stride=[1, 1, 2], + begin_mask=[False, False, True], + end_mask=[True, False, False], + squeeze_mask=[False, True, False], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(2, 3, 4).astype(x_dtype) + ans = np.copy(x_val) + ans[:, :, :] = update + _test_eval( + x=x_val, + begin=[0, 0, 0], + end=[0, 0, 0], + stride=[1, 1, 1], + begin_mask=[True, True, True], + end_mask=[True, True, True], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(1, 1).astype(x_dtype) + ans = np.copy(x_val) + ans[1:2, 1:2, 1] = update + _test_eval( + x=x_val, + begin=[1, 1, 1], + end=[2, 2, 0], + stride=[1, 1, 1], + squeeze_mask=[False, False, True], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(1, 3, 4).astype(x_dtype) + ans = np.copy(x_val) + ans[1:2, ...] = update + _test_eval( + x=x_val, + begin=[1, 0, 0], + end=[2, 0, 0], + stride=[1, 1, 1], + begin_mask=[False, True, True], + end_mask=[False, True, True], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(2, 3, 4).astype(x_dtype) + ans = np.copy(x_val) + ans[...] = update + _test_eval( + x=x_val, + begin=[0, 0, 0], + end=[0, 0, 0], + stride=[1, 1, 1], + begin_mask=[True, True, True], + end_mask=[True, True, True], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(1, 3, 1).astype(x_dtype) + ans = np.copy(x_val) + ans[1:2, ..., 1:2] = update + _test_eval( + x=x_val, + begin=[1, 0, 1], + end=[2, 0, 2], + stride=[1, 1, 1], + begin_mask=[False, True, False], + end_mask=[False, True, False], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(2, 3).astype(x_dtype) + ans = np.copy(x_val) + ans[..., 1] = update + _test_eval( + x=x_val, + begin=[0, 0, 1], + end=[0, 0, 0], + stride=[1, 1, 1], + begin_mask=[True, True, False], + end_mask=[True, True, False], + squeeze_mask=[False, False, True], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand( + 4, + ).astype(x_dtype) + ans = np.copy(x_val) + ans[0, 0, :] = update + _test_eval( + x=x_val, + begin=[0, 0, 0], + end=[0, 0, 0], + stride=[1, 1, 1], + begin_mask=[False, False, True], + end_mask=[False, False, True], + squeeze_mask=[True, True, False], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(1, 3, 4).astype(x_dtype) + ans = np.copy(x_val) + ans[1:2] = update + _test_eval( + x=x_val, + begin=[1, 0, 0], + end=[2, 0, 0], + stride=[1, 1, 1], + begin_mask=[False, True, True], + end_mask=[False, True, True], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(1, 1, 4).astype(x_dtype) + ans = np.copy(x_val) + ans[1:2, 1:2] = update + _test_eval( + x=x_val, + begin=[1, 1, 0], + end=[2, 2, 0], + stride=[1, 1, 1], + begin_mask=[False, False, True], + end_mask=[False, False, True], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(3, 4).astype(x_dtype) + ans = np.copy(x_val) + ans[1] = update + _test_eval( + x=x_val, + begin=[1, 0, 0], + end=[0, 0, 0], + stride=[1, 1, 1], + begin_mask=[False, True, True], + end_mask=[False, True, True], + squeeze_mask=[True, False, False], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(2, 3, 4).astype(x_dtype) + ans = np.copy(x_val) + ans[:] = update + _test_eval( + x=x_val, + begin=[0, 0, 0], + end=[0, 0, 0], + begin_mask=[True, True, True], + end_mask=[True, True, True], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + update = np.random.rand(2, 3, 4).astype(x_dtype) + ans = np.copy(x_val) + ans[..., ::-1] = update + _test_eval( + x=x_val, + begin=[0, 0, 0], + end=[0, 0, 0], + stride=[1, 1, -1], + begin_mask=[True, True, True], + end_mask=[True, True, True], + update=update, + ans=ans, + compute_unit=compute_unit, + backend=backend, + x_builtin_dtype=x_builtin_dtype, + ) + + def test_builder_eval_scalar_corner_cases(self): + pytest.xfail( + "rdar://128221986 ([Feature][Slice_update] The backend is not supporting scalar update for the slice_update op)" + ) + # two corner cases + x_val = np.array([2.0]) + update = np.float32(3.14) + ans = np.copy(x_val) + ans[0] = update + _test_eval( + x=x_val, + begin=[0], + end=[0], + squeeze_mask=[True], + update=update, + ans=ans, + run_conversion_test=False, # rank 0 input is not supported + ) + + x_val = np.array([[[[1.0], [3.0]]]]) + update = np.float32(7.78) + ans = np.copy(x_val) + ans[0, 0, 0, 0] = update + _test_eval( + x=x_val, + begin=[0, 0, 0, 0], + end=[0, 0, 0, 0], + squeeze_mask=[True, True, True, True], + update=update, + ans=ans, + run_conversion_test=False, # rank 0 input is not supported + ) + + @staticmethod + def test_rank_0_update_early_error_out(): + """ + Backend does not support rank-0 update for the slice_update op. + coremltools should early error out until this radar is fixed: + rdar://128221986 ([Feature][Slice_update] The backends is not supporting scalar update for the slice_update op) + """ + with pytest.raises( + ValueError, match="rank-0 'update' is not supported in 'slice_update' op" + ): + + @mb.program(input_specs=[], opset_version=target.iOS18) + def prog(): + return mb.slice_update( + x=[0.0, 0.0], + update=0.0, + begin=[0], + end=[1], + squeeze_mask=[True], + ) diff --git a/coremltools/converters/mil/mil/ops/tests/iOS18/test_transformers.py b/coremltools/converters/mil/mil/ops/tests/iOS18/test_transformers.py new file mode 100644 index 000000000..217c2c180 --- /dev/null +++ b/coremltools/converters/mil/mil/ops/tests/iOS18/test_transformers.py @@ -0,0 +1,419 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools + +import numpy as np +import pytest +import torch + +import coremltools as ct +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import get_new_symbol, types +from coremltools.converters.mil.mil.ops.tests.iOS18 import backends +from coremltools.converters.mil.mil.ops.tests.testing_utils import run_compare_builder +from coremltools.converters.mil.testing_reqs import compute_units + + +class TestScaledDotProductAttention: + @staticmethod + def _mb_eval_scaled_dot_product_attention( + query: np.ndarray, key: np.ndarray, value: np.ndarray, mask: np.ndarray = None + ) -> np.ndarray: + @mb.program(opset_version=ct.target.iOS18) + def prog(): + return mb.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=mask, + ) + + return ( + prog.functions["main"] + .find_ops(op_type="scaled_dot_product_attention")[0] + .outputs[0] + .val + ) + + @staticmethod + def _torch_scaled_dot_product_attention( + query: np.ndarray, key: np.ndarray, value: np.ndarray, mask: np.ndarray = None + ) -> np.ndarray: + """ + Two things: + 1. torch cannot consume np.ndarray, so need to convert to torch.Tensor + 2. torch cpu kernel has no half-precision support, so need to cast to float + """ + query_torch = torch.tensor(query).to(torch.float32) + key_torch = torch.tensor(key).to(torch.float32) + value_torch = torch.tensor(value).to(torch.float32) + + mask_torch = None + if mask is not None: + mask_torch = torch.tensor(mask) + if mask.dtype != bool: + mask_torch = mask_torch.to(torch.float32) + + return ( + torch.nn.functional.scaled_dot_product_attention( + query_torch, key_torch, value_torch, mask_torch + ) + .numpy() + .astype(query.dtype) + ) + + @pytest.mark.parametrize( + "batches, float_dtype, mask_dtype", + itertools.product( + ([3], [3, 2], [3, 2, 4]), + (np.float16, np.float32), + (None, bool, np.float16, np.float32), + ), + ) + def test_builder_eval_stress(self, batches, float_dtype, mask_dtype): + S = 5 + L = 7 + E = 16 + EV = 32 + + query_shape = batches + [L, E] + key_shape = batches + [S, E] + value_shape = batches + [S, EV] + + query = np.random.rand(*query_shape).astype(float_dtype) + key = np.random.rand(*key_shape).astype(float_dtype) + value = np.random.rand(*value_shape).astype(float_dtype) + mask = None + if mask_dtype is not None: + mask = np.zeros((1, 1, S), dtype=mask_dtype) + mask[:, :, S // 2 :] = False if mask_dtype is bool else -np.inf + + attention_coreml = self._mb_eval_scaled_dot_product_attention(query, key, value, mask) + attention_torch = self._torch_scaled_dot_product_attention(query, key, value, mask) + np.testing.assert_allclose( + attention_coreml, + attention_torch, + atol=1e-6 if float_dtype == np.float32 else 1e-3, + rtol=1e-6 if float_dtype == np.float32 else 1e-3, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, batches, float_dtype, mask_dtype", + itertools.product( + compute_units, + backends, + ([3], [3, 2], [3, 2, 4]), + (np.float16, np.float32), + (None, bool, np.float16, np.float32), + ), + ) + def test_builder_to_backend_stress( + self, compute_unit, backend, batches, float_dtype, mask_dtype + ): + def build(query, key, value): + return mb.scaled_dot_product_attention( + query=query, + key=key, + value=value, + ) + + def build_with_mask(query, key, value, mask): + return mb.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=mask, + ) + + S = 5 + L = 7 + E = 16 + EV = 32 + + query_shape = batches + [L, E] + key_shape = batches + [S, E] + value_shape = batches + [S, EV] + + query = np.random.rand(*query_shape).astype(float_dtype) + key = np.random.rand(*key_shape).astype(float_dtype) + value = np.random.rand(*value_shape).astype(float_dtype) + + input_placeholders = { + "query": mb.placeholder( + shape=query.shape, dtype=types.numpy_type_to_builtin_type(float_dtype) + ), + "key": mb.placeholder( + shape=key.shape, dtype=types.numpy_type_to_builtin_type(float_dtype) + ), + "value": mb.placeholder( + shape=value.shape, dtype=types.numpy_type_to_builtin_type(float_dtype) + ), + } + input_values = { + "query": query, + "key": key, + "value": value, + } + + mask = None + if mask_dtype is not None: + mask = np.zeros((1, 1, S), dtype=mask_dtype) + mask[:, :, S - 1 :] = False if mask_dtype is bool else -np.inf + + input_placeholders["mask"] = mb.placeholder( + shape=mask.shape, dtype=types.numpy_type_to_builtin_type(mask_dtype) + ) + input_values["mask"] = mask + + attention_torch = self._torch_scaled_dot_product_attention(query, key, value, mask) + run_compare_builder( + build if mask_dtype is None else build_with_mask, + input_placeholders, + input_values, + expected_output_types=[attention_torch.shape + (types.fp32,)], + expected_outputs=[attention_torch], + compute_unit=compute_unit, + backend=backend, + atol=1e-6 if backend.precision == "fp32" and float_dtype == np.float32 else 1e-3, + rtol=1e-6 if backend.precision == "fp32" and float_dtype == np.float32 else 1e-3, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, batches, float_dtype, mask_dtype", + itertools.product( + compute_units, + backends, + ([2], [2, 3], [2, 3, 4]), + (np.float16, np.float32), + (None, bool, np.float16, np.float32), + ), + ) + def test_builder_to_backend_dynamic_stress( + self, compute_unit, backend, batches, float_dtype, mask_dtype + ): + def build(query, key, value): + return mb.scaled_dot_product_attention( + query=query, + key=key, + value=value, + ) + + def build_with_mask(query, key, value, mask): + return mb.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=mask, + ) + + S = 2 + L = 2 + E = 4 + EV = 32 + + query_shape = batches + [L, E] + key_shape = batches + [S, E] + value_shape = batches + [S, EV] + + query = np.random.rand(*query_shape).astype(float_dtype) + key = np.random.rand(*key_shape).astype(float_dtype) + value = np.random.rand(*value_shape).astype(float_dtype) + + dynamic_query_shape = query_shape + dynamic_query_shape[0] = get_new_symbol() + dynamic_query_shape[-2] = get_new_symbol() + dynamic_key_shape = key_shape + dynamic_key_shape[-2] = get_new_symbol() + dynamic_value_shape = value_shape + dynamic_value_shape[-2] = get_new_symbol() + + input_placeholders = { + "query": mb.placeholder( + shape=tuple(dynamic_query_shape), + dtype=types.numpy_type_to_builtin_type(float_dtype), + ), + "key": mb.placeholder( + shape=tuple(dynamic_key_shape), dtype=types.numpy_type_to_builtin_type(float_dtype) + ), + "value": mb.placeholder( + shape=tuple(dynamic_value_shape), + dtype=types.numpy_type_to_builtin_type(float_dtype), + ), + } + input_values = { + "query": query, + "key": key, + "value": value, + } + + mask = None + if mask_dtype is not None: + mask = np.zeros((1, S), dtype=mask_dtype) + mask[:, S - 1 :] = False if mask_dtype is bool else -np.inf + + dynamic_mask_shape = [] + for i in range(len(mask.shape)): + dynamic_mask_shape.append(get_new_symbol()) + + input_placeholders["mask"] = mb.placeholder( + shape=tuple(dynamic_mask_shape), dtype=types.numpy_type_to_builtin_type(mask_dtype) + ) + input_values["mask"] = mask + + attention_torch = self._torch_scaled_dot_product_attention(query, key, value, mask) + output_shape = list(attention_torch.shape) + output_shape[0] = query_shape[0] + output_shape[-2] = query_shape[-2] + run_compare_builder( + build if mask_dtype is None else build_with_mask, + input_placeholders, + input_values, + expected_output_types=[tuple(output_shape) + (types.fp32,)], + expected_outputs=[attention_torch], + compute_unit=compute_unit, + backend=backend, + atol=1e-6 if backend.precision == "fp32" and float_dtype == np.float32 else 1e-3, + rtol=1e-6 if backend.precision == "fp32" and float_dtype == np.float32 else 1e-3, + ) + + def test_builder_invalid_shape(self): + B = 3 + S = 5 + L = 7 + E = 16 + EV = 32 + + with pytest.raises( + ValueError, + match=( + r"query, key, value must have a same rank, got\n" + r"\* query rank = [0-9]+\n" + r"\* key rank = [0-9]+\n" + r"\* value rank = [0-9]+" + ), + ): + query_shape = [B, L, E] + key_shape = [S, E] + value_shape = [S, EV] + + query = np.random.rand(*query_shape) + key = np.random.rand(*key_shape) + value = np.random.rand(*value_shape) + + self._mb_eval_scaled_dot_product_attention(query, key, value) + + with pytest.raises( + ValueError, + match=( + r"query, key, value must have at lease rank 3 " + r"for batch, sequence length, embedding, got rank [0-9]+" + ), + ): + query_shape = [L, E] + key_shape = [S, E] + value_shape = [S, EV] + + query = np.random.rand(*query_shape) + key = np.random.rand(*key_shape) + value = np.random.rand(*value_shape) + + self._mb_eval_scaled_dot_product_attention(query, key, value) + + with pytest.raises( + ValueError, + match=( + r"query, key, value must have a same batch dimension, got\n" + r"\* query batch = \((?:\s*\d+\s*,)+\s*\d*\)\n" + r"\* key batch = \((?:\s*\d+\s*,)+\s*\d*\)\n" + r"\* value batch = \((?:\s*\d+\s*,)+\s*\d*\)" + ), + ): + query_shape = [B + 1, L, E] + key_shape = [B, S, E] + value_shape = [B, S, EV] + + query = np.random.rand(*query_shape) + key = np.random.rand(*key_shape) + value = np.random.rand(*value_shape) + + self._mb_eval_scaled_dot_product_attention(query, key, value) + + with pytest.raises( + ValueError, + match=( + r"query and key must have a same embedding dimension, got\n" + r"\* query embedding = [0-9]+\n" + r"\* key embedding = [0-9]+" + ), + ): + query_shape = [B, L, E + 1] + key_shape = [B, S, E] + value_shape = [B, S, EV] + + query = np.random.rand(*query_shape) + key = np.random.rand(*key_shape) + value = np.random.rand(*value_shape) + + self._mb_eval_scaled_dot_product_attention(query, key, value) + + with pytest.raises( + ValueError, + match=( + r"key and value must have a same sequence length, got\n" + r"\* key sequence = [0-9]+\n" + r"\* value sequence = [0-9]+" + ), + ): + query_shape = [B, L, E] + key_shape = [B, S + 1, E] + value_shape = [B, S, EV] + + query = np.random.rand(*query_shape) + key = np.random.rand(*key_shape) + value = np.random.rand(*value_shape) + + self._mb_eval_scaled_dot_product_attention(query, key, value) + + with pytest.raises( + ValueError, + match=( + r"key and mask must have a same sequence length, got\n" + r"\* key sequence = [0-9]+\n" + r"\* mask sequence = [0-9]+" + ), + ): + query_shape = [B, L, E] + key_shape = [B, S, E] + value_shape = [B, S, EV] + + query = np.random.rand(*query_shape) + key = np.random.rand(*key_shape) + value = np.random.rand(*value_shape) + + mask = np.zeros(S + 1, dtype=bool) + mask[-1] = True + + self._mb_eval_scaled_dot_product_attention(query, key, value, mask) + + with pytest.raises( + ValueError, + match=( + r"Incompatible dim [0-9]+ in shapes " + r"\((?:\s*\d+\s*,)+\s*\d*\) vs\. \((?:\s*\d+\s*,)+\s*\d*\)" + ), + ): + query_shape = [B, L, E] + key_shape = [B, S, E] + value_shape = [B, S, EV] + + query = np.random.rand(*query_shape) + key = np.random.rand(*key_shape) + value = np.random.rand(*value_shape) + + mask = np.zeros((B + 1, L - 1, S), dtype=bool) + mask[:, :, -1] = True + + self._mb_eval_scaled_dot_product_attention(query, key, value, mask) diff --git a/coremltools/converters/mil/mil/ops/tests/test_utils.py b/coremltools/converters/mil/mil/ops/tests/test_utils.py index 69a36c124..e73fe0c00 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_utils.py +++ b/coremltools/converters/mil/mil/ops/tests/test_utils.py @@ -3,16 +3,12 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -import itertools import numpy as np -import pytest from coremltools.converters.mil.mil.ops.defs._utils import ( aggregated_pad, effective_kernel, - pack_elements_into_bits, - restore_elements_from_packed_bits, spatial_dimensions_out_shape, ) @@ -268,47 +264,3 @@ def test_same_padding_shape_dilation_2(self): expected = [5, 5] np.testing.assert_equal(actual, expected) - - -class TestPackUnpackBits: - def test_pack_basic(self): - """ - Original data: [-8, 7, 3, 4, -2]. - The 4-bit binary representation for those elements are: - -8: 1000; - 7: 0111; - 3: 0011 - 4: 0100 - -2: 1110 - Hence the packed quantized_data will be 3 bytes long, i.e., 24 bits long, which is: - 0111 1000 0100 0011 0000 1110 - So the packed data is represented by 3 uint8 values: [120, 67, 14]. - """ - original_data = np.array([-8, 7, 3, 4, -2], dtype=np.int8) - expected_packed_data = np.array([120, 67, 14], dtype=np.uint8) - packed_data = pack_elements_into_bits(original_data, nbits=4) - np.testing.assert_array_equal(packed_data, expected_packed_data) - - def test_pack_basic_2(self): - original_data = np.array([1, 2, 3, 4, 5], dtype=np.int8) - expected_packed_data = np.array([33, 67, 5], dtype=np.uint8) - packed_data = pack_elements_into_bits(original_data, nbits=4) - np.testing.assert_array_equal(packed_data, expected_packed_data) - - @pytest.mark.parametrize( - "nbits, data_dtype, element_num", - itertools.product(list(range(1, 9)), [np.int8, np.uint8], [1, 3, 20]), - ) - def test_round_trip_pack_unpack(self, nbits, data_dtype, element_num): - is_data_signed = np.issubdtype(data_dtype, np.signedinteger) - low, high = 0, 2**nbits - if is_data_signed: - low, high = -(2 ** (nbits - 1)), 2 ** (nbits - 1) - original_data = np.random.randint(low=low, high=high, size=(element_num,)).astype( - data_dtype - ) - packed_data = pack_elements_into_bits(original_data, nbits) - restored_data = restore_elements_from_packed_bits( - packed_data, nbits, element_num, are_packed_values_signed=is_data_signed - ) - np.testing.assert_array_equal(restored_data, original_data) diff --git a/coremltools/converters/mil/mil/ops/tests/testing_utils.py b/coremltools/converters/mil/mil/ops/tests/testing_utils.py index 2bce2e551..cbed314d5 100644 --- a/coremltools/converters/mil/mil/ops/tests/testing_utils.py +++ b/coremltools/converters/mil/mil/ops/tests/testing_utils.py @@ -68,6 +68,7 @@ def run_compare_builder( also_compare_shapes=True, converter=ct.convert, pass_pipeline: Optional[PassPipeline] = None, + pred_iters: Optional[int] = None, ): """ Inputs: @@ -100,8 +101,11 @@ def run_compare_builder( - backend: A BackendConfig that specifies the compute backend, precision and minimum_deployment_target + - pred_iters: Number of prediction to run the mlmodel. For a stateful model, + each prediction can have different numerical results. Can only be provided when mlmodel is stateful. + Returns: - The converted mlmodel + The converted mlmodel (MLModel), or Tuple[MLModel, MLState]. """ if backend is None: backend = BackendConfig( @@ -180,27 +184,37 @@ def run_compare_builder( if frontend_only: return mlmodel - if expected_outputs: - assert len(output_vars) == len(expected_outputs), ( - "Provided expected_outputs {}" - " should match number of output" - " variables {}".format(len(expected_outputs), len(output_vars)) + state = mlmodel.make_state() if mlmodel._is_stateful() else None + + if pred_iters is not None: + assert state is not None, "pred_iters can only be provided with stateful model." + else: + pred_iters = 1 + + for i in range(pred_iters): + # get the expected outputs from each prediction iteration + outputs = None + if expected_outputs is not None: + outputs = expected_outputs if pred_iters == 1 else expected_outputs[i] + assert len(output_vars) == len(outputs), ( + f"Provided expected_outputs {len(outputs)}" + " should match number of output" + f" variables {len(output_vars)}" + ) + outputs = {name: val for name, val in zip(output_names, outputs)} + + # run the mlmodel and compare the output numerical + compare_backend( + mlmodel=mlmodel, + input_key_values=input_values, + expected_outputs=outputs, + atol=atol, + rtol=rtol, + also_compare_shapes=also_compare_shapes, + dtype=backend[1], + state=state, ) - expected_outputs = { - name: val for name, val in zip(output_names, expected_outputs) - } - - compare_backend( - mlmodel=mlmodel, - input_key_values=input_values, - expected_outputs=expected_outputs, - atol=atol, - rtol=rtol, - also_compare_shapes=also_compare_shapes, - dtype=backend[1], - ) - return mlmodel diff --git a/coremltools/converters/mil/mil/passes/__init__.py b/coremltools/converters/mil/mil/passes/__init__.py index 8d47e5d90..0f6cd5622 100644 --- a/coremltools/converters/mil/mil/passes/__init__.py +++ b/coremltools/converters/mil/mil/passes/__init__.py @@ -34,6 +34,7 @@ cleanup, lower_complex_dialect_ops, optimize_activation, + optimize_activation_quantization, optimize_conv, optimize_elementwise_binary, optimize_linear, diff --git a/coremltools/converters/mil/mil/passes/defs/cleanup/const_deduplication.py b/coremltools/converters/mil/mil/passes/defs/cleanup/const_deduplication.py index 020ab294c..fab1da013 100644 --- a/coremltools/converters/mil/mil/passes/defs/cleanup/const_deduplication.py +++ b/coremltools/converters/mil/mil/passes/defs/cleanup/const_deduplication.py @@ -8,7 +8,7 @@ import numpy as np -from coremltools.converters.mil.mil import Block, Var, ListVar, types +from coremltools.converters.mil.mil import Block, ListVar, Program, Var, types from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass from coremltools.converters.mil.mil.passes.helper import block_context_manager from coremltools.converters.mil.mil.passes.pass_registry import register_pass @@ -58,6 +58,35 @@ def apply(self, prog) -> None: for f in prog.functions.values(): self._constant_deduplication_block(f) + def _deduplicate_const_across_functions(self, prog: Program) -> None: + """ + When there are duplicated consts across functions, we cannot create a common const op to be shared. + Instead, we set the weight_id to the consts, to allow them share the same blob file value when lowering into milproto. + """ + # We first make sure that consts are deduplicated within each function, + # to make sure we can maximize the weight sharing. + self.apply(prog) + + # check no weight_id is set yet in the program + for block in prog.functions.values(): + for op in block.operations: + if op.op_type != "const": + continue + if op.weight_id is not None: + raise ValueError(f"const op {op.name} already has weight_id {op.weight_id}") + + # deduplication across functions + blocks = list(prog.functions.values()) + unique2duplicates_const = self.find_constants(blocks) + for i, (k, v) in enumerate(unique2duplicates_const.items()): + if len(v) == 0: + continue + # There could be cases where two functions are pointing to the same block + all_vars = [k] + list(v) + all_vars = list(set(all_vars)) + for duplicate in all_vars: + duplicate.op.weight_id = i + def remove_duplicate_ops( self, block: Block, unique2duplicates: Dict[Var, List[Var]], force_replace: bool ) -> None: @@ -80,7 +109,7 @@ def _constant_deduplication_block(self, block: Block) -> None: self._constant_deduplication_block(b) # Deduplication of ``const`` op - unique2duplicates_const = self.find_constants(block) + unique2duplicates_const = self.find_constants([block]) self.remove_duplicate_ops(block, unique2duplicates_const, force_replace=False) # Deduplication of ``constexpr_*`` op @@ -88,24 +117,27 @@ def _constant_deduplication_block(self, block: Block) -> None: # Since after the above two functions, ``const`` ops with identical values are # deduplicated into a single ``Var`` object, which allows ``find_constexpr`` to # directly compare the ``const`` input attr pointers instead of the actual values. - unique2duplicates_constexpr = self.find_constexprs(block) + unique2duplicates_constexpr = self.find_constexprs([block]) self.remove_duplicate_ops(block, unique2duplicates_constexpr, force_replace=True) - def find_constexprs(self, block: Block) -> Dict[Var, List[Var]]: + @staticmethod + def find_constexprs(blocks: List[Block]) -> Dict[Var, List[Var]]: """ - Given a block, return all constexpr in the block in such a format: + Given a list of blocks, return all constexpr in the blocks in such a format: {unique_var_0: [duplicated_var_0_0, duplicated_var_0_1, ...], unique_var_1: [duplicated_var_1_0, duplicated_var_1_1, ...], ... } """ hashkey_2_duplicates: Dict[Tuple, List[Var]] = {} - for op in list(block.operations): - if "constexpr" in op.op_type: + for block in blocks: + for op in list(block.operations): + if "constexpr" not in op.op_type: + continue hash_key = [op.op_type] for v in op.inputs.values(): hash_key.append(v.dtype) - if np.prod(v.shape) < self.NUMEL_THRESH: + if np.prod(v.shape) < const_deduplication.NUMEL_THRESH: hash_key.append(str(v.val)) else: hash_key.append(v) @@ -117,9 +149,10 @@ def find_constexprs(self, block: Block) -> Dict[Var, List[Var]]: return {v[0]: v[1:] for v in hashkey_2_duplicates.values()} - def find_constants(self, block: Block) -> Dict[Var, List[Var]]: + @staticmethod + def find_constants(blocks: List[Block]) -> Dict[Var, List[Var]]: """ - Given a block, return all constants in the block in such a format: + Given a list of blocks, return all constants in the blocks in such a format: {unique_var_0: [duplicated_var_0_0, duplicated_var_0_1, ...], unique_var_1: [duplicated_var_1_0, duplicated_var_1_1, ...], ... @@ -129,21 +162,24 @@ def find_constants(self, block: Block) -> Dict[Var, List[Var]]: # instead of brute-force C_N^2 comparison, use a hash map to be O(N) constant_dict: Dict[Tuple[str, types.type, Tuple[int], str], List[Var]] = {} - for op in list(block.operations): - if op.op_type == "const": + for block in blocks: + for op in list(block.operations): + if op.op_type != "const": + continue + constant_var = op.outputs[0] if isinstance(constant_var, ListVar): continue shape = constant_var.shape numel = np.prod(shape) - if numel < self.NUMEL_THRESH: + if numel < const_deduplication.NUMEL_THRESH: continue dtype = constant_var.dtype value = constant_var.val hash = hashlib.sha1( - np.ascontiguousarray(value.reshape(-1)[: self.NUMEL_THRESH]) + np.ascontiguousarray(value.reshape(-1)[: const_deduplication.NUMEL_THRESH]) ).hexdigest() key = (dtype, shape, hash) @@ -159,7 +195,7 @@ def find_constants(self, block: Block) -> Dict[Var, List[Var]]: value, var.val, rtol=0.0, - atol=self.DTYPE2ATOL.get(dtype, 1e-12), + atol=const_deduplication.DTYPE2ATOL.get(dtype, 1e-12), ): existing_constant_var = var break diff --git a/coremltools/converters/mil/mil/passes/defs/cleanup/dead_code_elimination.py b/coremltools/converters/mil/mil/passes/defs/cleanup/dead_code_elimination.py index b7aa1f6ff..fe6932479 100644 --- a/coremltools/converters/mil/mil/passes/defs/cleanup/dead_code_elimination.py +++ b/coremltools/converters/mil/mil/passes/defs/cleanup/dead_code_elimination.py @@ -57,6 +57,11 @@ def _dead_code_elimination_block(block): # mark block's outputs to used used_vars.update(block.outputs) + # mark outputs from coreml_update_state to used + for op in block.operations: + if op.op_type == "coreml_update_state": + used_vars.update(op.outputs) + for op in reversed(block.operations): # if none of op's output is used, delete op if not set(op.outputs).intersection(used_vars): diff --git a/coremltools/converters/mil/mil/passes/defs/cleanup/expand_dynamic_linear.py b/coremltools/converters/mil/mil/passes/defs/cleanup/expand_dynamic_linear.py index 55ca9df71..0a8a06641 100644 --- a/coremltools/converters/mil/mil/passes/defs/cleanup/expand_dynamic_linear.py +++ b/coremltools/converters/mil/mil/passes/defs/cleanup/expand_dynamic_linear.py @@ -16,13 +16,12 @@ @register_pass(namespace="common") class expand_dynamic_linear(AbstractGraphPass): """ - ``Linear`` requires const or constexpr ``weight`` and ``bias``. In op translation, - we ambitiously prefer ``linear`` whenever possible, i.e. translate to ``linear`` - when operand is descendant of const, since such operand may be folded / fused into - const or constexpr later on by graph passes. + Translate to ``linear`` when the operand is a descendant of const, since such an operand + may be folded into const or fused into constexpr later by graph passes. In op translation, + we prefer ``linear`` whenever possible because it requires const or constexpr ``weight`` and ``bias``. - If such const folding / constexpr fusion did not happen, this pass would clean up - those too ambitious ``linear``s by replacing them with ``matmul``s + If such const folding or constexpr fusion did not happen, this pass would clean up + the too-ambitious ``linear`` ops by replacing them with ``matmul`` ops. """ def apply(self, prog: Program) -> None: diff --git a/coremltools/converters/mil/mil/passes/defs/cleanup/topological_reorder.py b/coremltools/converters/mil/mil/passes/defs/cleanup/topological_reorder.py index 6e12a2c27..c324abe78 100644 --- a/coremltools/converters/mil/mil/passes/defs/cleanup/topological_reorder.py +++ b/coremltools/converters/mil/mil/passes/defs/cleanup/topological_reorder.py @@ -114,6 +114,10 @@ def _move_operations_to_the_end_block(block, op_type_to_move): if not isinstance(new_var, (list, tuple)): new_var = [new_var] + # the new var should have the same name as the old var + for i, old_var in enumerate(op.outputs): + new_var[i].name = old_var.name + # Override current_op to be newly created op to ensure `first_use` # points to newly created op instead of old one. current_op = new_var[0].op diff --git a/coremltools/converters/mil/mil/passes/defs/optimize_activation_quantization.py b/coremltools/converters/mil/mil/passes/defs/optimize_activation_quantization.py new file mode 100644 index 000000000..7cd260624 --- /dev/null +++ b/coremltools/converters/mil/mil/passes/defs/optimize_activation_quantization.py @@ -0,0 +1,419 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + + +import numpy as np + +from coremltools.converters.mil.mil import Block +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import Operation, types +from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass +from coremltools.converters.mil.mil.passes.helper import _check_child_op_type, block_context_manager +from coremltools.converters.mil.mil.passes.pass_registry import register_pass + + +@register_pass(namespace="compression") +class insert_suffix_quantize_dequantize_pair(AbstractGraphPass): + """ + Insert trailing quantize and dequantize operation pairs after valid patterns. + + .. code-block:: + Pattern 1: + dequantize -> conv + Given: + %2 = dequantize(%1) + %3 = conv(%2) + ... + Result: + %2 = dequantize(%1) + %3 = conv(%2) + %4 = quantize(%3) + %5 = dequantize(%4) + ... + + Pattern 2: + dequantize ->| + |-> add + dequantize ->| + Given: + %2 = dequantize(%1) + %4 = dequantize(%3) + %5 = add(%2,%4) + ... + Result: + %2 = dequantize(%1) + %4 = dequantize(%3) + %5 = add(%2,%4) + %6 = quantize(%5) + %7 = dequantize(%6) + ... + """ + + _allowed_activations = { + "leaky_relu", + "tanh", + "scaled_tanh", + "sigmoid", + "hard_sigmoid", + "relu", + "relu6", + } + + # Graph pass option for setting compression config. + _config = None + + @property + def config(self): + return self._config + + @config.setter + def config(self, value): + self._config = value + if value._op_selector is not None: + self.op_selector = value._op_selector + + def apply(self, prog): + visited_ops = set() + for f in prog.functions.values(): + self._insert_quantize_dequantize(f, self._config, visited_ops) + + @block_context_manager + def _insert_quantize_dequantize(self, block: Block, config, visited_ops: set): + def help_insert_quantize_dequantize(block: Block) -> bool: + fusion_occurred = False + + for op in list(block.operations): + if op.enclosing_block is None: + continue + + if op in visited_ops: + continue + visited_ops.add(op) + + for b in op.blocks: + self._insert_quantize_dequantize(b) + + # Must start with "dequantize" op. + if op.op_type != "dequantize": + continue + + # Try matching valid patterns. + if self._try_match_and_transform_pattern(op, block, config, visited_ops): + fusion_occurred = True + + return fusion_occurred + + block_changed = True + while block_changed: + block_changed = help_insert_quantize_dequantize(block) + + def _try_match_and_transform_pattern( + self, dequantize_op: Operation, block: Block, config, visited_ops: set + ) -> bool: + """ + This function performs the pattern match for all target patterns. + It priorizes longer patterns to shorter ones for more fusions on hardware. + Reject if the trailing `quantize` and `dequantize` pair already existed. + + A list of valid patterns. + - conv + - conv, activation + - add + - add, activation + + E.g. Identify valid patterns: + - (`quantize` ->) dequantize` -> `conv` + - (`quantize` ->) dequantize` -> `conv` -> `relu` + E.g. Reject if trailing `quantize` -> `dequantize` exist: + - (`quantize` ->) dequantize` -> `conv` -> `quantize` -> `dequantize` + - (`quantize` ->) dequantize` -> `conv` -> `relu` -> `quantize` -> `dequantize` + """ + + # Reject if 1st operation is not `conv` or `add`. + if _check_child_op_type(dequantize_op, "conv") or _check_child_op_type( + dequantize_op, "add" + ): + pass + else: + return False + + core_op = dequantize_op.outputs[0].child_ops[0] + last_op = core_op + + # For operations with two inputs, both need to be `dequantize`. + if core_op.op_type == "add": + # Check both inputs + in_var_x = core_op.inputs["x"] + in_var_y = core_op.inputs["y"] + in_x_prev_op = in_var_x.op + in_y_prev_op = in_var_y.op + if not (in_x_prev_op.op_type == "dequantize" and in_y_prev_op.op_type == "dequantize"): + return False + + # Checking op-level config. Skip if we disable compression on certain operations. + op_config = config._get_op_config(core_op) + if op_config is None: + return False + + # Reject if trailing `quantize` -> `dequantize` pair exist. + if _check_child_op_type(core_op, "quantize"): + return False + + _child_op = None + if len(core_op.outputs[0].child_ops) > 0: + _child_op = core_op.outputs[0].child_ops[0] + + # Check if 2nd operation is part of a valid pattern. + # E.g. `dequantize` -> `conv` -> activation -> `quantize`. + if _child_op is not None: + if _child_op.op_type in self._allowed_activations: + if len(_child_op.outputs[0].child_ops) > 0: + if _check_child_op_type(_child_op, "quantize"): + return False + + _child_child_op = _child_op.outputs[0].child_ops[0] + last_op = _child_op + _child_op = _child_child_op + + return self._try_apply_transform(last_op, _child_op, block, visited_ops) + + @staticmethod + def _try_apply_transform( + last_op: Operation, + _child_op: Operation, + block: Block, + visited_ops: set, + ) -> bool: + """ + last_op: last op of a valid pattern. + E.g. in `conv` -> `relu`, last_op is `relu`; in `conv`, last_op is `conv`. + _child_op: the child op of the last_op. + block: current block. + visited_ops: a dict + + Pattern: + Given: + |-> child_op_1 + last_op -> |-> child_op_2 + |-> ... + Result: + |-> child_op_1 + last_op -> quantize -> dequantize -> |-> child_op_2 + |-> ... + """ + if _child_op is None: + return False + + scale_dtype = np.float16 if last_op.outputs[0].dtype == types.fp16 else np.float32 + + new_last_op = getattr(mb, last_op.op_type) + kargs = {} + for k, v in last_op.inputs.items(): + kargs[k] = v + kargs["name"] = last_op.name + kargs["before_op"] = last_op + new_last_op = new_last_op(**kargs) + + new_quantize_op = mb.quantize( + input=new_last_op, + scale=np.array(1).astype(scale_dtype), + zero_point=np.int8(0), + output_dtype="int8", + before_op=last_op, + ) + new_dequantize_op = mb.dequantize( + input=new_quantize_op, + scale=np.array(1).astype(scale_dtype), + zero_point=np.int8(0), + before_op=last_op, + ) + ops_to_remove = [last_op] + + last_op_var_name = last_op.outputs[0].name + # Replace output var of last_op with output of new_dequantize_op. + if last_op.enclosing_block.try_replace_uses_of_var_after_op( + anchor_op=last_op, + end_op=last_op, + old_var=last_op.outputs[0], + new_var=new_dequantize_op, + ): + block.remove_ops(ops_to_remove) + # The name of new quantize/dequantize may change. + # Add the new ones to the visited list to avoid revisiting. + visited_ops.add(new_dequantize_op.op) + visited_ops.add(new_quantize_op.op) + new_dequantize_var_name = new_dequantize_op.name + new_dequantize_op.set_name(f"{new_dequantize_var_name}__post__dequant") + new_last_op.set_name(f"{last_op_var_name}") + return True + + return False + + +@register_pass(namespace="compression") +class update_quantize_dequantize(AbstractGraphPass): + """ + Update scale and zero point values in `quantize` and `dequantize` operations with calibration statistics. + + .. code-block:: + Pattern: + Given: + %2 = quantize(%1) with random scale and zp + %3 = dequantize(%2) with random scale and zp + ... + Result: + %2 = quantize(%1) with calculated scale and zp + %3 = dequantize(%2) with calculated scale and zp + ... + """ + + _activation_stats = None + + @property + def activation_stats(self): + return self._activation_stats + + @activation_stats.setter + def activation_stats(self, value): + self._activation_stats = value + + def apply(self, prog): + visited_ops = set() + for f in prog.functions.values(): + self._update_quantize_dequantize(f, self._activation_stats, visited_ops) + + @block_context_manager + def _update_quantize_dequantize(self, block: Block, activation_stats: dict, visited_ops: set): + def help_update_quantize_dequantize(block: Block, activation_stats: dict) -> bool: + fusion_occurred = False + + for op in list(block.operations): + if op.enclosing_block is None: + continue + + if op in visited_ops: + continue + visited_ops.add(op) + + for b in op.blocks: + self._update_quantize_dequantize(b, activation_stats) + + # Must start with "quantize" op + if op.op_type != "quantize": + continue + + # Try pattern match: `quantize` -> `dequantize`. + if self._try_match_and_transform_pattern(op, block, activation_stats, visited_ops): + fusion_occurred = True + + return fusion_occurred + + block_changed = True + while block_changed: + block_changed = help_update_quantize_dequantize(block, activation_stats) + + def _try_match_and_transform_pattern( + self, quantize_op: Operation, block: Block, activation_stats: dict, visited_ops: set + ) -> bool: + """ + This function performs validation checks for the target pattern: + `quantize` -> `dequantize` + """ + if not _check_child_op_type(quantize_op, "dequantize"): + return False + dequantize_op = quantize_op.outputs[0].child_ops[0] + last_op = dequantize_op + + _child_op = None + if len(dequantize_op.outputs[0].child_ops) > 0: + _child_op = dequantize_op.outputs[0].child_ops[0] + + return self._try_apply_transform( + quantize_op, last_op, _child_op, block, activation_stats, visited_ops + ) + + @staticmethod + def _try_apply_transform( + quantize_op: Operation, + last_op: Operation, + _child_op: Operation, + block: Block, + activation_stats: dict, + visited_ops: set, + ) -> bool: + """ + last_op: last op of a valid pattern. it's 'dequantize' in this case. + _child_op: the child op of the last_op. + block: current block. + """ + ops_to_remove = [quantize_op, last_op] + + if _child_op is None: + return False + + # Name of input var to `quantize`. + in_var_name = quantize_op.inputs["input"].name + val = np.array([0, 0], dtype=np.float16) + + # It's possible there are two ``quantize -> dequantize`` pair in a sequence. + # Two pairs should share the same scale and zero_point values. + # The name of input var to the 2nd `quantize` is newly created and does not exist in the original uncompressed model. + # We make an adjustment by tracing the name of input var of 1st `quantize` to update the 2nd pair. + if in_var_name not in activation_stats: + # Make an adjustment by checking leading `quantize` `dequantize` pair. + prev_dequantize = quantize_op.input.op + prev_quantize = prev_dequantize.input.op + if prev_quantize.inputs["input"].name in activation_stats: + in_var_name = prev_quantize.inputs["input"].name + + val[0], val[1] = ( + activation_stats[in_var_name]["rmin"], + activation_stats[in_var_name]["rmax"], + ) + + # Numerically the scale and zp won't change if the input array only have two elements: + # the min and max of input array. Plus we don't care about quantized values. + # That's the trick to re-use quantize_weight util. + from coremltools.optimize.coreml._utils import quantize_weight + + _, _scale, _zero_point = quantize_weight( + val, + axes=0, + nbits=8, + signed=True, + quantization_mode="LINEAR_SYMMETRIC", + dtype=types.int8, + ) + + # New ``quantize -> dequantize``. + new_quantize_op = mb.quantize( + input=quantize_op.input, + scale=_scale, + zero_point=_zero_point, + output_dtype="int8", + name=quantize_op.name, + before_op=quantize_op, + ) + new_dequantize_op = mb.dequantize( + input=new_quantize_op, + scale=_scale, + zero_point=_zero_point, + name=last_op.name, + before_op=quantize_op, + ) + + # Replace old ``quantize -> dequantize`` with new ``quantize -> dequantize`` to update scale/zero_point. + if last_op.enclosing_block.try_replace_uses_of_var_after_op( + anchor_op=last_op, + end_op=last_op, + old_var=last_op.outputs[0], + new_var=new_dequantize_op, + ): + block.remove_ops(ops_to_remove) + # Add the new ones to the visited list to avoid revisiting. + visited_ops.add(new_quantize_op.op) + visited_ops.add(new_dequantize_op.op) + + return False diff --git a/coremltools/converters/mil/mil/passes/defs/optimize_normalization.py b/coremltools/converters/mil/mil/passes/defs/optimize_normalization.py index 15bfb3c19..8c910084b 100644 --- a/coremltools/converters/mil/mil/passes/defs/optimize_normalization.py +++ b/coremltools/converters/mil/mil/passes/defs/optimize_normalization.py @@ -1019,8 +1019,13 @@ def _try_match_and_transform_pattern_5(self, reduce_op, block) -> bool: if has_beta_and_gamma: beta_var = add_beta_op.y if add_beta_op.x == mul_op.outputs[0] else add_beta_op.x + gamma_var = ( + mul_gamma_op.y if mul_gamma_op.x == add_beta_op.outputs[0] else mul_gamma_op.x + ) + + if beta_var.val is None or gamma_var.val is None: + return False - gamma_var = mul_gamma_op.y if mul_gamma_op.x == add_beta_op.outputs[0] else mul_gamma_op.x gamma_var = mb.const( val=np.squeeze(gamma_var.val), name="_fuse_layernorm_gamma", diff --git a/coremltools/converters/mil/mil/passes/defs/optimize_quantization.py b/coremltools/converters/mil/mil/passes/defs/optimize_quantization.py index 47ac7ac2b..5aedd22b9 100644 --- a/coremltools/converters/mil/mil/passes/defs/optimize_quantization.py +++ b/coremltools/converters/mil/mil/passes/defs/optimize_quantization.py @@ -146,7 +146,7 @@ def _try_to_transform_per_tensor(op: Operation, block: Block) -> bool: ) # after transformation, we create a new constexpr_affine_dequantize op and do the replacement - new_var = _utils._construct_constexpr_affine_op( + new_var = _utils._construct_constexpr_dequant_op( cursor, op.zero_point, op.scale, @@ -799,25 +799,39 @@ def try_to_transform( if new_s_x is None and new_s_z is None: return False + def convert_mil_float_dtype_to_np(mil_dtype): + if mil_dtype == types.fp16 or mil_dtype == "float16": + np_dtype = np.float16 + else: + np_dtype = np.float32 + return np_dtype + + new_s_x_dtype = convert_mil_float_dtype_to_np(dequantize_x.scale.val.dtype) + new_s_y_dtype = convert_mil_float_dtype_to_np(dequantize_y.scale.val.dtype) + new_s_z_dtype = convert_mil_float_dtype_to_np(quantize_z.scale.val.dtype) + # insert normalized new_dequantize_x and new_dequantize_y before op new_dequantize_x = mb.dequantize( input=dequantize_x.input, - scale=new_s_x, + scale=new_s_x_dtype(new_s_x), zero_point=dequantize_x.zero_point, axis=dequantize_x.axis, before_op=op, ) new_dequantize_y = mb.dequantize( input=dequantize_y.input, - scale=1.0 if dequantize_y.axis is None else np.full(dequantize_y.scale.val.shape, 1.0), + scale=new_s_y_dtype(1) + if dequantize_y.axis is None + else np.full(dequantize_y.scale.val.shape, 1.0), zero_point=dequantize_y.zero_point, axis=dequantize_y.axis, before_op=op, ) + # insert normalized new_quantize_z before quantize_z new_quantize_z = mb.quantize( input=quantize_z.input, - scale=new_s_z, + scale=new_s_z_dtype(new_s_z), zero_point=quantize_z.zero_point, axis=quantize_z.axis, output_dtype=quantize_z.output_dtype, @@ -956,7 +970,7 @@ def transform_op(self, op): axis = None if op.axis is None else op.axis.val - new_var = _utils._construct_constexpr_affine_op( + new_var = _utils._construct_constexpr_dequant_op( quantized_data, zero_point, scale, @@ -968,3 +982,119 @@ def transform_op(self, op): block = op.enclosing_block block.replace_uses_of_var_after_op(anchor_op=op, old_var=op.outputs[0], new_var=new_var) block.remove_ops([op]) + + +@register_pass(namespace="common") +class reorder_lut_per_channel_scale(AbstractGraphPass): + """ + The lut with per-channel-scale was represented as the following op combinations: + weight = constexpr_lut_to_dense() + weight = constexpr_blockwise_shift_scale(weight) + output = linear/matmul/conv(x, weight) + However, for ANE, it requires the scale to be after the linear/matmul/conv, which is: + weight = constexpr_lut_to_dense() + unscaled_output = linear/matmul(x, weight) + output = mul(unscaled_output, scale) + This graph pass finds the lut with per-channel-scale and move the scale to be ANE-friendly. + """ + + _OPS_SUPPORT_MOVE_SCALE = {"linear", "matmul", "conv"} + + def apply(self, prog): + @block_context_manager + def apply_block(block: Block): + for op in list(block.operations): + for b in op.blocks: + apply_block(b) + + if op.op_type == "constexpr_lut_to_dense" and len(op.outputs[0].child_ops) == 1: + child_op = op.outputs[0].child_ops[0] + if child_op.op_type == "constexpr_blockwise_shift_scale": + # Can move the scale when the constexpr op is only used to scale the weight. + has_offset = child_op.offset is not None and child_op.offset.val.any() + if types.is_float(child_op.data.dtype) and not has_offset: + self._reorder_lut_per_channel_scale(block, op) + + for f in prog.functions.values(): + apply_block(f) + + def _reorder_lut_per_channel_scale(self, block: Block, lut_op: Operation): + # Lazy import to avoid circular import error. + from coremltools.optimize.coreml import _utils as optimize_utils + + # The original order is lut_op -> scale_op -> output_op. + scale_op = lut_op.outputs[0].child_ops[0] + + # Only move the scale when all ops that consume this scale op support moving. + for output_op in scale_op.outputs[0].child_ops: + if output_op.op_type not in self._OPS_SUPPORT_MOVE_SCALE: + return + + # Only the scale on output axis could be moved to get mathematically equivalent results. + scale_val: np.ndarray = scale_op.scale.val + output_axis = optimize_utils.select_input_output_channel_axis(scale_op)[1] + if output_axis < 0: + output_axis += len(scale_val.shape) + for axis, dim_size in enumerate(scale_val.shape): + if axis != output_axis and dim_size != 1: + return + + for output_op in list(scale_op.outputs[0].child_ops): + self._help_move_scale(block, lut_op, scale_op, output_op) + block.remove_ops([output_op]) + block.remove_ops([scale_op]) + + @staticmethod + def _help_move_scale( + block: Block, lut_op: Operation, scale_op: Operation, output_op: Operation + ): + """Move the scale from `lut_op -> scale_op -> output_op` to `lut_op -> output_op -> mul`.""" + scale_val: np.ndarray = scale_op.scale.val + inputs = output_op.inputs + if output_op.op_type == "linear": + scale_val = scale_val.T + inputs["weight"] = lut_op.outputs[0] + if getattr(output_op, "bias", None) and output_op.bias.val is not None: + original_bias = output_op.bias.val + new_bias = (original_bias / np.squeeze(scale_val)).astype(original_bias.dtype) + inputs["bias"] = new_bias + elif output_op.op_type == "matmul": + # Determine if the scaled weight is used by `x` or `y` in matmul. + if output_op.y == scale_op.outputs[0]: + if output_op.transpose_y.val is True: + scale_val = scale_val.T + inputs["y"] = lut_op.outputs[0] + else: + if output_op.transpose_x.val is True: + scale_val = scale_val.T + inputs["x"] = lut_op.outputs[0] + else: + if output_op.op_type != "conv": + raise AssertionError( + "The scale could only be moved for linear/matmul/conv, " + f"but got {output_op.op_type}" + ) + # The weight of conv has C_out at axis=0, but in output the C_out is at axis=1 + scale_val = np.squeeze(scale_val) + if len(scale_val.shape) > 1: + # The per-channel-scale should only have one axis with larger than 1 dim size. + return + channel_size = 1 if len(scale_val.shape) == 0 else scale_val.shape[0] + scale_val = scale_val.reshape((1, channel_size, 1, 1)) + inputs["weight"] = lut_op.outputs[0] + if getattr(output_op, "bias", None) and output_op.bias.val is not None: + original_bias = output_op.bias.val + new_bias = (original_bias / np.squeeze(scale_val)).astype(original_bias.dtype) + inputs["bias"] = new_bias + + # Reconstruct the unscaled output which uses lut output as weight (skip the original scale). + unscaled_output = getattr(mb, output_op.op_type)(**inputs, before_op=output_op) + scaled_output = mb.mul(x=unscaled_output, y=scale_val, before_op=output_op) + + # Now the order is lut_op -> unscaled_output -> scaled_output. + block.replace_uses_of_var_after_op( + anchor_op=output_op, + old_var=output_op.outputs[0], + new_var=scaled_output, + force_replace=True, # Need to force replace because it involves replacing constexpr op. + ) diff --git a/coremltools/converters/mil/mil/passes/defs/preprocess.py b/coremltools/converters/mil/mil/passes/defs/preprocess.py index e8dd6f899..ef8599983 100644 --- a/coremltools/converters/mil/mil/passes/defs/preprocess.py +++ b/coremltools/converters/mil/mil/passes/defs/preprocess.py @@ -171,6 +171,7 @@ def sanitize_name(self, name): "uint16", "uint32", "uint64", + "state", ] if new_name in reserved_names: new_name += "_workaround" @@ -306,13 +307,15 @@ def apply(self, prog): sanitizer_ops = NameSanitizer(prefix="op_") # sanitize the input/output of the main block - NameSanitizer.sanitize_block( - prog.functions["main"], - sanitizer_vars, - sanitizer_ops, - prog.functions["main"].input_types, - sanitize_model_inputs_outputs_only=True, - ) + # TODO: rdar://126498947 ([Infra] Investigate the name sanitizer on multifunction model) + if "main" in prog.functions: + NameSanitizer.sanitize_block( + prog.functions["main"], + sanitizer_vars, + sanitizer_ops, + prog.functions["main"].input_types, + sanitize_model_inputs_outputs_only=True, + ) # TODO: rdar://122845072 ([Infra] Refactor the transform_function_signatures, adjust_io_to_supported_types and update_output_dtypes using a shared graph pass) diff --git a/coremltools/converters/mil/mil/passes/defs/quantization.py b/coremltools/converters/mil/mil/passes/defs/quantization.py index 6c016a586..ccf1b6445 100644 --- a/coremltools/converters/mil/mil/passes/defs/quantization.py +++ b/coremltools/converters/mil/mil/passes/defs/quantization.py @@ -379,6 +379,8 @@ class FP16ComputePrecision(CastTypeQuantization): "list_read", "list_write", "list_length", + "read_state", + "coreml_update_state", } def __init__(self, op_selector=None): diff --git a/coremltools/converters/mil/mil/passes/defs/randomize.py b/coremltools/converters/mil/mil/passes/defs/randomize.py new file mode 100644 index 000000000..97b30805a --- /dev/null +++ b/coremltools/converters/mil/mil/passes/defs/randomize.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import numpy as np + +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import Operation +from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass +from coremltools.converters.mil.mil.passes.helper import block_context_manager +from coremltools.converters.mil.mil.passes.pass_registry import register_pass + + +@register_pass(namespace="common") +class WeightRandomizer(AbstractGraphPass): + """ + This graph pass randomizes the weights of each ``const`` op + + """ + + def apply(self, prog): + for f in prog.functions.values(): + self._randomize_weights_block(f) + + @block_context_manager + def _randomize_weights_block(self, block): + for op in list(block.operations): + for b in op.blocks: + self._randomize_weights_block(b) + + if self.is_valid_op(op): + self.transform_op(op) + + def is_valid_op(self, op: Operation): + # lazy import to prevent circular import + from coremltools.converters.mil.backend.mil.load import should_use_weight_file + + if op.op_type == "const" and should_use_weight_file(op.outputs[0].val): + return True + return False + + def transform_op(self, op): + weight = op.outputs[0].val + random_weight = np.random.rand(*weight.shape).astype(weight.dtype) + new_var = mb.const( + val=random_weight, + before_op=op, + name=op.name, + ) + + op.enclosing_block.replace_uses_of_var_after_op( + anchor_op=op, + old_var=op.outputs[0], + new_var=new_var, + no_check_var_types=True, + ) + + op.enclosing_block.remove_ops([op]) diff --git a/coremltools/converters/mil/mil/passes/graph_pass.py b/coremltools/converters/mil/mil/passes/graph_pass.py index ffa61801b..822eb1f0a 100644 --- a/coremltools/converters/mil/mil/passes/graph_pass.py +++ b/coremltools/converters/mil/mil/passes/graph_pass.py @@ -6,8 +6,8 @@ from abc import ABC, abstractmethod from typing import Callable, List, Optional, Text, Union -from coremltools.converters.mil import Operation, Program from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import Operation, Program from coremltools.converters.mil.mil.scope import ScopeInfo, ScopeSource diff --git a/coremltools/converters/mil/mil/passes/pass_pipeline.py b/coremltools/converters/mil/mil/passes/pass_pipeline.py index caa1f7d4e..33ebe6f5d 100644 --- a/coremltools/converters/mil/mil/passes/pass_pipeline.py +++ b/coremltools/converters/mil/mil/passes/pass_pipeline.py @@ -11,12 +11,13 @@ from coremltools import _logger as logger from coremltools.converters._profile_utils import _profile -from coremltools.converters.mil import Program +from coremltools.converters.mil.mil import Program from coremltools.converters.mil.mil.passes.graph_pass import PassOption from coremltools.converters.mil.mil.passes.helper import classproperty as _classproperty from coremltools.converters.mil.mil.passes.pass_registry import PASS_REGISTRY _COMMON_PASSES: List[Text] = [ + "common::reorder_lut_per_channel_scale", "common::lower_complex_dialect_ops", "common::update_output_dtypes", "common::cast_optimization", @@ -382,7 +383,11 @@ def get_pipeline(cls, pipeline_name: Text) -> PassPipeline: f"There is no pipeline for `{pipeline_name}`. " f"Available pipelines: {cls._PIPELINE_NAME_TO_PASSES.keys()}" ) - return PassPipeline(cls._PIPELINE_NAME_TO_PASSES[pipeline_name], pipeline_name) + # We need to copy the pass names when initialize a PassPipeline object, + # to prevent the member functions of PassPipeline from potentially modifying the original + # data in _PIPELINE_NAME_TO_PASSES. + passes = list(cls._PIPELINE_NAME_TO_PASSES[pipeline_name]) + return PassPipeline(passes, pipeline_name) @classmethod def list_available_pipelines(cls) -> List[str]: diff --git a/coremltools/converters/mil/mil/passes/tests/test_cleanup_passes.py b/coremltools/converters/mil/mil/passes/tests/test_cleanup_passes.py index 6269a95fb..9554c8276 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_cleanup_passes.py +++ b/coremltools/converters/mil/mil/passes/tests/test_cleanup_passes.py @@ -26,6 +26,7 @@ assert_op_count_match, assert_same_output_names, get_op_names_in_program, + get_op_types_in_block, get_op_types_in_program, ) @@ -33,6 +34,122 @@ class TestConstDeduplication: + def test_const_deduplication_cross_functions(self): + val_1 = np.random.rand( + 100, + ) + val_2 = np.random.rand( + 100, + ) + val_3 = np.random.rand( + 100, + ) + + @mb.function( + input_specs=[mb.TensorSpec((100,))], + ) + def func(x): + const_1 = mb.const(val=val_1) + const_2 = mb.const(val=val_1) + const_3 = mb.const(val=val_2) + const_4 = mb.const(val=val_3) + + x = mb.add(x=x, y=const_1) + x = mb.add(x=x, y=const_2) + x = mb.add(x=x, y=const_3) + return mb.add(x=x, y=const_4) + + @mb.function( + input_specs=[mb.TensorSpec((100,))], + ) + def func_1(x): + const_5 = mb.const(val=val_1) + const_6 = mb.const(val=val_2) + + x = mb.add(x=x, y=const_5) + return mb.add(x=x, y=const_6) + + prog = mil.Program() + prog.add_function("main", func) + prog.add_function("func_1", func_1) + + # In the above case, const_1 and const_2 in main is going to deduplicated in a single const op first. + # And it will share the same weight_id with const_5 in func_1. + # const_3 / const_6 are going to share the same weight_id across functions. + # While const_6.weight_id remains None. + graph_pass = PASS_REGISTRY["common::const_deduplication"] + graph_pass._deduplicate_const_across_functions(prog) + + # validate the prog + main_func = prog.functions["main"] + expected_ops = ["const", "const", "const", "add", "add", "add", "add"] + assert get_op_types_in_block(main_func, skip_const_ops=False) == expected_ops + const_ops = main_func.find_ops(op_type="const") + assert const_ops[0].weight_id == 0 + assert const_ops[1].weight_id == 1 + assert const_ops[2].weight_id is None + + func_1 = prog.functions["func_1"] + expected_ops = [ + "const", + "const", + "add", + "add", + ] + assert get_op_types_in_block(func_1, skip_const_ops=False) == expected_ops + const_ops = func_1.find_ops(op_type="const") + assert const_ops[0].weight_id == 0 + assert const_ops[1].weight_id == 1 + + def test_const_deduplication_cross_functions_from_same_source(self): + """ + In the case of users copying a source function into two functions, + same weight should be assigned with the same weighr_id as well. + """ + val_1 = np.random.rand( + 100, + ) + val_2 = np.random.rand( + 100, + ) + val_3 = np.random.rand( + 100, + ) + + @mb.function( + input_specs=[mb.TensorSpec((100,))], + ) + def func(x): + const_1 = mb.const(val=val_1) + const_2 = mb.const(val=val_1) + const_3 = mb.const(val=val_2) + const_4 = mb.const(val=val_3) + + x = mb.add(x=x, y=const_1) + x = mb.add(x=x, y=const_2) + x = mb.add(x=x, y=const_3) + return mb.add(x=x, y=const_4) + + prog = mil.Program() + prog.add_function("func_1", func) + prog.add_function("func_2", func) + + graph_pass = PASS_REGISTRY["common::const_deduplication"] + graph_pass._deduplicate_const_across_functions(prog) + + # validate the prog + func_1 = prog.functions["func_1"] + expected_ops = ["const", "const", "const", "add", "add", "add", "add"] + assert get_op_types_in_block(func_1, skip_const_ops=False) == expected_ops + func_2 = prog.functions["func_2"] + assert get_op_types_in_block(func_2, skip_const_ops=False) == expected_ops + + for func in [prog.functions["func_1"], prog.functions["func_2"]]: + const_ops = func.find_ops(op_type="const") + assert const_ops[0].weight_id == 0 + assert const_ops[1].weight_id == 1 + assert const_ops[2].weight_id == 2 + def test_const_deduplication(self): BATCH_DIM = 5 SEQUENCE_LENGTH = 4 @@ -1291,8 +1408,17 @@ def prog(x): expected_output_shapes={block.outputs[0].name: (10, 8)}, ) - def test_tile_elimination(self): - @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) + @pytest.mark.parametrize( + "dynamic", + [True, False], + ) + def test_tile_elimination(self, dynamic): + if dynamic: + input_shape = (get_new_symbol(), get_new_symbol()) + else: + input_shape = (2, 4) + + @mb.program(input_specs=[mb.TensorSpec(shape=input_shape)]) def prog(x): r1 = mb.tile(x=x, reps=[1, 1]) return mb.relu(x=r1) @@ -2162,11 +2288,11 @@ def test_move_sink_casts_to_the_end(self): def prog(x): x = mb.cast(x=x, dtype="fp16") x1 = mb.square(x=x) - x2 = mb.cast(x=x1, dtype="fp32") + x2 = mb.cast(x=x1, dtype="fp32", name="x2") x3 = mb.log(x=x) - x4 = mb.cast(x=x3, dtype="fp32") + x4 = mb.cast(x=x3, dtype="fp32", name="x4") x5 = mb.relu(x=x) - x6 = mb.cast(x=x5, dtype="fp32") + x6 = mb.cast(x=x5, dtype="fp32", name="x6") x7 = mb.relu(x=x6) return x2, x4, x7 @@ -2195,6 +2321,11 @@ def prog(x): "cast", ] + cast_ops = block.find_ops(op_type="cast") + assert cast_ops[1].outputs[0].name == "x6" + assert cast_ops[2].outputs[0].name == "x4" + assert cast_ops[3].outputs[0].name == "x2" + assert_model_is_valid( prog, {"x": (10, 20)}, diff --git a/coremltools/converters/mil/mil/passes/tests/test_pass_pipeline.py b/coremltools/converters/mil/mil/passes/tests/test_pass_pipeline.py index 58687cea8..30baf13e5 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_pass_pipeline.py +++ b/coremltools/converters/mil/mil/passes/tests/test_pass_pipeline.py @@ -117,3 +117,10 @@ def test_list_available_pipelines(self): assert len(available_pipelines) == 12 assert "default" in available_pipelines assert "default_palettization" in available_pipelines + + @staticmethod + def test_get_pipeline_should_use_copy(): + pipeline = PassPipeline.DEFAULT_PRUNING + pipeline.append_pass("compression::palettize_weights") + pipeline_2 = PassPipeline.DEFAULT_PRUNING + assert "compression::palettize_weights" not in pipeline_2.passes diff --git a/coremltools/converters/mil/mil/passes/tests/test_passes.py b/coremltools/converters/mil/mil/passes/tests/test_passes.py index 1187cd915..fed90fd55 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_passes.py +++ b/coremltools/converters/mil/mil/passes/tests/test_passes.py @@ -9,6 +9,7 @@ import numpy as np import pytest +import torch import coremltools as ct import coremltools.optimize as cto @@ -37,6 +38,9 @@ get_op_types_in_program, ) from coremltools.models.utils import _macos_version +from coremltools.test.optimize.coreml.test_post_training_quantization import ( + get_test_model_and_data_complex, +) np.random.seed(1984) _VALIDATE_MODEL = True @@ -5661,6 +5665,26 @@ def test_nn_backend_style_sanitization(self): relu_layer = spec.neuralNetwork.layers[0] assert relu_layer.output[0] == "relu/1" + @staticmethod + def test_sanitize_input_named_state(): + @mb.program( + input_specs=[ + mb.StateTensorSpec((2, 3), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def prog(state): + return mb.read_state(input=state) + + _, _, block = apply_pass_and_basic_check( + prog, + "common::sanitize_input_output_names", + skip_input_name_check=True, + ) + + assert len(block.inputs) == 1 + assert "state_workaround" in block.inputs + assert block.inputs["state_workaround"].name == "state_workaround" class TestUpdateOutputDtypes: def test_single_output(self): @@ -5821,8 +5845,10 @@ def prog(x): prog, {"x": shape}, expected_output_shapes={block.outputs[0].name: shape} ) - @pytest.mark.parametrize("with_affine", [True, False]) - def test_ane_layer_norm(self, with_affine): + @pytest.mark.parametrize( + "with_affine, constexpr_beta", itertools.product([True, False], [True, False]) + ) + def test_ane_layer_norm(self, with_affine, constexpr_beta): """ Detect layer norm pattern, found in models based on ml-ane-transformers. @@ -5836,7 +5862,7 @@ def test_ane_layer_norm(self, with_affine): """ shape = (3, 5, 1, 6) - @mb.program(input_specs=[mb.TensorSpec(shape=shape)]) + @mb.program(input_specs=[mb.TensorSpec(shape=shape)], opset_version=ct.target.iOS16) def prog(x): x1 = mb.reduce_mean(x=x, axes=[1], keep_dims=True) # mean x2 = mb.sub(x=x, y=x1) # x - mean @@ -5847,7 +5873,14 @@ def prog(x): y = mb.mul(x=x2, y=x6) # (x - mean) * rsqrt(variance + eps) if with_affine: - y = mb.add(x=y, y=np.random.rand(1, shape[1], 1, 1)) + beta = np.random.rand(1, shape[1], 1, 1) + if constexpr_beta: + beta = mb.constexpr_lut_to_dense( + lut=np.arange(2, dtype=np.float32), + indices=np.ones((int(np.ceil(np.prod(beta.shape) / 8)),), dtype=np.uint8), + shape=np.array(beta.shape, dtype=np.uint32), + ) + y = mb.add(x=y, y=beta) y = mb.mul(x=y, y=np.random.rand(1, shape[1], 1, 1)) return y @@ -5855,7 +5888,7 @@ def prog(x): prev_prog, prev_block, block = apply_pass_and_basic_check( prog, "common::fuse_layernorm_or_instancenorm" ) - assert get_op_types_in_program(prev_prog) == [ + prev_expected_ops = [ "reduce_mean", "sub", "mul", @@ -5863,10 +5896,22 @@ def prog(x): "add", "rsqrt", "mul", - ] + (["add", "mul"] if with_affine else []) - assert get_op_types_in_program(prog) == ["layer_norm"] + ] + if with_affine: + if constexpr_beta: + prev_expected_ops.append("constexpr_lut_to_dense") + prev_expected_ops += ["add", "mul"] + + assert get_op_types_in_program(prev_prog) == prev_expected_ops + if with_affine and constexpr_beta: + assert get_op_types_in_program(prog) == get_op_types_in_program(prev_prog) + else: + assert get_op_types_in_program(prog) == ["layer_norm"] assert_model_is_valid( - prog, {"x": shape}, expected_output_shapes={block.outputs[0].name: shape} + prog, + {"x": shape}, + expected_output_shapes={block.outputs[0].name: shape}, + minimum_deployment_target=ct.target.iOS16, ) @pytest.mark.parametrize("with_affine", [True, False]) @@ -6994,3 +7039,76 @@ def prog(x): assert transpose_op.scopes[ScopeSource.TORCHSCRIPT_MODULE_TYPE] == [ "module_2", ] + + +class TestRandomizeWeights: + @staticmethod + def assert_weights_changed(prog1, prog2): + changed = False + const_ops_before = prog1.find_ops(op_type="const") + const_ops_after = prog2.find_ops(op_type="const") + assert len(const_ops_before) == len(const_ops_after) + for i, op in enumerate(const_ops_before): + weight_before = op.outputs[0].val + weight_after = const_ops_after[i].outputs[0].val + if not np.array_equal(weight_before, weight_after): + changed = True + break + assert changed + + @staticmethod + def test_randomize_weights_pass(): + """ + Test the WeightRandomizer graph pass + + const + | + v + input -----> matmul -----> out + + const needs to large enough that should_use_weight_file==True + + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 10))]) + def prog(x): + weights_val = np.random.rand(2, 10).T.astype(np.float32) + weights = mb.const(val=weights_val) + + return mb.matmul(x=x, y=weights) + + prev_prog, prev_block, block = apply_pass_and_basic_check(prog, "common::WeightRandomizer") + # check ops haven't changed + assert get_op_types_in_program(prog) == ["matmul"] + + # check the weights have changed + TestRandomizeWeights.assert_weights_changed(prev_prog, prog) + + @staticmethod + def test_utils_randomize_weights(): + """ + Test ct.models.utils.randomize_weights method end to end + """ + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + compute_precision=ct.precision.FLOAT32, + ) + + # randomize weights + randomized_mlmodel = ct.models.utils.randomize_weights(mlmodel) + + # get before/after mil + prog_before = mlmodel._mil_program + prog_after = randomized_mlmodel._mil_program + + # check ops haven't changed + assert get_op_types_in_program(prog_before) == get_op_types_in_program(prog_after) + assert prog_before.find_ops(op_type="conv")[1].weight.op.op_type == "const" + assert prog_after.find_ops(op_type="conv")[1].weight.op.op_type == "const" + + # check the weights have changed + TestRandomizeWeights.assert_weights_changed(prog_before, prog_after) diff --git a/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py b/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py index ebcd0d7e5..b4ef9062f 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py +++ b/coremltools/converters/mil/mil/passes/tests/test_quantization_passes.py @@ -1799,6 +1799,220 @@ def prog(x): assert get_op_types_in_program(prog) == ["dequantize"] +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), reason="Only supported on macOS 15+") +class TestReorderLutPerChannelScale: + @staticmethod + def _verify_numerical(prev_prog, prog, block, input_shape, rtol=1e-7, atol=0.0): + # Verify the numerical output matches before and after the reordering. + prev_model = ct.convert( + prev_prog, + pass_pipeline=ct.PassPipeline.EMPTY, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + model = ct.convert( + prog, + pass_pipeline=ct.PassPipeline.EMPTY, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + output_name = block.outputs[0].name + x_val = np.random.rand(*input_shape).astype(np.float16) + input_dict = {"x": x_val} + prev_output = prev_model.predict(input_dict)[output_name] + output = model.predict(input_dict)[output_name] + np.testing.assert_allclose(prev_output, output, rtol=rtol, atol=atol) + + @staticmethod + def _get_lut_pcs_weight(shape: Tuple[int, ...], nbits=4, scale_axis: int = 0): + """Get a specific shape of weight produced by lut with per-channel-scale (pcs).""" + num_palette = 2**nbits + np_dtype = types.nptype_from_builtin(types.string_to_builtin(f"uint{nbits}")) + indices = np.arange(np.prod(shape)).reshape(shape).astype(np_dtype) + lut_shape = shape + (num_palette, 1) + lut = np.arange(np.prod(lut_shape)).reshape(lut_shape).astype(np.float16) + + lut_op = mb.constexpr_lut_to_dense(indices=indices, lut=lut) + scale_shape = [1] * len(shape) + scale_shape[scale_axis] = shape[scale_axis] + scale_shape = tuple(scale_shape) + scale_val = np.arange(1, np.prod(scale_shape) + 1).reshape(scale_shape).astype(np.float16) + return mb.constexpr_blockwise_shift_scale( + data=lut_op, + scale=scale_val, + ) + + @pytest.mark.parametrize( + "input_shape, has_bias", itertools.product([(4, 3), (2, 3, 2), (1, 2, 3, 4)], [True, False]) + ) + def test_reorder_scale_linear(self, input_shape: Tuple[int, ...], has_bias: bool): + @mb.program( + input_specs=[mb.TensorSpec(shape=input_shape, dtype=types.fp16)], + opset_version=ct.target.iOS18, + ) + def prog(x): + scaled_weight = self._get_lut_pcs_weight((2, input_shape[-1])) + bias = np.array([20, 50], dtype=np.float16) if has_bias else None + output = mb.linear(x=x, weight=scaled_weight, bias=bias) + return mb.add(x=output, y=np.float16(1.0)) + + prev_prog, _, block = apply_pass_and_basic_check( + prog, "common::reorder_lut_per_channel_scale", skip_essential_scope_check=True + ) + assert get_op_types_in_program(prev_prog) == [ + "constexpr_lut_to_dense", + "constexpr_blockwise_shift_scale", + "linear", + "add", + ] + assert get_op_types_in_program(prog) == ["constexpr_lut_to_dense", "linear", "mul", "add"] + self._verify_numerical(prev_prog, prog, block, input_shape) + + @pytest.mark.parametrize( + "use_y_as_weight, transpose_x, transpose_y", + itertools.product([True, False], [True, False], [True, False]), + ) + def test_reorder_scale_matmul(self, use_y_as_weight, transpose_x, transpose_y): + input_shape = (3, 4) + + @mb.program( + input_specs=[mb.TensorSpec(shape=input_shape, dtype=types.fp16)], + opset_version=ct.target.iOS18, + ) + def prog(x): + if use_y_as_weight: + if transpose_x: # x shape is (4, 3) + weight_shape = (2, 3) if transpose_y else (3, 2) + else: # x shape is (3, 4) + weight_shape = (2, 4) if transpose_y else (4, 2) + scaled_weight = self._get_lut_pcs_weight( + weight_shape, scale_axis=0 if transpose_y else 1 + ) + output = mb.matmul( + x=x, y=scaled_weight, transpose_x=transpose_x, transpose_y=transpose_y + ) + else: + if transpose_y: # y shape is (4, 3) + weight_shape = (4, 2) if transpose_x else (2, 4) + else: # y shape is (3, 4) + weight_shape = (3, 2) if transpose_x else (2, 3) + scaled_weight = self._get_lut_pcs_weight( + weight_shape, scale_axis=1 if transpose_x else 0 + ) + output = mb.matmul( + x=scaled_weight, y=x, transpose_x=transpose_x, transpose_y=transpose_y + ) + return mb.add(x=output, y=np.float16(1.0)) + + prev_prog, _, block = apply_pass_and_basic_check( + prog, "common::reorder_lut_per_channel_scale", skip_essential_scope_check=True + ) + assert get_op_types_in_program(prev_prog) == [ + "constexpr_lut_to_dense", + "constexpr_blockwise_shift_scale", + "matmul", + "add", + ] + assert get_op_types_in_program(prog) == ["constexpr_lut_to_dense", "matmul", "mul", "add"] + self._verify_numerical(prev_prog, prog, block, input_shape) + + @pytest.mark.parametrize( + "pad_type, has_bias, has_strides_dilations", + itertools.product(["valid", "same", "same_lower", "custom"], [True, False], [True, False]), + ) + def test_reorder_scale_conv(self, pad_type, has_bias, has_strides_dilations): + input_shape = (4, 3, 4, 3) + + @mb.program( + input_specs=[mb.TensorSpec(shape=input_shape, dtype=types.fp16)], + opset_version=ct.target.iOS18, + ) + def prog(x): + scaled_weight = self._get_lut_pcs_weight((2, 3, 2, 2), nbits=6) + bias = np.array([20, 50], dtype=np.float16) if has_bias else None + pad = [1, 1, 1, 1] if pad_type == "custom" else None + strides = [1, 2] if has_strides_dilations else None + dilations = [1, 2] if has_strides_dilations else None + output = mb.conv( + x=x, + weight=scaled_weight, + strides=strides, + pad_type=pad_type, + pad=pad, + dilations=dilations, + bias=bias, + ) + return mb.add(x=output, y=np.float16(1.0)) + + prev_prog, _, block = apply_pass_and_basic_check( + prog, "common::reorder_lut_per_channel_scale", skip_essential_scope_check=True + ) + assert get_op_types_in_program(prev_prog) == [ + "constexpr_lut_to_dense", + "constexpr_blockwise_shift_scale", + "conv", + "add", + ] + assert get_op_types_in_program(prog) == ["constexpr_lut_to_dense", "conv", "mul", "add"] + self._verify_numerical(prev_prog, prog, block, input_shape) + + @pytest.mark.parametrize( + "input_shape, has_bias", itertools.product([(4, 3), (2, 3, 2), (1, 2, 3, 4)], [True, False]) + ) + def test_reorder_multiple_usages(self, input_shape: Tuple[int, ...], has_bias: bool): + """The scaled weight is used by multiple ops.""" + + @mb.program( + input_specs=[mb.TensorSpec(shape=input_shape, dtype=types.fp16)], + opset_version=ct.target.iOS18, + ) + def prog(x): + scaled_weight = self._get_lut_pcs_weight((2, input_shape[-1])) + bias = np.array([20, 50], dtype=np.float16) if has_bias else None + linear_output = mb.linear(x=x, weight=scaled_weight, bias=bias) + matmul_output = mb.matmul(x=x, y=scaled_weight, transpose_x=False, transpose_y=True) + return mb.add(x=linear_output, y=matmul_output) + + prev_prog, _, block = apply_pass_and_basic_check( + prog, "common::reorder_lut_per_channel_scale", skip_essential_scope_check=True + ) + assert get_op_types_in_program(prev_prog) == [ + "constexpr_lut_to_dense", + "constexpr_blockwise_shift_scale", + "linear", + "matmul", + "add", + ] + assert get_op_types_in_program(prog) == [ + "constexpr_lut_to_dense", + "linear", + "mul", + "matmul", + "mul", + "add", + ] + self._verify_numerical(prev_prog, prog, block, input_shape) + + def test_reorder_not_happen(self): + """The scale won't be moved when the scaled weight is used in unsupported ops.""" + + @mb.program( + input_specs=[mb.TensorSpec(shape=(4, 16), dtype=types.fp16)], + opset_version=ct.target.iOS18, + ) + def prog(x): + scaled_weight = self._get_lut_pcs_weight((2, 16)) + linear_output1 = mb.linear(x=x, weight=scaled_weight) + add_out = mb.add(x=scaled_weight, y=np.float16(1.0)) + linear_output2 = mb.linear(x=x, weight=add_out) + return mb.add(x=linear_output1, y=linear_output2) + + prev_prog, _, block = apply_pass_and_basic_check( + prog, "common::reorder_lut_per_channel_scale", skip_essential_scope_check=True + ) + assert get_op_types_in_program(prog) == get_op_types_in_program(prev_prog) + + class TestFP16CastTransform: def assertEqual(self, first, second): """A convenience method to migrate from unittest (self.assertEqual) to pytest.""" @@ -1932,7 +2146,7 @@ def prog(x): ) mlmodel = ct.convert(prog, compute_units=ct.ComputeUnit.CPU_ONLY) - input_dict = {"x": np.random.rand(10, 20)} + input_dict = {"x": np.random.rand(10, 20) * 1e-3} if _IS_MACOS: prediction = mlmodel.predict(input_dict) diff --git a/coremltools/converters/mil/mil/program.py b/coremltools/converters/mil/mil/program.py index 2d142e9e4..3f1f45924 100644 --- a/coremltools/converters/mil/mil/program.py +++ b/coremltools/converters/mil/mil/program.py @@ -31,6 +31,7 @@ def _get_opset_str_value(op): def __init__(self): self.functions = {} self.skip_all_passes = False + self.default_function_name = "main" def _add_essential_scope_source( self, scope_source: Union[ScopeSource, List[ScopeSource]] @@ -119,8 +120,10 @@ def _check_program_opset_version(self): def _get_runtime_supported_dialect_opset() -> List[str]: """ Return a list of supported dialect opsets at runtime. + Right now, we are allowing ``coreml``, until we fix this radar: + rdar://114737210 ([Infra] Handle control flow mechanism in coremltools) """ - return [] + return ["coreml"] def _check_invalid_opset(self): """ @@ -234,6 +237,27 @@ def validate(self, check_essential_scope: Optional[bool] = False) -> None: for f in self.functions.values(): f.validate(force_validate=True, check_essential_scope=check_essential_scope) + def stringify_stack_trace(self) -> str: + result = "" + for function_name, function in self.functions.items(): + if ScopeSource.EXIR_STACK_TRACE not in function._essential_scope_sources: + raise NotImplementedError( + f"Function ({function_name}) must have EXIR_STACK_TRACE as an essential scope source." + ) + for operation in function.operations: + # TODO (rdar://115846569): Handle multi-block case from EXIR + if len(operation.blocks) > 0: + raise NotImplementedError("Multi-block case has not been supported yet") + stack_trace = operation.scopes[ScopeSource.EXIR_STACK_TRACE] + if stack_trace is None: + continue + stack_trace = stack_trace[0] + result += ( + f"{operation.op_type} : {operation.outputs[0].name}\n" + f"{stack_trace}\n" + ) + return result + def construct_debug_handle_to_ops_mapping(self) -> Dict: """ For PyMIL program translated from ExecuTorch only: Based on scope info inherited from EXIR, @@ -352,6 +376,34 @@ def _infer_output_var(self): # List of output vars (consistent w/ other ops) self.outputs = [Var(self.name, sym_type)] + +class StateTensorPlaceholder(Placeholder): + counter = 0 + + def __init__(self, sym_shape, dtype=None, name=None): + """ + A placeholder with a state wrapping a tensor. + + Parameters + ---------- + sym_shape: list, tuple + * shape of the tensor. + dtype: type + * types.float or other scalar builtin types. + name: str + * name of the placeholder + """ + self.sym_shape = sym_shape + if dtype is None: + dtype = types.fp32 + self.dtype = dtype + self.name = name + self._infer_output_var() + + def type_inference(self): + wrapped_tensor_type = types.tensor(self.dtype, self.sym_shape) + return types.state(wrapped_tensor_type) + def get_new_variadic_symbol(): global k_num_internal_syms s = Symbol("*is" + str(k_num_internal_syms)) diff --git a/coremltools/converters/mil/mil/scope.py b/coremltools/converters/mil/mil/scope.py index cc65b3f93..52744170a 100644 --- a/coremltools/converters/mil/mil/scope.py +++ b/coremltools/converters/mil/mil/scope.py @@ -36,6 +36,10 @@ class ScopeSource(Enum): and then undergoes "add_fp16_cast". # Torch export related: + EXIR_STACK_TRACE: + * The ``stack_trace`` metadata inherited from torch.fx.Node.meta in EXIR + * This metadata traces the MIL op back to original python source code + EXIR_DEBUG_HANDLE: * The ``debug_handle`` metadata inherited from torch.fx.Node.meta in EXIR * This metadata enables post-run analysis in ExecuTorch integration @@ -77,7 +81,8 @@ def forward(self, x): TORCHSCRIPT_MODULE_TYPE = 0 TORCHSCRIPT_MODULE_NAME = 1 COREMLTOOLS_GRAPH_PASS = 2 - EXIR_DEBUG_HANDLE = 3 + EXIR_STACK_TRACE = 3 # no serialization for such debug info should be allowed yet + EXIR_DEBUG_HANDLE = 4 class ScopeStack(defaultdict): diff --git a/coremltools/converters/mil/mil/tests/test_programs.py b/coremltools/converters/mil/mil/tests/test_programs.py index 5230600b6..e6e449c12 100644 --- a/coremltools/converters/mil/mil/tests/test_programs.py +++ b/coremltools/converters/mil/mil/tests/test_programs.py @@ -910,13 +910,17 @@ def test_EXIR_scope_handling(): # default list type @mb.program(input_specs=[mb.TensorSpec(shape=(2, 3))]) def prog(x): - with mb.scope(ScopeInfo(source=ScopeSource.EXIR_DEBUG_HANDLE, data=[1])): + with mb.scope( + ScopeInfo(source=ScopeSource.EXIR_STACK_TRACE, data=["x + 0.0"]), + ScopeInfo(source=ScopeSource.EXIR_DEBUG_HANDLE, data=[1]), + ): return mb.add(x=x, y=0.0) add_op_1 = prog.find_ops(op_type="add")[0] + assert add_op_1.scopes[ScopeSource.EXIR_STACK_TRACE] == ["x + 0.0"] assert add_op_1.scopes[ScopeSource.EXIR_DEBUG_HANDLE] == [1] - # data cannot have len > 1 + # debug handle data cannot have len > 1 @mb.program(input_specs=[mb.TensorSpec(shape=(2, 3))]) def prog(x): with pytest.raises(ValueError, match="EXIR_DEBUG_HANDLE scope cannot have len > 1."): @@ -929,9 +933,12 @@ def prog(x): def prog(x): with mb.scope(ScopeInfo(source=ScopeSource.EXIR_DEBUG_HANDLE, data=[None])): with mb.scope(ScopeInfo(source=ScopeSource.EXIR_DEBUG_HANDLE, data=[0])): - return mb.add(x=x, y=0.0) + with mb.scope(ScopeInfo(source=ScopeSource.EXIR_STACK_TRACE, data=["x + 0.0"])): + with mb.scope(ScopeInfo(source=ScopeSource.EXIR_STACK_TRACE, data=[None])): + return mb.add(x=x, y=0.0) add_op_1 = prog.find_ops(op_type="add")[0] + assert add_op_1.scopes[ScopeSource.EXIR_STACK_TRACE] == ["x + 0.0", None] assert add_op_1.scopes[ScopeSource.EXIR_DEBUG_HANDLE] == [None, 0] @staticmethod @@ -1949,15 +1956,18 @@ def prog(x): ScopeSource.COREMLTOOLS_GRAPH_PASS: ["pass_1"], } - # Case 4: essential scope set to EXIR_DEBUG_HANDLE + # Case 4: essential scope set to EXIR_STACK_TRACE and EXIR_DEBUG_HANDLE @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) def prog(x): with mb.scope( + ScopeInfo(source=ScopeSource.EXIR_STACK_TRACE, data=["torch.sin(x)"]), ScopeInfo(source=ScopeSource.EXIR_DEBUG_HANDLE, data=[1]), ): return mb.sin(x=x) - prog._add_essential_scope_source(ScopeSource.EXIR_DEBUG_HANDLE) + prog._add_essential_scope_source( + [ScopeSource.EXIR_STACK_TRACE, ScopeSource.EXIR_DEBUG_HANDLE] + ) block = prog.functions["main"] ops = list(block.operations) @@ -1971,6 +1981,7 @@ def prog(x): block._replace_var(block.inputs["x"], relu) assert relu.scopes == { + ScopeSource.EXIR_STACK_TRACE: [None], ScopeSource.EXIR_DEBUG_HANDLE: [None], ScopeSource.COREMLTOOLS_GRAPH_PASS: ["pass_1"], } diff --git a/coremltools/converters/mil/mil/tests/test_types.py b/coremltools/converters/mil/mil/tests/test_types.py index 134f64cbc..c358a14a7 100644 --- a/coremltools/converters/mil/mil/tests/test_types.py +++ b/coremltools/converters/mil/mil/tests/test_types.py @@ -3,10 +3,64 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import numpy as np import pytest +from coremltools import ImageType, StateType, TensorType from coremltools.converters.mil.mil import types from coremltools.converters.mil.mil.types import type_mapping +from coremltools.optimize.coreml import _utils as optimize_utils + + +class TestTypes: + def test_sub_byte_type(self): + assert types.is_int(types.int4) + assert types.is_int(types.uint1) + assert types.is_int(types.uint2) + assert types.is_int(types.uint3) + assert types.is_int(types.uint4) + assert types.is_int(types.uint6) + assert types.is_int(types.int8) + + assert types.is_sub_byte(types.int4) + assert types.is_sub_byte(types.uint1) + assert types.is_sub_byte(types.uint2) + assert types.is_sub_byte(types.uint3) + assert types.is_sub_byte(types.uint4) + assert types.is_sub_byte(types.uint6) + assert not types.is_sub_byte(types.int8) + assert not types.is_sub_byte(types.uint8) + + int4_instance = types.int4() + uint1_instance = types.uint1() + uint2_instance = types.uint2() + uint3_instance = types.uint3() + uint4_instance = types.uint4() + uint6_instance = types.uint6() + int8_instance = types.int8() + assert types.is_sub_byte(int4_instance) + assert types.is_sub_byte(uint1_instance) + assert types.is_sub_byte(uint2_instance) + assert types.is_sub_byte(uint3_instance) + assert types.is_sub_byte(uint4_instance) + assert types.is_sub_byte(uint6_instance) + assert not types.is_sub_byte(int8_instance) + + def test_state_type_with_tensor(self): + state_wrapped_type = types.tensor(types.int32, (2, 3)) + state_type = types.state(state_wrapped_type) + assert types.is_state(state_type) + assert state_type.wrapped_type() == state_wrapped_type + + def test_numpy_type_to_builtin_type(self): + assert types.numpy_type_to_builtin_type(np.float32) == types.fp32 + assert types.numpy_type_to_builtin_type(np.float16) == types.fp16 + assert types.numpy_type_to_builtin_type(np.int32) == types.int32 + assert types.numpy_type_to_builtin_type(np.int16) == types.int16 + assert types.numpy_type_to_builtin_type(np.int8) == types.int8 + assert types.numpy_type_to_builtin_type(types.np_int4_dtype) == types.int4 + assert types.numpy_type_to_builtin_type(types.np_uint4_dtype) == types.uint4 + assert types.numpy_type_to_builtin_type(types.np_uint3_dtype) == types.uint3 class TestTypeMapping: @@ -25,3 +79,39 @@ def test_promote_dtypes_different_input_sizes(self, input_size): type_mapping.promote_dtypes([types.int32, types.int64, types.int16] * input_size) == types.int64 ) + + def test_np_val_to_py_type(self): + assert types.type_mapping.np_val_to_py_type(np.array([True, False])) == (True, False) + assert types.type_mapping.np_val_to_py_type(np.array(32, dtype=np.int32)) == 32 + + # Sub-byte conversion. + int4_array = np.array([1, 2]).reshape([1, 2, 1]).astype(types.np_int4_dtype) + py_bytes = types.type_mapping.np_val_to_py_type(int4_array) + assert len(py_bytes) == 1 # Two 4-bit elements should only take 1 byte. + restored_array = optimize_utils.restore_elements_from_packed_bits( + np.frombuffer(py_bytes, dtype=np.uint8), + nbits=4, + element_num=2, + are_packed_values_signed=True, + ) + np.testing.assert_array_equal(restored_array.reshape([1, 2, 1]), int4_array) + + +class TestInputTypes: + def test_state_type(self): + state_type = StateType(name="x", wrapped_type=TensorType(shape=(2, 3), dtype=np.float32)) + assert state_type.name == "x" + assert state_type.shape.shape == (2, 3) + + def test_state_type_invalid_wrapped_type(self): + wrapped_type = ImageType(shape=(1, 3, 3, 3)) + with pytest.raises(ValueError, match="StateType only supports"): + StateType(wrapped_type=wrapped_type) + + with pytest.raises(ValueError, match="name cannot be set in the state wrapped_type"): + StateType(wrapped_type=TensorType(name="x", shape=(2, 3))) + + with pytest.raises( + ValueError, match="default_value cannot be set in the state wrapped_type" + ): + StateType(wrapped_type=TensorType(shape=(3,), default_value=np.array([0.0, 0.0, 0.0]))) diff --git a/coremltools/converters/mil/mil/types/__init__.py b/coremltools/converters/mil/mil/types/__init__.py index b49028e6e..95c3c113b 100644 --- a/coremltools/converters/mil/mil/types/__init__.py +++ b/coremltools/converters/mil/mil/types/__init__.py @@ -11,12 +11,39 @@ from .type_dict import dict, empty_dict from .type_double import double, float, fp16, fp32, fp64, is_float from .type_globals_pseudo_type import globals_pseudo_type -from .type_int import int8, int16, int32, int64, is_int, uint, uint8, uint16, uint32, uint64 +from .type_int import ( + _SUB_BYTE_TYPES, + SUB_BYTE_DTYPE_METADATA_KEY, + int4, + int8, + int16, + int32, + int64, + is_int, + is_sub_byte, + np_int4_dtype, + np_uint1_dtype, + np_uint2_dtype, + np_uint3_dtype, + np_uint4_dtype, + np_uint6_dtype, + uint, + uint1, + uint2, + uint3, + uint4, + uint6, + uint8, + uint16, + uint32, + uint64, +) from .type_list import empty_list, is_list, list from .type_mapping import ( BUILTIN_TO_PROTO_TYPES, PROTO_TO_BUILTIN_TYPE, builtin_to_string, + get_nbits_int_builtin_type, is_builtin, is_dict, is_primitive, @@ -34,6 +61,7 @@ string_to_builtin, type_to_builtin_type, ) +from .type_state import is_state, state from .type_str import str from .type_tensor import ( is_compatible_type, diff --git a/coremltools/converters/mil/mil/types/type_int.py b/coremltools/converters/mil/mil/types/type_int.py index bcecd57a9..7f415b944 100644 --- a/coremltools/converters/mil/mil/types/type_int.py +++ b/coremltools/converters/mil/mil/types/type_int.py @@ -157,19 +157,57 @@ def __neg__(self): return int +int4 = make_int(4, "") int8 = make_int(8, "") int16 = make_int(16, "") int32 = make_int(32, "") int64 = make_int(64, "") +uint1 = make_int(1, "u") +uint2 = make_int(2, "u") +uint3 = make_int(3, "u") +uint4 = make_int(4, "u") +uint6 = make_int(6, "u") uint8 = make_int(8, "u") uint16 = make_int(16, "u") uint32 = make_int(32, "u") uint64 = make_int(64, "u") uint = uint64 -_INT_TYPES = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_INT_TYPES = ( + int4, + int8, + int16, + int32, + int64, + uint1, + uint2, + uint3, + uint4, + uint6, + uint8, + uint16, + uint32, + uint64, +) + +# The key name for storing type info in `np.dtype.metadata`. +SUB_BYTE_DTYPE_METADATA_KEY = "true_dtype" +# Uses np.int8/uint8 as np doesn't natively support sub-byte type (such as int4/uint4) yet. +np_int4_dtype = np.dtype(np.int8, metadata={SUB_BYTE_DTYPE_METADATA_KEY: int4}) +np_uint1_dtype = np.dtype(np.uint8, metadata={SUB_BYTE_DTYPE_METADATA_KEY: uint1}) +np_uint2_dtype = np.dtype(np.uint8, metadata={SUB_BYTE_DTYPE_METADATA_KEY: uint2}) +np_uint3_dtype = np.dtype(np.uint8, metadata={SUB_BYTE_DTYPE_METADATA_KEY: uint3}) +np_uint4_dtype = np.dtype(np.uint8, metadata={SUB_BYTE_DTYPE_METADATA_KEY: uint4}) +np_uint6_dtype = np.dtype(np.uint8, metadata={SUB_BYTE_DTYPE_METADATA_KEY: uint6}) + +_SUB_BYTE_TYPES = (int4, uint1, uint2, uint3, uint4, uint6) def is_int(t): return any(t is i or isinstance(t, i) for i in _INT_TYPES) + + +def is_sub_byte(t): + """Determines if a type (or instance) is sub-byte (less than 8-bit data type).""" + return t in _SUB_BYTE_TYPES or isinstance(t, _SUB_BYTE_TYPES) diff --git a/coremltools/converters/mil/mil/types/type_mapping.py b/coremltools/converters/mil/mil/types/type_mapping.py index 78e82a18a..bd278abbb 100644 --- a/coremltools/converters/mil/mil/types/type_mapping.py +++ b/coremltools/converters/mil/mil/types/type_mapping.py @@ -24,11 +24,27 @@ from .type_double import fp32 as types_fp32 from .type_double import fp64 as types_fp64 from .type_double import is_float +from .type_int import SUB_BYTE_DTYPE_METADATA_KEY +from .type_int import int4 as types_int4 from .type_int import int8 as types_int8 from .type_int import int16 as types_int16 from .type_int import int32 as types_int32 from .type_int import int64 as types_int64 -from .type_int import is_int +from .type_int import ( + is_int, + is_sub_byte, + np_int4_dtype, + np_uint1_dtype, + np_uint2_dtype, + np_uint3_dtype, + np_uint4_dtype, + np_uint6_dtype, +) +from .type_int import uint1 as types_uint1 +from .type_int import uint2 as types_uint2 +from .type_int import uint3 as types_uint3 +from .type_int import uint4 as types_uint4 +from .type_int import uint6 as types_uint6 from .type_int import uint8 as types_uint8 from .type_int import uint16 as types_uint16 from .type_int import uint32 as types_uint32 @@ -39,10 +55,16 @@ _TYPES_TO_NPTYPES = { types_bool: np.bool_, + types_int4: np_int4_dtype, types_int8: np.int8, types_int16: np.int16, types_int32: np.int32, types_int64: np.int64, + types_uint1: np_uint1_dtype, + types_uint2: np_uint2_dtype, + types_uint3: np_uint3_dtype, + types_uint4: np_uint4_dtype, + types_uint6: np_uint6_dtype, types_uint8: np.uint8, types_uint16: np.uint16, types_uint32: np.uint32, @@ -75,10 +97,16 @@ _TYPES_TO_STRINGS = { types_bool: "bool", + types_int4: "int4", types_int8: "int8", types_int16: "int16", types_int32: "int32", types_int64: "int64", + types_uint1: "uint1", + types_uint2: "uint2", + types_uint3: "uint3", + types_uint4: "uint4", + types_uint6: "uint6", types_uint8: "uint8", types_uint16: "uint16", types_uint32: "uint32", @@ -93,7 +121,13 @@ _TYPES_TO_RESOLUTION = { types_bool: 1, + types_int4: 1, types_int8: 1, + types_uint1: 1, + types_uint2: 1, + types_uint3: 1, + types_uint4: 1, + types_uint6: 1, types_uint8: 1, types_int16: 1, types_uint16: 1, @@ -108,7 +142,13 @@ _TYPES_TO_RANGE = { types_bool: RangeTuple(0, 1), + types_int4: RangeTuple(np.iinfo(np.int8).min >> 4, np.iinfo(np.int8).max >> 4), types_int8: RangeTuple(np.iinfo(np.int8).min, np.iinfo(np.int8).max), + types_uint1: RangeTuple(np.iinfo(np.uint8).min >> 7, np.iinfo(np.uint8).max >> 7), + types_uint2: RangeTuple(np.iinfo(np.uint8).min >> 6, np.iinfo(np.uint8).max >> 6), + types_uint3: RangeTuple(np.iinfo(np.uint8).min >> 5, np.iinfo(np.uint8).max >> 5), + types_uint4: RangeTuple(np.iinfo(np.uint8).min >> 4, np.iinfo(np.uint8).max >> 4), + types_uint6: RangeTuple(np.iinfo(np.uint8).min >> 2, np.iinfo(np.uint8).max >> 2), types_uint8: RangeTuple(np.iinfo(np.uint8).min, np.iinfo(np.uint8).max), types_int16: RangeTuple(np.iinfo(np.int16).min, np.iinfo(np.int16).max), types_uint16: RangeTuple(np.iinfo(np.uint16).min, np.iinfo(np.uint16).max), @@ -129,7 +169,13 @@ types_fp64: _mil_pm.FLOAT64, # int + types_uint1: _mil_pm.UINT1, + types_uint2: _mil_pm.UINT2, + types_uint3: _mil_pm.UINT3, + types_uint4: _mil_pm.UINT4, + types_uint6: _mil_pm.UINT6, types_uint8: _mil_pm.UINT8, + types_int4: _mil_pm.INT4, types_int8: _mil_pm.INT8, types_uint16: _mil_pm.UINT16, @@ -160,6 +206,16 @@ def np_dtype_to_py_type(np_dtype): PROTO_TO_BUILTIN_TYPE = {v: k for k, v in BUILTIN_TO_PROTO_TYPES.items()} _STRINGS_TO_TYPES = {v: k for k, v in _TYPES_TO_STRINGS.items()} _STRINGS_TO_NPTYPES = {v: k for k, v in _NPTYPES_TO_STRINGS.items()} +_STRINGS_TO_NPTYPES.update( + { + "int4": np_int4_dtype, + "uint1": np_uint1_dtype, + "uint2": np_uint2_dtype, + "uint3": np_uint3_dtype, + "uint4": np_uint4_dtype, + "uint6": np_uint6_dtype, + } +) def string_to_builtin(s): """ @@ -344,6 +400,10 @@ def is_builtin(t): def _numpy_dtype_instance_to_builtin_type(np_dtype: np.dtype) -> Optional[type]: + metadata_dict = np_dtype.metadata + if metadata_dict is not None and SUB_BYTE_DTYPE_METADATA_KEY in metadata_dict: + return metadata_dict[SUB_BYTE_DTYPE_METADATA_KEY] + if np_dtype in _NPTYPES_TO_STRINGS: return string_to_builtin(_NPTYPES_TO_STRINGS[np_dtype]) return None @@ -489,6 +549,13 @@ def is_subtype(type1, type2): def _numpy_val_to_bytes(val: Union[np.ndarray, np.generic]) -> bytes: + # Import here to avoid circular import. + from coremltools.optimize.coreml import _utils as optimize_utils + + builtin_type = numpy_type_to_builtin_type(val.dtype) + if is_sub_byte(builtin_type): + val = optimize_utils.pack_elements_into_bits(val, builtin_type.get_bitwidth()) + return val.tobytes() def np_val_to_py_type(val): @@ -537,3 +604,9 @@ def infer_fp_dtype_from_complex(complex_dtype): return types_fp64 else: raise ValueError(f"Unsupported complex dtype ({complex_dtype}).") + + +def get_nbits_int_builtin_type(nbits: int, signed: True) -> type: + """Get the nbits int built-in type.""" + type_prefix = "u" if not signed else "" + return string_to_builtin(f"{type_prefix}int{nbits}") diff --git a/coremltools/converters/mil/mil/types/type_state.py b/coremltools/converters/mil/mil/types/type_state.py new file mode 100644 index 000000000..f64c5736c --- /dev/null +++ b/coremltools/converters/mil/mil/types/type_state.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from coremltools.converters.mil.mil.types.get_type_info import get_type_info +from coremltools.converters.mil.mil.types.type_spec import Type + + +def memoize(f): + memo = {} + + def helper(state_type): + if state_type not in memo: + memo[state_type] = f(state_type) + return memo[state_type] + + return helper + + +@memoize +def state(state_type): + class state: + T = [state_type] + + def __init__(self): + self.val = [] + + @property + def val(self): + return self._val + + @classmethod + def wrapped_type(cls): + return state_type + + @classmethod + def __type_info__(cls): + return Type("state", [get_type_info(state_type)], python_class=cls) + + state.__template_name__ = f"state[{get_type_info(state_type).name}]" + return state + + +def is_state(t): + if t is None: + return False + return get_type_info(t).name == "state" diff --git a/coremltools/converters/mil/mil/types/type_tensor.py b/coremltools/converters/mil/mil/types/type_tensor.py index 6782cbf6a..d60489e98 100644 --- a/coremltools/converters/mil/mil/types/type_tensor.py +++ b/coremltools/converters/mil/mil/types/type_tensor.py @@ -201,6 +201,12 @@ def is_compatible_type(type1, type2): """ Return if type1 and type2 are compatible. """ + # For single-element tensor, it's compatible with scalar. + if is_tensor(type1) and len(type1.get_shape()) == 0: + type1 = type1.get_primitive() + if is_tensor(type2) and len(type2.get_shape()) == 0: + type2 = type2.get_primitive() + if not is_subtype(type1, type2): is_comp, _ = is_tensor_and_is_compatible(type1, type2) return is_comp diff --git a/coremltools/converters/mil/mil/var.py b/coremltools/converters/mil/mil/var.py index 0d0756de8..c99a550d8 100644 --- a/coremltools/converters/mil/mil/var.py +++ b/coremltools/converters/mil/mil/var.py @@ -7,12 +7,14 @@ from collections import defaultdict from typing import Dict, List, Optional, Union +import numpy as np + from coremltools.converters.mil.mil import types -from coremltools.converters.mil.mil.types import builtin_to_string from coremltools.converters.mil.mil.types.symbolic import any_symbolic from .scope import ScopeSource + class Var: """ Var represents the outputs of an Operation. Most Vars are derived from an @@ -133,7 +135,21 @@ def __init__( self.is_descendant_of_const = Var._propagate_constness_upstream(self) def _adjust_sym_val(self): - pass + """For sub-byte dtype var, adjust the sym_val to make sure it reflects the true dtype.""" + if types.is_list(self.sym_type): + return + + if not types.is_sub_byte(self.dtype): + return + + if isinstance(self.sym_val, (np.generic, np.ndarray)): + np_val = self._sym_val.val + if ( + np_val.dtype.metadata is None + or types.SUB_BYTE_DTYPE_METADATA_KEY not in np_val.dtype.metadata + ): + target_np_dtype = types.nptype_from_builtin(self.dtype) + self._sym_val.val = np_val.astype(target_np_dtype) @property def nonreplaceable_vars_upstream(self): @@ -214,6 +230,10 @@ def sym_type(self): def shape(self): if types.is_tensor(self._sym_type): return self._sym_type.get_shape() + if types.is_state(self._sym_type): + wrapped_type = self._sym_type.wrapped_type() + assert types.is_tensor(wrapped_type), "only tensor type is supported in state type." + return wrapped_type.get_shape() return tuple() @property @@ -224,6 +244,10 @@ def rank(self): def dtype(self): if types.is_tensor(self._sym_type): return self._sym_type.get_primitive() + if types.is_state(self._sym_type): + wrapped_type = self._sym_type.wrapped_type() + assert types.is_tensor(wrapped_type), "only tensor type is supported in state type." + return wrapped_type.get_primitive() return self._sym_type @property @@ -275,10 +299,13 @@ def shape_str(self): def type_str(self): is_tensor = types.is_tensor(self.sym_type) is_list = types.is_list(self.sym_type) + is_state = types.is_state(self.sym_type) if is_tensor: type_string = "(Tensor)" elif is_list: type_string = "(List)" + elif is_state: + type_string = "(State)" else: type_string = "(Scalar)" return type_string @@ -288,8 +315,10 @@ def set_name(self, name): def is_tensor_or_scalar_of(self, dtype: Union[str, type]): if isinstance(dtype, type): - dtype = builtin_to_string(dtype) - return (types.is_tensor(self.sym_type) or types.is_scalar(self.sym_type)) and builtin_to_string(self.dtype) == dtype + dtype = types.builtin_to_string(dtype) + return ( + types.is_tensor(self.sym_type) or types.is_scalar(self.sym_type) + ) and types.builtin_to_string(self.dtype) == dtype def __str__(self): return "%" + self.name + ": " + self.shape_str() + self.type_str() diff --git a/coremltools/converters/mil/testing_reqs.py b/coremltools/converters/mil/testing_reqs.py index 455a5e5e9..02dfbffaf 100644 --- a/coremltools/converters/mil/testing_reqs.py +++ b/coremltools/converters/mil/testing_reqs.py @@ -17,7 +17,13 @@ _SUPPORTED_BACKENDS = ("neuralnetwork", "mlprogram") _SUPPORTED_PRECISIONS = ("fp32", "fp16") _SUPPORTED_OPSET_VERSIONS_NN = (ct.target.iOS14,) -_SUPPORTED_OPSET_VERSIONS_MLPROGRAM = (ct.target.iOS15, ct.target.iOS16, ct.target.iOS17) +_SUPPORTED_OPSET_VERSIONS_MLPROGRAM = ( + ct.target.iOS15, + ct.target.iOS16, + ct.target.iOS17, + ct.target.iOS18, +) + @define(frozen=True) class BackendConfig: diff --git a/coremltools/converters/mil/testing_utils.py b/coremltools/converters/mil/testing_utils.py index ee7a34908..4ab44c59e 100644 --- a/coremltools/converters/mil/testing_utils.py +++ b/coremltools/converters/mil/testing_utils.py @@ -38,8 +38,23 @@ ct.target.iOS15: 12, ct.target.iOS16: 13, ct.target.iOS17: 14, + ct.target.iOS18: 15, } +_COREMLTOOLS_DEBUG_SAVE_MLMODEL_DIRECTORY = "/tmp/coremltools_debug_save_mlmodel" + +debug_save_mlmodels = set() +debug_save_mlmodel_config_file_name = os.environ.get("DEBUG_SAVE_MLMODEL", "0") +if debug_save_mlmodel_config_file_name != "0": + if not os.path.isfile(debug_save_mlmodel_config_file_name): + raise ValueError("DEBUG_SAVE_MLMODEL must be the name of a config file with tests to save") + with open(debug_save_mlmodel_config_file_name, "r") as f: + lines = f.readlines() + for line in lines: + if line[0] == "#" or line == "\n": + continue + debug_save_mlmodels.add(line[:-1]) + hardcoded_einsum_equations: List[str] = [ # hardcoded cases "abcd,adce->abce", @@ -101,9 +116,36 @@ def macos_compatible_with_deployment_target(minimum_deployment_target): return True def _serialize_current_pytest(mlmodel): - class_name = os.environ.get('PYTEST_CURRENT_TEST').split("::")[1].strip() - test_name = "::".join(os.environ.get('PYTEST_CURRENT_TEST').split("::")[2:]).split("(call)")[0].strip() - mlpackage_path = "/tmp/pytest_failures/{}/{}/model.mlpackage".format(class_name, test_name) + """ + Usually pytest test name is of format file::class::test_function[param0-param1] (call)... + Assume each test produces only one Core ML model, + then file::class::test_function[param0-param1] is enough to determine unique name + {_COREMLTOOLS_DEBUG_SAVE_MLMODEL_DIRECTORY}/file/class/test_function/param0/param1/model.mlpackage + """ + mlpackage_path = _COREMLTOOLS_DEBUG_SAVE_MLMODEL_DIRECTORY + "/" + + PYTEST_CURRENT_TEST = os.environ.get("PYTEST_CURRENT_TEST").split("(call)")[0].strip() + test_name_fragments = PYTEST_CURRENT_TEST.split("::") + + for test_name_fragment in test_name_fragments[:-1]: + mlpackage_path += f"{test_name_fragment.strip()}/" + + test_name = test_name_fragments[-1] + # For a parameterized test, further decompose parameters into directories + if "[" in test_name and test_name[-1] == "]": + # Split test name with [] + bra_index = test_name.index("[") + test_function_name = test_name[:bra_index] + parameters = test_name[bra_index + 1 : -1].split("-") + # Append test function name and parameter to mlpackage path + mlpackage_path += f"{test_function_name}/" + for parameter in parameters: + mlpackage_path += f"{parameter}/" + else: + mlpackage_path += f"{test_name}/" + + mlpackage_path += "model.mlpackage" + Path(mlpackage_path).mkdir(parents=True, exist_ok=True) mlmodel.save(mlpackage_path) @@ -303,7 +345,7 @@ def to_tuple(v): return tuple(v) -def run_core_ml_predict(mlmodel, input_key_values): +def run_core_ml_predict(mlmodel, input_key_values, state=None): for k, v in input_key_values.items(): if isinstance(v, Image.Image): continue @@ -311,7 +353,7 @@ def run_core_ml_predict(mlmodel, input_key_values): input_key_values[k] = v.astype(np.float32) else: input_key_values[k] = np.array([v], dtype=np.float32) - return mlmodel.predict(input_key_values) + return mlmodel.predict(input_key_values, state=state) def _get_coreml_out_from_dict(out_dict, out_name): if out_name in out_dict: @@ -322,9 +364,10 @@ def _get_coreml_out_from_dict(out_dict, out_name): else: raise KeyError(f"{out_name} output not found in Core ML outputs") -def _get_proto_output_shape(spec, out_name): + +def _get_proto_output_shape(desc, out_name): sanitized_out_name = _NameSanitizer._replace_invalid_char_with_underscore(out_name) - for coreml_o in spec.description.output: + for coreml_o in desc.output: if coreml_o.name == sanitized_out_name: return coreml_o.type.multiArrayType.shape raise KeyError(f"{out_name} output not found in Core ML outputs") @@ -337,6 +380,7 @@ def compare_backend( atol=1e-04, rtol=1e-05, also_compare_shapes=True, + state=None, ): """ Inputs: @@ -353,7 +397,7 @@ def compare_backend( if dtype not in ["fp32", "fp16"]: raise ValueError("Unsupported dtype config") - pred = run_core_ml_predict(mlmodel, input_key_values) + pred = run_core_ml_predict(mlmodel, input_key_values, state) if also_compare_shapes: compare_shapes( mlmodel, @@ -419,10 +463,18 @@ def compare_shapes(mlmodel, input_key_values, expected_outputs, pred=None): # the output information in the mlprogram proto. spec = mlmodel.get_spec() if spec.WhichOneof("Type") == "mlProgram": + + if mlmodel._is_multifunction(): + desc = mlmodel._get_function_description(mlmodel.function_name) + else: + desc = spec.description + # The proto output and the runtime outputs are different for classifier - if spec.description.predictedFeatureName != "": + if desc.predictedFeatureName != "": continue - proto_shape = _get_proto_output_shape(spec, o) + + proto_shape = _get_proto_output_shape(desc, o) + if proto_shape != []: assert proto_shape == list( coreml_out.shape @@ -464,6 +516,13 @@ def ct_convert( if target == "neuralnetwork": compute_precision = None + PYTEST_CURRENT_TEST = os.environ.get("PYTEST_CURRENT_TEST").split("(call)")[0].strip() + is_current_test_to_be_debugged = PYTEST_CURRENT_TEST in debug_save_mlmodels + if is_current_test_to_be_debugged: + # If current test is to be debugged, then it is probably buggy in Core ML framework, + # so we skip its load to dodge potential bug which might kill python process + skip_model_load = True + mlmodel = converter( program, source=source, @@ -477,9 +536,9 @@ def ct_convert( **kwargs ) - if os.environ.get("DEBUG_SAVE_MLMODEL", "0") == "1": - from coremltools.converters.mil.testing_utils import _serialize_current_pytest + if is_current_test_to_be_debugged: _serialize_current_pytest(mlmodel) + pytest.xfail("This test is to be debugged") return mlmodel @@ -700,9 +759,10 @@ def verify_prediction(mlmodel, multiarray_type=None): input_dict[input_desc.name] = random_gen_input_feature_type(input_desc) if multiarray_type is not None: input_dict[input_desc.name] = input_dict[input].astype(multiarray_type) - res = mlmodel.predict(input_dict) + state = mlmodel.make_state() if mlmodel._is_stateful() else None + res = mlmodel.predict(input_dict, state=state) assert isinstance(res, dict) - assert len(res) >= 1 + assert len(res) == len(spec.description.output) def assert_spec_input_image_type(spec, expected_feature_type): assert spec.description.input[0].type.imageType.colorSpace == expected_feature_type @@ -730,3 +790,16 @@ def validate_minimum_deployment_target( pytest.skip( f"IOS{minimum_deployment_target} target is not runnable on this macOS {coremltoolsutils._macos_version()}" ) + + +def compute_snr_and_psnr(x, y): + assert len(x) == len(y) + eps = 1e-5 + eps2 = 1e-10 + noise = x - y + noise_var = np.sum(noise**2) / len(noise) + signal_energy = np.sum(y**2) / len(y) + max_signal_energy = np.amax(y**2) + snr = 10 * np.log10((signal_energy + eps) / (noise_var + eps2)) + psnr = 10 * np.log10((max_signal_energy + eps) / (noise_var + eps2)) + return snr, psnr diff --git a/coremltools/models/_compiled_model.py b/coremltools/models/_compiled_model.py index fe9f7a168..0539f4433 100644 --- a/coremltools/models/_compiled_model.py +++ b/coremltools/models/_compiled_model.py @@ -4,10 +4,13 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause from os.path import expanduser as _expanduser +from typing import Optional as _Optional + +from coremltools import ComputeUnit as _ComputeUnit +from coremltools.models.model import MLState as _MLState from .model import MLModel as _MLModel from .utils import _macos_version -from coremltools import ComputeUnit as _ComputeUnit try: from ..libcoremlpython import _MLModelProxy @@ -18,7 +21,7 @@ class CompiledMLModel: @staticmethod - def _init_check(path: str, compute_units: _ComputeUnit): + def _init_check(path: str, compute_units: _ComputeUnit, function_name: str): if _macos_version() < (10, 13): raise Exception("Loading compiled Core ML models is only support on macOS 10.13 or higher.") if _MLModelProxy is None: @@ -29,9 +32,15 @@ def _init_check(path: str, compute_units: _ComputeUnit): raise TypeError('The "path" parameter must be of type "str".') if not isinstance(compute_units, _ComputeUnit): raise TypeError('The "compute_units" parameter must be of type: "coremltools.ComputeUnit".') - - - def __init__(self, path: str, compute_units: _ComputeUnit =_ComputeUnit.ALL): + if not isinstance(function_name, str): + raise TypeError('The "function_name" parameter must be of type "str".') + + def __init__( + self, + path: str, + compute_units: _ComputeUnit = _ComputeUnit.ALL, + function_name: _Optional[str] = None, + ): """ Loads a compiled Core ML model. @@ -55,19 +64,21 @@ def __init__(self, path: str, compute_units: _ComputeUnit =_ComputeUnit.ALL): .. sourcecode:: python my_compiled_model = ct.models.CompiledMLModel("my_model_path.mlmodelc") - y = my_compiled_model.predict({'x': 3}) + y = my_compiled_model.predict({"x": 3}) See Also -------- predict """ - self._init_check(path, compute_units) + if function_name is None: + function_name = "" - path = _expanduser(path) - self._proxy = _MLModelProxy(path, compute_units.name) + self._init_check(path, compute_units, function_name) + path = _expanduser(path) + self._proxy = _MLModelProxy(path, compute_units.name, function_name) - def predict(self, data): + def predict(self, data, state: _Optional[_MLState] = None): """ Return predictions for the model. @@ -77,6 +88,9 @@ def predict(self, data): Dictionary of data to use for predictions, where the keys are the names of the input features. For batch predictons, use a list of such dictionaries. + state : MLState + Optional state object as returned by ``make_state()``. + Returns ------- dict[str, value] @@ -89,18 +103,35 @@ def predict(self, data): -------- .. sourcecode:: python - data = {'bedroom': 1.0, 'bath': 1.0, 'size': 1240} - predictions = model.predict(data) - - data = [ {'bedroom': 1.0, 'bath': 1.0, 'size': 1240}, - {'bedroom': 4.0, 'bath': 2.5, 'size': 2400} ] - batch_predictions = model.predict(data) + data = {"bedroom": 1.0, "bath": 1.0, "size": 1240} + predictions = model.predict(data) + + data = [ + {"bedroom": 1.0, "bath": 1.0, "size": 1240}, + {"bedroom": 4.0, "bath": 2.5, "size": 2400}, + ] + batch_predictions = model.predict(data) """ _MLModel._check_predict_data(data) return _MLModel._get_predictions( - self._proxy, - _MLModel._update_float16_multiarray_input_to_float32, - data + self._proxy, _MLModel._update_float16_multiarray_input_to_float32, data, state ) + + def make_state(self) -> _MLState: + """ + Returns a new state object, which can be passed to the ``predict`` method. + + Examples + -------- + .. sourcecode:: python + + state = model.make_state() + predictions = model.predict(x, state) + + See Also + -------- + predict + """ + return _MLState(self._proxy.newState()) diff --git a/coremltools/models/_deprecation.py b/coremltools/models/_deprecation.py index 2effb07d2..bce3fee84 100644 --- a/coremltools/models/_deprecation.py +++ b/coremltools/models/_deprecation.py @@ -24,7 +24,7 @@ def wrapped(*args, **kwargs): ) if suffix: msg += f"; {suffix}" - warnings.warn(msg, category=FutureWarning) + warnings.warn(msg, category=DeprecationWarning) return obj(*args, **kwargs) return wrapped diff --git a/coremltools/models/ml_program/compression_utils.py b/coremltools/models/ml_program/compression_utils.py index 7e566cc18..dabd16aab 100644 --- a/coremltools/models/ml_program/compression_utils.py +++ b/coremltools/models/ml_program/compression_utils.py @@ -5,7 +5,7 @@ import numpy as _np -from coremltools.converters.mil import Operation as _Operation +from coremltools.converters.mil.mil import Operation as _Operation from coremltools.models._deprecation import deprecated as _deprecated from coremltools.optimize.coreml import ( OpLinearQuantizerConfig as _OpLinearQuantizerConfig, diff --git a/coremltools/models/model.py b/coremltools/models/model.py index ac3dcebbc..f30e197d7 100644 --- a/coremltools/models/model.py +++ b/coremltools/models/model.py @@ -4,14 +4,13 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause import atexit as _atexit -from copy import deepcopy as _deepcopy import json import os as _os import shutil as _shutil import tempfile as _tempfile -from typing import Optional as _Optional import warnings as _warnings - +from copy import deepcopy as _deepcopy +from typing import Optional as _Optional import numpy as _np import numpy as _numpy @@ -28,11 +27,11 @@ _MLPACKAGE_AUTHOR_NAME, _MLPACKAGE_EXTENSION, _MODEL_FILE_NAME, - _WEIGHTS_DIR_NAME, _create_mlpackage, _has_custom_layer, _is_macos, _macos_version, + _try_get_weights_dir_path, ) from .utils import load_spec as _load_spec from .utils import save_spec as _save_spec @@ -139,22 +138,18 @@ def __iter__(self): yield f.name -def _try_get_weights_dir_path(mlpackage_path): - """ - Try to find the weights in mlpackage and return the path to the weights directory if found. - Return None if not found. - :param mlpackage_path: str, path to the mlpackage directory - :return: path to the weights directory inside the mlpackage directory - """ - weights_dir = None - try: - if _ModelPackage.isValid(mlpackage_path): - item_info = _ModelPackage(mlpackage_path).findItemByNameAuthor(_WEIGHTS_DIR_NAME, _MLPACKAGE_AUTHOR_NAME) - if item_info is not None: - weights_dir = item_info.path() - except: - pass - return weights_dir +class MLState: + def __init__(self, proxy): + """ + Holds state for an MLModel. + + This is an opaque object. Nothing can be done with it except pass it to MLModel.predict. + + See Also + -------- + ct.MLModel.predict + """ + self.__proxy__ = proxy class MLModel: @@ -167,7 +162,7 @@ class MLModel: - Model parameters: The set of parameters required to represent a specific instance of the model. - Metadata: Information about the origin, license, and author of the model. - With this class, you can inspect a CoreML model, modify metadata, and make + With this class, you can inspect a Core ML model, modify metadata, and make predictions for the purposes of testing (on select platforms). Examples @@ -210,6 +205,9 @@ class MLModel: # if model type is mlprogram, i.e. spec.WhichOneof('Type') == "mlProgram", then: model = MLModel(spec, weights_dir=model.weights_dir) + # Load a non-default function from a multifunction .mlpackage + model = MLModel("MultifunctionModel.mlpackage", function_name="deep_features") + See Also -------- predict @@ -223,6 +221,7 @@ def __init__( skip_model_load=False, compute_units=_ComputeUnit.ALL, weights_dir=None, + function_name=None, ): """ Construct an MLModel from an ``.mlmodel``. @@ -250,7 +249,7 @@ def __init__( mil_program: coremltools.converters.mil.Program Set to the MIL program object, if available. It is available whenever an MLModel object is constructed using - the unified converter API `coremltools.convert() `_. + the unified converter API `coremltools.convert() `_. skip_model_load: bool Set to ``True`` to prevent Core ML Tools from calling into the Core ML framework @@ -279,6 +278,10 @@ def __init__( Path to the weight directory, required when loading an MLModel of type ``mlprogram``, from a spec object, such as when the argument ``model`` is of type ``Model_pb2``. + function_name : str + The name of the function from ``model`` to load. + If not provided, ``function_name`` will be set to the ``defaultFunctionName`` in the proto. + Notes ----- Internally this maintains the following: @@ -296,8 +299,8 @@ def __init__( -------- .. sourcecode:: python - loaded_model = MLModel('my_model.mlmodel') - loaded_model = MLModel("my_model.mlpackage") + loaded_model = MLModel("my_model.mlmodel") + loaded_model = MLModel("my_model.mlpackage") """ @@ -340,6 +343,7 @@ def does_model_contain_mlprogram(model) -> bool: 'coremltools.ComputeUnit.CPU_AND_NE is only available on macOS >= 13.0' ) self.compute_unit = compute_units + self.function_name = function_name self.is_package = False self.is_temp_package = False @@ -395,12 +399,24 @@ def does_model_contain_mlprogram(model) -> bool: if self.is_package and self.is_temp_package: _atexit.register(cleanup, self.package_path) + # If function_name is not passed, self.function_name defaults to defaultFunctionName in the proto. + default_function_name = self._spec.description.defaultFunctionName + if self.function_name is None and len(default_function_name) > 0: + self.function_name = default_function_name + + if self.function_name is not None: + if not self._is_multifunction() and self.function_name != "main": + raise ValueError('function_name must be "main" for non multifunction model') - def _get_proxy_and_spec(self, - filename: str, - compute_units: _ComputeUnit, - skip_model_load: _Optional[bool] = False): + # Updated self._model_input_names_set based on self.function_name. + # self._model_input_names_set defines the allowed input keys for the data dictionary passed to self.predict(). + if self.function_name is not None and self._is_multifunction(): + f = self._get_function_description(self.function_name) + self._model_input_names_set = set([i.name for i in f.input]) + def _get_proxy_and_spec( + self, filename: str, compute_units: _ComputeUnit, skip_model_load: _Optional[bool] = False + ): filename = _os.path.expanduser(filename) specification = _load_spec(filename) @@ -413,8 +429,14 @@ def _get_proxy_and_spec(self, # version of the engine can support so we'll not try to have a proxy object return None, specification, None + function_name = "" if self.function_name is None else self.function_name + try: - return _MLModelProxy(filename, compute_units.name), specification, None + return ( + _MLModelProxy(filename, compute_units.name, function_name), + specification, + None, + ) except RuntimeError as e: _warnings.warn( "You will not be able to run predict() on this Core ML model." @@ -494,8 +516,8 @@ def save(self, save_path: str): -------- .. sourcecode:: python - model.save('my_model_file.mlmodel') - loaded_model = MLModel('my_model_file.mlmodel') + model.save("my_model_file.mlmodel") + loaded_model = MLModel("my_model_file.mlmodel") """ save_path = _os.path.expanduser(save_path) @@ -572,13 +594,13 @@ def get_spec(self): -------- .. sourcecode:: python - spec = model.get_spec() + spec = model.get_spec() """ return _deepcopy(self._spec) - def predict(self, data): + def predict(self, data, state: _Optional[MLState] = None): """ Return predictions for the model. @@ -591,6 +613,9 @@ def predict(self, data): The following dictionary values types are acceptable: list, array, numpy.ndarray, tensorflow.Tensor and torch.Tensor. + state : MLState + Optional state object as returned by ``make_state()``. + Returns ------- dict[str, value] @@ -603,12 +628,14 @@ def predict(self, data): -------- .. sourcecode:: python - data = {'bedroom': 1.0, 'bath': 1.0, 'size': 1240} - predictions = model.predict(data) + data = {"bedroom": 1.0, "bath": 1.0, "size": 1240} + predictions = model.predict(data) - data = [ {'bedroom': 1.0, 'bath': 1.0, 'size': 1240}, - {'bedroom': 4.0, 'bath': 2.5, 'size': 2400} ] - batch_predictions = model.predict(data) + data = [ + {"bedroom": 1.0, "bath": 1.0, "size": 1240}, + {"bedroom": 4.0, "bath": 2.5, "size": 2400}, + ] + batch_predictions = model.predict(data) """ def verify_and_convert_input_dict(d): @@ -624,7 +651,10 @@ def verify_and_convert_input_dict(d): MLModel._check_predict_data(data) if self.__proxy__: - return MLModel._get_predictions(self.__proxy__, verify_and_convert_input_dict, data) + return self._get_predictions(self.__proxy__, + verify_and_convert_input_dict, + data, + state) else: # Error case if _macos_version() < (10, 13): raise Exception( @@ -655,6 +685,7 @@ def verify_and_convert_input_dict(d): else: raise Exception("Unable to load CoreML.framework. Cannot make predictions.") + @staticmethod def _check_predict_data(data): if type(data) not in (list, dict): @@ -664,16 +695,65 @@ def _check_predict_data(data): @staticmethod - def _get_predictions(proxy, preprocess_method, data): + def _get_predictions(proxy, preprocess_method, data, state): if type(data) == dict: preprocess_method(data) - return proxy.predict(data) + state = None if state is None else state.__proxy__ + return proxy.predict(data, state) else: assert type(data) == list + assert state is None, "State can only be used for unbatched predictions" for i in data: preprocess_method(i) return proxy.batchPredict(data) + def _is_stateful(self) -> bool: + model_desc = self._spec.description + + # For a single function model, we check if len(state) > 0 + if len(model_desc.functions) == 0: + return len(model_desc.state) > 0 + + # For a multifunction model, we first get the corresponding function description, + # and check the state field. + f = list(filter(lambda f: f.name == self.function_name, model_desc.functions)) + return len(f.state) > 0 + + def _is_multifunction(self) -> bool: + return len(self._spec.description.functions) > 0 + + def _get_function_description(self, function_name: str) -> _proto.Model_pb2.FunctionDescription: + f = list(filter(lambda f: f.name == function_name, self._spec.description.functions)) + + if len(f) == 0: + raise ValueError(f"function_name {function_name} not found in the model.") + + assert len(f) == 1, f"Invalid proto: two functions with the same name {function_name}." + + return f[0] + + def make_state(self) -> MLState: + """ + Returns a new state object, which can be passed to the ``predict`` method. + + State functionality is only supported on macOS 15+ + + Examples + -------- + .. sourcecode:: python + + state = model.make_state() + predictions = model.predict(x, state) + + See Also + -------- + predict + """ + if not _is_macos() or _macos_version() < (15, 0): + raise Exception("State functionality is only supported on macOS 15+") + + return MLState(self.__proxy__.newState()) + def _input_has_infinite_upper_bound(self) -> bool: """Check if any input has infinite upper bound (-1).""" @@ -714,7 +794,7 @@ def _get_mil_internal(self): """ Get a deep copy of the MIL program object, if available. It's available whenever an MLModel object is constructed using - the unified converter API [`coremltools.convert()`](https://apple.github.io/coremltools/source/coremltools.converters.mil.html#coremltools.converters._converters_entry.convert). + the unified converter API [``coremltools.convert()``](https://apple.github.io/coremltools/source/coremltools.converters.mil.html#coremltools.converters._converters_entry.convert). Returns ------- @@ -724,7 +804,7 @@ def _get_mil_internal(self): -------- .. sourcecode:: python - mil_prog = model._get_mil_internal() + mil_prog = model._get_mil_internal() """ return _deepcopy(self._mil_program) diff --git a/coremltools/models/neural_network/flexible_shape_utils.py b/coremltools/models/neural_network/flexible_shape_utils.py index fbb8d1f03..7ddad70db 100644 --- a/coremltools/models/neural_network/flexible_shape_utils.py +++ b/coremltools/models/neural_network/flexible_shape_utils.py @@ -4,7 +4,7 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause """ -Utilities to annotate Neural Network Features with flexible shape information. +Utilities to annotate neural network features with flexible shape information. """ from typing import List as _List @@ -43,8 +43,8 @@ def __init__(self, size_value): class NeuralNetworkMultiArrayShape: """ An object representing a shape for a multiArray feature in a - neural network. Valid shapes must have have only the Channel [C] - shape or the Channel, Height and Width [C, H, W] shapes populated + neural network. Valid shapes must have have only the Channel ``[C]`` + shape or the Channel, Height and Width ``[C, H, W]`` shapes populated """ def __init__(self, channel=None, height=None, width=None): @@ -162,8 +162,8 @@ def isFlexible(self): class NeuralNetworkMultiArrayShapeRange: """ An object representing a range of shapes for a multiArray feature in a - neural network. Valid shape ranges must have have only the Channel [C] - range or the Channel, Height and Width [C, H, W] ranges populated. A "-1" + neural network. Valid shape ranges must have have only the Channel ``[C]`` + range or the Channel, Height and Width ``[C, H, W]`` ranges populated. A ``-1`` value in an upper bound represents an unbounded range. """ @@ -252,7 +252,7 @@ def isFlexible(self): class NeuralNetworkImageSizeRange: """ An object representing a range of sizes for an image feature inside a - neural network. Valid ranges for height and width are > 0. A "-1" + neural network. Valid ranges for height and width are > 0. A ``-1`` upper bound value for either width or height represents an unbounded size for that dimension. """ @@ -446,20 +446,20 @@ def _add_enumerated_image_sizes_for_feature( def add_enumerated_multiarray_shapes(spec, feature_name, shapes): """ - Annotate an input or output multiArray feature in a Neural Network spec to - to accommodate a list of enumerated array shapes + Annotate an input or output multiArray feature in a neural network spec to + to accommodate a list of enumerated array shapes. :param spec: MLModel - The MLModel spec containing the feature + The MLModel spec containing the feature. :param feature_name: str The name of the image feature for which to add shape information. If the feature is not found in the input or output descriptions then - an exception is thrown + an exception is thrown. :param shapes: [] | NeuralNetworkMultiArrayShape A single or a list of NeuralNetworkImageSize objects which encode valid - size information for a image feature + size information for a image feature. Examples -------- @@ -533,20 +533,20 @@ def add_enumerated_multiarray_shapes(spec, feature_name, shapes): def add_enumerated_image_sizes(spec, feature_name, sizes): """ - Annotate an input or output image feature in a Neural Network spec to - to accommodate a list of enumerated image sizes + Annotate an input or output image feature in a neural network spec to + to accommodate a list of enumerated image sizes. :param spec: MLModel - The MLModel spec containing the feature + The MLModel spec containing the feature. :param feature_name: str The name of the image feature for which to add size information. If the feature is not found in the input or output descriptions then - an exception is thrown + an exception is thrown. :param sizes: [] | NeuralNetworkImageSize A single or a list of NeuralNetworkImageSize objects which encode valid - size information for a image feature + size information for a image feature. Examples -------- @@ -603,16 +603,16 @@ def add_enumerated_image_sizes(spec, feature_name, sizes): def update_image_size_range(spec, feature_name, size_range): """ - Annotate an input or output Image feature in a Neural Network spec to - to accommodate a range of image sizes + Annotate an input or output Image feature in a neural network spec to + to accommodate a range of image sizes. :param spec: MLModel - The MLModel spec containing the feature + The MLModel spec containing the feature. :param feature_name: str The name of the Image feature for which to add shape information. If the feature is not found in the input or output descriptions then - an exception is thrown + an exception is thrown. :param size_range: NeuralNetworkImageSizeRange A NeuralNetworkImageSizeRange object with the populated image size @@ -647,22 +647,22 @@ def update_image_size_range(spec, feature_name, size_range): def update_multiarray_shape_range(spec, feature_name, shape_range): """ - Annotate an input or output MLMultiArray feature in a Neural Network spec - to accommodate a range of shapes + Annotate an input or output MLMultiArray feature in a neural network spec + to accommodate a range of shapes. :param spec: MLModel - The MLModel spec containing the feature + The MLModel spec containing the feature. :param feature_name: str The name of the feature for which to add shape range information. If the feature is not found in the input or output - descriptions then an exception is thrown + descriptions then an exception is thrown. :param shape_range: NeuralNetworkMultiArrayShapeRange A NeuralNetworkMultiArrayShapeRange object with the populated shape range information. The shape_range object must either contain only shape information for channel or channel, height and width. If - the object is invalid then an exception is thrown + the object is invalid then an exception is thrown. Examples -------- @@ -681,7 +681,7 @@ def update_multiarray_shape_range(spec, feature_name, shape_range): ) :return: - None. The spec is updated + None. The spec is updated. """ if not isinstance(shape_range, NeuralNetworkMultiArrayShapeRange): raise Exception("Shape range should be of type MultiArrayShapeRange") @@ -718,27 +718,27 @@ def update_multiarray_shape_range(spec, feature_name, shape_range): def set_multiarray_ndshape_range(spec, feature_name, lower_bounds, upper_bounds): """ - Annotate an input or output MLMultiArray feature in a Neural Network spec + Annotate an input or output MLMultiArray feature in a neural network spec to accommodate a range of shapes. - This is different from "update_multiarray_shape_range", which works with rank 5 + This is different from ``update_multiarray_shape_range``, which works with rank 5 SBCHW mapping. :param spec: MLModel - The MLModel spec containing the feature + The MLModel spec containing the feature. :param feature_name: str The name of the feature for which to add shape range information. If the feature is not found in the input or output - descriptions then an exception is thrown + descriptions then an exception is thrown. :param lower_bounds: List[int] list of integers specifying the lower bounds of each dimension. - Length must be same as the rank (length of shape) of the feature_name. + Length must be same as the rank (length of shape) of the ``feature_name``. :param upper_bounds: List[int] list of integers specifying the upper bounds of each dimension. - -1 corresponds to unbounded range. - Length must be same as the rank (length of shape) of the feature_name. + ``-1`` corresponds to unbounded range. + Length must be same as the rank (length of shape) of the ``feature_name``. Examples @@ -758,7 +758,7 @@ def set_multiarray_ndshape_range(spec, feature_name, lower_bounds, upper_bounds) ) :return: - None. The spec is updated + None. The spec is updated. """ feature = _get_feature(spec, feature_name) _set_multiarray_ndshape_range_for_feature(feature, lower_bounds, upper_bounds) @@ -770,11 +770,13 @@ def set_multiarray_ndshape_range(spec, feature_name, lower_bounds, upper_bounds) def add_multiarray_ndshape_enumeration(spec, feature_name, enumerated_shapes): """ - Annotate an input or output MLMultiArray feature in a Neural Network spec + Annotate an input or output MLMultiArray feature in a neural network spec to accommodate a range of shapes. + Add provided enumerated shapes to the list of shapes already present. - This method is different from "add_enumerated_multiarray_shapes", which is applicable - for rank 5 mapping, SBCHW, arrays. + + This method is different from ``add_enumerated_multiarray_shapes``, which is applicable + for rank 5 mapping, SBCHW, and arrays. :param spec: MLModel The MLModel spec containing the feature @@ -802,7 +804,7 @@ def add_multiarray_ndshape_enumeration(spec, feature_name, enumerated_shapes): ) :return: - None. The spec is updated + None. The spec is updated. """ feature = _get_feature(spec, feature_name) _add_multiarray_ndshape_enumeration_for_feature(feature, enumerated_shapes) diff --git a/coremltools/models/utils.py b/coremltools/models/utils.py index 851598c16..97408611e 100644 --- a/coremltools/models/utils.py +++ b/coremltools/models/utils.py @@ -7,6 +7,7 @@ Utilities for the entire package. """ +import copy as _copy import math as _math import os as _os import shutil as _shutil @@ -16,15 +17,27 @@ import warnings as _warnings from collections.abc import Iterable as _Iterable from functools import lru_cache as _lru_cache +from typing import Callable as _Callable +from typing import Dict as _Dict from typing import Optional as _Optional +from typing import Tuple as _Tuple from typing import Union as _Union import numpy as _np import coremltools as _ct +from coremltools import _SPECIFICATION_VERSION_IOS_16, _SPECIFICATION_VERSION_IOS_18 from coremltools import ComputeUnit as _ComputeUnit from coremltools import proto as _proto +from coremltools.converters.mil import mil as _mil +from coremltools.converters.mil.frontend.milproto import load as _milproto_to_pymil +from coremltools.converters.mil.mil import Program as _Program from coremltools.converters.mil.mil.passes.defs.preprocess import NameSanitizer as _NameSanitizer +from coremltools.converters.mil.mil.passes.defs.randomize import ( + WeightRandomizer as _WeightRandomizer, +) +from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass as _AbstractGraphPass +from coremltools.converters.mil.mil.passes.pass_registry import PASS_REGISTRY as _PASS_REGISTRY from .._deps import _HAS_SCIPY @@ -67,13 +80,21 @@ def _create_mlpackage( package_path: _Optional[str] = None, ) -> str: """ - Args: - proto_spec: The proto spec of the model. - weights_dir: Copy weights from this path to the mlpackage. - package_path: Place the created mlpackage at this path. Error out if this path is a non-empty directory. - Returns: - path to the mlpackage + Parameters + ---------- + proto_spec + The proto spec of the model. + + weights_dir + Copy weights from this path to the ``mlpackage``. + + package_path + Place the created ``mlpackage`` at this path. Error out if this path is a non-empty directory. + + Returns + ------- + path to the ``mlpackage``. """ if package_path is None: package_path = _tempfile.mkdtemp(suffix=_MLPACKAGE_EXTENSION) @@ -121,17 +142,17 @@ def save_spec(spec, filename, auto_set_specification_version=False, weights_dir= Parameters ---------- spec: Model_pb - Protobuf representation of the model + Protobuf representation of the model. filename: str - File path where the spec gets saved. + File path where the spec is saved. auto_set_specification_version: bool - If True, will always try to set specification version automatically. + If ``True``, will always try to set specification version automatically. weights_dir: str Path to the directory containing the weights.bin file. This is required - when the spec has model type mlprogram. If the mlprogram does not contain + when the spec has model type ``mlprogram``. If the ``mlprogram`` does not contain any weights, this path can be an empty directory. Examples @@ -192,7 +213,7 @@ def save_spec(spec, filename, auto_set_specification_version=False, weights_dir= def load_spec(model_path: str) -> _proto.Model_pb2: """ - Load a protobuf model specification from file (mlmodel) or directory (mlpackage). + Load a protobuf model specification from file (``mlmodel``) or directory (``mlpackage``). Parameters ---------- @@ -201,7 +222,7 @@ def load_spec(model_path: str) -> _proto.Model_pb2: Returns ------- model_spec: Model_pb - Protobuf representation of the model + Protobuf representation of the model. Examples -------- @@ -239,7 +260,7 @@ def _get_nn_layers(spec): Returns ------- [NN layer] - list of all layers (including layers from elements of a pipeline + list of all layers (including layers from elements of a pipeline). """ @@ -314,21 +335,21 @@ def _convert_neural_network_spec_weights_to_fp16(fp_spec): def _convert_neural_network_weights_to_fp16(full_precision_model): """ - Utility function to convert a full precision (float) MLModel to a - half precision MLModel (float16). + Utility function to convert a full-precision (float) MLModel to a + half-precision MLModel (float16). Parameters ---------- full_precision_model: MLModel Model which will be converted to half precision. Currently conversion for only neural network models is supported. If a pipeline model is - passed in then all embedded neural network models embedded within + passed in, then all embedded neural network models embedded within will be converted. Returns ------- model: MLModel - The converted half precision MLModel + The converted half precision MLModel. """ spec = full_precision_model.get_spec() @@ -348,19 +369,19 @@ def _get_model(spec, compute_units=_ComputeUnit.ALL): def evaluate_regressor(model, data, target="target", verbose=False): """ - Evaluate a CoreML regression model and compare against predictions + Evaluate a Core ML regression model and compare against predictions from the original framework (for testing correctness of conversion). Parameters ---------- model: MLModel or str - A loaded MLModel or a path to a saved MLModel + A loaded MLModel or a path to a saved MLModel. data: Dataframe - Test data on which to evaluate the models + Test data on which to evaluate the models. target: str - Name of the column in the dataframe to be compared against the prediction + Name of the column in the dataframe to be compared against the prediction. verbose: bool Set to true for a more verbose output. @@ -421,18 +442,18 @@ def evaluate_classifier(model, data, target="target", verbose=False): Parameters ---------- filename: list of str or list of MLModel - File from where to load the model from (OR) a loaded + File to load the model from, or a loaded version of the MLModel. data: list of str or list of Dataframe Test data on which to evaluate the models (dataframe, - or path to a csv file). + or path to a CSV file). target: str - Column to interpret as the target column + Column to interpret as the target column. verbose: bool - Set to true for a more verbose output. + Set to true for more verbose output. See Also -------- @@ -483,15 +504,15 @@ def evaluate_classifier_with_probabilities( Parameters ---------- filename: [str | Model] - File from where to load the model from (OR) a loaded + File to load the model from, or a loaded version of the MLModel. data: [str | Dataframe] Test data on which to evaluate the models (dataframe, - or path to a csv file). + or path to a CSV file). probabilities: str - Column to interpret as the probabilities column + Column to interpret as the probabilities column. verbose: bool Verbosity levels of the predictions. @@ -561,12 +582,12 @@ def rename_feature( New name of the feature. rename_inputs: bool - Search for `current_name` only in the input features (i.e ignore output - features) + Search for ``current_name`` only in the input features (that is, ignore output + features). rename_outputs: bool - Search for `current_name` only in the output features (i.e ignore input - features) + Search for ``current_name`` only in the output features (that is, ignore input + features). Examples -------- @@ -752,8 +773,8 @@ def evaluate_transformer(model, input_data, reference_output, verbose=False): Parameters ---------- spec: list of str or list of MLModel - File from where to load the Model from (OR) a loaded - version of MLModel. + File to load the Model from, or a loaded + version of the MLModel. input_data: list of dict Test data on which to evaluate the models. @@ -825,7 +846,7 @@ def _has_custom_layer(spec): Returns ------- - True if the protobuf specification contains a neural network with a custom layer, False otherwise. + ``True`` if the protobuf specification contains a neural network with a custom layer, ``False`` otherwise. """ @@ -840,7 +861,7 @@ def _has_custom_layer(spec): def _get_custom_layer_names(spec): """ - Returns a list of className fields which appear in the given protobuf spec + Returns a list of ``className`` fields which appear in the given protobuf spec. Parameters ---------- @@ -848,8 +869,8 @@ def _get_custom_layer_names(spec): Returns ------- - - set(str) A set of unique className fields of custom layers that appear in the model. + set(str) + A set of unique ``className`` fields of custom layers that appear in the model. """ layers = _get_nn_layers(spec) @@ -872,8 +893,8 @@ def _get_custom_layers(spec): Returns ------- - - [NN layer] A list of custom layer implementations + [NN layer] + A list of custom layer implementations. """ layers = _get_nn_layers(spec) layers_out = [] @@ -887,20 +908,21 @@ def _get_custom_layers(spec): def _replace_custom_layer_name(spec, oldname, newname): """ - Substitutes newname for oldname in the className field of custom layers. If there are no custom layers, or no - layers with className=oldname, then the spec is unchanged. + Substitutes ``newname`` for ``oldname`` in the ``className`` field of custom layers. If there are no custom layers, or no + layers with ``className`` = ``oldname``, then the spec is unchanged. Parameters ---------- spec: mlmodel spec - oldname: str The custom layer className to be replaced. + oldname: str + The custom layer ``className`` to be replaced. - newname: str The new className value to replace oldname + newname: str + The new ``className`` value to replace ``oldname``. Returns ------- - An mlmodel spec. """ @@ -964,12 +986,12 @@ def _get_input_names(spec): def convert_double_to_float_multiarray_type(spec): """ Convert all double multiarrays feature descriptions (input, output, training input) - to float multiarrays + to float multiarrays. Parameters ---------- spec: Model_pb - The specification containing the multiarrays types to convert + The specification containing the multiarrays types to convert. Examples -------- @@ -1009,7 +1031,7 @@ def compile_model(model: _proto.Model_pb2.Model, destination_path: _Optional[str model: Model_pb2 Spec/protobuf to compile. - Note: an mlprogam which uses a blob file is not supported. + Note: an ``mlprogam`` which uses a blob file is not supported. destination_path: str Path where the compiled model will be saved. @@ -1018,7 +1040,7 @@ def compile_model(model: _proto.Model_pb2.Model, destination_path: _Optional[str ------- str : Path to compiled model directory - If the destination_path is specified, that is the value that will be returned. + If the ``destination_path`` is specified, that is the value that will be returned. Examples -------- @@ -1238,3 +1260,418 @@ def updateBlobFileName(proto_message, new_path): _shutil.copyfile(weight_file_path, dst + f"/{i}-weight.bin") return _ct.models.MLModel(pipeline_spec, compute_units=compute_units, weights_dir=dst) + + +def _convert_model_spec_to_pymil_prog( + mlmodel: "_ct.models.MLModel", + specification_version: int, + pymil_load_func: _Callable, + skip_model_load: bool = False, +) -> _Program: + """ + A utility that converts an ``mlprogram`` model into PyMIL program. + """ + model_spec = mlmodel.get_spec() + model_type = model_spec.WhichOneof("Type") + if model_type in ( + "neuralNetwork", + "neuralNetworkClassifier", + "neuralNetworkRegressor", + "pipeline", + "PipelineClassifier", + "PipelineRegressor", + ): + msg = ( + "coremltools.optimize.coreml are meant to be used only with mlprogram typed coreml models. " + "This model has type {}. Please use coremltools.models.neural_network.quantization_utils.quantize_weights" + "instead to compress the weights of the model." + ) + raise TypeError(msg.format(model_type)) + elif model_type == "mlProgram": + pass + else: + raise TypeError("weight compression not applicable for model type {}".format(model_type)) + + prog = pymil_load_func( + model_spec=model_spec, + specification_version=specification_version, + file_weights_dir=mlmodel.weights_dir, + skip_model_load=skip_model_load, + ) + return prog + + +def _apply_graph_pass( + mlmodel: "_ct.models.MLModel", + graph_pass: _AbstractGraphPass, + spec_version: int = _SPECIFICATION_VERSION_IOS_16, + skip_model_load: bool = False, + pymil_load_func: _Callable = _milproto_to_pymil.load, + return_pymil_prog: bool = False, +) -> _Union["_ct.models.MLModel", _Program]: + # We do the lazy import to prevent circular import + from coremltools.converters.mil.converter import mil_convert as _mil_convert + + # Utility function which compresses a Core ML model + # Converts the full precision mlmodel into a pymil program + model_spec = mlmodel.get_spec() + specification_version = max(model_spec.specificationVersion, spec_version) + prog = _convert_model_spec_to_pymil_prog( + mlmodel, specification_version, pymil_load_func, skip_model_load + ) + + # Apply graph pass. + print(type(graph_pass)) + assert isinstance( + graph_pass, _AbstractGraphPass + ), "graph pass must be an AbstractGraphPass instance" + graph_pass.apply(prog) + + # An early return can prevent running all other optimization paths triggered by _mil_convert. + if return_pymil_prog: + return prog + + # Convert the pymil program back to mlmodel + compressed_mlmodel = _mil_convert( + prog, + convert_to="mlprogram", + convert_from="milinternal", + specification_version=specification_version, + compute_units=mlmodel.compute_unit, + model_description=model_spec.description, + skip_model_load=skip_model_load, + ) + return compressed_mlmodel + + +def _try_get_weights_dir_path(mlpackage_path): + """ + Try to find the weights in mlpackage and return the path to the weights directory if found. + Return None if not found. + :param mlpackage_path: str, path to the mlpackage directory + :return: path to the weights directory inside the mlpackage directory + """ + weights_dir = None + try: + if _ModelPackage.isValid(mlpackage_path): + item_info = _ModelPackage(mlpackage_path).findItemByNameAuthor( + _WEIGHTS_DIR_NAME, _MLPACKAGE_AUTHOR_NAME + ) + if item_info is not None: + weights_dir = item_info.path() + except: + pass + return weights_dir + + +class MultiFunctionDescriptor: + """ + The data class defines how to construct a multifunction model from different model sources. + The users can use the ``add_function`` method to specify the path to the source ``mlpackage``, + along with the source and target function names. + + After setting the ``default_function_name`` to the ``MultiFunctionDescriptor`` instance, + a multifunction model can be exported using the ``save_multifunction`` method. + + Examples + -------- + .. sourcecode:: python + + from coremltools.utils import MultiFunctionDescriptor, save_multifunction + + # Initialize a MultiFunctionDescriptor instance with functions in an existing mlpackage. + # desc will constain all functions in "my_model.mlpackage" + desc = MultiFunctionDescriptor("my_model.mlpackage") + + # Construct a MultiFunctionDescriptor instance from scratch. + # The below code inserts "main" function from "my_model.mlpackage" as "main_1", + # and inserts "main" function from "my_model_2.mlpackage" as "main_2". + desc = MultiFunctionDescriptor() + desc.add_function( + model_path="my_model.mlpackage", + source_function_name="main", + target_function_name="main_1", + ) + desc.add_function( + model_path="my_model_2.mlpackage", + source_function_name="main", + target_function_name="main_2", + ) + + # Each MultiFunctionDescriptor must has a default function name before saved + # as a multifunction mlpackage on disk. + desc.default_function_name = "main_1" + save_multifunction(desc, "my_multifunction_model.mlpackage") + + See Also + -------- + save_multifunction + + """ + + def __init__(self, model_path: _Optional[str] = None): + """ + If ``model_path`` is passed to the constructor, it must be a str pointing to a + mlpackage on disk. The MultiFunctionDescriptor instance will be initiated + with functions in ``model_path``. + """ + self._default_function_name = None + self._name_to_source_function = {} + self._modelpath_to_functions = {} + self._modelpath_to_spec = {} + + if model_path is not None: + self.add_model(model_path) + + def _functions(self) -> _Dict[str, _Tuple[str, str]]: + """ + Returns ``self._name_to_source_function`` + """ + return _copy.copy(self._name_to_source_function) + + def _add_modelpath_to_cache(self, model_path: str) -> None: + """ + Given a mlpackage path ``model_path``, the utils caches related metadata. + """ + if model_path in self._modelpath_to_functions: + return + + try: + spec = load_spec(model_path) + except Exception as err: + raise ValueError(f"invalid model_path {model_path} with error {err} while loading.") + + desc = spec.description + + # for the iOS17 and below protobuf, there were no functions field, + # in which "main" is the only function associated with the model. + if len(desc.functions) == 0: + self._modelpath_to_functions[model_path] = ["main"] + else: + self._modelpath_to_functions[model_path] = [func.name for func in desc.functions] + self._modelpath_to_spec[model_path] = spec + + @property + def default_function_name(self) -> _Union[str, None]: + return self._default_function_name + + @default_function_name.setter + def default_function_name(self, val: str) -> None: + if not isinstance(val, str): + raise ValueError(f"default_function_name must be type of str. Got {val}.") + self._default_function_name = val + + def add_function( + self, model_path: str, src_function_name: str, target_function_name: str + ) -> None: + """ + Insert ``src_function_name`` function from ``model_path`` as ``target_function_name`` + function in the multifunction descriptor. + """ + self._add_modelpath_to_cache(model_path) + + if src_function_name not in self._modelpath_to_functions[model_path]: + raise ValueError(f"src_function_name {src_function_name} not found in {model_path}.") + + if target_function_name in self._name_to_source_function: + raise ValueError(f"function {target_function_name} already exist.") + + self._name_to_source_function[target_function_name] = (model_path, src_function_name) + + def add_model(self, model_path: str) -> None: + """ + Insert all functions in ``model_path`` into the multifunction descriptor. + Same function names in the original model will be applied. + """ + self._add_modelpath_to_cache(model_path) + + for func_name in self._modelpath_to_functions[model_path]: + self.add_function(model_path, func_name, func_name) + + def remove_function(self, function_name: str) -> None: + """ + Remove function ``function_name`` from the multifunction descriptor. + """ + if function_name not in self._name_to_source_function: + raise ValueError(f"function_name {function_name} not found.") + del self._name_to_source_function[function_name] + + +def save_multifunction( + desc: MultiFunctionDescriptor, + destination_path: str, +): + """ + Save a MultiFunctionDescriptor instance into a multifunction ``mlpackage``. + The utility also performs constant deduplication across functions to allow weight sharing. + + Parameters + ---------- + desc : MultiFunctionDescriptor + Multifunction descriptor to save on the disk. + + destination_path : str + The saved ``mlpackage`` model path. + + Examples + -------- + .. sourcecode:: python + + from coremltools.utils import MultiFunctionDescriptor, save_multifunction + + desc = MultiFunctionDescriptor("my_model_1.mlpackage") + desc.add_function("my_model_2.mlpackage", "main", "main_2") + desc.default_function_name = "main_2" + + save_multifunction(desc, "multifunctino_model.mlpackage") + + See Also + -------- + MultiFunctionDescriptor + + """ + # We do the lazy import to prevent circular import + from coremltools.converters.mil.converter import mil_convert as _mil_convert + + def get_function_spec( + spec: _proto.Model_pb2, func_name: str + ) -> _proto.Model_pb2.FunctionDescription: + """ + Utils to construct a FunctionDescription from the source spec. + """ + model_desc = spec.description + # For single function model, we construct the FunctionDescription ourselves + if len(model_desc.functions) == 0: + assert func_name == "main", f"invalid function name {func_name}" + return _proto.Model_pb2.FunctionDescription( + input=model_desc.input, + output=model_desc.output, + state=model_desc.state, + predictedFeatureName=model_desc.predictedFeatureName, + predictedProbabilitiesName=model_desc.predictedProbabilitiesName, + ) + # For multifunction model, we look for the corresponding FunctionDescription + for func_desc in model_desc.functions: + if func_desc.name != func_name: + continue + res = _proto.Model_pb2.FunctionDescription() + res.CopyFrom(func_desc) + res.name = "" + return res + + # compile model information: spec / weight_dir + modelpath_to_spec_and_weightdir = {} + for k, v in desc._name_to_source_function.items(): + model_path = v[0] + if model_path in modelpath_to_spec_and_weightdir: + continue + spec = desc._modelpath_to_spec[model_path] + weight_dir = _try_get_weights_dir_path(model_path) + if weight_dir is None: + raise ValueError(f"weight_dir for model_path {model_path} not found.") + modelpath_to_spec_and_weightdir[model_path] = (spec, weight_dir) + + # min spec version to support multi-functions model is iOS18 + # we also make the target spec version the max among the input models + spec_version = max( + map(lambda val: val[0].specificationVersion, modelpath_to_spec_and_weightdir.values()) + ) + spec_version = max(spec_version, _SPECIFICATION_VERSION_IOS_18) + + # convert spec into pymil program + modelpath_to_pymil = {} + for model_path, (spec, weight_dir) in modelpath_to_spec_and_weightdir.items(): + prog = _milproto_to_pymil.load( + spec, + spec_version, + weight_dir, + ) + modelpath_to_pymil[model_path] = prog + + # construct a multifunction pymil program + multifunction_prog = _mil.Program() + function_to_desc = {} + for target_func_name, v in desc._name_to_source_function.items(): + model_path = v[0] + src_func_name = v[1] + prog = modelpath_to_pymil[model_path] + multifunction_prog.add_function(target_func_name, prog.functions[src_func_name]) + + # get the corresponding function description from the spec + spec = modelpath_to_spec_and_weightdir[model_path][0] + function_spec = get_function_spec(spec, src_func_name) + assert function_spec.name == "", "function_spec should not have name set" + function_spec.name = target_func_name + function_to_desc[target_func_name] = function_spec + + # Here we deduplicate the same weights across functions, to allow consts to use + # the same blob file value when lowered into milproto. + # By weight sharing, we can make the model size as small as we could. + graph_pass = _PASS_REGISTRY["common::const_deduplication"] + graph_pass._deduplicate_const_across_functions(multifunction_prog) + + # set default function name + default_function_name = desc.default_function_name + if default_function_name is None: + raise ValueError( + "default_function_name must be set for the MultiFunctionDescriptor instance before calling save_multifunction." + ) + + if default_function_name not in multifunction_prog.functions: + raise ValueError( + f"default_function_name {default_function_name} not found in the program. Available functions names are {list(multifunction_prog.functions.keys())}" + ) + multifunction_prog.default_function_name = default_function_name + + # export program into multi-functions CoreML model + functions = [] + for func in multifunction_prog.functions: + functions.append(function_to_desc[func]) + model_description = _proto.Model_pb2.ModelDescription( + functions=functions, + defaultFunctionName=default_function_name, + ) + multifunction_prog.skip_all_passes = True + mlmodel = _mil_convert( + multifunction_prog, + convert_to="mlprogram", + convert_from="milinternal", + specification_version=spec_version, + compute_units=_ct.ComputeUnit.CPU_ONLY, + model_description=model_description, + export_multi_functions=True, + skip_model_load=True, + ) + mlmodel.save(destination_path) + + +def randomize_weights(mlmodel: "_ct.models.MLModel"): + """ + Utility function to randomize weights + + Parameters + ---------- + mlmodel: MLModel + Model which will be randomized. + + Returns + ------- + model: MLModel + The MLModel with randomized weights. + + Examples + -------- + .. sourcecode:: python + + import coremltools as ct + + model = ct.models.MLModel("my_model.mlpackage") + randomized_mlmodel = ct.models.utils.randomize_weights(mlmodel) + + """ + + randomized_mlmodel = _apply_graph_pass( + mlmodel, graph_pass=_WeightRandomizer(), skip_model_load=True + ) + + return randomized_mlmodel diff --git a/coremltools/optimize/__init__.py b/coremltools/optimize/__init__.py index 1f35b1199..25c7d28c5 100644 --- a/coremltools/optimize/__init__.py +++ b/coremltools/optimize/__init__.py @@ -2,5 +2,3 @@ # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause - -from . import coreml diff --git a/coremltools/optimize/coreml/__init__.py b/coremltools/optimize/coreml/__init__.py index 061ad56e8..a060992cf 100644 --- a/coremltools/optimize/coreml/__init__.py +++ b/coremltools/optimize/coreml/__init__.py @@ -19,3 +19,4 @@ palettize_weights, prune_weights, ) +from . import experimental \ No newline at end of file diff --git a/coremltools/optimize/coreml/_config.py b/coremltools/optimize/coreml/_config.py index 9835099f8..b5e027dfb 100644 --- a/coremltools/optimize/coreml/_config.py +++ b/coremltools/optimize/coreml/_config.py @@ -3,20 +3,31 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +from __future__ import annotations + import sys from abc import ABC, abstractmethod from collections import OrderedDict -from typing import IO, Any, Callable, Dict, Optional, Tuple, Union +from enum import Enum +from typing import IO, Any, Callable, Dict, List, Optional, Tuple, Union import cattrs import numpy as np import yaml from attrs import define, field, validators -from coremltools.converters.mil.mil import Operation, types +from coremltools.converters.mil.mil import operation, types from coremltools.converters.mil.mil.types.type_mapping import is_builtin, numpy_type_to_builtin_type +# TODO: Share the enum between cto.coreml and cto.torch (rdar://124409664). +class CompressionGranularity(Enum): + PER_TENSOR = 1 + PER_GROUPED_CHANNEL = 2 + PER_CHANNEL = 3 + PER_BLOCK = 4 + + class OpCompressorConfig(ABC): """ An abstract class for the compressor configuration @@ -41,10 +52,81 @@ def _check_weight_threshold(instance, attribute, value): if value is not None and value < 0: raise ValueError(f"\"weight_threshold\" must be a non-negative integer. Got {value}.") +def _normalize_dtype(dtype: Union[str, type]) -> type: + if isinstance(dtype, str): + try: + dtype = types.string_to_builtin(dtype) + except KeyError: + raise ValueError(f"Invalid dtype {dtype}. Only support int8/uint8/int4/uint4.") + elif np.issubdtype(dtype, np.integer): + dtype = types.numpy_type_to_builtin_type(dtype) + elif not types.is_builtin(dtype): + raise ValueError(f"dtype={dtype} is unsupported for OpLinearQuantizerConfig.") + return dtype + + """ Linear Quantization configuration """ + +def _normalize_granularity( + granularity: Union[str, CompressionGranularity] +) -> CompressionGranularity: + if isinstance(granularity, CompressionGranularity): + return granularity + + if granularity == "per_tensor": + return CompressionGranularity.PER_TENSOR + elif granularity == "per_grouped_channel": + return CompressionGranularity.PER_GROUPED_CHANNEL + elif granularity == "per_channel": + return CompressionGranularity.PER_CHANNEL + elif granularity == "per_block": + return CompressionGranularity.PER_BLOCK + else: + raise ValueError(f"Invalid granularity={granularity}") + + +def check_block_size(instance, attr, block_size): + """ + Validator for block_size. + + Note the `instance` and `attr` are not used but required by attrs interface. + """ + if block_size is not None: + if isinstance(block_size, int): + if block_size < 0: + raise ValueError( + f"The block_size must be non-negative values, but got {block_size}" + ) + elif isinstance(block_size, (list, tuple)): + for it_block_size in block_size: + if not isinstance(it_block_size, int) or it_block_size < 0: + raise ValueError("All values in block_size must be non-negative values.") + else: + raise ValueError( + f"The block_size should be int or list/tuple of int, but got {type(block_size)}." + ) + + +def _structure_block_size_type(block_size, dtype): + """ + The block_size's type Union[int, List[int], Tuple[int, ...]] need a custom structure hook + for attrs yaml conversion. + + Note the `dtype` parameter is not used but required by attrs interface. + """ + if isinstance(block_size, int): + return block_size + else: + if not isinstance(block_size, (list, tuple)): + raise ValueError( + f'"block_size" must be int or list/tuple of int. Got {type(block_size)}' + ) + return block_size + + @define class OpLinearQuantizerConfig(OpCompressorConfig): """ @@ -59,30 +141,97 @@ class OpLinearQuantizerConfig(OpCompressorConfig): * ``"linear"``: Input data are quantized in the range :math:`[min(w_r), max(w_r)]`. - dtype: np.generic or mil.type type - Determines the quantized data type (int8/uint8). + dtype: str or np.generic or mil.type + Determines the quantized data type (int8/uint8/int4/uint4). * The allowed values are: * ``np.int8`` (the default) * ``np.uint8`` * ``coremltools.converters.mil.mil.types.int8`` * ``coremltools.converters.mil.mil.types.uint8`` + * ``coremltools.converters.mil.mil.types.int4`` + * ``coremltools.converters.mil.mil.types.uint4`` + * strings to specify dtype such as "int4", "uint4", etc + + granularity: str + Granularity for quantization. + + * ``"per_tensor"`` + * ``"per_channel"`` (default) + * ``"per_block"`` + + block_size: int or List/Tuple of int + + * Only effective when granularity is set to "per_block". + * Determines size of the block, where all elements in a block share the same scale and zero_point. + * If it's int, the block size on each axis is auto determined for best performance. More specifially, + the block will have ``block_size`` on input axis and ``1`` on output axis, where input/output + axis is auto picked based on op type. + For example, if weight has shape [Cout, Cin], the block will have shape [1, block_size]; + If the weight has shape [C_out, C_in, KH, KW], the block will has shape [1, block_size, KH, KW]. + * If it's a tuple of int, it must have the same rank as the weight, which specify the block size on each axis. + * The value 0 means block size equal to dim size at the corresponding axis. + * If the dim size on any axis is not divisible by the corresponding block size, the op will be skipped. + + The tuple input of ``block_size`` provides users fully control about the block. + Here are some examples about how different granularities could be achieved: + + Given the weight of a 2D Conv which has shape [C_out, C_in, KH, KW]: + |------------------------|--------------------------|---------------------------|----------------------------| + | Granularity | output_channel_block_size| input_channel_block_size | Weight Shape of Each Block | + |------------------------|--------------------------|---------------------------|----------------------------| + | Per Tensor | 0 | 0 | [C_out, C_in, KH, KW] | + | Per Input Channel | 0 | 1 | [C_out, 1, KH, KW] | + | Per Output Channel | 1 | 0 | [1, C_in, KH, KW] | + | Per Block | 1 | 32 | [1, 32, KH, KW] | + |------------------------|--------------------------|---------------------------|----------------------------| + + Given the weight of a linear layer which has shape [C_out, C_in]: + |------------------------|--------------------------|---------------------------|----------------------------| + | Granularity | output_channel_block_size| input_channel_block_size | Weight Shape of Each Block | + |------------------------|--------------------------|---------------------------|----------------------------| + | Per Tensor | 0 | 0 | [C_out, C_in] | + | Per Input Channel | 0 | 1 | [C_out, 1] | + | Per Output Channel | 1 | 0 | [1, C_in] | + | Per Block | 1 | 32 | [1, 32] | + |------------------------|--------------------------|---------------------------|----------------------------| + + Given the weight of matmul's y (transpose_y=False) which has shape [..., C_in, C_out]: + |------------------------|--------------------------|---------------------------|----------------------------| + | Granularity | output_channel_block_size| input_channel_block_size | Weight Shape of Each Block | + |------------------------|--------------------------|---------------------------|----------------------------| + | Per Tensor | 0 | 0 | [..., C_in, C_out] | + | Per Input Channel | 0 | 1 | [..., 1, C_out] | + | Per Output Channel | 1 | 0 | [..., C_in, 1] | + | Per Block | 1 | 32 | [..., 32, 1] | + |------------------------|--------------------------|---------------------------|----------------------------| weight_threshold: int The size threshold, above which weights are pruned. That is, a weight tensor is pruned only if its total number of elements are greater than ``weight_threshold``. + Default to 2048. For example, if ``weight_threshold = 1024`` and a weight tensor is of shape ``[10, 20, 1, 1]``, hence ``200`` elements, it will not be pruned. - - * If not provided, it will be set to ``2048``, in which weights bigger than ``2048`` elements are compressed. """ mode: str = field(default="linear_symmetric", validator=validators.instance_of(str)) - dtype: type = field(default=np.int8, validator=validators.instance_of(type)) + dtype: Union[str, type] = field(default=types.int8, converter=_normalize_dtype) + granularity: Union[str, CompressionGranularity] = field( + default=CompressionGranularity.PER_CHANNEL, + validator=validators.instance_of(CompressionGranularity), + converter=_normalize_granularity, + ) + block_size: Union[int, List[int], Tuple[int, ...]] = field( + default=32, validator=check_block_size + ) weight_threshold: Optional[int] = field(default=2048, validator=validators.optional([validators.instance_of(int), _check_weight_threshold])) _WEIGHT_AFFINE_QUANTIZATION_MODES = ("LINEAR_SYMMETRIC", "LINEAR") - _WEIGHT_AFFINE_DTYPES = (types.int8, types.uint8) + _VALID_GRANULARITIES = ( + CompressionGranularity.PER_TENSOR, + CompressionGranularity.PER_CHANNEL, + CompressionGranularity.PER_BLOCK, + ) @mode.validator def check_mode(self, attr, mode): @@ -91,37 +240,39 @@ def check_mode(self, attr, mode): @dtype.validator def check_dtype(self, attr, dtype): - msg = f"dtype={dtype} is unsupported for affine_quantize_weights." - if not is_builtin(dtype): - try: - dtype = numpy_type_to_builtin_type(dtype) - except TypeError: - raise ValueError(msg) - - if dtype not in self._WEIGHT_AFFINE_DTYPES: - raise ValueError(msg) + if not types.is_builtin(dtype): + raise ValueError(f"Invalid dtype. Should be builtin dtype, but got {type(dtype)}") + if not (types.is_int(dtype) and dtype.get_bitwidth() in {4, 8}): + raise ValueError( + f"Invalid dtype. Should be int4/8 or uint4/8, but got {types.builtin_to_string(dtype)}" + ) + + @granularity.validator + def check_granularity(self, attr, granularity): + if granularity not in self._VALID_GRANULARITIES: + raise ValueError( + f'"granularity" must be one of {self._VALID_GRANULARITIES}, but got {granularity}' + ) def __attrs_post_init__(self): self.mode = self.mode.upper() if not is_builtin(self.dtype): self.dtype = numpy_type_to_builtin_type(self.dtype) - @classmethod - def _from_dict(cls, config_dict: Dict[str, Any]) -> "OpLinearQuantizerConfig": - def _structure_type(value, dtype): - if isinstance(value, type): - return value - else: - if not isinstance(value, str) or value not in ("int8", "uint8"): - raise ValueError( - f'"dtype" must be type of type or str ["int8", "uint8"]. Got {value}' - ) - return getattr(np, value) + # Set nbits and signed for backward compatibility with existing code. + if types.is_int(self.dtype): + self.nbits = self.dtype.get_bitwidth() + self.signed = not self.dtype.is_unsigned() + @classmethod + def _from_dict(cls, config_dict: Dict[str, Any]) -> OpLinearQuantizerConfig: converter = cattrs.Converter(forbid_extra_keys=True) - converter.register_structure_hook(type, _structure_type) + converter.register_structure_hook( + Union[int, List[int], Tuple[int, ...]], _structure_block_size_type + ) return converter.structure(config_dict, cls) + """ Pruner configurations """ @@ -428,7 +579,7 @@ class OpPalettizerConfig(OpCompressorConfig): nbits: int Number of bits per weight. Required for ``kmeans`` or ``uniform`` mode, but must not be set for ``unique`` or ``custom`` mode. A LUT would have - 2\ :sup:`nbits` entries, where `nbits` can be ``{1, 2, 4, 6, 8}``. + 2\ :sup:`nbits` entries, where `nbits` can be ``{1, 2, 3, 4, 6, 8}``. mode: str Determine how the LUT is constructed by specifying one of the following: @@ -514,6 +665,24 @@ def lut_function(weight): return lut, indices + granularity: str + Granularity for quantization. + * ``"per_tensor"`` (default) + * ``"per_grouped_channel"`` + + group_size: int + * Specify the number of channels in a group. Only effective when granularity is per_grouped_channel. + * Default to 32. + + channel_axis: Optional[int] = None + * Specify the channel axis to form a group of channels. Only effective when granularity is per_grouped_channel. + * Default to None, where the axis is automatically picked based on op type. + + num_kmeans_workers: int + * Number of worker processes to use for performing k-means. It is recommended to use more + than one worker process to parallelize the clustering, especially when multiple CPUs are available. + * Default to 1. + weight_threshold: int The size threshold, above which weights are pruned. That is, a weight tensor is pruned only if its total number of elements are greater than ``weight_threshold``. @@ -526,10 +695,22 @@ def lut_function(weight): mode: str = field(default="kmeans", validator=validators.instance_of(str)) nbits: Optional[int] = field(default=None) lut_function: Optional[Callable] = field(default=None) + granularity: Union[str, CompressionGranularity] = field( + default=CompressionGranularity.PER_TENSOR, + validator=validators.instance_of(CompressionGranularity), + converter=_normalize_granularity, + ) + group_size: int = field(default=32) + channel_axis: Optional[int] = field(default=None) + num_kmeans_workers: int = field(default=1, validator=validators.instance_of(int)) weight_threshold: Optional[int] = field(default=2048, validator=validators.optional([validators.instance_of(int), _check_weight_threshold])) _WEIGHT_PALETTIZATION_MODES = ("KMEANS", "UNIFORM", "UNIQUE", "CUSTOM") - _VALID_NBITS = (1, 2, 4, 6, 8) + _VALID_NBITS = (1, 2, 3, 4, 6, 8) + _VALID_GRANULARITIES = ( + CompressionGranularity.PER_TENSOR, + CompressionGranularity.PER_GROUPED_CHANNEL, + ) @nbits.validator def check_nbits(self, attr, nbits): @@ -551,7 +732,6 @@ def check_mode(self, attr, mode): if not mode.upper() in self._WEIGHT_PALETTIZATION_MODES: raise ValueError(f"Only modes {self._WEIGHT_PALETTIZATION_MODES} are supported for weight palettization. Got \"mode\": \"{mode}\".") - @lut_function.validator def check_lut_function(self, attr, lut_function): mode = self.mode.upper() @@ -565,11 +745,18 @@ def check_lut_function(self, attr, lut_function): if lut_function is not None and not callable(lut_function): raise ValueError(f"A function object must be provided as \"lut_function\". Got a \"lut_function\" as type {type(self.lut_function)}") + @granularity.validator + def check_granularity(self, attr, granularity): + if granularity not in self._VALID_GRANULARITIES: + raise ValueError( + f'"granularity" must be one of {self._VALID_GRANULARITIES}, but got {granularity}' + ) + def __attrs_post_init__(self): self.mode = self.mode.upper() @classmethod - def _from_dict(cls, config_dict: Dict[str, Any]) -> "OpPalettizerConfig": + def _from_dict(cls, config_dict: Dict[str, Any]) -> OpPalettizerConfig: if "lut_function" in config_dict: raise ValueError( "_from_dict method does not support lut_function. Please create the OpPalettizerConfig from scratch." @@ -577,6 +764,7 @@ def _from_dict(cls, config_dict: Dict[str, Any]) -> "OpPalettizerConfig": converter = cattrs.Converter(forbid_extra_keys=True) return converter.structure(config_dict, cls) + @define class OptimizationConfig: """ @@ -763,14 +951,13 @@ def check_global_configs(self, attr, global_config): return self._check_op_config_type(global_config) - - def _get_op_config(self, op: Operation): + def _get_op_config(self, op: operation.Operation): """ - This utility function retrieve the compression config for an non-const Operation instance. + This utility function retrieve the compression config for an non-const operation.Operation instance. The priority is by: op name -> op type -> global """ - if not isinstance(op, Operation): - raise TypeError(f"op must be type of Operation. Got {type(op)}") + if not isinstance(op, operation.Operation): + raise TypeError(f"op must be type of operation.Operation. Got {type(op)}") if op.op_type == "const": raise TypeError("op must not be of type const") @@ -782,13 +969,13 @@ def _get_op_config(self, op: Operation): return self.global_config - def _get_const_op_config(self, op: Operation): + def _get_const_op_config(self, op: operation.Operation): """ - This utility function retrieves the compression config by an const Operation instance. + This utility function retrieves the compression config by an const operation.Operation instance. If the const is fed into multiple operations, an error would be thrown if a conflict is detected. """ - if not isinstance(op, Operation): - raise TypeError(f"op must be type of Operation. Got {type(op)}") + if not isinstance(op, operation.Operation): + raise TypeError(f"op must be type of operation.Operation. Got {type(op)}") if not (op.op_type == "const" or op.op_type.startswith("constexpr_")): raise TypeError(f"op must be of type const or constexpr. Got {op.op_type}") @@ -803,10 +990,14 @@ def _get_const_op_config(self, op: Operation): # If the constant's output is only connected to the block output, we don't do compression # Due to this bug: rdar://108274019 ([Bug] constexpr ops cannot be directly fed to block output) - child_ops = op.outputs[0].child_ops + child_ops = [child_op for op_output in op.outputs for child_op in op_output.child_ops] if len(child_ops) == 0: return None + # If the const is fed into constexpr ops, we follow the chain to get the non-constexpr. + if all(child_op.op_type.startswith("constexpr_") for child_op in child_ops): + return self._get_const_op_config(child_ops[0]) + op_configs = [self._get_op_config(op) for op in child_ops] for i, config in enumerate(op_configs): diff --git a/coremltools/optimize/coreml/_post_training_quantization.py b/coremltools/optimize/coreml/_post_training_quantization.py index 3e4d0ae03..9bcc6e804 100644 --- a/coremltools/optimize/coreml/_post_training_quantization.py +++ b/coremltools/optimize/coreml/_post_training_quantization.py @@ -4,108 +4,76 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause from collections import OrderedDict -from typing import Callable, Dict, List, Optional +from typing import Dict, List, Optional import numpy as np from attrs import define, field, validators from tqdm import tqdm -from coremltools import _SPECIFICATION_VERSION_IOS_16 -from coremltools.converters.mil.converter import mil_convert as _mil_convert from coremltools.converters.mil.frontend.milproto import load as _milproto_to_pymil -from coremltools.converters.mil.mil.passes.defs.quantization import ( - AbstractQuantizationPass as _AbstractQuantizationPass, -) from coremltools.converters.mil.mil.passes.graph_pass import PassOption from coremltools.converters.mil.mil.passes.pass_registry import PASS_REGISTRY -from coremltools.models import MLModel as _MLModel +from coremltools.models import model as _model +from coremltools.models import utils as _model_utils from coremltools.optimize.coreml import OptimizationConfig as _OptimizationConfig from coremltools.optimize.coreml._config import _MetaDataDict from ._quantization_passes import WeightDecompressor as _WeightDecompressor -from ._quantization_passes import linear_quantize_weights as _linear_quantize_weights -from ._quantization_passes import palettize_weights as _palettize_weights -def _convert_model_spec_to_pymil_prog( - mlmodel: _MLModel, specification_version: int, pymil_load_func: Callable -): - """ - An utility that converts a ml program model into PyMIL program. +def _is_valid_const(val, weight_threshold): + return isinstance(val, np.ndarray) and val.size >= weight_threshold + + +def _multifunction_unsupported(func): """ - model_spec = mlmodel.get_spec() - model_type = model_spec.WhichOneof("Type") - if model_type in ("neuralNetwork", "neuralNetworkClassifier", "neuralNetworkRegressor", "pipeline", "PipelineClassifier", "PipelineRegressor"): - msg = ("coremltools.optimize.coreml are meant to be used only with mlprogram typed coreml models. " - "This model has type {}. Please use coremltools.models.neural_network.quantization_utils.quantize_weights" - "instead to compress the weights of the model.") - raise TypeError(msg.format(model_type)) - elif model_type == "mlProgram": - pass - else: - raise TypeError("weight compression not applicable for model type {}".format(model_type)) - - prog = pymil_load_func( - model_spec=model_spec, - specification_version=specification_version, - file_weights_dir=mlmodel.weights_dir, - ) - return prog + The decorator marks the PTQ API that doesn't support the multifunction model. + We should use this decorator until the radar is fixed: + rdar://126084385 ([Infra] Figure out the story of PTQ or other passes operate on loaded Mutli-function model) + Note that the API must take `mlmodel` with type of `MLModel` as an input. + """ -def _apply_graph_pass( - mlmodel: _MLModel, - graph_pass: _AbstractQuantizationPass, - spec_version: int = _SPECIFICATION_VERSION_IOS_16, - skip_model_load: bool = False, - pymil_load_func: Callable = _milproto_to_pymil.load, -): - # Utility function which compresses a Core ML model - # converts the full precision mlmodel into a pymil program - model_spec = mlmodel.get_spec() - specification_version = max(model_spec.specificationVersion, spec_version) - prog = _convert_model_spec_to_pymil_prog(mlmodel, specification_version, pymil_load_func) - - # apply compression graph pass - assert isinstance( - graph_pass, _AbstractQuantizationPass - ), "compression pass must be an AbstractQuantizationPass instance" - graph_pass.apply(prog) - - # convert the pymil program back to mlmodel - compressed_mlmodel = _mil_convert( - prog, - convert_to="mlprogram", - convert_from="milinternal", - specification_version=specification_version, - compute_units=mlmodel.compute_unit, - model_description=model_spec.description, - skip_model_load=skip_model_load, - ) - return compressed_mlmodel + def decorator(*args, **kwargs): + num_args = func.__code__.co_argcount + arg_names = list(func.__code__.co_varnames)[:num_args] + param_dict = {k: v for k, v in zip(arg_names, args)} + model = param_dict.get("mlmodel", None) + if model is None: + raise ValueError( + f'Function {func} decorated with _multifunction_unsupported must takes "mlmodel" as an input.' + ) + if model._is_multifunction(): + raise ValueError(f"{func} is not supported for a multifunction model.") + return func(*args, **kwargs) + return decorator -def _is_valid_const(val, weight_threshold): - return isinstance(val, np.ndarray) and val.size >= weight_threshold -def linear_quantize_weights(mlmodel: _MLModel, config: _OptimizationConfig): +@_multifunction_unsupported +def linear_quantize_weights( + mlmodel: _model.MLModel, config: _OptimizationConfig, joint_compression: bool = False +): """ Utility function to convert a float precision MLModel of type ``mlprogram``, which uses - float-precision weights, into a compressed MLModel that uses 8-bit weights. This is - achieved by converting the float weight values that are stored in the ``const`` op - into the ``constexpr_affine_dequantize`` op. + float-precision weights, into a compressed MLModel that uses n-bit weights (currently only + support n=4 and n=8). This is achieved by converting the float weight values that are stored in + the ``const`` op into the ``constexpr_affine_dequantize`` or ``constexpr_blockwise_shift_scale`` + op (based on model's minimum deployment target). - This function uses linear quantization on the float weights, providing up to 2x + This function uses linear quantization on the float weights, providing up to 4x (for 4-bit) savings in storage compared to float 16, or up to 4x savings compared to float 32. All computation at runtime uses float precision; the precision of the intermediate tensors and the compute precision of the ops are not altered. - For each weight, this utility function converts the weight into the int8 or uint8 type using - either `linear interpolation` (``"linear"`` mode) or `linear symmetric - interpolation` (``"linear_symmetric"`` mode, the default). + For each weight, this utility function converts the weight into the int4/8 or uint4/8 type using + either `linear interpolation` (``"linear"`` mode) or `linear symmetric interpolation` + (``"linear_symmetric"`` mode, the default). **Linear interpolation** + The following description uses 8-bit quantization to illustrate, and 4-bit is similar to it. + Linear interpolation (``"linear"`` mode) maps the min/max of the float range to the 8-bit integer range ``[low, high]`` using a zero point (also called quantization bias, or offset) and a scale factor. For the int8 quantization, ``[low, high] = [-128, 127]``, while uint8 @@ -177,6 +145,16 @@ def linear_quantize_weights(mlmodel: _MLModel, config: _OptimizationConfig): config: OptimizationConfig An :py:class:`OptimizationConfig` object that specifies the parameters for weight quantization. + joint_compression: bool + When it is set, the input mlmodel (should already be compressed) is further quantized to a + jointly compressed mlmodel. For what compression schema that could be futher jointly + quantized, see the `blockwise_quantize_weights` graph pass for details. + + Using "palettize + quantize" as an example, where the input mlmodel is already palettized, + and the palettization's lut will be further quantized. The weight values are represented by + ``constexpr_blockwise_shift_scale`` + ``constexpr_lut_to_dense`` ops: + lut(int8) -> constexpr_blockwise_shift_scale -> lut(fp16) -> constexpr_lut_to_dense -> dense(fp16) + Returns ------- @@ -197,14 +175,20 @@ def linear_quantize_weights(mlmodel: _MLModel, config: _OptimizationConfig): compressed_model = cto.coreml.linear_quantize_weights(model, config) """ + blockwise_weight_quantizer = PASS_REGISTRY["compression::linear_quantize_weights"] + blockwise_weight_quantizer.set_options( + [PassOption("config", config), PassOption("joint_compression", joint_compression)] + ) + return _model_utils._apply_graph_pass(mlmodel, blockwise_weight_quantizer) - linear_weight_quantizer = _linear_quantize_weights(config, fake_compression=False) - return _apply_graph_pass(mlmodel, linear_weight_quantizer) -def palettize_weights(mlmodel: _MLModel, config: _OptimizationConfig): +@_multifunction_unsupported +def palettize_weights( + mlmodel: _model.MLModel, config: _OptimizationConfig, joint_compression: bool = False +): """ Utility function to convert a float precision MLModel of type ``mlprogram`` to a - compressed MLModel by reducing the overall number of weights using a lookup table + compressed MLModel by reducing the overall number of weights using one or more look-up-table (LUT). A LUT contains a list of float values. An `nbit` LUT has 2\ :sup:`nbits` entries. For example, a float weight vector such as ``{0.3, 0.3, 0.5, 0.5}`` can be compressed @@ -245,6 +229,16 @@ def palettize_weights(mlmodel: _MLModel, config: _OptimizationConfig): config: OptimizationConfig An :py:class:`OptimizationConfig` object that specifies the parameters for weight palettization. + joint_compression: bool + When it is set, the input mlmodel (should already be compressed) is further palettized to a + jointly compressed mlmodel. For what compression schema that could be futher jointly + palettized, see the `channelwise_palettize_weights` graph pass for details. + + Using "prune + palettize" as an example, where the input mlmodel is already pruned, + and the non-zero entries will be further palettized. The weight values are represented by + ``constexpr_lut_to_sparse`` + ``constexpr_sparse_to_dense`` ops: + lut(sparse) -> constexpr_lut_to_sparse -> weight(sparse) -> constexpr_sparse_to_dense -> weight(dense) + Returns ------- model: MLModel @@ -264,11 +258,17 @@ def palettize_weights(mlmodel: _MLModel, config: _OptimizationConfig): compressed_model = cto.coreml.palettize_weights(model, config) """ + weight_palettizer = PASS_REGISTRY["compression::palettize_weights"] + weight_palettizer.set_options( + [PassOption("config", config), PassOption("joint_compression", joint_compression)] + ) + return _model_utils._apply_graph_pass(mlmodel, weight_palettizer) - weight_palettizer = _palettize_weights(config, fake_compression=False) - return _apply_graph_pass(mlmodel, weight_palettizer) -def prune_weights(mlmodel: _MLModel, config: _OptimizationConfig): +@_multifunction_unsupported +def prune_weights( + mlmodel: _model.MLModel, config: _OptimizationConfig, joint_compression: bool = False +): """ Utility function to convert a float precision MLModel of type ``mlprogram`` to a compressed MLModel using sparse representation. The ``const`` ops storing weight @@ -301,6 +301,16 @@ def prune_weights(mlmodel: _MLModel, config: _OptimizationConfig): config: OptimizationConfig An :py:class:`OptimizationConfig` object that specifies the parameters for weight pruning. + joint_compression: bool + When it is set, the input mlmodel (should already be compressed) is further pruned to a + jointly compressed mlmodel. For what compression schema that could be futher jointly + pruned, see the `prune_weights` graph pass for details. + + Using "quantize + prune" as an example, where the input mlmodel is already quantized, + and it will be further pruned. The weight values are represented by + ``constexpr_sparse_blockwise_shift_scale`` + ``constexpr_sparse_to_dense`` ops: + quantized(sparse) -> constexpr_sparse_blockwise_shift_scale -> weight(sparse) -> constexpr_sparse_to_dense -> weight(dense) + Returns ------- model: MLModel @@ -321,10 +331,14 @@ def prune_weights(mlmodel: _MLModel, config: _OptimizationConfig): """ weight_pruner = PASS_REGISTRY["compression::prune_weights"] - weight_pruner.set_options([PassOption("config", config)]) - return _apply_graph_pass(mlmodel, weight_pruner) + weight_pruner.set_options( + [PassOption("config", config), PassOption("joint_compression", joint_compression)] + ) + return _model_utils._apply_graph_pass(mlmodel, weight_pruner) + -def decompress_weights(mlmodel: _MLModel): +@_multifunction_unsupported +def decompress_weights(mlmodel: _model.MLModel): """ Utility function to convert weights that are sparse or palettized or affine quantized, back to the float format. That is, convert any of the following three ops to ``mb.const``: @@ -355,11 +369,11 @@ def decompress_weights(mlmodel: _MLModel): """ weight_decompressor = _WeightDecompressor(op_selector=lambda op: True) - return _apply_graph_pass(mlmodel, weight_decompressor) + return _model_utils._apply_graph_pass(mlmodel, weight_decompressor) - -def get_weights_metadata(mlmodel: _MLModel, weight_threshold: int = 2048): +@_multifunction_unsupported +def get_weights_metadata(mlmodel: _model.MLModel, weight_threshold: int = 2048): """ Utility function to get the weights metadata as a dictionary, which maps the weight's name to its corresponding CoreMLWeightMetaData. @@ -471,8 +485,9 @@ def _get_weight_metadata(op): ) return CoreMLWeightMetaData(op.val.val, child_ops=child_ops) - prog = _convert_model_spec_to_pymil_prog(mlmodel, mlmodel.get_spec().specificationVersion, - _milproto_to_pymil.load) + prog = _model_utils._convert_model_spec_to_pymil_prog( + mlmodel, mlmodel.get_spec().specificationVersion, _milproto_to_pymil.load + ) res = _MetaDataDict({}) def get_weights_meta_block(block): @@ -582,7 +597,7 @@ class CoreMLWeightMetaData: print(meta_data) Outputs:: - + [ val: np.ndarray(shape=(2, 2), dtype=float32) sparsity: 0.5 diff --git a/coremltools/optimize/coreml/_quantization_passes.py b/coremltools/optimize/coreml/_quantization_passes.py index a87160d8d..c426dcd93 100644 --- a/coremltools/optimize/coreml/_quantization_passes.py +++ b/coremltools/optimize/coreml/_quantization_passes.py @@ -3,33 +3,41 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -from typing import Callable, Optional, Tuple +import atexit +from itertools import repeat +from multiprocessing import Pool +from typing import Callable, List, Optional, Tuple, Union import numpy as np from tqdm import tqdm -import coremltools.converters.mil.frontend._utils as frontend_utils from coremltools import _logger as logger from coremltools.converters.mil._deployment_compatibility import AvailableTarget from coremltools.converters.mil.backend.mil.load import should_use_weight_file from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import Operation, Program, types from coremltools.converters.mil.mil.block import is_current_opset_version_compatible_with -from coremltools.converters.mil.mil.ops.defs._utils import pack_elements_into_bits +from coremltools.converters.mil.mil.ops.defs.iOS16 import constexpr_affine_dequantize from coremltools.converters.mil.mil.ops.defs.iOS16 import ( - constexpr_affine_dequantize, + constexpr_lut_to_dense as constexpr_lut_to_dense_ios16, +) +from coremltools.converters.mil.mil.ops.defs.iOS16 import ( + constexpr_sparse_to_dense as constexpr_sparse_to_dense_ios16, +) +from coremltools.converters.mil.mil.ops.defs.iOS18 import ( + constexpr_blockwise_shift_scale, constexpr_lut_to_dense, constexpr_sparse_to_dense, ) from coremltools.converters.mil.mil.passes.defs.quantization import AbstractQuantizationPass from coremltools.converters.mil.mil.passes.helper import block_context_manager from coremltools.converters.mil.mil.passes.pass_registry import register_pass -from coremltools.converters.mil.mil.types.type_mapping import nptype_from_builtin from coremltools.converters.mil.mil.var import Var -from coremltools.models._deprecation import deprecated as _deprecated +from coremltools.models._deprecation import deprecated from coremltools.models.neural_network.quantization_utils import _get_kmeans_lookup_table_and_weight -from coremltools.optimize.coreml import _utils +from coremltools.optimize.coreml import _utils as optimize_utils from coremltools.optimize.coreml._config import ( + CompressionGranularity, OpLinearQuantizerConfig, OpMagnitudePrunerConfig, OpPalettizerConfig, @@ -38,17 +46,41 @@ ) -""" ------------------------- -Compression graph pass - ------------------------- -""" class AbstractCompressionPass(AbstractQuantizationPass): """ The abstract class for the compression graph passes. """ _MINIMUM_OPSET_VERSION = AvailableTarget.iOS16 + # Graph pass option for setting compression config. + _config: Optional[OptimizationConfig] = None + + # Graph pass option for enabling joint compressions. + _joint_compression: bool = False + + @property + def config(self) -> OptimizationConfig: + return self._config + + @config.setter + def config(self, value: OptimizationConfig): + self._check_config_type(value) + self._config = value + if value._op_selector is not None: + self.op_selector = value._op_selector + + @property + def joint_compression(self): + return self._joint_compression + + @joint_compression.setter + def joint_compression(self, joint_compression: bool): + if not isinstance(joint_compression, bool): + raise ValueError( + f"joint_compression only supports bool, but got {type(joint_compression)}" + ) + self._joint_compression = joint_compression + def __init__(self, config: OptimizationConfig = None, fake_compression: bool = False): if not isinstance(config, (OptimizationConfig, type(None))): raise ValueError(f"config must be of type OptimizationConfig. Got {type(config)}.") @@ -74,6 +106,14 @@ def apply_block(block): f"Skipped the compression pass {self.__class__}.") return + if self._joint_compression and not is_current_opset_version_compatible_with( + AvailableTarget.iOS18 + ): + raise ValueError( + "Joint compression is only supported since iOS18. Please set the " + "minimum deployment target to iOS18 if you want to use it." + ) + valid_consts = [] for op in list(block.operations): for b in op.blocks: @@ -97,17 +137,6 @@ def apply_block(block): for f in prog.functions.values(): apply_block(f) - @property - def config(self) -> OptimizationConfig: - return self._config - - @config.setter - def config(self, value: OptimizationConfig): - self._check_config_type(value) - self._config = value - if value._op_selector is not None: - self.op_selector = value._op_selector - def need_compress_const( self, op: Operation, _is_deprecated: bool, weight_threshold: float ) -> bool: @@ -129,6 +158,13 @@ def need_compress_const( if weight_threshold is None: raise ValueError("weight_threshold cannot be None") + # Disable 1D tensor compression due to MIL 1D Tensor bug (rdar://113860800). + if ( + not op.outputs[0].child_ops[0].op_type.startswith("constexpr_") + and op.outputs[0].rank <= 1 + ): + return False + return ( should_use_weight_file(val) and self._get_weight_to_compress_size(op) > weight_threshold ) @@ -161,40 +197,6 @@ def get_supported_types_as_str(supported_type): supported_type_str = get_supported_types_as_str(self._SUPPORTED_CONFIG_TYPE) raise ValueError(f"{self.__class__.__name__} only accept {supported_type_str} type config. Got {config.__class__.__name__}.") - @staticmethod - def select_input_output_channel_axis(op: Operation) -> Tuple[int, int]: - """ - Here are some representative ops: - - linear: [D_out, D_in] - - matmul's y: [..., D_in, D_out] if transpose_y is False, else [..., D_out, D_in] - - conv: [C_out, C_in_div_group, KH, KW] - - conv_transpose: [C_in, C_out_div_group, KH, KW] - - The input output channel axis selection criteria is: - - For conv_transpose the output channel is 1 and input channel is 0. - - For matmul's y: - - When transpose_y=False, output channel is -1 and input channel is -2 - - When transpose_y=True, output channel is -2 and input channel is -1 - - For all other ops, output channel is 0 and input channel is 1. - """ - output_channel_axis, input_channel_axis = 0, 1 - var = op.outputs[0] - if len(var.child_ops) == 1: - child_op = var.child_ops[0] - if child_op.op_type == "conv_transpose": - output_channel_axis = 1 - input_channel_axis = 0 - if child_op.op_type == "matmul" and child_op.y == var: - if child_op.transpose_y.val: - output_channel_axis = -2 - input_channel_axis = -1 - else: - output_channel_axis = -1 - input_channel_axis = -2 - if child_op.op_type.startswith("constexpr_"): - return AbstractCompressionPass.select_input_output_channel_axis(child_op) - return input_channel_axis, output_channel_axis - def is_valid_op(self, op: Operation): if op.op_type == "const" and should_use_weight_file(self._get_const_value(op)): return True @@ -206,6 +208,31 @@ def _get_const_value(self, op: Operation) -> np.ndarray: return op.outputs[0].val def _get_weight_to_compress_size(self, op: Operation) -> int: + """ + For joint compression, the constexpr op is the intermediate compressed result, so we + need to go along the constexpr op chain to get the op which actually is the weight need + to be compressed. + + For example, the op could be a const feed into constexpr_lut_to_dense as indices, and the + constexpr_lut_to_dense is fed into a conv op. In this case, we need to find the original + weight of the conv op, instead of using the const indices to determine if we want to + compress the op. + """ + if not (op.op_type == "const" or op.op_type.startswith("constexpr_")): + raise ValueError(f"Only support const or constexpr ops, but got {op.op_type}") + + if self.joint_compression: + for op_output in op.outputs: + # If the current const/constexpr is used in multiple ops, we do a depth-first + # search to find the endpoint of the chained const/constexpr ops. + for child_op in op_output.child_ops: + if child_op.op_type.startswith("constexpr_"): + return self._get_weight_to_compress_size(child_op) + else: + # The child op is not constexpr, which means the current op is the real + # weight (not intermediate constexpr) that need compression. + return np.prod(op.outputs[0].shape) + if op.op_type != "const": raise ValueError("Only const weight can be compressed") return np.prod(op.outputs[0].shape) @@ -225,21 +252,65 @@ class prune_weights(AbstractCompressionPass): - If ``fake_compression=False``, the zeroed-out value is encoded using the ``constexpr_sparse_to_dense`` op. - If ``fake_compression=True``, the zeroed-out value is encoded using the ``const`` op. - Old ``const`` is replaced by a new operation with zeroed-out value. + + When the `joint_compression` option is set, for each existing compressed constexpr op, it will + check if the result is sparse. If the result is sparse, it will replace the constexpr op by the + corresponding sparse version to support joint compression. More specifically: + - For quantization, `constexpr_blockwise_shift_scale` is replaced by `constexpr_sparse_blockwise_shift_scale` + + `constexpr_sparse_to_dense` if the dequantized result is sparse. + - For palettization, `constexpr_lut_to_dense` is replaced by `constexpr_lut_to_sparse` + + `constexpr_sparse_to_dense` if the depalettized result is sparse. + + .. code-block:: + + Input graph: + + constexpr_blockwise_shift_scale -> downstream op + + Output graph: + + constexpr_sparse_blockwise_shift_scale -> constexpr_sparse_to_dense -> downstream op + + Support Options: + + - ``joint_compression``: Enable joint compression. Similar to blockwise_quantize_weights and """ _SUPPORTED_CONFIG_TYPE = (OpMagnitudePrunerConfig, OpThresholdPrunerConfig) + # Ops to be further pruned for joint compression. + _JOINT_SUPPORT_OPS = {"constexpr_blockwise_shift_scale", "constexpr_lut_to_dense"} + + def is_valid_op(self, op: Operation): + if not self.joint_compression: + return super().is_valid_op(op) + if op.op_type in self._JOINT_SUPPORT_OPS and should_use_weight_file( + self._get_const_value(op) + ): + return True + return False + + def _get_const_value(self, op: Operation) -> np.ndarray: + if op.op_type == "const" or not self.joint_compression: + return super()._get_const_value(op) + elif op.op_type.startswith("constexpr_"): + # The materialized_val_inference is expensive, so only do it for joint compression, as + # we need to get the de-compressed value and prune it. + return op.materialized_val_inference() + else: + raise ValueError(f"The op {op} is not a const/constexpr.") @staticmethod - def _pack_val_to_sparse_param(val): + def _produce_sparse_param(val) -> optimize_utils.SparseParamsIos16: flattened_val = val.flatten() - params = _utils.SparseParams( + return optimize_utils.SparseParamsIos16( nonzero_data=flattened_val[np.where(flattened_val != 0)], mask=np.packbits(np.where(flattened_val != 0, 1, 0), bitorder="little"), shape=val.shape, ) - return params @staticmethod - def compress_by_threshold(val, threshold, minimum_sparsity_percentile): + def compress_by_threshold( + val, threshold, minimum_sparsity_percentile + ) -> Optional[optimize_utils.SparseParamsIos16]: val = np.where(np.abs(val) <= threshold, 0, val) sparsity_percentile = np.sum(val == 0.0) / val.size if sparsity_percentile < minimum_sparsity_percentile: @@ -248,10 +319,12 @@ def compress_by_threshold(val, threshold, minimum_sparsity_percentile): ) logger.warning(msg) return None - return prune_weights._pack_val_to_sparse_param(val) + return prune_weights._produce_sparse_param(val) @staticmethod - def compress_by_magnitude(val, target_sparsity, block_size=None, dim=None): + def compress_by_magnitude( + val, target_sparsity, block_size=None, dim=None + ) -> Optional[optimize_utils.SparseParamsIos16]: def _apply_block_sparsity(val, block_size, dim): shape = val.shape rank = len(shape) @@ -317,10 +390,10 @@ def _apply_block_sparsity(val, block_size, dim): val = 0 * val elif q != 0: val = np.where(magnitude_map <= np.percentile(magnitude_map, q), 0, val) - return prune_weights._pack_val_to_sparse_param(val) + return prune_weights._produce_sparse_param(val) @staticmethod - def compress_by_nm_sparsity(val, n_m_ratio, dim): + def compress_by_nm_sparsity(val, n_m_ratio, dim) -> Optional[optimize_utils.SparseParamsIos16]: n, m = n_m_ratio assert n <= m shape = val.shape @@ -373,20 +446,59 @@ def compress_by_nm_sparsity(val, n_m_ratio, dim): n_m_mask = np.transpose(n_m_mask, axes=perm_back) val = val * (1 - n_m_mask) - return prune_weights._pack_val_to_sparse_param(val) + return prune_weights._produce_sparse_param(val) @staticmethod - def decompress(params): - if not isinstance(params, _utils.SparseParams): + def decompress( + params: Union[optimize_utils.SparseParamsIos16, optimize_utils.SparseParams] + ) -> np.ndarray: + if isinstance(params, optimize_utils.SparseParamsIos16): + return constexpr_sparse_to_dense_ios16.decompress( + params.nonzero_data, params.mask, params.shape + ) + elif isinstance(params, optimize_utils.SparseParams): + return constexpr_sparse_to_dense.decompress(params.nonzero_data, params.mask) + else: raise ValueError("Invalid type of params") - return constexpr_sparse_to_dense.decompress(params.nonzero_data, params.mask, params.shape) @staticmethod - def _create_constexpr_var(op: Operation, sparse_params: _utils.SparseParams) -> Var: + def _create_constexpr_var( + op: Operation, sparse_params: optimize_utils.SparseParams, joint_compression: bool = False + ) -> Var: + if not is_current_opset_version_compatible_with(AvailableTarget.iOS18): + sparse_params_ios16 = optimize_utils.ios18_sparse_params_to_ios16(sparse_params) + return mb.constexpr_sparse_to_dense( + nonzero_data=sparse_params_ios16.nonzero_data, + mask=sparse_params_ios16.mask, + shape=np.uint32(sparse_params_ios16.shape), + before_op=op, + name=op.name + "_sparsified", + ) + + mask = sparse_params.mask + nonzero_data = sparse_params.nonzero_data + + if joint_compression: + if op.op_type == "constexpr_blockwise_shift_scale": + mask, nonzero_data = mb.constexpr_sparse_blockwise_shift_scale( + data_mask=mask, + nonzero_data=op.data.val[mask != 0].flatten(), + scale=op.scale, + offset=op.offset, + before_op=op, + ) + elif op.op_type == "constexpr_lut_to_dense": + mask, nonzero_data = mb.constexpr_lut_to_sparse( + indices_mask=mask, + indices_nonzero_data=op.indices.val[mask != 0].flatten(), + lut=op.lut, + vector_axis=op.vector_axis, + before_op=op, + ) + return mb.constexpr_sparse_to_dense( - nonzero_data=sparse_params.nonzero_data, - mask=sparse_params.mask, - shape=np.uint32(sparse_params.shape), + nonzero_data=nonzero_data, + mask=mask, before_op=op, name=op.name + "_sparsified", ) @@ -402,6 +514,8 @@ def transform_op(self, op: Operation): if not isinstance(const_val, (np.ndarray, np.generic)): raise ValueError("Only numpy arrays are supported") + sparse_params: Optional[optimize_utils.SparseParamsIos16] = None + skip_msg = f"op named {op.name} not applicable for {op_config} configuration. Skipped." if isinstance(op_config, OpThresholdPrunerConfig): sparse_params = self.compress_by_threshold( val=const_val, @@ -415,7 +529,7 @@ def transform_op(self, op: Operation): # if it is explicitly set by set_op_name, if not op_config._check_const_op_is_valid(op): if op.name not in self.config.op_name_configs: - logger.warning(f"op named {op.name} not applicable for {OpMagnitudePrunerConfig} configuration. Skipped.") + logger.warning(skip_msg) return if op_config.target_sparsity is not None: @@ -433,10 +547,19 @@ def transform_op(self, op: Operation): ) if sparse_params is None: + logger.warning(skip_msg) return + sparse_params: optimize_utils.SparseParams = optimize_utils.ios16_sparse_params_to_ios18( + sparse_params + ) + if not self.fake_compression: - new_var = self._create_constexpr_var(op, sparse_params) + new_var = self._create_constexpr_var( + op, + sparse_params, + joint_compression=self.joint_compression and op.op_type in self._JOINT_SUPPORT_OPS, + ) else: decompressed_val = self.decompress(sparse_params) new_var = mb.const( @@ -450,10 +573,12 @@ def transform_op(self, op: Operation): old_var=op.outputs[0], new_var=new_var, no_check_var_types=True, + force_replace=True, # Need force_replace to replace the constexpr. ) op.enclosing_block.remove_ops([op]) + @register_pass(namespace="compression") class palettize_weights(AbstractCompressionPass): """ @@ -468,9 +593,72 @@ class palettize_weights(AbstractCompressionPass): - If ``fake_compression=False``, compressed value is encoded using the ``constexpr_lut_to_dense`` op. - If ``fake_compression=True``, compressed value is decompressed and then encoded using the ``const`` op. - Old ``const`` op is replaced by a newly created operation. + + Here is an example for input and output graph of this graph pass: + + .. code-block:: + + Input graph: + + const -> downstream op + + Output graph: + + constexpr_lut_to_dense -> downstream op + + + Support Options: + + - ``joint_compression``: + Enable joint compression by quantizing an already compressed model. + What op could be further quantized is in `_validate_child_constexpr_for_compress`. + + Using pruning + palettization as an example, for each existing ``constexpr_sparse_to_dense`` + op, it tries to palettize the non-sparse elements in the spasified data, which could be + represented as: + + + - For each existing ``constexpr_sparse_to_dense`` op, it tries to palettize the + non-sparse elements in the spasified data, which could be represented as: + + + .. code-block:: + + Input graph: + + sparse weight(fp16) -> constexpr_sparse_to_dense -> dense weight(fp16) + + Output graph: + + sparse lut(int8) -> constexpr_lut_to_sparse -> sparse weight(fp16) -> constexpr_sparse_to_dense -> dense weight(fp16) + + For details about different palettization schemas, see `OpPalettizerConfig` for more details. """ _SUPPORTED_CONFIG_TYPE = OpPalettizerConfig - _SUPPORTED_NBITS = (1, 2, 4, 6, 8) + _SUPPORTED_NBITS = (1, 2, 3, 4, 6, 8) + + _compress_pool: Optional[Pool] = None + + def __del__(self): + if palettize_weights._compress_pool is not None: + palettize_weights._compress_pool.close() + + def _validate_child_constexpr_for_compress(self, op: Operation) -> bool: + """Determines which pattern supports joint compression.""" + if ( + is_current_opset_version_compatible_with(AvailableTarget.iOS18) + and self.joint_compression + ): + # In iOS18 joint compression, the sparsified data could be further palettized. + if len(op.outputs[0].child_ops) == 1: + child_op = op.outputs[0].child_ops[0] + if ( + child_op.op_type == "constexpr_sparse_to_dense" + and child_op.nonzero_data == op.outputs[0] + ): + return True + + return super()._validate_child_constexpr_for_compress(op) @staticmethod def _get_nbits_for_unique_mode(val: np.ndarray, allowed_nbits: Tuple[int, ...]) -> int: @@ -552,7 +740,18 @@ def compress_unique(val, nbits): return lut, indices @staticmethod - def compress(val, mode, nbits=None, lut_function=None) -> _utils.LutParams: + @deprecated( + suffix="Please use coremltools.optimize.coreml.palettize_weights.blockwise_compress", + version="8.2", + obj_prefix="coremltools.optimize.coreml.palettize_weights.", + ) + def compress(val, mode, nbits=None, lut_function=None) -> optimize_utils.LutParamsIos16: + """ + [Legacy] Per-tensor palletization. + + This API is for backward compatibility only. It's no longer used inside the coremltools. + It's recommended to use `blockwise_compress` instead, which is more general. + """ def check_lut_parameters_are_valid(val, lut, indices): if not isinstance(lut, np.ndarray) or not isinstance(indices, np.ndarray): raise ValueError("LUT and indices must be type of numpy array.") @@ -582,25 +781,197 @@ def check_lut_parameters_are_valid(val, lut, indices): check_lut_parameters_are_valid(val, lut, indices) - params = _utils.LutParams( + params = optimize_utils.LutParamsIos16( lut=lut, - indices=pack_elements_into_bits(indices, int(np.log2(lut.shape[0]))), + indices=optimize_utils.pack_elements_into_bits(indices, int(np.log2(lut.shape[0]))), shape=val.shape, ) return params @staticmethod - def decompress(params): - if not isinstance(params, _utils.LutParams): + def blockwise_compress( + original_data: np.ndarray, + mode: str, + nbits: Optional[int], + block_sizes: List[int], + lut_function: Optional[Callable] = None, + num_kmeans_workers: int = 1, + ) -> Optional[optimize_utils.LutParams]: + """ + Compress original_data into n-bit representation by palettization. + + Supported nbits: 1, 2, 3, 4, 6, 8 + Supported mode: KMEANS, UNIFORM, UNIQUE, CUSTOM + + block_sizes: Each element is the block size on corresponding axis for original_data. + + Returns None if the weight cannot be compressed (for example, the dim size on an axis is not + divisible by the corresponding block_size). + """ + # TODO (rdar://127342739): Support more general blockwise palettization. + # As general blockwise palettization hasn't been supported yet, we try to infer channel axis + # and channel group size from block_sizes, and use grouped channelwise palettization instead. + channel_axis = None + channel_group_size = 0 + for axis, block_size in enumerate(block_sizes): + if block_size != 0 and block_size != original_data.shape[axis]: + if channel_axis is not None: + raise NotImplementedError( + "General block-wise palettization is not supported. Please use " + "'per_grouped_channel' or 'per_tensor' for the 'granularity' in config." + ) + channel_axis = axis + channel_group_size = block_size + if channel_axis is None: + # Per-tensor compression, just need to pick a dummy axis. + channel_axis = 0 + + return palettize_weights.grouped_channelwise_compress( + original_data, + mode, + nbits, + channel_axis, + channel_group_size, + lut_function, + num_kmeans_workers, + ) + + @staticmethod + def grouped_channelwise_compress( + original_data: np.ndarray, + mode: str, + nbits: Optional[int], + channel_axis: int, + channel_group_size: int, + lut_function: Optional[Callable] = None, + num_kmeans_workers: int = 1, + ) -> Optional[optimize_utils.LutParams]: + """ + Compress original_data into n-bit representation by grouped channelwise palettization. + + Supported nbits: 1, 2, 3, 4, 6, 8 + Supported mode: KMEANS, UNIFORM, UNIQUE, CUSTOM + + block_sizes: Each element is the block size on corresponding axis for original_data. + + Returns None if the weight cannot be compressed (for example, the dim size on an axis is not + divisible by the corresponding channel_group_size). + """ + if not isinstance(original_data, np.ndarray): + raise ValueError(f"Only numpy arrays are supported, but got {type(original_data)}") + if nbits is not None and nbits not in palettize_weights._SUPPORTED_NBITS: + raise ValueError( + f"Invalid nbits. Support {palettize_weights._SUPPORTED_NBITS}, but got {nbits}" + ) + data_rank = len(original_data.shape) + if not (-data_rank <= channel_axis < data_rank): + raise ValueError( + "Invalid channel_axis. Should be in range " + f"[{-data_rank}, {data_rank}), but got {channel_axis}" + ) + + if channel_axis < 0: + channel_axis += len(original_data.shape) + + channel_num = original_data.shape[channel_axis] + if channel_group_size == 0: + channel_group_size = channel_num + if channel_num % channel_group_size != 0: + logger.warning( + f"Can't perform palettization: The number of channels at {channel_axis}th axis " + f"({channel_num}) is not divisible by channel_group_size ({channel_group_size})." + ) + return None + channel_group_num = channel_num // channel_group_size + + if channel_axis != 0: + original_data = np.swapaxes(original_data, 0, channel_axis) + grouped_channel_data = np.split(original_data, channel_group_num, axis=0) + + # If mode is UNIQUE, infer nbits from the number of unique values in each group. + if mode.upper() == "UNIQUE": + try: + for per_group_data in grouped_channel_data: + per_group_nbits = palettize_weights._get_nbits_for_unique_mode( + per_group_data, palettize_weights._SUPPORTED_NBITS + ) + # Pick the largest per-channel nbits to be used as the nbits for the whole op. + if nbits is None or per_group_nbits > nbits: + nbits = per_group_nbits + except ValueError as e: + logger.warning(f"Can't perform palettization:{e}") + return None + + # The subprocesses have overhead, so only use it for expensive computations (k-means). + if mode.upper() == "KMEANS" and num_kmeans_workers > 1: + if palettize_weights._compress_pool is None: + palettize_weights._compress_pool = Pool(processes=num_kmeans_workers) + atexit.register(lambda: palettize_weights._compress_pool.terminate()) + lut, indices = zip( + *palettize_weights._compress_pool.starmap( + palettize_weights._get_lut_and_indices, + zip(grouped_channel_data, repeat(mode), repeat(nbits), repeat(lut_function)), + ) + ) + else: + lut, indices = zip( + *[ + palettize_weights._get_lut_and_indices( + per_channel_group_data, mode, nbits, lut_function + ) + for per_channel_group_data in grouped_channel_data + ] + ) + + lut = np.stack(lut, axis=0) + indices = np.stack(indices, axis=0) + + if mode.upper() == "CUSTOM": + # The custom lut_function provided by users should have nbits info. + nbits = int(np.ceil(np.log2(lut.shape[-1]))) + + # The lut and indices from `_get_lut_and_indices` is flattened. The desired result should be + # `lut` with shape [channel_group_num, palette_num], and `indices` with same shape as the + # original_data. + palette_num = 2**nbits + indices = indices.reshape(original_data.shape) + lut_target_shape = [1] * (len(original_data.shape) + 2) + lut_target_shape[0] = channel_group_num + lut_target_shape[-2] = palette_num + lut = lut.reshape(lut_target_shape) + + if channel_axis != 0: + lut = np.swapaxes(lut, 0, channel_axis) + indices = np.swapaxes(indices, 0, channel_axis) + + indices_np_dtype = types.nptype_from_builtin(types.string_to_builtin(f"uint{nbits}")) + return optimize_utils.LutParams(indices.astype(indices_np_dtype), lut) + + @staticmethod + def decompress(params: Union[optimize_utils.LutParamsIos16, optimize_utils.LutParams]): + if isinstance(params, optimize_utils.LutParamsIos16): + return constexpr_lut_to_dense_ios16.decompress(params.lut, params.indices, params.shape) + elif isinstance(params, optimize_utils.LutParams): + return constexpr_lut_to_dense.decompress(params.indices, params.lut, None) + else: raise ValueError("Invalid type of params") - return constexpr_lut_to_dense.decompress(params.lut, params.indices, params.shape) @staticmethod - def _create_constexpr_var(op: Operation, lut_params: _utils.LutParams) -> Var: + def _create_constexpr_var(op: Operation, lut_params: optimize_utils.LutParams) -> Var: + """Create constexpr lut op based on opset version.""" + if not is_current_opset_version_compatible_with(AvailableTarget.iOS18): + lut_params_ios16 = optimize_utils.ios18_lut_params_to_ios16(lut_params) + return mb.constexpr_lut_to_dense( + indices=lut_params_ios16.indices, + lut=lut_params_ios16.lut, + shape=np.uint32(lut_params_ios16.shape), + before_op=op, + name=op.name + "_palettized", + ) + return mb.constexpr_lut_to_dense( indices=lut_params.indices, lut=lut_params.lut, - shape=np.uint32(lut_params.shape), before_op=op, name=op.name + "_palettized", ) @@ -612,24 +983,49 @@ def transform_op(self, op: Operation): if not self.need_compress_const(op, self.config._is_deprecated, op_config.weight_threshold): return - if op_config.mode == "UNIQUE": - try: - palettize_weights._get_nbits_for_unique_mode( - op.outputs[0].val, self._SUPPORTED_NBITS + weight_to_compress = op.outputs[0].val + if self.joint_compression: + child_op = op.outputs[0].child_ops[0] + if child_op.op_type == "constexpr_sparse_to_dense": + # When the child op is sparse_to_dense op, the weight_to_compress is the sparse + # representation, which need to be restored to dense representation for compression. + weight_to_compress = constexpr_sparse_to_dense.decompress( + weight_to_compress, child_op.mask.val ) - except ValueError as e: - logger.warning(f"Skip op {op.name} for palettization, because {e}") - return - lut_params = self.compress( - op.outputs[0].val, + block_sizes = optimize_utils.infer_block_sizes(op, op_config, weight_to_compress) + lut_params = self.blockwise_compress( + weight_to_compress, op_config.mode, op_config.nbits, - op_config.lut_function + block_sizes, + op_config.lut_function, + num_kmeans_workers=op_config.num_kmeans_workers, ) + if lut_params is None: + logger.warning(f"Cannot perform palettization on {op.name}. Skipped this op.") + return if not self.fake_compression: - new_var = palettize_weights._create_constexpr_var(op, lut_params) + new_var: Optional[Var] = None + + # Specially handle sparse-related compression ops chaining. + if self.joint_compression: + child_op = op.outputs[0].child_ops[0] + if child_op.op_type == "constexpr_sparse_to_dense": + mask, nonzero_data = mb.constexpr_lut_to_sparse( + indices_mask=child_op.mask, + indices_nonzero_data=lut_params.indices[child_op.mask.val != 0].flatten(), + lut=lut_params.lut, + before_op=child_op, + name=op.name + "_palettized", + ) + # Feed the sparse lut's nonzero_data output to the child sparse op. + new_var = nonzero_data + + # For other cases, the new quant var could be constructed directly from lut_params. + if new_var is None: + new_var = self._create_constexpr_var(op, lut_params) else: decompressed_val = self.decompress(lut_params) new_var = mb.const( @@ -658,9 +1054,44 @@ class linear_quantize_weights(AbstractCompressionPass): The transform performs the following: - - Values are linearly quantized into unsigned 8-bits. - - If ``fake_compression=False``, compressed value is encoded using the ``constexpr_affine_dequantize`` op. + - Values are linearly quantized into n-bit. + - If ``fake_compression=False``, compressed value is encoded using the + ``constexpr_affine_dequantize`` op (pre-iOS18) or the ``constexpr_blockwise_shift_scale`` op (iOS18). - If ``fake_compression=True``, compressed value is decompressed and then encoded using the ``const`` op. + + Here is an example for input and output graph of this graph pass: + + .. code-block:: + + Input graph: + + const -> downstream op + + Output graph: + + constexpr_blockwise_shift_scale -> downstream op + + Support Options: + + - ``joint_compression``: + + Enable joint compression by quantizing an already compressed model. + What op could be further quantized is in `_validate_child_constexpr_for_compress`. + + Using palettization + quantization as an example, for each existing ``constexpr_lut_to_dense`` + op, it tries to quantize the elements in the lut, which could be represented as: + + .. code-block:: + + Input graph: + + lut(fp16) -> constexpr_lut_to_dense -> dense(fp16) -> downstream op + + Output graph: + + lut(int8) -> constexpr_blockwise_shift_scale -> lut(fp16) -> constexpr_lut_to_dense -> dense(fp16) -> downstream op + + For details about different quantization schemas, see `OpLinearQuantizerConfig` for more details. """ _SUPPORTED_CONFIG_TYPE = OpLinearQuantizerConfig _MODE_DTYPE_TO_RANGE = { @@ -670,60 +1101,51 @@ class linear_quantize_weights(AbstractCompressionPass): (types.uint8, "LINEAR_SYMMETRIC"): (0, 254), } - @classmethod - @_deprecated( - suffix="Please use _utils.quantize_weight", - version="8.0", - obj_prefix="coremltools.optimize.coreml._quantization_passes.", - ) - def _get_quantized_data( - cls, original_data: np.ndarray, axes: Tuple[int, ...], mode: str, dtype: type - ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: - """[Deprecated] Get quantized data along with metadata (scale, zero_point).""" - if not np.issubdtype(original_data.dtype, np.floating): - raise ValueError("Only floating numpy arrays are supported.") - - val_min = np.amin(original_data, axis=axes, keepdims=True) - val_max = np.amax(original_data, axis=axes, keepdims=True) - - if mode == "LINEAR_SYMMETRIC": - # For the linear_symmetric mode, the range is symmetrical to 0 - max_abs = np.maximum(np.abs(val_min), np.abs(val_max)) - val_min = -max_abs - val_max = max_abs - else: - assert mode == "LINEAR" - # For the linear mode, we need to make sure the data range contains `0` - val_min = np.minimum(0.0, val_min) - val_max = np.maximum(0.0, val_max) - - q_val_min, q_val_max = cls._MODE_DTYPE_TO_RANGE[(dtype, mode)] - np_dtype = nptype_from_builtin(dtype) - zero_point = None - if mode == "LINEAR_SYMMETRIC": - if dtype.is_unsigned(): - zero_point_shift = q_val_max // 2 - zero_point = zero_point_shift * np.ones(val_min.shape) - else: - assert mode == "LINEAR" - zero_point = (q_val_min * val_max - q_val_max * val_min) / (val_max - val_min) - zero_point = np.round(zero_point) - zero_point = np.clip(zero_point, q_val_min, q_val_max) - - scale = (val_max - val_min) / (q_val_max - q_val_min) - quantized_data = np.round(original_data / scale) - if zero_point is not None: - quantized_data += zero_point - zero_point = zero_point.squeeze().astype(np_dtype) - quantized_data = np.clip(quantized_data, q_val_min, q_val_max).astype(np_dtype) - scale = scale.astype(original_data.dtype).squeeze() + def _validate_child_constexpr_for_compress(self, op: Operation) -> bool: + """ + Overrides external method to support joint compression for iOS18+. - return quantized_data, scale, zero_point + In iOS18 joint compression, the palettized/sparsified data could be further quantized. + For each specific op, we only quantize the specific input: + - constexpr_lut_to_dense's lut + - constexpr_lut_to_sparse's lut + - constexpr_sparse_to_dense's nonzero_data + """ + if ( + is_current_opset_version_compatible_with(AvailableTarget.iOS18) + and self.joint_compression + ): + if len(op.outputs[0].child_ops) == 1: + child_op = op.outputs[0].child_ops[0] + if child_op.op_type == "constexpr_lut_to_dense" and child_op.lut == op.outputs[0]: + return True + elif ( + child_op.op_type == "constexpr_lut_to_sparse" and child_op.lut == op.outputs[0] + ): + return True + elif ( + child_op.op_type == "constexpr_sparse_to_dense" + and child_op.nonzero_data == op.outputs[0] + ): + return True + + return super()._validate_child_constexpr_for_compress(op) @classmethod + @deprecated( + suffix="Please use coremltools.optimize.coreml.linear_quantize_weights.blockwise_compress", + version="8.2", + obj_prefix="coremltools.optimize.coreml.linear_quantize_weights.", + ) def compress( cls, val: np.ndarray, axis: int, mode: str, dtype: type - ) -> _utils.AffineQuantParams: + ) -> optimize_utils.QuantParamsIos16: + """ + [Legacy] Per-channel quantization on axis. + + This API is for backward compatibility only. It's no longer used inside the coremltools. + It's recommended to use `blockwise_compress` instead, which is more general. + """ if not isinstance(val, (np.ndarray, np.generic)): raise ValueError("Only numpy arrays are supported") if isinstance(dtype, np.dtype): @@ -731,56 +1153,164 @@ def compress( if not types.is_builtin(dtype): raise ValueError(f"The input dtype is should be a built-in type, but got {type(dtype)}") - axes = tuple([i for i in range(len(val.shape)) if i != axis]) - quantized_data, scale, zero_point = _utils.quantize_weight( + block_sizes = [0] * len(val.shape) + block_sizes[axis] = 1 + quant_params = cls.blockwise_compress( val, - axes, nbits=dtype.get_bitwidth(), + mode=mode, signed=not dtype.is_unsigned(), - quantization_mode=mode, - dtype=types.nptype_from_builtin(dtype), + block_sizes=block_sizes, ) + if quant_params is None: + raise ValueError("Failed to quantize.") - if zero_point is None: - # The iOS16 constexpr_affine_dequantize op requires zero_point. - zero_point = np.zeros_like(scale).astype(quantized_data.dtype) - return _utils.AffineQuantParams(quantized_data, zero_point, scale, axis) + return optimize_utils.ios18_quant_params_to_ios16(quant_params) + + @classmethod + def blockwise_compress( + cls, + original_data: np.ndarray, + nbits: int, + mode: str, + signed: bool, + block_sizes: List[int], + ) -> Optional[optimize_utils.QuantParams]: + """ + Compress original_data into n-bit representation by quantization. + + block_sizes: Each element is the block size on corresponding axis for original_data. + + Returns None if the weight cannot be compressed (for example, the dim size on an axis is not + divisible by the corresponding block_size). + """ + if not isinstance(original_data, np.ndarray): + raise ValueError("Only numpy arrays are supported") + + result = optimize_utils.compute_qparams( + original_data, + nbits, + signed, + mode, + types.nptype_from_builtin(types.get_nbits_int_builtin_type(nbits, signed)), + block_sizes, + ) + + if result is None: + return None + + quantized_data, scale, zero_point = result + return optimize_utils.QuantParams( + data=quantized_data, scale=scale, offset=zero_point, nbits=np.uint8(nbits) + ) @staticmethod - def decompress(params: _utils.AffineQuantParams) -> np.ndarray: - if not isinstance(params, _utils.AffineQuantParams): + def decompress(params: Union[optimize_utils.QuantParamsIos16, optimize_utils.QuantParams]): + if isinstance(params, optimize_utils.QuantParamsIos16): + return constexpr_affine_dequantize.decompress( + params.quantized_data, params.zero_point, params.scale, params.axis + ) + elif isinstance(params, optimize_utils.QuantParams): + return constexpr_blockwise_shift_scale.decompress( + params.data, + params.scale, + params.offset, + ) + else: raise ValueError("Invalid type of params") - return constexpr_affine_dequantize.decompress( - params.quantized_data, params.zero_point, params.scale, params.axis + + @staticmethod + def _create_constexpr_var(op: Operation, quant_params: optimize_utils.QuantParams) -> Var: + """Create constexpr quant op based on opset version.""" + if not is_current_opset_version_compatible_with(AvailableTarget.iOS18): + quant_params_ios16 = optimize_utils.ios18_quant_params_to_ios16(quant_params) + return mb.constexpr_affine_dequantize( + quantized_data=quant_params_ios16.quantized_data, + zero_point=quant_params_ios16.zero_point, + scale=quant_params_ios16.scale, + axis=quant_params_ios16.axis, + before_op=op, + name=op.name + "_quantized", + ) + + return mb.constexpr_blockwise_shift_scale( + data=quant_params.data, + scale=quant_params.scale, + offset=quant_params.offset, + before_op=op, + name=op.name + "_quantized", ) def transform_op(self, op: Operation): - op_config = self.config._get_const_op_config(op) + op_config: Optional[OpLinearQuantizerConfig] = self.config._get_const_op_config(op) if op_config is None: return if not self.need_compress_const(op, self.config._is_deprecated, op_config.weight_threshold): return - output_channel = self.select_input_output_channel_axis(op)[1] - quant_params = self.compress( - op.outputs[0].val, output_channel, op_config.mode, op_config.dtype + weight_to_compress = op.outputs[0].val + if self.joint_compression: + child_op = op.outputs[0].child_ops[0] + if child_op.op_type == "constexpr_sparse_to_dense": + # When the child op is sparse_to_dense op, the weight_to_compress is the sparse + # representation, which need to be restored to dense representation for compression. + weight_to_compress = constexpr_sparse_to_dense.decompress( + weight_to_compress, child_op.mask.val + ) + elif child_op.op_type.startswith("constexpr_lut_to_"): + if not op_config.granularity == CompressionGranularity.PER_TENSOR: + raise NotImplementedError( + "When use joint compression for palettization-quantization, please make " + "sure to use per-tensor quantization, because the axis for the data to be" + "quantized (palettization's lut) is different from the original weight." + ) + + block_sizes = optimize_utils.infer_block_sizes(op, op_config, weight_to_compress) + quant_params = self.blockwise_compress( + weight_to_compress, + op_config.nbits, + op_config.mode, + op_config.signed, + block_sizes, ) + if quant_params is None: + logger.warning(f"Cannot perform quantization on {op.name}. Skipped this op.") + return + if not self.fake_compression: - new_var = frontend_utils._construct_constexpr_affine_op( - quant_params.quantized_data, - quant_params.zero_point, - quant_params.scale, - quant_params.axis, - name=op.name + "_affine_quantized", - before_op=op, - ) + new_var: Optional[Var] = None + + # Specially handle sparse-related compression ops chaining. + if self.joint_compression: + child_op = op.outputs[0].child_ops[0] + if child_op.op_type == "constexpr_sparse_to_dense": + mask, nonzero_data = mb.constexpr_sparse_blockwise_shift_scale( + data_mask=child_op.mask, + nonzero_data=quant_params.data[child_op.mask.val != 0].flatten(), + scale=quant_params.scale, + offset=quant_params.offset, + before_op=child_op, + name=op.name + "_quantized", + ) + # Feed the sparse quantization op's nonzero_data output to the child sparse op. + new_var = nonzero_data + + elif child_op.op_type == "constexpr_lut_to_sparse": + # Here we only quantize the lut itself, which is a dense data, so we cannot use + # the sparse version of the quant op; instead we just use the dense version of + # the quant op. Will change if backends don't support it. + pass + + # For other cases, the new quant var could be constructed directly from quant_params. + if new_var is None: + new_var = self._create_constexpr_var(op, quant_params) else: decompressed_val = self.decompress(quant_params) new_var = mb.const( val=decompressed_val, before_op=op, - name=op.name + "_fake_affine_quantized", + name=op.name + "_fake_quantized", ) op.enclosing_block.replace_uses_of_var_after_op( @@ -792,6 +1322,7 @@ def transform_op(self, op: Operation): op.enclosing_block.remove_ops([op]) + @register_pass(namespace="compression") class WeightDecompressor(AbstractQuantizationPass): """ diff --git a/coremltools/optimize/coreml/_utils.py b/coremltools/optimize/coreml/_utils.py index 347ee343e..5a39a7767 100644 --- a/coremltools/optimize/coreml/_utils.py +++ b/coremltools/optimize/coreml/_utils.py @@ -3,14 +3,30 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import math from collections import namedtuple -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np -SparseParams = namedtuple("SparseParams", "nonzero_data mask shape") -LutParams = namedtuple("LutParams", "lut indices shape") -AffineQuantParams = namedtuple("AffineQuantParams", "quantized_data zero_point scale axis") +from coremltools import _getLogger +from coremltools.converters.mil.mil import Operation, types +from coremltools.optimize.coreml import _utils as optimize_utils +from coremltools.optimize.coreml._config import ( + CompressionGranularity, + OpLinearQuantizerConfig, + OpPalettizerConfig, +) + +_logger = _getLogger() + +SparseParamsIos16 = namedtuple("SparseParamsIos16", "nonzero_data mask shape") +LutParamsIos16 = namedtuple("LutParamsIos16", "lut indices shape") +QuantParamsIos16 = namedtuple("QuantParamsIos16", "quantized_data zero_point scale axis") + +SparseParams = namedtuple("SparseParams", "nonzero_data mask") +LutParams = namedtuple("LutParams", "indices lut") +QuantParams = namedtuple("QuantParams", "data scale offset nbits") def get_quant_range(n_bits: int, signed: bool, mode: str) -> Tuple[int, int]: @@ -77,3 +93,401 @@ def quantize_weight( scale = scale.astype(weight.dtype).squeeze() return quantized_data, scale, zero_point + + +def compute_qparams( + weight: np.ndarray, + nbits: int, + signed: bool, + quantization_mode: str, + dtype: np.dtype, + block_sizes: List[int], +) -> Optional[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]]: + """ + Compress the given weight matrix by quantizing the weights. + Provide different configurations of quantization by specifying a ``block_sizes`` which + is a list containing the block size for each dimension of the weight or 0 otherwise. + + Note that per-tensor, per-channel, channelwise-grouped and per-block are + just variants of specifying the block sizes for each dimension. + """ + if len(block_sizes) != len(weight.shape): + raise AssertionError( + "Each axis should have a block size, which means len(block_sizes) must be " + f"equal to weight's rank, but got {len(block_sizes)} vs {len(weight.shape)}" + ) + + new_shape, scale_shape, axes_to_skip = [], [], [] + for axis, (dim_size, block_size) in enumerate(zip(weight.shape, block_sizes)): + if block_size > 0: + if dim_size % block_size != 0: + _logger.warning( + f"Invalid block_sizes; On {axis}th axis, the dim size {dim_size} is " + f"not divisible by block size {block_size}. Unable to perform " + "structured quantization." + ) + return None + + # Skip this axis while computing min & max + axes_to_skip.append(len(new_shape)) + + # channel dim now will be (num_blocks, block_size) + num_blocks = dim_size // block_size + new_shape.extend([num_blocks, block_size]) + scale_shape.append(num_blocks) + else: + new_shape.append(dim_size) + scale_shape.append(1) + + # Axes to reduce while compute min & max values + axes = tuple(filter(lambda x: x not in axes_to_skip, range(len(new_shape)))) + + quantized_data, scale, zero_point = quantize_weight( + weight.reshape(new_shape), axes, nbits, signed, quantization_mode, dtype + ) + + quantized_data = quantized_data.reshape(weight.shape) + scale = scale.reshape(scale_shape) + if zero_point is not None: + zero_point = zero_point.reshape(scale_shape) + + return quantized_data, scale, zero_point + + +def find_indices_for_lut(data: np.ndarray, lut: np.ndarray) -> np.ndarray: + """ + Given a data and a look-up-table (LUT), find the closest indices in LUT that each element in + data correspond to. It's the reverse process of "Given a LUT and indices, produce data using + indices to fetch elements in LUT". + + Note the elements in data may not exactly match the elements in lut due to numerical instability. + So we use fuzzy match to find the closest one instead of doing exact match. + + Parameters + - data: Arbitrary numpy array. + - lut: [block_num1, ..., 2**nbits, vector_size]. LUT's rank is K + 2, where K is the rank of data. + Each dimension of data should be divisible by each corresponding dimension of the LUT. + e.g., when data's shape is [2, 3, 4], the first three elements in lut's shape is [1, 1, 2], + it means that there are two lookup tables over the last axis, and each of them have their + own LUT values. See details in the iOS18 `constexpr_lut_to_dense` op. + """ + if len(lut.shape) != len(data.shape) + 2: + raise ValueError("The lut's rank should be data's rank + 2. See constexpr_lut_to_dense.") + + # TODO (rdar://124474258): Handle vector palettization. + if lut.shape[-1] > 1: + raise NotImplementedError( + "Not support vector palettization. Progress tracked in rdar://124474258." + ) + + # lut has shape [block_num0, block_num1, ..., 2**nbits, vector_size], so need to interleaved + # repeat it to make each block match the weight. + repeated_lut = lut + for axis, block_num in enumerate(lut.shape[:-2]): + weight_dim_size = data.shape[axis] + if weight_dim_size % block_num != 0: + raise ValueError( + "The weight dim size in each axis must be divisible by the number " + f"of luts. Got invalid lut {lut.shape} for weight shape " + f"{data.shape[axis]} at axis {axis}" + ) + block_size = weight_dim_size // block_num + # Can use np.kron for higher efficiency, but repeat is easier to understand. + if block_size > 1: + repeated_lut = np.repeat(repeated_lut, block_size, axis=axis) + + # Find the closest value for each element. + indices = np.argmin( + np.abs(np.expand_dims(data, axis=-1) - np.squeeze(repeated_lut, axis=-1)), axis=-1 + ) + nbits = int(math.log2(lut.shape[-2])) + indices = indices.astype(types.nptype_from_builtin(types.string_to_builtin(f"uint{nbits}"))) + return indices + + +def infer_block_sizes( + op: "Operation", + op_config: Union[OpLinearQuantizerConfig, OpPalettizerConfig], + weight_to_compress: np.ndarray, +) -> List[int]: + """ + Infer block size on each axis based on the op and compression config. + + For per-channel, the channel axis is auto-picked. + For per-block, the input/output axis is auto-picked if block_size is int. + See the docstring of OpLinearQuantizerConfig for more details. + """ + if op_config.granularity == CompressionGranularity.PER_BLOCK and not isinstance( + op_config.block_size, int + ): + if len(op_config.block_size) != len(weight_to_compress.shape): + raise ValueError( + "The block_size in config must has one element for each axis. However, for op " + f"{op.name}, there are {len(op_config.block_size)} elements in block_size, " + f"but there are {len(weight_to_compress.shape)} axes in the weight." + ) + return list(op_config.block_size) + + input_channel_axis, output_channel_axis = optimize_utils.select_input_output_channel_axis(op) + if ( + op_config.granularity == CompressionGranularity.PER_GROUPED_CHANNEL + and op_config.channel_axis is not None + ): + output_channel_axis = op_config.channel_axis + + block_sizes = [0] * len(weight_to_compress.shape) + if op_config.granularity == CompressionGranularity.PER_TENSOR: + input_channel_block_size = 0 + output_channel_block_size = 0 + elif op_config.granularity == CompressionGranularity.PER_CHANNEL: + input_channel_block_size = 0 + output_channel_block_size = 1 + elif op_config.granularity == CompressionGranularity.PER_GROUPED_CHANNEL: + input_channel_block_size = 0 + output_channel_block_size = op_config.group_size + else: + assert op_config.granularity == CompressionGranularity.PER_BLOCK and isinstance( + op_config.block_size, int + ) + input_channel_block_size = op_config.block_size + output_channel_block_size = 1 + + if input_channel_axis < len(block_sizes): + block_sizes[input_channel_axis] = input_channel_block_size + if output_channel_axis < len(block_sizes): + block_sizes[output_channel_axis] = output_channel_block_size + return block_sizes + + +def select_input_output_channel_axis(op: "Operation") -> Tuple[int, int]: + """ + Here are some representative ops: + - linear: [D_out, D_in] + - matmul's y: [..., D_in, D_out] if transpose_y is False, else [..., D_out, D_in] + - conv: [C_out, C_in_div_group, KH, KW] + - conv_transpose: [C_in, C_out_div_group, KH, KW] + + The input output channel axis selection criteria is: + - For conv_transpose the output channel is 1 and input channel is 0. + - For matmul's y: + - When transpose_y=False, output channel is -1 and input channel is -2 + - When transpose_y=True, output channel is -2 and input channel is -1 + - For matmul's x: + - When transpose_x=False, output channel is -2 and input channel is -1 + - When transpose_y=True, output channel is -1 and input channel is -2 + - For all other ops, output channel is 0 and input channel is 1. + """ + output_channel_axis, input_channel_axis = 0, 1 + var = op.outputs[0] + if len(var.child_ops) == 1: + child_op = var.child_ops[0] + if child_op.op_type == "conv_transpose": + output_channel_axis = 1 + input_channel_axis = 0 + if child_op.op_type == "matmul": + if child_op.y == var: + if child_op.transpose_y.val: + output_channel_axis = -2 + input_channel_axis = -1 + else: + output_channel_axis = -1 + input_channel_axis = -2 + else: # var is used as matmul's x. + if child_op.transpose_x.val: + output_channel_axis = -1 + input_channel_axis = -2 + else: + output_channel_axis = -2 + input_channel_axis = -1 + if child_op.op_type.startswith("constexpr_"): + return select_input_output_channel_axis(child_op) + return input_channel_axis, output_channel_axis + + +def ios16_sparse_params_to_ios18(sparse_params: SparseParamsIos16) -> SparseParams: + """ + The iOS18 constexpr_sparse_to_dense no longer accepts `shape` param. Instead, the `mask` param + has shape info. So we need to convert the old bit-packed `mask` to new uint1 `mask`. + """ + if not isinstance(sparse_params, SparseParamsIos16): + raise ValueError("Invalid type of params") + + mask = ( + np.unpackbits(sparse_params.mask, count=np.prod(sparse_params.shape), bitorder="little") + .reshape(sparse_params.shape) + .astype(types.np_uint1_dtype) + ) + + return SparseParams(nonzero_data=sparse_params.nonzero_data, mask=mask) + + +def ios18_sparse_params_to_ios16(sparse_params: SparseParams) -> SparseParamsIos16: + """The iOS16 sparse params pack mask into bytes, and need a `shape` parameter.""" + return SparseParamsIos16( + nonzero_data=sparse_params.nonzero_data, + mask=np.packbits(sparse_params.mask, bitorder="little"), + shape=sparse_params.mask.shape, + ) + + +def ios16_lut_params_to_ios18(lut_params: LutParamsIos16) -> LutParams: + """ + The iOS18 constexpr_lut_to_dense no longer accepts `shape` param. We need to convert the iOS16 + params to the format acceptable by the iOS18 op. + """ + num_palettes = lut_params.lut.shape[0] + nbits = int(math.log2(num_palettes)) + if 2**nbits != num_palettes: + raise AssertionError( + f"Invalid number of palettes in lut_params. It should be 2**nbits, but got {num_palettes}" + ) + # Notice that the indices in iOS16 is packed, so we need to unpack first. + unpacked_indices = restore_elements_from_packed_bits( + lut_params.indices, nbits, np.prod(lut_params.shape) + ) + indices = unpacked_indices.reshape(lut_params.shape).astype( + types.type_mapping.string_to_nptype(f"uint{nbits}") + ) + lut_shape = [1] * len(lut_params.shape) + [num_palettes, 1] + lut = lut_params.lut.reshape(lut_shape) + return LutParams(indices=indices, lut=lut) + + +def ios18_lut_params_to_ios16(lut_params: LutParams) -> LutParamsIos16: + """The iOS16 lut params pack indices into bytes, and need a `shape` parameter.""" + for idx, dim_size in enumerate(lut_params.lut.shape[:-2]): + if dim_size > 1: + raise AssertionError( + "The iOS16 only supports per-tensor lut, but got more than one " + f"lut on {idx}th axis. LUT shape: {lut_params.lut.shape}" + ) + + num_palettes = lut_params.lut.shape[-2] + nbits = int(math.log2(num_palettes)) + return LutParamsIos16( + lut=lut_params.lut.reshape((num_palettes,)), + indices=pack_elements_into_bits(lut_params.indices, nbits), + shape=lut_params.indices.shape, + ) + + +def ios18_quant_params_to_ios16(quant_params: QuantParams) -> QuantParamsIos16: + """ + Transform iOS18 quant params to iOS16 version. + + The iOS16 constexpr_affine_dequantize op requires axis, and it requires scale and zero_point to + have rank 0 or 1. + """ + # Infer the axis based on scale's shape. + non_single_dim = [dim for dim, dim_size in enumerate(quant_params.scale.shape) if dim_size > 1] + if len(non_single_dim) > 2: + raise AssertionError( + "The constexpr_affine_dequantize op doesn't support scale which " + "have more than one non-single dimensions. Got scale with shape " + f"{quant_params.scale.shape}" + ) + # If non_single_dim is empty, it means it's per-tensor quantization, just use a dummy axis. + axis = 0 if len(non_single_dim) == 0 else non_single_dim[0] + + scale = quant_params.scale + zero_point = quant_params.offset + if zero_point is None: + # The constexpr_affine_dequantize op requires zero_point. + zero_point = np.zeros_like(scale).astype(quant_params.data.dtype) + + # The constexpr_affine_dequantize op requires scale and zero_point to have rank 0 or 1. + if isinstance(scale, (np.ndarray, np.generic)): + scale = np.squeeze(scale) + if isinstance(zero_point, (np.ndarray, np.generic)): + zero_point = np.squeeze(zero_point) + + return QuantParamsIos16( + quantized_data=quant_params.data, zero_point=zero_point, scale=scale, axis=np.int32(axis) + ) + + +def pack_elements_into_bits(elements: np.ndarray, nbits: int) -> np.ndarray: + """ + Pack elements into nbits representation, by starting with the least significant bit (LSB) and + moving upward to the most significant bit (MSB). + + Returns packed elements as np.uint8. + """ + if not np.issubdtype(elements.dtype, np.integer): + raise ValueError(f"Only support packing integers elements, but got {elements.dtype}") + + # Adjust allowed value range based on if the input is signed or unsigned. + if np.issubdtype(elements.dtype, np.signedinteger): + max_val = 2 ** (nbits - 1) - 1 + min_val = -max_val - 1 + else: + max_val = 2**nbits - 1 + min_val = 0 + if np.max(elements) > max_val: + raise ValueError( + f"To pack elements into {nbits}-bit, the max value is {max_val}, but got {np.max(elements)}" + ) + if np.min(elements) < min_val: + raise ValueError( + f"To pack elements into {nbits}-bit, the min value is {min_val}, but got {np.min(elements)}" + ) + + # As np.unpackbits only supports uint8, convert to uint8 first. + # Notice that it will not lose information, because the bits are unchanged when converting int8 + # to uint8. For example, the signed int -6 has bit representation '11111010', and when we unpackbits + # we get [0, 1, 0, 1, 1, 1, 1, 1], where only first 4 elements are needed for 4-bit representation. + elements = elements.astype(np.uint8) + bitarray = np.unpackbits(elements.reshape(-1, 1), bitorder="little", axis=-1)[:, :nbits] + return np.packbits(bitarray.flatten(), bitorder="little") + + +def restore_elements_from_packed_bits( + packed_values: np.ndarray, nbits: int, element_num: int, are_packed_values_signed: bool = False +) -> np.ndarray: + """ + Restore elements from packed bits. Requires values that are packed by starting with the + least significant bit (LSB) and moving upward to the most significant bit (MSB), which is the + method used in `pack_elements_into_bits`. + + are_packed_values_signed: Indicates if the packed_values were packed from signed integers. If + True, the n-bit number unpacked from packed_values will be interpreted as signed integers, + and the returned ndarray will have dtype np.int8. Otherwise, np.uint8 will be used. + """ + if len(packed_values.shape) != 1: + raise NotImplementedError( + f"Only support 1-rank packed_values. But got {len(packed_values.shape)}" + ) + + if packed_values.dtype == np.int8: + # As np.unpackbits only supports uint8, need to convert first. + packed_values = packed_values.astype(np.uint8) + elif packed_values.dtype != np.uint8: + raise NotImplementedError( + f"Only support int8 or uint8 packed_values, but got {packed_values.dtype}" + ) + + bitarray = np.unpackbits(packed_values, bitorder="little") + pad_required = bitarray.size % nbits != 0 + if pad_required: + bitarray = np.concatenate([bitarray, np.zeros(nbits - bitarray.size % nbits)]).astype( + bitarray.dtype + ) + if bitarray.size % nbits != 0: + raise ValueError( + f"The length of bitarray ({bitarray.size}) should be divisible by " + f"nbits ({nbits})." + ) + bitarray = bitarray.reshape(-1, nbits)[:element_num, :] + # The np.packbits doesn't work well for signed int if we feed `bitarray` to it directly. + # For example, the original signed int is -6, which is packed as 1010 for 4-bit representation, + # and here `bitarray` is [[0, 1, 0, 1]], where the value will be interpreted as 10 (b'1010') + # by np.packbits. + # To make np.packbits work correctly, we need to repeat the sign bit. For example, 1010 will + # become 11111010, where np.packbits can correctly handle and after converting to int8 it's -6. + if are_packed_values_signed: + # Repeat the sign bit to make uint8 to int8 works. + bitarray = np.repeat(bitarray, [1] * (nbits - 1) + [8 - nbits + 1], axis=1) + restored_elements = np.packbits(bitarray, bitorder="little", axis=-1).reshape(-1) + if are_packed_values_signed: + restored_elements = restored_elements.astype(np.int8) + return restored_elements diff --git a/coremltools/optimize/coreml/experimental/__init__.py b/coremltools/optimize/coreml/experimental/__init__.py new file mode 100644 index 000000000..f8496d4dd --- /dev/null +++ b/coremltools/optimize/coreml/experimental/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from ._config import OpActivationLinearQuantizerConfig +from ._post_training_quantization import linear_quantize_activations diff --git a/coremltools/optimize/coreml/experimental/_config.py b/coremltools/optimize/coreml/experimental/_config.py new file mode 100644 index 000000000..0a3defd67 --- /dev/null +++ b/coremltools/optimize/coreml/experimental/_config.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from __future__ import annotations + +from typing import Any, Dict, Optional, Union + +import cattrs +import numpy as np +from attrs import define, field, validators + +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.types.type_mapping import is_builtin, numpy_type_to_builtin_type + +from .._config import OpCompressorConfig, _check_weight_threshold, _normalize_dtype + +""" +Activation Linear Quantization configuration +""" + +# TODO: This should be refactored to reuse OpLinearQuantizerConfig (rdar://129257210). +@define +class OpActivationLinearQuantizerConfig(OpCompressorConfig): + """ + Parameters + ---------- + mode: str + Mode for linear quantization: + + * ``"linear_symmetric"`` (default): Input data are quantized in the range + ``[-R, R]``, where :math:`R = max(abs(w_r))`. + + dtype: str or np.generic or mil.type + Determines the quantized data type. + + * The allowed values are: + * ``np.int8`` (the default) + * ``coremltools.converters.mil.mil.types.int8`` + + weight_threshold: int + If the operation has weight, above which activation are compressed. + + Set the same ``weight_threshold`` for activation as for weight linear quantization can guarantee + valid operations get both weight and activation quantization to improve efficiency. + * If not provided, it will be set to ``2048``, in which operations with weights bigger than ``2048`` + elements are compressed. + """ + + # TODO: enable more modes/dtypes (rdar://129257210). + mode: str = field(default="linear_symmetric", validator=validators.instance_of(str)) + dtype: Union[str, type] = field(default=types.int8, converter=_normalize_dtype) + + # Set the same ``weight_threshold`` for activation linear quantization as for weight linear quantization can guarantee + # valid operations get both the weight (if weight exists) and activation linear quantized to improve efficiency. + weight_threshold: Optional[int] = field( + default=2048, + validator=validators.optional([validators.instance_of(int), _check_weight_threshold]), + ) + + _ACTIVATION_AFFINE_QUANTIZATION_MODES = ("LINEAR_SYMMETRIC",) + + @mode.validator + def check_mode(self, attr, mode): + if not mode.upper() in self._ACTIVATION_AFFINE_QUANTIZATION_MODES: + raise ValueError( + f'Only mode {self._ACTIVATION_AFFINE_QUANTIZATION_MODES} supported for activation affine quantization. Got mode: "{mode}".' + ) + + @dtype.validator + def check_dtype(self, attr, dtype): + if not types.is_builtin(dtype): + raise ValueError(f"Invalid dtype. Should be builtin dtype, but got {type(dtype)}") + if not (types.is_int(dtype) and dtype.get_bitwidth() in {8} and not dtype.is_unsigned()): + raise ValueError( + f"Invalid dtype. Should be int8, but got {types.builtin_to_string(dtype)}" + ) + + def __attrs_post_init__(self): + self.mode = self.mode.upper() + if not is_builtin(self.dtype): + self.dtype = numpy_type_to_builtin_type(self.dtype) + + @classmethod + def _from_dict(cls, config_dict: Dict[str, Any]) -> "OpActivationLinearQuantizerConfig": + def _structure_type(value, dtype): + if isinstance(value, type): + return value + else: + if not isinstance(value, str) or value not in ("int8",): + raise ValueError(f'"dtype" must be type of type or str ["int8"]. Got {value}') + return getattr(np, value) + + converter = cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook(type, _structure_type) + return converter.structure(config_dict, cls) diff --git a/coremltools/optimize/coreml/experimental/_model_debugger.py b/coremltools/optimize/coreml/experimental/_model_debugger.py new file mode 100644 index 000000000..80df50e4e --- /dev/null +++ b/coremltools/optimize/coreml/experimental/_model_debugger.py @@ -0,0 +1,332 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import numpy as np + +import coremltools as ct + + +class OperationInfo: + def __init__(self, spec): + self.dependants = [] + self.dependencies = [] + outputs = dict([(output.name, output) for output in spec.outputs]) + self.outputs = outputs + self.spec = spec + + +class BlockInfo: + def __init__(self, name, operations, spec): + self.name = name + self.operations = operations + self.spec = spec + + +class FunctionInfo: + def __init__(self, name, blocks, spec): + self.name = name + self.blocks = blocks + self.spec = spec + + +class ProgramInfo: + def __init__(self, functions, spec): + self.functions = functions + self.spec = spec + + +class ModelInfo: + def __init__(self, program_info, spec): + self.program_info = program_info + self.spec = spec + + +class ModelDebugger: + @classmethod + def batch(cls, iterable, n=1): + l = len(iterable) + for index in range(0, l, n): + yield iterable[index : min(index + n, l)] + + @classmethod + def unique(cls, sequence): + seen = set() + return [x for x in sequence if not (x in seen or seen.add(x))] + + @classmethod + def split_list(cls, list): + half = len(list) // 2 + return list[:half], list[half:] + + @classmethod + def get_block_info(cls, block_name, block_spec): + operations = {} + for operation_spec in block_spec.operations: + operation = OperationInfo(operation_spec) + dependencies = [] + + for input_name in operation_spec.inputs: + arguments = operation_spec.inputs[input_name].arguments + input_dependencies = [ + operations.get(argument.name, None) + for argument in arguments + if argument.name is not None + ] + input_dependencies = [ + input_dependency + for input_dependency in input_dependencies + if input_dependency is not None + ] + dependencies.extend(input_dependencies) + + dependencies = cls.unique(dependencies) + for dependency in dependencies: + dependency.dependants.append(operation) + operation.dependencies = dependencies + + output_names = [output.name for output in operation_spec.outputs] + for output_name in output_names: + operations[output_name] = operation + + return BlockInfo(block_name, operations, block_spec) + + @classmethod + def get_function_info(cls, function_name, function_spec): + blocks = {} + for block_name, block_spec in function_spec.block_specializations.items(): + blocks[block_name] = cls.get_block_info(block_name, block_spec) + + return FunctionInfo(function_name, blocks, function_spec) + + @classmethod + def get_program_info(cls, program_spec): + functions = {} + for function_name, function_spec in program_spec.functions.items(): + functions[function_name] = cls.get_function_info(function_name, function_spec) + + return ProgramInfo(functions, program_spec) + + @classmethod + def get_model_info(cls, model): + model_spec = model.get_spec() + return ModelInfo(cls.get_program_info(model_spec.mlProgram), model_spec) + + @classmethod + def populate_outputs(cls, output_names, all_operations, acc): + if len(output_names) == 0: + return + next_output_names = [] + operations = [all_operations.get(output_name, None) for output_name in output_names] + operations = [operation for operation in operations if operation is not None] + acc.extend([output for operation in operations for output in operation.outputs.values()]) + prev_output_names = [ + output_name + for operation in operations + for dependency in operation.dependencies + for output_name in dependency.outputs.keys() + ] + prev_output_names = cls.unique(prev_output_names) + cls.populate_outputs(prev_output_names, all_operations, acc) + + @classmethod + def get_all_outputs(cls, block_info): + acc = [] + output_names = block_info.spec.outputs + cls.populate_outputs(output_names, block_info.operations, acc) + return acc + + @classmethod + def get_any_function(cls, model_info): + program_info = model_info.program_info + function_name = list(program_info.functions.keys())[0] + return program_info.functions[function_name] + + @classmethod + def get_any_block(cls, model_info): + function_info = cls.get_any_function(model_info) + block_specialization_name = list(function_info.blocks.keys())[0] + return function_info.blocks[block_specialization_name] + + @classmethod + def clone_spec(cls, spec): + spec_class = spec.__class__ + new_spec = spec_class() + new_spec.CopyFrom(spec) + return new_spec + + @classmethod + def get_output_feature_type(cls, output_name, operations): + operation = operations[output_name] + data_type = operation.outputs[output_name].type.tensorType.dataType + data_type_to_feature_type = { + ct.proto.MIL_pb2.DataType.FLOAT16: ct.proto.FeatureTypes_pb2.ArrayFeatureType.FLOAT16, + ct.proto.MIL_pb2.DataType.FLOAT64: ct.proto.FeatureTypes_pb2.ArrayFeatureType.DOUBLE, + ct.proto.MIL_pb2.DataType.FLOAT32: ct.proto.FeatureTypes_pb2.ArrayFeatureType.FLOAT32, + ct.proto.MIL_pb2.DataType.INT32: ct.proto.FeatureTypes_pb2.ArrayFeatureType.INT32, + } + return data_type_to_feature_type[data_type] + + def __init__(self, model): + self.weights_dir = model.weights_dir + self.model_info = self.__class__.get_model_info(model) + self.block_info = self.__class__.get_any_block(self.model_info) + + model_outputs = [output for output in self.model_info.spec.description.output] + output_names = set([output.name for output in model_outputs]) + all_outputs = self.__class__.get_all_outputs(self.block_info) + intermediate_outputs = [output for output in all_outputs if output.name not in output_names] + + self.__model_outputs = model_outputs + self.__all_outputs = all_outputs + self.__intermediate_outputs = intermediate_outputs + self.__intermediate_output_names = self.__class__.unique( + [output_spec.name for output_spec in intermediate_outputs] + ) + self.__cached_models = {} + + @property + def output_names(self): + return self.__class__.unique([output.name for output in self.outputs]) + + def get_intermediate_output_names( + self, op_include_fn=(lambda op: not (op.spec.type == "const")) + ): + all_operations = self.block_info.operations + intermediate_output_names = list( + filter( + lambda name: op_include_fn(all_operations[name]), self.__intermediate_output_names + ) + ) + intermediate_output_names.reverse() + + return self.__class__.unique(intermediate_output_names) + + def get_model_with_intermediate_outputs( + self, intermediate_output_names, compute_units=ct.ComputeUnit.ALL + ): + model_key = frozenset(intermediate_output_names) + model = self.__cached_models.get(model_key) + if model is not None: + # Found cached model. + return model + + cloned_spec = self.__class__.clone_spec(self.model_info.spec) + cloned_model_info = ModelInfo( + ModelDebugger.get_program_info(cloned_spec.mlProgram), cloned_spec + ) + cloned_spec.specificationVersion = max(self.model_info.spec.specificationVersion, 7) + cloned_block_info = self.__class__.get_any_block(cloned_model_info) + for output_name in intermediate_output_names: + cloned_block_info.spec.outputs.append(output_name) + cloned_output = ct.proto.Model_pb2.FeatureDescription() + cloned_output.name = output_name + cloned_output.type.multiArrayType.dataType = self.__class__.get_output_feature_type( + output_name, self.block_info.operations + ) + cloned_model_info.spec.description.output.append(cloned_output) + + model = ct.models.MLModel( + cloned_spec, weights_dir=self.weights_dir, compute_units=compute_units + ) + + self.__cached_models[model_key] = model + + return model + + def get_models_with_intermediate_outputs_safely( + self, intermediate_output_names, compute_units=ct.ComputeUnit.ALL + ): + if len(intermediate_output_names) == 0: + return [] + + models = [] + output_names = [intermediate_output_names] + while len(output_names) > 0: + curr_output_names = output_names[0] + del output_names[0] + model = None + try: + # This could fail compilation + model = self.get_model_with_intermediate_outputs(curr_output_names, compute_units) + except ValueError as ex: + print( + f"Failed to create model with intermediate outputs={intermediate_output_names}, error={ex}" + ) + if len(curr_output_names) > 1: + print("Retrying") + # split in two and then retry + xs = self.__class__.split_list(curr_output_names) + output_names.insert(0, xs[1]) + output_names.insert(0, xs[0]) + + if model is not None: + models.append(model) + + return models + + # Clears all cached models + def clear_cached_models(self): + self.__cached_models.clear() + + # The function will get called for each intermediate output, return `False` if you want to stop the enumeration otherwise `True`. + def check_intermediate_output(output_value, output_name, operation, activation_stats_dict): + tensor_min = np.min(output_value.flatten()) + tensor_max = np.max(output_value.flatten()) + activation_stats_dict[output_name]["rmin"] = tensor_min + activation_stats_dict[output_name]["rmax"] = tensor_max + if output_name in activation_stats_dict: + activation_stats_dict[output_name]["rmin"] = min( + tensor_min, activation_stats_dict[output_name]["rmin"] + ) + activation_stats_dict[output_name]["rmax"] = max( + tensor_max, activation_stats_dict[output_name]["rmax"] + ) + else: + activation_stats_dict[output_name]["rmin"] = tensor_min + activation_stats_dict[output_name]["rmax"] = tensor_max + return True + + def step( + self, + step_fn, + inputs, + activation_stats_dict, + intermediate_output_names=None, + compute_units=ct.ComputeUnit.CPU_ONLY, + batch_size=500, + ): + if intermediate_output_names is None: + intermediate_output_names = self.get_intermediate_output_names() + + model_output_names = [output.name for output in self.__model_outputs] + model_outputs = None + + batch_size = len(intermediate_output_names) + for output_names in self.__class__.batch(intermediate_output_names, batch_size): + models = self.get_models_with_intermediate_outputs_safely(output_names, compute_units) + for model in models: + outputs = model.predict(inputs) + # cache model outputs + if model_outputs is None: + model_outputs = { + key: value for key, value in outputs.items() if key in model_output_names + } + # remove model outputs + outputs = { + key: value for key, value in outputs.items() if key not in model_output_names + } + output_names = list(outputs.keys()) + for output_name in output_names: + output_value = outputs[output_name] + del outputs[output_name] + operation = self.block_info.operations.get(output_name, None) + if not step_fn(output_value, output_name, operation, activation_stats_dict): + return + outputs = {} + + for (output_name, output_value) in model_outputs.items(): + operation = self.block_info.operations.get(output_name, None) + if not step_fn(output_value, output_name, operation, activation_stats_dict): + return diff --git a/coremltools/optimize/coreml/experimental/_post_training_quantization.py b/coremltools/optimize/coreml/experimental/_post_training_quantization.py new file mode 100644 index 000000000..6b254b9e1 --- /dev/null +++ b/coremltools/optimize/coreml/experimental/_post_training_quantization.py @@ -0,0 +1,237 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from collections import defaultdict +from typing import List + +import numpy as np + +from coremltools import _SPECIFICATION_VERSION_IOS_17 +from coremltools import _logger as logger +from coremltools.converters.mil.converter import mil_convert as _mil_convert +from coremltools.converters.mil.frontend.milproto import load as _milproto_to_pymil +from coremltools.converters.mil.mil.passes.graph_pass import PassOption +from coremltools.converters.mil.mil.passes.pass_registry import PASS_REGISTRY +from coremltools.models import MLModel as _MLModel +from coremltools.models import utils as _model_utils +from coremltools.optimize.coreml import OptimizationConfig as _OptimizationConfig + +from ._model_debugger import ModelDebugger +from ._quantization_passes import ( + insert_prefix_quantize_dequantize_pair as _insert_prefix_quantize_dequantize_pair, +) + + +def linear_quantize_activations(mlmodel: _MLModel, config: _OptimizationConfig, sample_data: List): + """ + Utility function to convert a float precision MLModel of type ``mlprogram``, which uses + float-precision activations, into a compressed MLModel that uses n-bit activations (currently only + support n=8). + + This is achieved by calibrating the float activation values that observed by feeding real sample data into + the model, converting calibrated statistics into the ``quantize`` and ``dequantize`` op pairs, and inserted + into where activation get quantized. + + It's recommended to use with linear_quantize_weights for 8-bit activation and 8-bit weight linear quantization. + It's also compatible to use with other weight compression methods. + + Parameters + ---------- + mlmodel: MLModel + Model to be quantized. This MLModel should be of type ``mlprogram``. + + config: OptimizationConfig + An :py:class:`OptimizationConfig` object that specifies the parameters for activation quantization. + + sample_data: List + Data used to characterize statistics of the activation values of the original float precision model. + Expecting a list of sample input dictionaries. + + Returns + ------- + model: MLModel + The activation quantized MLModel instance. + + Examples + -------- + .. sourcecode:: python + + import coremltools as ct + import coremltools.optimize as cto + + model = ct.coreml.models.MLModel("my_model.mlpackage") + activation_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.experimental.OpActivationLinearQuantizerConfig( + mode="linear_symmetric" + ) + ) + compressed_model_a8 = cto.coreml.experimental.linear_quantize_activations( + model, activation_config, sample_data + ) + + # (Optional) It's recommended to use with linear_quantize_weights. + weight_config = cto.coreml.OptimizationConfig( + global_config=cto.OpLinearQuantizerConfig(mode="linear_symmetric") + ) + compressed_model_w8a8 = cto.linear_quantize_weights(compressed_model_a8, weight_config) + """ + + ### Apply four major graph passes in order. + + # Graph pass I + # Insert prefix quantize/dequantize pairs to valid patterns. + logger.info("Running compression pass linear_quantize_activations phase 1/4 ...") + linear_activation_quantizer = PASS_REGISTRY[ + "compression::insert_prefix_quantize_dequantize_pair" + ] + linear_activation_quantizer = _insert_prefix_quantize_dequantize_pair( + config, fake_compression=False + ) + linear_activation_quantizer.set_options([PassOption("config", config)]) + + prog = _model_utils._apply_graph_pass( + mlmodel, + linear_activation_quantizer, + spec_version=_SPECIFICATION_VERSION_IOS_17, + pymil_load_func=_milproto_to_pymil.load, + skip_model_load=True, # Save memony + return_pymil_prog=True, + ) + + # Graph pass II + # Insert suffix quantize/dequantize pairs to valid patterns. + logger.info("Running compression pass linear_quantize_activations phase 2/4 ...") + graph_pass = PASS_REGISTRY["compression::insert_suffix_quantize_dequantize_pair"] + graph_pass.set_options([PassOption("config", config)]) + graph_pass(prog) + prog.validate() + + # Graph pass III + # Re-use exsiting path to dedup quantize/dequantize operations. + logger.info("Running compression pass linear_quantize_activations phase 3/4 ...") + graph_pass = PASS_REGISTRY["common::dequantize_quantize_pair_elimination"] + graph_pass(prog) + prog.validate() + + # Graph pass IV + # Updating scale/zero_point in all quantize/dequantize ops calculated by calibration data. + logger.info("Running compression pass linear_quantize_activations phase 4/4 ...") + activation_stats = _get_activation_calibration_stats(mlmodel, sample_data) + graph_pass = PASS_REGISTRY["compression::update_quantize_dequantize"] + graph_pass.set_options([PassOption("activation_stats", activation_stats)]) + graph_pass(prog) + prog.validate() + + # Convert the pymil program (prog) back to mlmodel + model_spec = mlmodel.get_spec() + specification_version = max(model_spec.specificationVersion, _SPECIFICATION_VERSION_IOS_17) + mlmodel_activation_quantized = _mil_convert( + prog, + convert_to="mlprogram", + convert_from="milinternal", + specification_version=specification_version, + compute_units=mlmodel.compute_unit, + model_description=model_spec.description, + skip_model_load=False, # Must be False to avoid manually re-load from disk before running prediction. + ) + return mlmodel_activation_quantized + + +def _get_tensor_range(tensor_name, tensor_value, activation_stats_dict): + tensor_min = np.min(np.array(tensor_value).flatten()) + tensor_max = np.max(np.array(tensor_value).flatten()) + activation_stats_dict[tensor_name]["rmin"] = tensor_min + activation_stats_dict[tensor_name]["rmax"] = tensor_max + if tensor_name in activation_stats_dict: + activation_stats_dict[tensor_name]["rmin"] = min( + tensor_min, activation_stats_dict[tensor_name]["rmin"] + ) + activation_stats_dict[tensor_name]["rmax"] = max( + tensor_max, activation_stats_dict[tensor_name]["rmax"] + ) + else: + activation_stats_dict[tensor_name]["rmin"] = tensor_min + activation_stats_dict[tensor_name]["rmax"] = tensor_max + + +def _get_activation_calibration_stats(fpmodel: _MLModel, sample_data: List): + """ + Calibration and store a dict of intermediate tensor stats. + E.g. activation_stats_dict = {tensor_0: {rmin: 0.2, rmax: 3.8}, tensor_1: {rmin: 4.5, rmax: 12.6}}} + Parameters + ---------- + fpmodel: MLModel + Path to fp16/fp32 "model.mlpackage". (Expect the orginal mlmodel, not the one with quantize and dequant op) + sample_data: list[dict] + Data for calibration. + + Returns + ------- + activation_calibration_stats: dict + """ + + logger.warning( + "Running compression pass linear_quantize_activations: start calibrating {} samples".format( + len(sample_data) + ) + ) + logger.warning( + "Running compression pass linear_quantize_activations: calibration may take a while ..." + ) + + analyzed = 0 + tried = 0 + debugger = ModelDebugger(fpmodel) + activation_stats_dict = defaultdict(dict) + intermediate_output_names = debugger.get_intermediate_output_names( + lambda op: (op.spec.type != "const") + ) + + # Get data ranges for all inputs. + for data in sample_data: + for input_name in data: + _get_tensor_range(input_name, data[input_name], activation_stats_dict) + + # The last few elements in intermediate_output_names might be output. + # We don't maintain min/max value for an output tensor. + # If it's an output tensor we exclude it, otherwise include it. + model_spec = fpmodel.get_spec() + output_count = len(fpmodel.get_spec().description.output) + output_names = [] + for i in range(0, output_count): + output_name = model_spec.description.output[i].name + output_names.append(output_name) + + for intermediate_output_name in intermediate_output_names: + if intermediate_output_name in output_names: + intermediate_output_names.remove(intermediate_output_name) + + # Get data ranges for all intermeditate outputs. + for data in sample_data: + tried += 1 + try: + debugger.step( + step_fn=ModelDebugger.check_intermediate_output, + inputs=data, + activation_stats_dict=activation_stats_dict, + intermediate_output_names=intermediate_output_names, + ) + analyzed += 1 + logger.warning( + "Running compression pass linear_quantize_activations: calibrating sample {}/{} succeeds.".format( + tried, len(sample_data) + ) + ) + + except Exception as e: + logger.error(e) + logger.error( + "Running compression pass linear_quantize_activations: calibrating sample {}/{} fails.".format( + tried, len(sample_data) + ) + ) + continue + + return activation_stats_dict diff --git a/coremltools/optimize/coreml/experimental/_quantization_passes.py b/coremltools/optimize/coreml/experimental/_quantization_passes.py new file mode 100644 index 000000000..67021f2aa --- /dev/null +++ b/coremltools/optimize/coreml/experimental/_quantization_passes.py @@ -0,0 +1,251 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + + +import numpy as np +from tqdm import tqdm + +from coremltools import _logger as logger +from coremltools.converters.mil._deployment_compatibility import AvailableTarget +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import Operation, Program, types +from coremltools.converters.mil.mil.block import is_current_opset_version_compatible_with +from coremltools.converters.mil.mil.passes.defs.quantization import AbstractQuantizationPass +from coremltools.converters.mil.mil.passes.helper import block_context_manager +from coremltools.converters.mil.mil.passes.pass_registry import register_pass +from coremltools.optimize.coreml._config import OptimizationConfig +from coremltools.optimize.coreml.experimental._config import OpActivationLinearQuantizerConfig + +""" +----------------------------------- +Activation compression graph pass - +----------------------------------- +""" + + +class AbstractActCompressionPass(AbstractQuantizationPass): + """ + The abstract class for the activation compression graph passes. + """ + + _MINIMUM_OPSET_VERSION = AvailableTarget.iOS17 + + def __init__(self, config: OptimizationConfig = None, fake_compression: bool = False): + if not isinstance(config, (OptimizationConfig, type(None))): + raise ValueError(f"config must be of type OptimizationConfig. Got {type(config)}.") + + op_selector = None if config is None else config._op_selector + + super().__init__(op_selector=op_selector) + + self.fake_compression = fake_compression + self._config = config + if config is not None: + self._check_config_type(config) + + def apply(self, prog): + if not isinstance(prog, Program): + raise TypeError('Transform "{}" can only be applied on PyMIL programs.'.format(self)) + + @block_context_manager + def apply_block(block): + if not is_current_opset_version_compatible_with(self._MINIMUM_OPSET_VERSION): + logger.warning( + f"The program's opset is not compatible with {self._MINIMUM_OPSET_VERSION}. " + f"Skipped the compression pass {self.__class__}." + ) + return + + valid_consts = [] + for op in list(block.operations): + for b in op.blocks: + apply_block(b) + + if self.is_valid_op(op): + need_transform = True + if self.op_selector is not None: + need_transform = self.op_selector(op) + + if need_transform: + valid_consts.append(op) + + for op in tqdm( + valid_consts, + desc=f"Running activation compression pass {self.__class__.__name__}", + unit=" ops", + ): + self.transform_op(op) + + for f in prog.functions.values(): + apply_block(f) + + @property + def config(self) -> OptimizationConfig: + return self._config + + @config.setter + def config(self, value: OptimizationConfig): + self._check_config_type(value) + self._config = value + if value._op_selector is not None: + self.op_selector = value._op_selector + + def _check_config_type(self, config: OptimizationConfig): + """ + The utility function is checking the OptimizationConfig is holding correct type of op config. + """ + + def get_supported_types_as_str(supported_type): + if not isinstance(supported_type, (tuple, list)): + supported_type = [supported_type] + return ", ".join([f"{val.__name__}" for val in supported_type]) + + all_configs = [] + if config.global_config is not None: + all_configs.append(config.global_config) + all_configs.extend(list(config.op_type_configs.values())) + all_configs.extend(list(config.op_name_configs.values())) + + for config in all_configs: + if not isinstance(config, self._SUPPORTED_CONFIG_TYPE) and config is not None: + supported_type_str = get_supported_types_as_str(self._SUPPORTED_CONFIG_TYPE) + raise ValueError( + f"{self.__class__.__name__} only accept {supported_type_str} type config. Got {config.__class__.__name__}." + ) + + def is_valid_op(self, op: Operation): + return True + + +@register_pass(namespace="compression") +class insert_prefix_quantize_dequantize_pair(AbstractActCompressionPass): + """ + This graph pass applies transform on each valid activation quantization pattern. + A valid activation quantization pattern should be surrounded by a quantize/dequantize pair before and after this pattern. + This transform adds a quantize/dequantize pair before valid activation quantization patterns. + + .. code-block:: + Input graph: + ... -> downstream op + Output graph: + quantize -> dequantize -> downstream op + """ + + _SUPPORTED_CONFIG_TYPE = OpActivationLinearQuantizerConfig + _MODE_DTYPE_TO_RANGE = { + (types.int8, "LINEAR_SYMMETRIC"): (-127, 127), + } + + def transform_op(self, op: Operation): + if op.op_type not in ("conv", "add"): + return False + + # Checking op-level config. Skip if we disable compression on certain ops. + op_config = self.config._get_op_config(op) + if op_config is None: + return + + scale_dtype = None + if op.inputs["x"].dtype == types.fp16: + scale_dtype = np.float16 + else: + scale_dtype = np.float32 + + if op.op_type in ("conv"): + new_quantize_op = mb.quantize( + input=op.inputs["x"], + scale=np.array(1).astype(scale_dtype), + zero_point=np.int8(0), + output_dtype="int8", + before_op=op, + ) + new_dequantize_op = mb.dequantize( + input=new_quantize_op, + scale=np.array(1).astype(scale_dtype), + zero_point=np.int8(0), + before_op=op, + ) + + kargs = {} + for k, v in op.inputs.items(): + kargs[k] = v + kargs["x"] = new_dequantize_op + kargs["name"] = op.name + kargs["before_op"] = op + new_conv_op = mb.conv(**kargs) + new_conv_op.name = op.outputs[0].name + + if new_conv_op.op.enclosing_block.try_replace_uses_of_var_after_op( + old_var=op.outputs[0], + new_var=new_conv_op, + anchor_op=new_conv_op.op, + end_op=new_conv_op, + ): + pass + new_conv_op.op.enclosing_block.remove_ops([op]) + + if op.op_type in ("add"): + """ + For op with two live inputs (e.g. add): + Input graph: + ... ->| + |-> downstream op + ... ->| + Output graph: + quantize -> dequantize | + |-> downstream op + quantize -> dequantize | + """ + + # Validation check. + # Both inputs x and y need to be non-const. + # Reject when either input is const. + x_is_const = op.inputs["x"].op is not None and op.inputs["x"].op.op_type == "const" + y_is_const = op.inputs["y"].op is not None and op.inputs["y"].op.op_type == "const" + if x_is_const != y_is_const: + return + + new_quantize_op_x = mb.quantize( + input=op.inputs["x"], + scale=np.array(1).astype(scale_dtype), + zero_point=np.int8(0), + output_dtype="int8", + before_op=op, + ) + new_dequantize_op_x = mb.dequantize( + input=new_quantize_op_x, + scale=np.array(1).astype(scale_dtype), + zero_point=np.int8(0), + before_op=op, + ) + new_quantize_op_y = mb.quantize( + input=op.inputs["y"], + scale=np.array(1).astype(scale_dtype), + zero_point=np.int8(0), + output_dtype="int8", + before_op=op, + ) + new_dequantize_op_y = mb.dequantize( + input=new_quantize_op_y, + scale=np.array(1).astype(scale_dtype), + zero_point=np.int8(0), + before_op=op, + ) + new_add_op = mb.add( + x=new_dequantize_op_x, + y=new_dequantize_op_y, + name=op.name, + before_op=op, + ) + new_add_op.name = op.outputs[0].name + + if new_add_op.op.enclosing_block.try_replace_uses_of_var_after_op( + old_var=op.outputs[0], + new_var=new_add_op, + anchor_op=new_add_op.op, + end_op=new_add_op, + ): + pass + new_add_op.op.enclosing_block.remove_ops([op]) diff --git a/coremltools/optimize/torch/__init__.py b/coremltools/optimize/torch/__init__.py index e66d96feb..88f7c3052 100644 --- a/coremltools/optimize/torch/__init__.py +++ b/coremltools/optimize/torch/__init__.py @@ -1,10 +1,11 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause from coremltools.optimize.torch import ( base_model_optimizer, + layerwise_compression, optimization_config, palettization, pruning, diff --git a/coremltools/optimize/torch/_logging.py b/coremltools/optimize/torch/_logging.py index f046b7378..c2dee9105 100644 --- a/coremltools/optimize/torch/_logging.py +++ b/coremltools/optimize/torch/_logging.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/optimize/torch/_typing.py b/coremltools/optimize/torch/_typing.py index 227b587e9..90192d94c 100644 --- a/coremltools/optimize/torch/_typing.py +++ b/coremltools/optimize/torch/_typing.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/optimize/torch/_utils/__init__.py b/coremltools/optimize/torch/_utils/__init__.py index 25c7d28c5..5dc5e6747 100644 --- a/coremltools/optimize/torch/_utils/__init__.py +++ b/coremltools/optimize/torch/_utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/optimize/torch/_utils/dist_utils.py b/coremltools/optimize/torch/_utils/dist_utils.py new file mode 100644 index 000000000..2546f3f90 --- /dev/null +++ b/coremltools/optimize/torch/_utils/dist_utils.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import os as _os + +import torch as _torch +import torch.distributed as _dist + + +def ddp_setup(rank: int, world_size: int): + """ + Set environment variables which are used for initializing distributed + process group for :py:class:`DistributedDataParallel`. + + Args: + rank: Unique identifier of each process + world_size: Total number of processes + """ + _os.environ["MASTER_ADDR"] = "localhost" + _os.environ["MASTER_PORT"] = "12355" + _os.environ["WORLD_SIZE"] = f"{world_size}" + _os.environ["RANK"] = f"{rank}" + _os.environ["LOCAL_RANK"] = f"{rank}" + _torch.cuda.set_device(f"cuda:{rank}") + _dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def is_leader(): + """ + Returns ``True`` if the rank of the current process is 0. + """ + if _dist.is_initialized(): + return _dist.get_rank() == 0 + return True diff --git a/coremltools/optimize/torch/_utils/fsdp_utils.py b/coremltools/optimize/torch/_utils/fsdp_utils.py new file mode 100644 index 000000000..706a6b0fc --- /dev/null +++ b/coremltools/optimize/torch/_utils/fsdp_utils.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from abc import ABC as _ABC +from abc import abstractmethod as _abstractmethod +from functools import partial as _partial +from typing import Iterable as _Iterable +from typing import Type as _Type + +import torch as _torch +from attr import define as _define +from torch.distributed.fsdp.wrap import ModuleWrapPolicy as _TorchModuleWrapPolicy +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy as _size_based_auto_wrap_policy + + +class FSDPAutoWrapPolicy(_ABC): + """ + An abstract base class for implementing an `FSDP `_ auto wrap policy. + + Wrapping a model with ``FSDP`` wrapper, ``FSDP(model)``, results in a single FSDP unit for the entire model. + Thus, during the model's execution, the ``all-gather`` operation collects all the parameters of the model on all + GPUs and hence, parameter sharding doesn't save any CUDA memory. + + To avoid this, one can specify a :py:class:`FSDPAutoWrapPolicy`, which automatically creates multiple FSDP units + nested within the top level FSDP unit, based on certain criteria such as a minimum size limit for each FSDP + unit or based on the class structure of the model. This way, only one FSDP unit needs to collect full + parameters at a time, and one can compute gradients for a much larger model, which wouldn't be possible otherwise. + + For more details, please refer to `FSDP documentation `_ + """ + @_abstractmethod + def get_policy(self): + """ + Return a policy for wrapping different submodules of a model with FSDP wrapper. + """ + + +@_define +class ModuleWrapPolicy(FSDPAutoWrapPolicy): + """ + An auto wrap policy which wraps instances of modules with classes specified by ``module_classes`` into separate + FSDP units. + + This policy is useful for transformer like models which can be naturally split into distinct submodules. + + For example, for a GPT style decoder model, with ``Attention`` and ``FeedForward`` as the two + types of layers in it, one can specify ``module_classes = [Attention, FeedForward]``. This would lead to + each instance of ``Attention`` and ``FeedForward`` layer in the model to be wrapped + into an individual FSDP unit. + """ + + module_classes: _Iterable[_Type[_torch.nn.Module]] + + def get_policy(self): + return _TorchModuleWrapPolicy(self.module_classes) + + +@_define +class SizeBasedWrapPolicy: + """ + An auto wrap policy which creates a new FSDP instances when the number of parameters in the the current FSDP + unit exceeds ``min_num_params``. + """ + + min_num_params: int + + def get_policy(self): + return _partial(_size_based_auto_wrap_policy, min_num_params=self.min_num_params) diff --git a/coremltools/optimize/torch/_utils/k_means.py b/coremltools/optimize/torch/_utils/k_means.py new file mode 100644 index 000000000..b2e90735a --- /dev/null +++ b/coremltools/optimize/torch/_utils/k_means.py @@ -0,0 +1,921 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import logging as _logging +import queue as _queue +from abc import abstractmethod as _abstractmethod +from typing import Any as _Any +from typing import Dict as _Dict +from typing import List as _List +from typing import Optional as _Optional +from typing import Tuple as _Tuple +from typing import Type as _Type +from typing import Union as _Union + +import numpy as _np +import torch as _torch +import torch.multiprocessing as _mp +from attr import define as _define +from sklearn.cluster import KMeans as _kmeans2d # 2d kmeans + +from coremltools._deps import _kmeans1d +from coremltools.converters.mil.mil.ops.defs.iOS18 import ( + constexpr_blockwise_shift_scale as _quantize_op, +) +from coremltools.optimize.coreml._utils import compute_qparams as _compute_qparams +from coremltools.optimize.torch._utils.metadata_utils import ( + CompressionMetadata as _CompressionMetadata, +) +from coremltools.optimize.torch._utils.metadata_utils import ( + register_metadata_version as _register_metadata_version, +) +from coremltools.optimize.torch._utils.python_utils import ClassRegistryMixin as _ClassRegistryMixin +from coremltools.optimize.torch._utils.torch_utils import ( + get_atomic_layers, + get_n_bits_from_dtype, + get_sign_from_dtype, +) + +_logger = _logging.getLogger(__name__) + + +@_define(frozen=True) +class KMeansConfig: + n_bits: int = 4 + axis: int = 0 + lut_dtype: _torch.dtype = None + block_size: _Optional[int] = None + cluster_dim: _Optional[int] = None + enable_per_channel_scale: bool = False + mask: _Optional[_torch.Tensor] = None + importance: _Optional[_torch.Tensor] = None + + +class KMeansSupportedModulesRegistry(_ClassRegistryMixin): + """ + A registry of :py:class:`KMeansModule` classes + """ + + REGISTRY: _Dict[str, _Type["KMeansModule"]] + + @classmethod + def get_kmeans_module(cls, module: _torch.nn.Module) -> _Optional[_Type["KMeansModule"]]: + """ + Returns the :py:class:`KMeansModule` class which implements k-means + for the given module. + """ + for _, layer_cls in cls.REGISTRY.items(): + if layer_cls.is_supported_module(module): + return layer_cls + return None + + @classmethod + def get_supported_modules(cls) -> _Tuple[_Type[_torch.nn.Module]]: + """ + Returns all supported module types for k-means. + """ + return tuple(layer_cls.layer_type for _, layer_cls in cls.REGISTRY.items()) + + +class KMeansModule: + """ + An interface for adding support for a given module class for running + k-means. Implements methods to retrieve parameters which can be clustered + and to update them with new values after clustering. + """ + + layer_type: _Type[ + _torch.nn.Module + ] # The layer type which this interface supports clustering for + parameter_names: _List[str] = [] # List of parameters which are clustered for this layer type + + def __init_subclass__(cls): + KMeansSupportedModulesRegistry.register(cls.__name__)(cls) + + def __init__(self, module: _torch.nn.Module, config: _Dict[str, KMeansConfig]): + self.module = module + self.config = config + self._parameter_metadata = None + self._init_parameter_metadata() + + @_abstractmethod + def _init_parameter_metadata(self): + """ + Initialize metadata for k-means clustering for this layer type. + The metadata is a dictionary from parameter name to a dictionary + of metadata name and its value. This method should add the shape of + the parameters as the metadata for each parameter which + should be clustered. + """ + + @_abstractmethod + def _get_parameters_impl(self) -> _Dict[str, _torch.Tensor]: + """ + Returns a dictionary of parameter name to the parameter tensor + which should be clustered for this layer type. + """ + + @_abstractmethod + def _update_parameters_impl(self, param_name: str, new_value: _torch.Tensor): + """ + Update the parameter corresponding to this parameter name with the + new value after reshaping to original parameter shape. + """ + + @_abstractmethod + def _reshape_for_kmeans(self, param_name: str, value: _torch.Tensor) -> _torch.Tensor: + """ + Reshape any value of original parameter shape to flattened shape for k-means. + """ + + @_abstractmethod + def _reshape_to_original(self, param_name: str, value: _torch.Tensor) -> _torch.Tensor: + """ + Reshape any value flattened for k-means back to original parameter shape. + """ + + def _compute_lut_and_indices(self, param_name: str, param: _torch.Tensor): + """ + Compute LUT and indices from parameter. + For 4-bit palettization and param shape (32, 16, 3, 3), + Case-1: If block_size = 4 and axis = 0, then LUT has shape (8, 1, 1, 1, 16, 1) + Case-2: If block_size = 4 and axis = 1, then LUT has shape (1, 4, 1, 1, 16, 1) + Case-3: If cluster_dim = 4, then LUT has shape (1, 1, 1, 1, 16, 4) + """ + axis = self.config[param_name].axis + num_channels = param.shape[axis] + mask = self.config[param_name].mask + block_size = self.config[param_name].block_size + block_size = num_channels if block_size is None else block_size + cluster_dim = self.config[param_name].cluster_dim + orig_param_shape = self._parameter_metadata[param_name]["shape"] + cluster_dim = 1 if cluster_dim is None else cluster_dim + + lut, indices = [], [] + if cluster_dim == 1: + # Scalar palettization + for block_idx in range(0, num_channels, block_size): + if axis == 0: + lut_idx, ind_idx = _torch.unique( + param[block_idx : block_idx + block_size, :], + return_inverse=True, + ) + else: + lut_idx, ind_idx = _torch.unique( + param[:, block_idx : block_idx + block_size], + return_inverse=True, + ) + + # Ensure param was correctly palettized + # Unless a mask was applied, number of unique values cannot exceed 2^nbits + max_unique_val = 2 ** self.config[param_name].n_bits + assert mask is not None or len(lut_idx) <= max_unique_val, ( + f"Found more than expected unique values in {self.module} " + f"for {param_name}, expected <= {max_unique_val}, found = {len(lut_idx)}" + ) + # Pad lut with zeros if fewer than 2^n_bit unique values are found + if len(lut_idx) < max_unique_val: + padded_lut_idx = _torch.zeros(max_unique_val) + padded_lut_idx[: len(lut_idx)] = lut_idx + lut_idx = padded_lut_idx + + lut.append(lut_idx) + indices.append(ind_idx) + + lut = _torch.stack(lut).unsqueeze(1 - axis).unsqueeze(-1) + indices = _torch.cat(indices, dim=axis) + indices = self._reshape_to_original(param_name, indices) + else: + # Vector palettization + # Reshape param for 2D clustering + if axis == 0: + param_reshaped = param.reshape(-1, cluster_dim) + else: + param_reshaped = param.transpose(0, 1).reshape(-1, cluster_dim) + lut, indices = _torch.unique(param_reshaped, dim=0, return_inverse=True) + + # Undo reshaping in indices done for 2D clustering + if axis == 0: + indices = indices.reshape(param.shape[0], param.shape[1] // cluster_dim) + else: + indices = indices.reshape(param.shape[0] // cluster_dim, param.shape[1]) + + # Incorporate param dimensions in lut shape + for i in range(len(orig_param_shape) - lut.dim() + 2): + lut = lut.unsqueeze(-3) + + return lut, indices + + def _scale_by_per_channel_scale(self, param_name: str, param: _torch.Tensor) -> _torch.Tensor: + """ + Compute per channel scales for scaling the parameter in the range ``[-1, 1]`` + and store them in the parameter metadata. Also scale the parameter using + the computed scales. + """ + if self.config[param_name].enable_per_channel_scale: + flattened_param = param.flatten(1) + per_channel_scale = _torch.max(_torch.abs(flattened_param), dim=1, keepdim=True).values + # Handle zero scales + per_channel_scale[per_channel_scale == 0] = 1 + flattened_param /= per_channel_scale + param = flattened_param.reshape(param.shape) + self._parameter_metadata[param_name]["per_channel_scale"] = per_channel_scale + return param + + def _get_compression_metadata( + self, param_name: str, param: _torch.Tensor + ) -> _CompressionMetadata: + """ + Return compression metadata to be stored in the model for this parameter + """ + metadata = _CompressionMetadata(param_name) + compression_type = ["palettization"] + # LUT + metadata.lut, _ = self._compute_lut_and_indices(param_name, param) + # Per channel scale + if self.config[param_name].enable_per_channel_scale: + per_channel_scale = self._parameter_metadata[param_name]["per_channel_scale"] + reshaped_param = self._reshape_to_original(param_name, param) + for _ in range(reshaped_param.dim() - per_channel_scale.dim()): + per_channel_scale = per_channel_scale.unsqueeze(-1) + metadata.palettization_scale = per_channel_scale + # LUT quantization + if self.config[param_name].lut_dtype is not None: + dtype = self.config[param_name].lut_dtype + compression_type.append("quantization") + metadata.quantization_n_bits = get_n_bits_from_dtype(dtype) + scale = self._parameter_metadata[param_name]["lut_quantization_scale"] + # match scale rank to lut rank + for i in range(metadata.lut.dim() - scale.dim()): + scale = scale.unsqueeze(-1) + metadata.quantization_scale = scale + zp = self._parameter_metadata[param_name]["lut_quantization_zp"] + if zp is not None: + # match zp rank to lut rank + for i in range(metadata.lut.dim() - zp.dim()): + zp = zp.unsqueeze(-1) + metadata.zero_point = zp + # Compression type + metadata.compression_type = compression_type + return metadata + + def _register_compression_metadata(self, param_name: str, param: _torch.Tensor): + """ + Register compression metadata on the model so that it can be serialized. + """ + metadata = self._get_compression_metadata(param_name, param) + metadata.register(self.module) + + def _unscale_by_per_channel_scale(self, param_name: str, param: _torch.Tensor) -> _torch.Tensor: + """ + Re-scale the parameter with ``param_name`` back to its original range by multiplying + per channel scales. + """ + if self.config[param_name].enable_per_channel_scale: + per_channel_scale = self._parameter_metadata[param_name]["per_channel_scale"] + flattened_param = param.flatten(1) + flattened_param *= per_channel_scale + param = flattened_param.reshape(param.shape) + return param + + @classmethod + def is_supported_module(cls, module: _torch.nn.Module) -> bool: + """ + Returns ``True`` if clustering this module is supported by this interface. + """ + return isinstance(module, cls.layer_type) + + def get_parameters(self) -> _Dict[str, _torch.Tensor]: + """ + Returns a dictionary of parameter name to the parameter tensor + which should be clustered for this layer type. Scales the weights + in the range ``[-1, 1]`` if ``per_channel_scale`` is enabled. + """ + return self._get_parameters_impl() + + def update_parameters(self, param_name: str, new_value: _torch.Tensor): + """ + Update the parameter corresponding to this parameter name with the + new value. + """ + self._register_compression_metadata(param_name, new_value) + self._update_parameters_impl(param_name, new_value) + + def get_param_config(self, param_name: str, param: _torch.Tensor) -> KMeansConfig: + """ + Returns KMeansConfig for the specified parameter + """ + config = self.config[param_name] + block_size = param.shape[config.axis] if config.block_size is None else config.block_size + cluster_dim = 1 if config.cluster_dim is None else config.cluster_dim + importance = self._reshape_for_kmeans(param_name, config.importance) + mask = self._reshape_for_kmeans(param_name, config.mask) + + return KMeansConfig( + n_bits=config.n_bits, + axis=config.axis, + lut_dtype=config.lut_dtype, + block_size=block_size, + cluster_dim=cluster_dim, + enable_per_channel_scale=config.enable_per_channel_scale, + mask=mask, + importance=importance, + ) + + +class Linear(KMeansModule): + layer_type: _Type = _torch.nn.Linear + parameter_names: _List[str] = ["weight"] + + def _init_parameter_metadata(self): + self._parameter_metadata = { + "weight": { + "shape": self.module.weight.shape, + } + } + + def _get_parameters_impl(self): + scaled_param = self._scale_by_per_channel_scale("weight", self.module.weight.data) + return {"weight": self._reshape_for_kmeans("weight", scaled_param)} + + def _update_parameters_impl(self, param_name: str, new_value: _torch.Tensor): + param = self._reshape_to_original(param_name, new_value) + self.module.weight.data = self._unscale_by_per_channel_scale(param_name, param) + + def _reshape_for_kmeans(self, param_name: str, value: _torch.Tensor) -> _torch.Tensor: + return value + + def _reshape_to_original(self, param_name: str, value: _torch.Tensor) -> _torch.Tensor: + return value + + +class Embedding(KMeansModule): + layer_type: _Type = _torch.nn.Embedding + parameter_names: _List[str] = ["weight"] + + def _init_parameter_metadata(self): + self._parameter_metadata = { + "weight": { + "shape": self.module.weight.shape, + } + } + + def _get_parameters_impl(self): + scaled_param = self._scale_by_per_channel_scale("weight", self.module.weight.data) + return {"weight": self._reshape_for_kmeans("weight", scaled_param)} + + def _update_parameters_impl(self, param_name: str, new_value: _torch.Tensor): + param = self._reshape_to_original(param_name, new_value) + self.module.weight.data = self._unscale_by_per_channel_scale(param_name, param) + + def _reshape_for_kmeans(self, param_name: str, value: _torch.Tensor) -> _torch.Tensor: + return value + + def _reshape_to_original(self, param_name: str, value: _torch.Tensor) -> _torch.Tensor: + return value + + +class Conv2d(KMeansModule): + layer_type: _Type = _torch.nn.Conv2d + parameter_names: _List[str] = ["weight"] + + def _init_parameter_metadata(self): + self._parameter_metadata = { + "weight": { + "shape": self.module.weight.shape, + } + } + + def _get_parameters_impl(self): + scaled_param = self._scale_by_per_channel_scale("weight", self.module.weight.data) + return {"weight": self._reshape_for_kmeans("weight", scaled_param)} + + def _update_parameters_impl(self, param_name: str, new_value: _torch.Tensor): + param = self._reshape_to_original(param_name, new_value) + self.module.weight.data = self._unscale_by_per_channel_scale(param_name, param) + + def _reshape_for_kmeans(self, param_name: str, value: _torch.Tensor) -> _torch.Tensor: + if value is None: + return value + + if self.config[param_name].axis == 0: + new_value = value.flatten(1) + else: + new_value = value.transpose(0, 1).flatten(1).transpose(0, 1) + + return new_value + + def _reshape_to_original(self, param_name: str, value: _torch.Tensor) -> _torch.Tensor: + if value is None: + return value + + weight_shape = self._parameter_metadata[param_name]["shape"] + if self.config[param_name].axis == 0: + new_value = value.reshape(weight_shape) + else: + new_value = ( + value.transpose(0, 1) + .reshape( + ( + weight_shape[1], + weight_shape[0], + weight_shape[2], + weight_shape[3], + ) + ) + .transpose(0, 1) + ) + return new_value + + +class MultiheadAttention(KMeansModule): + layer_type: _Type = _torch.nn.MultiheadAttention + parameter_names: _List[str] = ["in_proj_weight"] + + def _init_parameter_metadata(self): + self._parameter_metadata = { + "in_proj_weight": { + "shape": self.module.in_proj_weight.shape, + }, + } + + def _get_parameters_impl(self): + scaled_param = self._scale_by_per_channel_scale( + "in_proj_weight", self.module.in_proj_weight.data + ) + return {"in_proj_weight": self._reshape_for_kmeans("in_proj_weight", scaled_param)} + + def _update_parameters_impl(self, param_name: str, new_value: _torch.Tensor): + param = self._reshape_to_original(param_name, new_value) + self.module.in_proj_weight.data = self._unscale_by_per_channel_scale(param_name, param) + + def _reshape_for_kmeans(self, param_name: str, value: _torch.Tensor) -> _torch.Tensor: + return value + + def _reshape_to_original(self, param_name: str, value: _torch.Tensor) -> _torch.Tensor: + return value + + +class KMeans: + @classmethod + @_torch.no_grad() + def _cluster_weights_worker( + cls, + work_q: _Union[_mp.Queue, _queue.Queue], + results_q: _Union[_mp.Queue, _queue.Queue], + ): + while True: + try: + ( + layer_name, + weight_name, + weight, + config, + ) = work_q.get_nowait() + except _queue.Empty: + break + + _logger.info(f"Starting to process layer {layer_name}") + + ( + n_bits, + axis, + lut_dtype, + block_size, + cluster_dim, + mask, + importance, + enable_per_channel_scale, + ) = ( + config.n_bits, + config.axis, + config.lut_dtype, + config.block_size, + config.cluster_dim, + config.mask, + config.importance, + config.enable_per_channel_scale, + ) + + new_weight = _torch.zeros_like(weight, dtype=weight.dtype) + num_clusters = 2**n_bits + + _logger.info( + f"Number of blocks in {layer_name}.{weight_name}: {weight.shape[axis] // block_size}" + ) + + lut_quant_scale = [] + lut_quant_zp = [] + # 1-D clustering (block_size is activated and cluster_dim = 1). + if cluster_dim == 1: + for block_idx in range(0, weight.shape[axis], block_size): + if axis == 0: + block_importance = ( + importance[block_idx : block_idx + block_size, :].flatten() + if importance is not None + else None + ) + block_weight = weight[block_idx : block_idx + block_size, :] + block_mask = ( + mask[block_idx : block_idx + block_size, :].flatten() + if mask is not None + else None + ) + else: + block_importance = ( + importance[:, block_idx : block_idx + block_size].flatten() + if importance is not None + else None + ) + block_weight = weight[:, block_idx : block_idx + block_size] + block_mask = ( + mask[:, block_idx : block_idx + block_size].flatten() + if mask is not None + else None + ) + + block_weight_flatten = block_weight.flatten() + if block_mask is not None: + block_weight_flatten_masked = block_weight_flatten[block_mask] + if len(block_weight_flatten_masked) > 0: + if block_importance is not None: + kmeans_results = _kmeans1d.cluster( + block_weight_flatten_masked.numpy(), + num_clusters, + weights=block_importance[block_mask].numpy(), + ) + else: + kmeans_results = _kmeans1d.cluster( + block_weight_flatten_masked.numpy(), num_clusters + ) + else: + kmeans_results = None + else: + if block_importance is not None: + kmeans_results = _kmeans1d.cluster( + block_weight_flatten.numpy(), + num_clusters, + weights=block_importance.numpy(), + ) + else: + kmeans_results = _kmeans1d.cluster( + block_weight_flatten.numpy(), num_clusters + ) + + centroids = ( + _np.array(kmeans_results.centroids) if kmeans_results is not None else None + ) + clusters = ( + _np.array(kmeans_results.clusters) if kmeans_results is not None else None + ) + + # quantize LUT + if lut_dtype is not None: + centroids, scale, zp = cls._quantize_centroids(lut_dtype, centroids) + lut_quant_scale.append(scale) + if zp: + lut_quant_zp.append(zp) + + if block_mask is not None: + new_block_weight = block_weight_flatten.clone() + if kmeans_results is not None: + new_block_weight[block_mask] = _torch.tensor( + centroids[clusters], dtype=weight.dtype + ) + new_block_weight = new_block_weight.reshape(block_weight.shape) + else: + new_block_weight = _torch.tensor( + centroids[clusters], dtype=weight.dtype + ).reshape(block_weight.shape) + if axis == 0: + new_weight[block_idx : block_idx + block_size, :] = new_block_weight + else: + new_weight[:, block_idx : block_idx + block_size] = new_block_weight + + # 2-D clustering. (cluster_dim is activated and block_size is ignored). + # Not yet support with block_mask/block_importance (ignoring both). + else: + # Convert weight from N-D to 2-D. E.g. (Cin, W, H, Cout) -> (cluster_dim, Cin/cluster_dim * W * H * Cout) + # Apply 2-D kmeans clustering on 2-D weights. + if axis == 0: + weight_2d = weight.reshape(-1, cluster_dim) + else: + weight_2d = weight.transpose(0, 1).reshape(-1, cluster_dim) + + kmeans_results = _kmeans2d(n_clusters=num_clusters).fit(weight_2d.numpy()) + centroids = ( + _np.array(kmeans_results.cluster_centers_) + if kmeans_results is not None + else None + ) + clusters = _np.array(kmeans_results.labels_) if kmeans_results is not None else None + + # quantize LUT + if lut_dtype is not None: + centroids, scale, zp = cls._quantize_centroids(lut_dtype, centroids) + lut_quant_scale.append(scale) + if zp: + lut_quant_zp.append(zp) + + weight_palettized = _torch.tensor(centroids[clusters], dtype=weight.dtype) + if axis == 0: + new_weight = weight_palettized.reshape(weight.shape) + else: + new_weight = weight_palettized.reshape( + weight.shape[1], weight.shape[0] + ).transpose(0, 1) + + if new_weight is not None: + _logger.info( + f"Finished processing {weight_name} in layer {layer_name} successfully" + ) + + # Combine quantization scales / zp for all LUTs into single tensor + scale, zp = None, None + if lut_dtype is not None: + scale = _torch.stack(lut_quant_scale, dim=axis) + if len(lut_quant_zp) > 0: + zp = _torch.stack(lut_quant_zp, dim=axis) + + results_q.put((layer_name, weight_name, new_weight, scale, zp)) + + _logger.info("Process done, work queue is empty") + + @classmethod + def _quantize_centroids(self, dtype: _torch.dtype, centroids: _torch.Tensor): + ret = _compute_qparams( + weight=centroids, + nbits=get_n_bits_from_dtype(dtype), + quantization_mode="LINEAR_SYMMETRIC", + dtype=centroids.dtype, + block_sizes=[0] * centroids.ndim, + signed=get_sign_from_dtype(dtype), + ) + + if ret is None: + _logger.warning(f"Unable to quantize centroids {centroids}") + return + + quant_centroids, scale, zp = ret + dequant_centroids = _quantize_op.decompress( + quant_centroids, + scale, + zp, + ) + + # Convert back to torch tensors + dequant_centroids = _torch.from_numpy(dequant_centroids) + scale = _torch.from_numpy(scale) + if zp is not None: + zp = _torch.from_numpy(zp) + + return dequant_centroids, scale, zp + + @classmethod + def _get_weights_to_cluster( + cls, + model: _torch.nn.Module, + work_q: _Union[_mp.Queue, _queue.Queue], + config: _Union[_Dict[str, _Dict[str, KMeansConfig]], KMeansConfig] = KMeansConfig(), + ) -> _Tuple[_Dict[str, KMeansModule], _Dict[str, _Any]]: + if not isinstance(config, dict): + layers_to_cluster = get_atomic_layers( + model, + layer_types=list(KMeansSupportedModulesRegistry.get_supported_modules()), + name_prefix="", + ) + config_dict = {} + for layer_name, layer in layers_to_cluster.items(): + layer_config = {} + for param_name in KMeansSupportedModulesRegistry.get_kmeans_module( + layer + ).parameter_names: + layer_config[param_name] = config + config_dict[layer_name] = layer_config + else: + layers_to_cluster = { + layer_name: model.get_submodule(layer_name) for layer_name, _ in config.items() + } + config_dict = config + + k_means_module_map = dict() + + param_dict = {} + for layer_name, layer in layers_to_cluster.items(): + layer_config = config_dict[layer_name] + + k_means_module_cls = KMeansSupportedModulesRegistry.get_kmeans_module(layer) + k_means_module: KMeansModule = k_means_module_cls(layer, layer_config) + + k_means_module_map[layer_name] = k_means_module + + for param_name, param in k_means_module.get_parameters().items(): + param_config = k_means_module.get_param_config(param_name, param) + work_q.put((layer_name, param_name, param, param_config)) + param_dict[f"{layer_name}${param_name}"] = (param, param_config) + + return k_means_module_map, param_dict + + @classmethod + def _prepare_worker_processes( + cls, num_workers: int + ) -> _Tuple[ + _Union[_mp.Queue, _queue.Queue], + _Union[_mp.Queue, _queue.Queue], + _Optional[_List[_mp.Process]], + ]: + raise NotImplementedError("This method is not implemented by base class.") + + @classmethod + def _run_worker_processes( + cls, + work_q: _Union[_mp.Queue, _queue.Queue], + results_q: _Union[_mp.Queue, _queue.Queue], + worker_processes: _Optional[_List[_mp.Process]], + ): + raise NotImplementedError("This method is not implemented by base class.") + + @classmethod + def _join_worker_processes(cls, worker_processes: _Optional[_List[_mp.Process]]): + raise NotImplementedError("This method is not implemented by base class.") + + @classmethod + @_torch.no_grad() + def cluster_weights( + cls, + model: _torch.nn.Module, + config: _Union[_Dict[str, _Dict[str, KMeansConfig]], KMeansConfig] = KMeansConfig(), + num_workers: int = 1, + ) -> _torch.nn.Module: + work_q, results_q, worker_processes = cls._prepare_worker_processes(num_workers) + k_means_module_map, param_dict = cls._get_weights_to_cluster( + model=model, + work_q=work_q, + config=config, + ) + + num_params = len(param_dict) + remaining_params = param_dict + + def _worker_loop() -> None: + cls._run_worker_processes(work_q, results_q, worker_processes) + num_params_left = len(remaining_params) + num_errors = 0 + last_chance = False + while remaining_params: + try: + layer_name, param_name, new_value, scale, zp = results_q.get(timeout=10) + except _queue.Empty: + if worker_processes is not None: + # This if path is for ParallelKMeans + # Check if workers are still running, in which case they may still be chewing on data and we + # need to wait. Also identify if any worker died (maybe it has been killed for OOM) and count + # it as an error + for proc in list(worker_processes): + if not proc.is_alive(): + proc.join() + if proc.exitcode != 0: + _logger.error( + f"Process {proc} exited with exit code {proc.exitcode}" + ) + num_errors += 1 + alive_processes = sum(proc.is_alive() for proc in worker_processes) + if not alive_processes: + if last_chance: + _logger.info( + f"All processes are done, but queue is empty, which is unexpected. Expecting to " + f"receive {num_params_left} more param(s). Will end now." + ) + break + else: + last_chance = True + continue + _logger.info( + f"Result queue is empty, but {alive_processes} process(es) is / are still alive, " + f"continuing..." + ) + continue + else: + # This else path is for SequentialKMeans + if not last_chance: + last_chance = True + continue + else: + raise ValueError( + f"Queue is empty, which is unexpected. Expecting to receive {num_params_left} more " + f"param(s)." + ) + else: + _logger.info(f"Progress: {100 * (1.0 - (num_params_left / num_params)):.2f} %") + k_means_module = k_means_module_map[layer_name] + k_means_module._parameter_metadata[param_name]["lut_quantization_scale"] = scale + k_means_module._parameter_metadata[param_name]["lut_quantization_zp"] = zp + k_means_module.update_parameters(param_name, new_value) + remaining_params.pop(f"{layer_name}${param_name}") + # Even though it might not have succeeded + num_params_left -= 1 + + _logger.info("joining worker processes") + cls._join_worker_processes(worker_processes) + + _worker_loop() + + if remaining_params: + _logger.error( + f"The {len(remaining_params)} following params of following layers were not successfully palettized and" + f" a new palettization will be attempted using a single worker: {', '.join(sorted(remaining_params))}" + ) + work_q, results_q, worker_processes = cls._prepare_worker_processes( + num_workers=1 + ) # Running the remaining params with 1 worker as that is more stable + for current_param, param_tuple in remaining_params.items(): + layer_name, param_name = current_param.split("$") + work_q.put((layer_name, param_name, param_tuple[0], param_tuple[1])) + + _worker_loop() + + if remaining_params: + raise RuntimeError( + f"Even after rerunning all failed layers with a single worker, {len(remaining_params)} are " + f"still missing: {', '.join(sorted(remaining_params))}" + ) + else: + _logger.info( + "After rerunning all failed layers with a single worker, all palettizations succeeded" + ) + + _register_metadata_version(model) + return model + + +class ParallelKMeans(KMeans): + @classmethod + def _prepare_worker_processes( + cls, + num_workers: int, + ) -> _Tuple[ + _Union[_mp.Queue, _queue.Queue], + _Union[_mp.Queue, _queue.Queue], + _Optional[_List[_mp.Process]], + ]: + ctx = _mp.get_context("spawn") + manager = ctx.Manager() + work_q = manager.Queue() + results_q = manager.Queue() + + worker_processes = [ + ctx.Process( + target=cls._cluster_weights_worker, + args=(work_q, results_q), + name=f"Process-{rank}", + daemon=True, + ) + for rank in range(num_workers) + ] + return work_q, results_q, worker_processes + + @classmethod + def _run_worker_processes( + cls, + work_q: _Union[_mp.Queue, _queue.Queue], + results_q: _Union[_mp.Queue, _queue.Queue], + worker_processes: _Optional[_List[_mp.Process]], + ): + for worker_process in worker_processes: + worker_process.start() + _logger.info(f"Started {worker_process.name} for clustering weights.") + + @classmethod + def _join_worker_processes(cls, worker_processes: _Optional[_List[_mp.Process]]): + for worker_process in worker_processes: + worker_process.join() + _logger.info(f"Finished {worker_process.name}.") + + +class SequentialKMeans(KMeans): + @classmethod + def _prepare_worker_processes( + cls, num_workers: int + ) -> _Tuple[ + _Union[_mp.Queue, _queue.Queue], + _Union[_mp.Queue, _queue.Queue], + _Optional[_List[_mp.Process]], + ]: + work_q = _queue.Queue() + results_q = _queue.Queue() + return work_q, results_q, None + + @classmethod + def _run_worker_processes( + cls, + work_q: _Union[_mp.Queue, _queue.Queue], + results_q: _Union[_mp.Queue, _queue.Queue], + worker_processes: _Optional[_List[_mp.Process]], + ): + cls._cluster_weights_worker(work_q, results_q) + + @classmethod + def _join_worker_processes(cls, worker_processes: _Optional[_List[_mp.Process]]): + return diff --git a/coremltools/optimize/torch/_utils/math_utils.py b/coremltools/optimize/torch/_utils/math_utils.py index 38038ac50..d6d80fe28 100644 --- a/coremltools/optimize/torch/_utils/math_utils.py +++ b/coremltools/optimize/torch/_utils/math_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/optimize/torch/_utils/metadata_utils.py b/coremltools/optimize/torch/_utils/metadata_utils.py new file mode 100644 index 000000000..e28c63af4 --- /dev/null +++ b/coremltools/optimize/torch/_utils/metadata_utils.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from enum import Enum +from typing import Dict as _Dict +from typing import List as _List +from typing import Optional as _Optional + +import torch as _torch +from attr import define as _define +from attr import field as _field +from attrs import validators as _validators + +from coremltools.optimize.torch._utils.python_utils import DictableDataClass as _DictableDataClass + +STATE_DICT_METADATA_BUFFER_PREFIX = "_COREML_" +BUFFER_NAME_SEPARATOR = "/" +METADATA_VERSION_BUFFER = ( + STATE_DICT_METADATA_BUFFER_PREFIX + BUFFER_NAME_SEPARATOR + "metadata_version" +) +METADATA_VERSION = _torch.tensor(1) + + +class CompressionType(Enum): + pruning = 1 + palettization = 2 + quantization = 3 + + def __str__(self): + return self.name + + +@_define +class CompressionMetadata(_DictableDataClass): + """ + Class to encapsulate and register (store as buffer in state_dict) compression metadata per parameter within a module. + + Args: + param_name (:obj:`str`): Name of parameter corresponding to which metadata is stored. + quantization_n_bits (:obj:`int`): The dtype to use for quantizing the weights. + quantization_scale (:py:class:`torch.Tensor`): Quantization parameters used for scaling weights. + zero_point (:py:class:`torch.Tensor`): Quantization parameters used for translating weights in affine + or unsigned symmetric quantization. + lut (:py:class:`torch.Tensor`): Look up table for palettized weights. + palettization_scale (:py:class:`torch.Tensor`): Per channel scales used to normalize weights before being palettized. + compression_type (:obj:`list` of :py:class:`CompressionType`): List of compression types applied to the parameter + in the order in which they were applied. + """ + + param_name: str = _field(validator=_validators.optional(_validators.instance_of(str))) + quantization_n_bits: _Optional[int] = _field( + default=None, validator=_validators.optional(_validators.instance_of(int)) + ) + quantization_scale: _Optional[_torch.Tensor] = _field( + default=None, + validator=_validators.optional(_validators.instance_of(_torch.Tensor)), + ) + zero_point: _Optional[_torch.Tensor] = _field( + default=None, + validator=_validators.optional(_validators.instance_of(_torch.Tensor)), + ) + lut: _Optional[_torch.Tensor] = _field( + default=None, + validator=_validators.optional(_validators.instance_of(_torch.Tensor)), + ) + palettization_scale: _Optional[_torch.Tensor] = _field( + default=None, + validator=_validators.optional(_validators.instance_of(_torch.Tensor)), + ) + compression_type: _Optional[_List[str]] = _field( + default=None, + converter=lambda lst: [CompressionType[item].value for item in lst] if lst else None, + validator=_validators.optional( + _validators.deep_iterable( + member_validator=_validators.instance_of(int), + iterable_validator=_validators.instance_of(list), + ) + ), + ) + + def register(self, module: _torch.nn.Module): + """ + Register compression metadata as buffers in module's state_dict + """ + for metadata, value in self.as_dict().items(): + if metadata == "param_name" or value is None: + continue + buffer_name = self._get_metadata_buffer_name(metadata) + + # Handle chaining of compression types + if metadata == "compression_type": + try: + current_value = module.get_buffer(buffer_name) + value = current_value.tolist() + value + except AttributeError: + # Previous value doesn't exist + pass + + # Wrap value as a tensor to register as a buffer in module state_dict + if not _torch.is_tensor(value): + value = _torch.tensor(value) + + module.register_buffer(buffer_name, value) + + def _get_metadata_buffer_name(self, metadata_key: str) -> str: + return BUFFER_NAME_SEPARATOR.join( + [STATE_DICT_METADATA_BUFFER_PREFIX, self.param_name, metadata_key] + ) + + @classmethod + def from_state_dict(cls, prefixed_dict) -> _Dict[str, "CompressionMetadata"]: + """ + Initialize per parameter CompressionMetadata from state_dict + """ + param_to_metadata_dict = dict() + for key, value in prefixed_dict.items(): + if key.startswith(STATE_DICT_METADATA_BUFFER_PREFIX) and key != METADATA_VERSION_BUFFER: + prefix, param_name, metadata = key.split(BUFFER_NAME_SEPARATOR) + if param_name not in param_to_metadata_dict: + param_to_metadata_dict[param_name] = {"param_name": param_name} + # For compression type, convert tensor to list of strings + if metadata == "compression_type": + value = [str(CompressionType(x)) for x in value.tolist()] + param_to_metadata_dict[param_name][metadata] = value + + result = { + pname: cls.from_dict(metadata) for pname, metadata in param_to_metadata_dict.items() + } + return result + + +def register_metadata_version(model: _torch.nn.Module): + """ + Register metadata version for the model + """ + model.register_buffer(METADATA_VERSION_BUFFER, METADATA_VERSION) diff --git a/coremltools/optimize/torch/_utils/python_utils.py b/coremltools/optimize/torch/_utils/python_utils.py index ba39c081f..5a40fabfb 100644 --- a/coremltools/optimize/torch/_utils/python_utils.py +++ b/coremltools/optimize/torch/_utils/python_utils.py @@ -5,7 +5,16 @@ import logging as _logging from collections import OrderedDict as _OrderedDict +from typing import IO as _IO from typing import Any as _Any +from typing import Dict as _Dict +from typing import Type as _Type +from typing import Union as _Union + +import cattrs as _cattrs +import torch as _torch +import yaml as _yaml +from attr import asdict as _asdict _logger = _logging.getLogger(__name__) @@ -55,3 +64,58 @@ class FunctionRegistryMixin(RegistryMixin): @classmethod def get_function(cls, name: str): return cls._get_object(name) + + +class DictableDataClass: + """ + Utility class that provides convertors to and from Python dict + """ + + @classmethod + def from_dict(cls, data_dict: _Dict[str, _Any]) -> "DictableDataClass": + """ + Create class from a dictionary of string keys and values. + + Args: + data_dict (:obj:`dict` of :obj:`str` and values): A nested dictionary of strings + and values. + """ + # Explicitly raise exception for unrecognized keys + cls._validate_dict(data_dict) + converter = _cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook(_torch.Tensor, lambda obj, type: obj) + return converter.structure_attrs_fromdict(data_dict, cls) + + @classmethod + def from_yaml(cls, yml: _Union[_IO, str]) -> "DictableDataClass": + """ + Create class from a yaml stream. + + Args: + yml: An :py:class:`IO` stream containing yaml or a :obj:`str` + path to the yaml file. + """ + if isinstance(yml, str): + with open(yml, "r") as file: + dict_from_yml = _yaml.safe_load(file) + else: + dict_from_yml = _yaml.safe_load(yml) + if dict_from_yml is None: + dict_from_yml = {} + assert isinstance(dict_from_yml, dict), ( + "Invalid yaml received. yaml stream should return a dict " + f"on parsing. Received type: {type(dict_from_yml)}." + ) + return cls.from_dict(dict_from_yml) + + def as_dict(self) -> _Dict[str, _Any]: + """ + Returns the config as a dictionary. + """ + return _asdict(self) + + @classmethod + def _validate_dict(cls: _Type, config_dict: _Dict[str, _Any]): + for key, _ in config_dict.items(): + if not hasattr(cls, key): + raise ValueError(f"Found unrecognized key {key} in config_dict: {config_dict}.") diff --git a/coremltools/optimize/torch/_utils/registry.py b/coremltools/optimize/torch/_utils/registry.py new file mode 100644 index 000000000..6280ceece --- /dev/null +++ b/coremltools/optimize/torch/_utils/registry.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from abc import ABC as _ABC + + +class BaseRegistry(_ABC): + """ + Base class for registries that register all subclasses automatically for ease-of-use. + """ + + # Maps from child class registry name to child class registry + registry_map = dict() + + def __init_subclass__(cls, *args, **kwargs): + # Adds mapping from child class registry name to empty child class registry + BaseRegistry.registry_map[cls.__name__] = dict() + + @classmethod + def instantiate(cls, subcls, *args, **kwargs): + """ + Instantiates a subclass entry in the registry of the provided class. + The registry is stored as a dictionary that maps from the subclass name + to a freshly created instance of the subclass. + + Args: + cls: The registry class, which is a subclass of BaseRegistry. + subcls: The subclass to be registered in the registry class. + args: The arguments to be used to create an instance of the subclass. + kwargs: The keyword arguments to be used to create an instance of the subclass. + + """ + + subcls_instance = subcls(*args, **kwargs) + cls.register(subcls_instance) + + @classmethod + def instantiate_key(cls, subcls_key, subcls, *args, **kwargs): + """ + Instantiates a subclass entry in the registry of the provided class. + The registry is stored as a dictionary that maps from the subclass key + to a freshly created instance of the subclass. + + Args: + cls: The registry class, which is a subclass of BaseRegistry. + subcls_key: The subclass key to be used for the registry entry. + subcls: The subclass to be registered in the registry class. + args: The arguments to be used to create an instance of the subclass. + kwargs: The keyword arguments to be used to create an instance of the subclass. + """ + + subcls_instance = subcls(*args, **kwargs) + cls.register_key(subcls_key, subcls_instance) + + @classmethod + def register(cls, subcls): + """ + Registers subclass instance in registry of provided class. + Uses the subclass name as the key for the registry entry. + + Args: + cls: The registry class, which is a subclass of BaseRegistry. + subcls: The subclass instance to register in the registry class. + """ + + registry = cls.get_registry() + # Syntax is needed because cannot look up __name__ from class instance + registry[subcls.__class__.__name__] = subcls + + @classmethod + def register_key(cls, subcls_key, subcls): + """ + Registers subclass instance in registry of provided class. + Uses the subclass key as the key for the registry entry. + + Args: + cls: The registry class, which is a subclass of BaseRegistry. + subcls_key: The subclass key to be used for the registry entry. + subcls: The subclass instance to register in the registry class. + """ + registry = cls.get_registry() + registry[subcls_key] = subcls + + @classmethod + def get_registry(cls): + """ + Looks up the registry corresponding to the provided registry class and + returns it. + + Args: + cls: The registry class, which is a subclass of BaseRegistry. + """ + + return BaseRegistry.registry_map[cls.__name__] + + @classmethod + def get_registry_values(cls): + """ + Looks up the registry corresponding to the provided registry class and + returns its values. This is useful for List/Set style registries with + keys generated automatically by this class. + + Args: + cls: The registry class, which is a subclass of BaseRegistry. + """ + + return BaseRegistry.registry_map[cls.__name__].values() diff --git a/coremltools/optimize/torch/_utils/report_utils.py b/coremltools/optimize/torch/_utils/report_utils.py new file mode 100644 index 000000000..97d5668d5 --- /dev/null +++ b/coremltools/optimize/torch/_utils/report_utils.py @@ -0,0 +1,102 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import logging as _logging +from typing import Tuple, Type + +import torch + +from coremltools.optimize.torch._utils.math_utils import rmse_error +from coremltools.optimize.torch._utils.metadata_utils import CompressionMetadata, CompressionType +from coremltools.optimize.torch.base_model_optimizer import _Report +from coremltools.optimize.torch.pruning._utils import ( + block2_sparsity, + structured_sparsity, + unstructured_sparsity, +) + +_logger = _logging.getLogger(__name__) + + +def _normalize_report(report: _Report) -> _Report: + """ + Normalizes the report by making sure all parameter reports have the same number + """ + all_keys = set() + for _, param_report in report.items(): + for key in param_report: + all_keys.add(key) + + for _, param_report in report.items(): + for key in all_keys: + if key not in param_report: + param_report[key] = -1 + return report + + +def compute_post_training_report( + uncompressed_model: torch.nn.Module, + compressed_model: torch.nn.Module, + supported_modules: Tuple[Type[torch.nn.Module]], +) -> _Report: + """ + Computes rmse between compressed and uncompressed parameters + """ + report = _Report() + for name, module in compressed_model.named_modules(): + if not isinstance(module, supported_modules): + continue + + compression_metadata = CompressionMetadata.from_state_dict(module.state_dict()) + + for param_name in compression_metadata: + module_summary = dict() + param_key = f"{name}.{param_name}" if name else param_name + + with torch.no_grad(): + compression_types = [ + CompressionType(x) for x in compression_metadata[param_name].compression_type + ] + + uncompressed_module = uncompressed_model.get_submodule(name) + compressed_param = module.get_parameter(param_name) + uncompressed_param = uncompressed_module.get_parameter(param_name) + + module_summary["error"] = rmse_error(compressed_param, uncompressed_param).item() + + module_summary["#params"] = int(torch.numel(compressed_param)) + + if CompressionType.pruning in compression_types: + sparse_summary = { + "structured_weight_sparsity": structured_sparsity(compressed_param), + "unstructured_weight_sparsity": unstructured_sparsity(compressed_param), + } + + if compressed_param.size(0) % 2 == 0: + sparse_summary["block2_weight_sparsity"] = block2_sparsity(compressed_param) + else: + sparse_summary["block2_weight_sparsity"] = -1 # Not applicable + + module_summary.update(sparse_summary) + + if CompressionType.quantization in compression_types: + quantization_n_bits = compression_metadata[param_name].quantization_n_bits + # FIXME: add sign of dtype here + module_summary["dtype"] = f"dtype=int{quantization_n_bits}" + + if CompressionType.palettization in compression_types: + lut_shape = compression_metadata[param_name].lut.shape + + n_clusters = lut_shape[-2] + cluster_dim = lut_shape[-1] + + module_summary[ + "palettization_mode" + ] = f"num_clusters={n_clusters}, cluster_dim={cluster_dim}" + + report[param_key] = module_summary + + report = _normalize_report(report) + return report diff --git a/coremltools/optimize/torch/_utils/state_dict_utils.py b/coremltools/optimize/torch/_utils/state_dict_utils.py index 08ec1d3c0..31e5b00b7 100644 --- a/coremltools/optimize/torch/_utils/state_dict_utils.py +++ b/coremltools/optimize/torch/_utils/state_dict_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/optimize/torch/_utils/torch_utils.py b/coremltools/optimize/torch/_utils/torch_utils.py index 54acf2575..bc6b33460 100644 --- a/coremltools/optimize/torch/_utils/torch_utils.py +++ b/coremltools/optimize/torch/_utils/torch_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -7,12 +7,17 @@ import operator as _operator import re as _re from contextlib import contextmanager +from distutils.version import StrictVersion as _StrictVersion +from typing import Any as _Any +from typing import Dict as _Dict from typing import List as _List from typing import Tuple as _Tuple +from typing import Type as _Type from typing import Union as _Union import numpy as _np import torch as _torch +import torch.nn as _nn _logger = _logging.getLogger(__name__) @@ -45,11 +50,54 @@ def list_or_str_to_tensor(alist: _Union[_List[int], str, _torch.Tensor]) -> _tor ) +def _get_dtype_info(dtype: _torch.dtype): + if dtype.is_floating_point: + info_fn = _torch.finfo + else: + info_fn = _torch.iinfo + + return info_fn(dtype) + + +def get_n_bits_from_dtype(dtype: _Union[str, _torch.dtype]) -> int: + if type(dtype) is _torch.dtype: + dtype_info = _get_dtype_info(dtype) + return dtype_info.bits + elif type(dtype) is str: + return int(_re.search(r"\d+", dtype).group()) + else: + raise TypeError( + "dtype must either be a string or an instance of torch.dtype," f" not {type(dtype)}" + ) + + +def get_sign_from_dtype(dtype: _Union[str, _torch.dtype]) -> int: + if type(dtype) is _torch.dtype: + dtype_info = _get_dtype_info(dtype) + return dtype_info.min < 0 + elif type(dtype) is str: + return not dtype.startswith("u") + else: + raise TypeError( + "dtype must either be a string or an instance of torch.dtype," f" not {type(dtype)}" + ) + + def maybe_convert_str_to_dtype(dtype: _Union[str, _torch.dtype]) -> _torch.dtype: _str_to_dtype_map = { "quint8": _torch.quint8, "qint8": _torch.qint8, "float32": _torch.float32, + "int8": _torch.int8, + "uint8": _torch.uint8, + # Torch doesn't support int4 or int3 + # but we can represent it as int8 + "int4": _torch.int8, + "uint4": _torch.uint8, + "qint4": _torch.qint8, + "quint4": _torch.quint8, + "uint3": _torch.uint8, + "int3": _torch.int8, } if isinstance(dtype, str): dtype = dtype.lower() @@ -81,7 +129,7 @@ def maybe_convert_str_to_mod_type(mod_type: str): @contextmanager -def get_eval_model(model): +def get_eval_model(model: _nn.Module): train_flag = model.training try: yield model.eval() @@ -98,3 +146,63 @@ def get_parent_child_name(name: str) -> _Tuple[str, str]: return "", split[0] else: return split[0], split[1] + + +def get_fully_qualified_name(model: _torch.nn.Module, module: _torch.nn.Module) -> str: + """ + Returns fully qualified name for a module if it exists in the model. The fully qualified + name can be used to fetch the module using ``model.get_submodule``. + """ + for mod_name, mod in model.named_modules(remove_duplicate=True): + if mod == module: + return mod_name + raise ValueError(f"Module: {module} is not a submodule of {model}.") + + +def get_atomic_layers( + module: _nn.Module, layer_types: _List[_Type], name_prefix: str = "" +) -> _Dict[str, _nn.Module]: + """ + Returns a dictionary of layer_name: layer for every layer in the module which + matches the types specified in layers_to_find. + """ + if isinstance(module, tuple(layer_types)): + return {name_prefix: module} + result = {} + for name, child in module.named_children(): + result.update( + get_atomic_layers( + child, + layer_types=layer_types, + name_prefix=name_prefix + "." + name if name_prefix != "" else name, + ) + ) + + return result + + +def clone_tensor_object(obj: _Any): + """ + Clone a nested list, tuple or dict of tensors. + """ + if isinstance(obj, _torch.Tensor): + return obj.clone() + elif isinstance(obj, tuple): + return tuple(clone_tensor_object(item) for item in obj) + elif isinstance(obj, list): + return [clone_tensor_object(item) for item in obj] + elif isinstance(obj, dict): + return {key: clone_tensor_object(val) for key, val in obj.items()} + else: + raise ValueError(f"Cannot clone unrecognized object type: {obj}.") + + +def get_torch_version(version): + """ + returns torch version given a version string. Works for versions like + "2.1.1", "2.1.1+cpu", "2.1.1+rc" etc and would return 2.1.1 for these + cases + """ + version_regex = r"\d+\.\d+\.\d+" + version = _re.search(version_regex, str(version)).group(0) + return _StrictVersion(version) diff --git a/coremltools/optimize/torch/_utils/validation_utils.py b/coremltools/optimize/torch/_utils/validation_utils.py new file mode 100644 index 000000000..51eb06554 --- /dev/null +++ b/coremltools/optimize/torch/_utils/validation_utils.py @@ -0,0 +1,172 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import copy as _copy +import logging as _logging +from typing import List as _List +from typing import Optional as _Optional + +import torch as _torch + +from coremltools.optimize.torch.optimization_config import ( + ModuleOptimizationConfig as _ModuleOptimizationConfig, +) +from coremltools.optimize.torch.optimization_config import ( + PalettizationGranularity as _PalettizationGranularity, +) +from coremltools.optimize.torch.optimization_config import ( + QuantizationGranularity as _QuantizationGranularity, +) + +_logger = _logging.getLogger(__name__) + + +class ConfigValidator: + def __init__( + self, + param_name: str, + param: _torch.Tensor, + config: _Optional[_ModuleOptimizationConfig], + ): + self.param_name = param_name + self.param = param + self.config = _copy.deepcopy(config) + + def validate(self, checks_to_run: _List[str]) -> bool: + for check_name in checks_to_run: + check_method = getattr(self, f"sanitize_{check_name}", None) + assert check_method, f"Check {check_method} not found" + + result = check_method() + if not result: + return result + + return True + + def sanitize_quantization_block_size(self): + """ + Validates and updates block_size attribute in quantization config for specified parameter. + If compression should be skipped for param, returns False. + Else, returns True and updates config inplace. + """ + if self.config.granularity != _QuantizationGranularity.per_block: + return True + + if len(self.config.block_size) > self.param.ndim: + _logger.warning( + f"{self.param_name}: Length of block_size tuple {len(self.config.block_size)} " + f"should not exceed the number of dimensions in the parameter {self.param.ndim}" + ) + return False + + # Verify that for non input or output channel axis, block size is either zero or equal to axis length + for idx, bs in enumerate(self.config.block_size): + if idx > 1: + if bs != 0 and bs != self.param.shape[idx]: + _logger.warning( + f"{self.param_name}: Unsupported block_size={self.config.block_size}. " + "Blocking is currently only supported along input OR output channel axis." + ) + return False + + # Determine whether it is an N-D block or a integer block size + if len(self.config.block_size) >= 2: + bs_output = self.config.block_size[0] + bs_input = self.config.block_size[1] + else: + bs_output = None + bs_input = self.config.block_size[0] + + should_block_output = ( + bs_output > 0 and bs_output < self.param.shape[0] if bs_output else False + ) + should_block_input = bs_input > 0 and bs_input < self.param.shape[1] + + if should_block_input and not should_block_output: + # By default we will always have per-channel on output-channel axis + bs_output = 1 + should_block_output = True + + if not should_block_input and not should_block_output: + _logger.warning( + f"{self.param_name}: Valid block_size={self.config.block_size} not specified for any axis. " + "Use per_channel or per_tensor granularity if blocking is not required." + ) + return False + + # Check if the output-channel block size is divisible by the axis length + if should_block_output and self.param.shape[0] % bs_output != 0: + _logger.warning( + f"{self.param_name}: block_size={bs_output} is not divisible by axis length={self.param.shape[0]}" + ) + return False + + # Check if the input-channel block size is divisible by the axis length + if should_block_input and self.param.shape[1] % bs_input != 0: + _logger.warning( + f"{self.param_name}: block_size={bs_input} is not divisible by axis length={self.param.shape[0]}" + ) + return False + + self.config.block_size = (bs_output, bs_input) + return True + + def sanitize_palettization_group_size(self): + """ + Validates and updates block_size attribute in palettization config for specified parameter. + If compression should be skipped for param, returns False. + Else, returns True and updates config inplace. + """ + if self.config.granularity != _PalettizationGranularity.per_grouped_channel: + return True + + # If block size is not divisible by axis length skip palettizing this param + axis_length = self.param.shape[self.config.channel_axis] + if axis_length % self.config.group_size != 0: + _logger.warning( + f"{self.param_name}: group_size={self.config.group_size} is not divisible by axis length={axis_length}" + ) + return False + + return True + + def sanitize_palettization_cluster_dim(self): + """ + Validates and updates cluster_dim attribute in palettization config for specified parameter. + If compression should be skipped for param, returns False. + Else, returns True and updates config inplace. + """ + if self.config.cluster_dim is None: + self.config.cluster_dim = 1 + return True + + if self.config.cluster_dim > 1: + # By default, vectors are formed along the output channel axis. + # Hence, the size of remaining channels should be divisible by ``cluster_dim`` + dim_size = self.param.flatten(1).shape[1] + if dim_size % self.config.cluster_dim != 0: + _logger.warning( + f"{self.param_name}: The number of elements in non-output channels {dim_size} " + f"is not divisible by cluster_dim={self.config.cluster_dim}" + ) + return False + + return True + + +def validate_param_config( + param_name: str, + param: _torch.Tensor, + config: _Optional[_ModuleOptimizationConfig], + checks_to_run: _List[str], +): + validator = ConfigValidator(param_name, param, config) + is_valid_config = validator.validate(checks_to_run) + if not is_valid_config: + # Skip compression for this param if config is invalid + _logger.info(f"Skipping compression for {param_name}") + return None + + return validator.config diff --git a/coremltools/optimize/torch/_utils/version_utils.py b/coremltools/optimize/torch/_utils/version_utils.py index c8b4fde7d..e0908feb8 100644 --- a/coremltools/optimize/torch/_utils/version_utils.py +++ b/coremltools/optimize/torch/_utils/version_utils.py @@ -1,14 +1,14 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause import torch as _torch -from packaging.version import Version +from packaging import version def version_ge(module, target_version): - return Version(module.__version__) >= Version(target_version) + return version.parse(module.__version__) >= version.parse(target_version) def get_torch_version(): diff --git a/coremltools/optimize/torch/base_model_optimizer.py b/coremltools/optimize/torch/base_model_optimizer.py index 54a9298e2..74b92088e 100644 --- a/coremltools/optimize/torch/base_model_optimizer.py +++ b/coremltools/optimize/torch/base_model_optimizer.py @@ -1,14 +1,17 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import copy as _copy import logging as _logging from abc import ABC as _ABC from abc import abstractmethod as _abstractmethod from collections import UserDict as _UserDict +from typing import Iterable as _Iterable from typing import Optional as _Optional from typing import Tuple as _Tuple +from typing import Type as _Type import torch as _torch @@ -19,6 +22,9 @@ class _Report(_UserDict): + """ + A dictionary with pretty printing. + """ def __repr__(self): if len(self.data) < 1: return "" @@ -44,11 +50,36 @@ def __repr__(self): class BaseModelOptimizer(_ABC): - _supported_modules: _Tuple + """ + An abstract base class for implementing optimizers. + """ + _supported_modules: _Tuple[_Type[_torch.nn.Module]] def __init__(self, model: _torch.nn.Module, config: _Optional[_OptimizationConfig] = None): self._model = model self._config = config + + @_abstractmethod + def report(self) -> _Report: + raise NotImplementedError() + + @property + def supported_modules(self) -> _Tuple[_Type[_torch.nn.Module]]: + return self._supported_modules + + def _get_model_for_compression(self, inplace: bool): + return self._model if inplace else _copy.deepcopy(self._model) + + +class BaseTrainingTimeModelOptimizer(BaseModelOptimizer): + """ + An abstract base class for implementing optimization algorithms which + are integrated in model training pipelines. These optimizers simulate + model compression and learn compression parameters during model training. + """ + + def __init__(self, model: _torch.nn.Module, config: _Optional[_OptimizationConfig] = None): + super().__init__(model, config) self._step_count = 0 @_abstractmethod @@ -65,10 +96,49 @@ def finalize( ) -> _torch.nn.Module: raise NotImplementedError() - @_abstractmethod - def report(self) -> _Report: - raise NotImplementedError() - @property - def supported_modules(self): - return self._supported_modules +class BasePostTrainingModelOptimizer(BaseModelOptimizer): + """ + An abstract base class for implementing optimization algorithms which + perform zero-shot compression, after a model has been trained. These + optimizers do no need any data to perform compression. + """ + + def __init__(self, model: _torch.nn.Module, config: _Optional[_OptimizationConfig] = None): + super().__init__(model, config) + self._uncompressed_model = None + + def compress(self, *args, inplace: bool = False, **kwargs) -> _torch.nn.Module: + # if inplace is True: + # self._uncompressed_model -> deep copy of model passed by user + # self._model -> model passed by user + # if inplace is False: + # self._uncompressed_model -> model passed by user + # self._model -> deep copy of model passed by user + self._uncompressed_model = self._get_model_for_compression(inplace=not inplace) + self._model = self._get_model_for_compression(inplace=inplace) + return self._model + + +class BaseDataCalibratedModelOptimizer(BaseModelOptimizer): + """ + An abstract base class for optimization algorithms which use calibration data + to compress models. + """ + + def __init__(self, model: _torch.nn.Module, config: _Optional[_OptimizationConfig] = None): + super().__init__(model, config) + self._uncompressed_model = None + + def compress( + self, dataloader: _Iterable, *args, inplace: bool = False, **kwargs + ) -> _torch.nn.Module: + # if inplace is True: + # self._uncompressed_model -> deep copy of model passed by user + # self._model -> model passed by user + # if inplace is False: + # self._uncompressed_model -> model passed by user + # self._model -> deep copy of model passed by user + self._uncompressed_model = self._get_model_for_compression(inplace=not inplace) + self._model = self._get_model_for_compression(inplace=inplace) + return self._model diff --git a/coremltools/optimize/torch/layerwise_compression/__init__.py b/coremltools/optimize/torch/layerwise_compression/__init__.py new file mode 100644 index 000000000..b72a68337 --- /dev/null +++ b/coremltools/optimize/torch/layerwise_compression/__init__.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +""" +.. _coremltools_optimize_torch_layerwise_compression: + +_`LayerwiseCompressor` +================================== + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.LayerwiseCompressorConfig + :members: from_dict, as_dict, from_yaml, get_layers + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.LayerwiseCompressor + :members: compress + +Algorithms +========== + +:obj:`coremltools.optimize.torch.layerwise_compression.algorithms` submodule contains classes +that implement the algorithms to be used with :py:class:`LayerwiseCompressor`, +which can be used to compress LLM-based models + +GPTQ +---- + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.algorithms.ModuleGPTQConfig + :show-inheritance: + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.algorithms.GPTQ + :show-inheritance: + +SparseGPT +--------- + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.algorithms.ModuleSparseGPTConfig + :show-inheritance: + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.algorithms.SparseGPT + :show-inheritance: + + +Base class for layerwise compression algorithms config +------------------------------------------------------ + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.LayerwiseCompressionAlgorithmConfig + :show-inheritance: + :no-members: + +Base class for layerwise compression algorithms +----------------------------------------------- + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.LayerwiseCompressionAlgorithm + :show-inheritance: + :members: add_batch, cleanup, compress + +Input Cacher +============ + +:obj:`coremltools.optimize.torch.layerwise_compression.input_cacher` submodule contains classes +which provide a way of capturing the model's inputs up till the first module set up +to be compressed. + +FirstLayerInputCacher +--------------------- + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.FirstLayerInputCacher + :show-inheritance: + :members: cache + +DefaultInputCacher +------------------ + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.DefaultInputCacher + :show-inheritance: + :members: cache + +GPTFirstLayerInputCacher +------------------------ + +.. autoclass:: coremltools.optimize.torch.layerwise_compression.GPTFirstLayerInputCacher + :show-inheritance: + :members: cache + +""" + + +from .algorithms import ( + GPTQ, + LayerwiseCompressionAlgorithm, + LayerwiseCompressionAlgorithmConfig, + ModuleGPTQConfig, + ModuleSparseGPTConfig, + SparseGPT, +) +from .input_cacher import DefaultInputCacher, FirstLayerInputCacher, GPTFirstLayerInputCacher +from .layerwise_compressor import LayerwiseCompressor, LayerwiseCompressorConfig diff --git a/coremltools/optimize/torch/layerwise_compression/_quant.py b/coremltools/optimize/torch/layerwise_compression/_quant.py new file mode 100644 index 000000000..da282fd50 --- /dev/null +++ b/coremltools/optimize/torch/layerwise_compression/_quant.py @@ -0,0 +1,211 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +# Original implementation from https://github.com/IST-DASLab/sparsegpt +# Copyright 2023 IST Austria Distributed Algorithms and Systems Lab. All Rights Reserved. + +import torch as _torch + +_normal_float_palette = { + # The 4 bit numbers are copied from QLoRA paper: https://arxiv.org/abs/2305.14314 + 4: _torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ] + ), + # The 3 bit numbers are obtained from bitsandbytes: https://github.com/TimDettmers/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L236 + 3: _torch.tensor([-1.0, -0.4786292, -0.21714179, 0.0, 0.1609302, 0.33791524, 0.562617, 1.0]), +} + + +def quantize( + x: _torch.Tensor, + scale: _torch.Tensor, + zero: _torch.Tensor, + max_q: _torch.Tensor, + enable_normal_float: bool, +): + """ + Quantize ``x`` by rounding and clamping the value using specified + quantization parameters. + """ + n_bits = _torch.log2(max_q + 1).item() + if enable_normal_float: + if n_bits not in _normal_float_palette: + raise ValueError(f"Normal float format is not supported for {n_bits}.") + nf_palette = _normal_float_palette[n_bits] + nf_palette = nf_palette.to(x.device) + distances = _torch.cdist((x / scale).view(-1, 1), nf_palette.unsqueeze(0).T) + indices = _torch.min(distances, dim=1).indices + return scale * nf_palette[indices].view(x.shape) + else: + q = _torch.clamp(_torch.round(x / scale) + zero, 0, max_q) + return scale * (q - zero) + + +class Quantizer(_torch.nn.Module): + """ + A module for quantizing tensors by scaling, shifting, rounding and clamping them such that the values + are represented in ``n_bits`` precision. + """ + + def __init__( + self, + n_bits: int, + per_channel: bool = True, + symmetric: bool = False, + enable_normal_float: bool = False, + mse: bool = False, + norm: float = 2.4, + grid: int = 100, + max_shrink: float = 0.8, + group_rows: int = 1, + ): + super().__init__() + self._per_channel = per_channel + self._symmetric = symmetric + self._enable_normal_float = enable_normal_float + self._mse = mse + self._norm = norm + self._grid = grid + self._max_shrink = max_shrink + self._group_rows = group_rows + self.register_buffer("max_q", _torch.tensor(2**n_bits - 1)) + self.register_buffer("scale", _torch.zeros(1)) + self.register_buffer("zero", _torch.zeros(1)) + + def find_params(self, x, weight=False): + """ + Compute quantization parameters. + """ + device = x.device + self.max_q = self.max_q.to(device) + + shape = x.shape + if self._per_channel: + if weight: + x = x.flatten(1) + if self._group_rows > 1: + x = x.reshape((x.shape[0] // self._group_rows, -1)) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = _torch.zeros(x.shape[0], device=device) + x_min = _torch.minimum(x.min(1)[0], tmp) + x_max = _torch.maximum(x.max(1)[0], tmp) + + if self._symmetric: + xmax = _torch.maximum(_torch.abs(x_min), x_max) + tmp = x_min < 0 + if _torch.any(tmp): + x_min[tmp] = -xmax[tmp] + tmp = (x_min == 0) & (x_max == 0) + x_min[tmp] = -1 + x_max[tmp] = +1 + + if self._enable_normal_float: + self.scale = _torch.maximum(x_max, abs(x_min)) + else: + self.scale = (x_max - x_min) / self.max_q + + if self._symmetric: + self.zero_point = _torch.full_like(self.scale, (self.max_q + 1) / 2) + else: + self.zero_point = _torch.round(-x_min / self.scale) + + if self._mse: + best = _torch.full([x.shape[0]], float("inf"), device=device) + for i in range(int(self._max_shrink * self._grid)): + p = 1 - i / self._grid + x_min1 = p * x_min + x_max1 = p * x_max + scale1 = (x_max1 - x_min1) / self.max_q + zero_point1 = ( + _torch.round(-x_min1 / scale1) if not self._symmetric else self.zero_point + ) + q = quantize( + x, + scale1.unsqueeze(1), + zero_point1.unsqueeze(1), + self.max_q, + self._enable_normal_float, + ) + q -= x + q.abs_() + q.pow_(self._norm) + err = _torch.sum(q, 1) + tmp = err < best + if _torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero_point[tmp] = zero_point1[tmp] + if not self._per_channel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero_point = self.zero_point.repeat(tmp) + + if weight: + if self._group_rows > 1: + self.scale = self.scale.unsqueeze(1).repeat(1, self._group_rows) + self.zero_point = self.zero_point.unsqueeze(1).repeat(1, self._group_rows) + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero_point = self.zero_point.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero_point = self.zero_point.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero_point = self.zero_point.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero_point = self.zero_point.unsqueeze(0) + + def quantize(self, x): + """ + Quantize ``x`` using pre-computed quantization parameters. + """ + if self.ready(): + return quantize(x, self.scale, self.zero_point, self.max_q, self._enable_normal_float) + return x + + def enabled(self): + """ + Returns ``True`` if quantization is enabled. + """ + return self.max_q > 0 + + def ready(self): + """ + Returns ``True`` if quantization parameters have been computed. + """ + return _torch.all(self.scale != 0) diff --git a/coremltools/optimize/torch/layerwise_compression/algorithms.py b/coremltools/optimize/torch/layerwise_compression/algorithms.py new file mode 100644 index 000000000..68d62bc20 --- /dev/null +++ b/coremltools/optimize/torch/layerwise_compression/algorithms.py @@ -0,0 +1,686 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +# Original implementation from https://github.com/IST-DASLab/sparsegpt +# Copyright 2023 IST Austria Distributed Algorithms and Systems Lab. All Rights Reserved. + +import logging as _logging +import math as _math +import time as _time +from abc import ABC as _ABC +from abc import abstractmethod as _abstractmethod +from typing import Optional as _Optional +from typing import Tuple as _Tuple +from typing import Union as _Union + +import cattrs as _cattrs +import torch as _torch +import torch.nn as _nn +from attr import define as _define +from attr import field as _field +from attrs import validators as _validators + +from coremltools.optimize.torch._utils.metadata_utils import ( + CompressionMetadata as _CompressionMetadata, +) +from coremltools.optimize.torch._utils.python_utils import ClassRegistryMixin as _ClassRegistryMixin +from coremltools.optimize.torch._utils.torch_utils import ( + get_n_bits_from_dtype as _get_n_bits_from_dtype, +) +from coremltools.optimize.torch._utils.torch_utils import ( + maybe_convert_str_to_dtype as _maybe_convert_str_to_dtype, +) +from coremltools.optimize.torch.layerwise_compression._quant import Quantizer as _Quantizer +from coremltools.optimize.torch.layerwise_compression._quant import _normal_float_palette +from coremltools.optimize.torch.layerwise_compression._quant import quantize as _quantize +from coremltools.optimize.torch.optimization_config import ( + ModuleOptimizationConfig as _ModuleOptimizationConfig, +) +from coremltools.optimize.torch.optimization_config import QuantizationGranularity +from coremltools.optimize.torch.quantization.quantization_config import ( + QuantizationScheme as _QuantizationScheme, +) + +_logger = _logging.getLogger(__name__) + + +class LayerwiseCompressionAlgorithmConfig(_ABC, _ClassRegistryMixin, _ModuleOptimizationConfig): + """ + A template class and registry for configuration classes to be used + with :py:class:`LayerwiseCompressionAlgorithm`. + """ + + pass + +@LayerwiseCompressionAlgorithmConfig.register("gptq") +@_define +class ModuleGPTQConfig(LayerwiseCompressionAlgorithmConfig): + """ + Configuration class for specifying global and module level compression options for + `GPTQ `_ algorithm. + + Args: + weight_dtype (:py:class:`torch.dtype`): The dtype to use for quantizing the weights. The number of bits used + for quantization is inferred from the dtype. When dtype is set to :py:class:`torch.float32`, the weights + corresponding to that layer are not quantized. Defaults to :py:class:`torch.uint8` which corresponds to + 8-bit quantization. + granularity (:py:class:`QuantizationGranularity`): Specifies the granularity at which quantization parameters + will be computed. Can be one of ``per_channel``, ``per_tensor`` or ``per_block``. When using ``per_block``, + ``block_size`` argument must be specified. Defaults to ``per_channel``. + quantization_scheme: (:py:class:`~.coremltools.optimize.torch.quantization.quantization_config.QuantizationScheme`): Type of + quantization configuration to use. When this parameter is set to ``QuantizationScheme.symmetric``, all + weights are quantized with zero point as zero. When it is set to ``QuantizationScheme.affine``, zero point + can be set anywhere in the range of values allowed for the quantized weight. + Defaults to ``QuantizationScheme.symmetric``. + block_size (:obj:`int`): When ``block_size`` is specified, ``block_size`` + number of values will share the same quantization parameters of scale (and zero point if applicable) across + the input-channel axis. Defaults to ``None``. + enable_normal_float (:obj:`bool`): When ``True``, normal float format is used for quantization. It's + only supported when ``weight_dtype`` is equal to ``int3`` and ``int4``. Defaults to ``False``. + hessian_dampening: (:obj:`float`): Dampening factor added to the diagonal of the + Hessian used by GPTQ algorithm. Defaults to ``0.01``. + use_activation_order_heuristic (:obj:`bool`): When ``True``, columns of weight are sorted + in descending order of values of Hessian diagonal elements. Defaults to ``True``. + processing_group_size (:obj:`int`): The weights are updated in + blocks of size processing_group_size. Defaults to ``128``. + + .. note: + Currently blocking is limited to only the input-channel axis for GPTQ. + """ + + weight_dtype: _Union[str, _torch.dtype] = _field( + default="uint8", + ) + granularity: QuantizationGranularity = _field( + default="per_channel", + converter=QuantizationGranularity, + validator=_validators.in_(QuantizationGranularity), + ) + quantization_scheme: _QuantizationScheme = _field( + default="symmetric", + converter=_QuantizationScheme, + validator=_validators.in_(_QuantizationScheme), + ) + block_size: _Optional[int] = _field( + default=None, validator=_validators.optional(_validators.instance_of(int)) + ) + enable_normal_float: bool = _field(default=False, validator=_validators.instance_of(bool)) + hessian_dampening: float = _field(default=0.01, validator=_validators.instance_of(float)) + use_activation_order_heuristic: bool = _field( + default=False, validator=_validators.instance_of(bool) + ) + processing_group_size: int = _field(default=128, validator=_validators.instance_of(int)) + algorithm: str = _field(default="gptq", validator=_validators.in_("gptq")) + + def __attrs_post_init__(self): + self.weight_n_bits = _get_n_bits_from_dtype(self.weight_dtype) + self.weight_dtype = _maybe_convert_str_to_dtype(self.weight_dtype) + if self.weight_dtype not in [_torch.uint8, _torch.float32]: + raise ValueError( + f"weight_dtype must be one of (torch.uint8, torch.float32) not {self.weight_dtype}" + ) + + @classmethod + def from_dict(cls, config_dict): + converter = _cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook( + _Union[str, _torch.dtype], + lambda obj, type: obj, + ) + return converter.structure_attrs_fromdict(config_dict, cls) + + +@LayerwiseCompressionAlgorithmConfig.register("sparse_gpt") +@_define +class ModuleSparseGPTConfig(LayerwiseCompressionAlgorithmConfig): + """ + Configuration class for specifying global and module level compression options for + `SparseGPT `_ algorithm. + + Args: + target_sparsity (:obj:`float`): Fraction of weight elements to set to ``0``. Defaults to + ``0.5``. + n_m_ratio (:obj:`tuple` of :obj:`int`): A tuple of two integers which specify how ``n:m`` pruning should be + applied. In ``n:m`` pruning, out of every ``m`` elements, ``n`` with lowest magnitude are set to + zero. When ``n_m_ratio`` is not ``None``, the value of ``target_sparsity`` is ignored and the actual + target sparsity is determined by the ``n:m`` ratio. + weight_dtype (:py:class:`torch.dtype`): The dtype to use for quantizing the weights. The number of bits used + for quantization is inferred from the dtype. When dtype is set to :py:class:`torch.float32`, the weights + corresponding to that layer are not quantized. Defaults to :py:class:`torch.float32` which corresponds to + no quantization. + quantization_granularity (:py:class:`QuantizationGranularity`): Specifies the granularity at which quantization parameters + will be computed. Can be one of ``per_channel``, ``per_tensor`` or ``per_block``. When using ``per_block``, + ``block_size`` argument must be specified. Defaults to ``per_channel``. + quantization_scheme: (:py:class:`~.coremltools.optimize.torch.quantization.quantization_config.QuantizationScheme`): Type of + quantization configuration to use. When this parameter is set to ``QuantizationScheme.symmetric``, all + weights are quantized with zero point as zero. When it is set to ``QuantizationScheme.affine``, zero point + can be set anywhere in the range of values allowed for the quantized weight. + Defaults to ``QuantizationScheme.symmetric``. + enable_normal_float (:obj:`bool`): When ``True``, normal float format is used for quantization. It's + only supported for ``weight_dtype`` is equal to ``int3`` and ``int4``. + hessian_dampening (:obj:`float`): Dampening factor added to the diagonal of the + Hessian used by GPTQ algorithm. Defaults to ``0.01``. + processing_group_size (:obj:`int`): The weights are updated in + blocks of size processing_group_size. Defaults to ``128``. + """ + + target_sparsity: float = _field(default=0.5, validator=_validators.instance_of(float)) + n_m_ratio: _Optional[_Tuple[int, int]] = _field( + default=None, + validator=_validators.optional( + _validators.deep_iterable( + member_validator=_validators.instance_of(int), + iterable_validator=_validators.instance_of((tuple, list)), + ) + ), + ) + weight_dtype: _Union[str, _torch.dtype] = _field( + default="uint8", + ) + quantization_granularity: QuantizationGranularity = _field( + default="per_channel", + converter=QuantizationGranularity, + validator=_validators.in_(QuantizationGranularity), + ) + quantization_scheme: _QuantizationScheme = _field( + default="symmetric", + converter=_QuantizationScheme, + validator=_validators.in_(_QuantizationScheme), + ) + + enable_normal_float: bool = _field(default=False, validator=_validators.instance_of(bool)) + hessian_dampening: float = _field(default=0.01, validator=_validators.instance_of(float)) + processing_group_size: int = _field(default=128, validator=_validators.instance_of(int)) + algorithm: str = _field(default="sparse_gpt", validator=_validators.in_("sparse_gpt")) + + def __attrs_post_init__(self): + self.weight_n_bits = _get_n_bits_from_dtype(self.weight_dtype) + self.weight_dtype = _maybe_convert_str_to_dtype(self.weight_dtype) + if self.weight_dtype not in [_torch.uint8, _torch.float32]: + raise ValueError( + f"weight_dtype must be one of (torch.uint8, torch.float32) not {self.weight_dtype}" + ) + + @classmethod + def from_dict(cls, config_dict): + converter = _cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook( + _Union[str, _torch.dtype], + lambda obj, type: obj, + ) + return converter.structure_attrs_fromdict(config_dict, cls) + + +class LayerwiseCompressionAlgorithm(_ClassRegistryMixin): + """ + A template class for implementing layerwise compression algorithms + to be used with :py:class:`LayerwiseCompressor`. + """ + + @_abstractmethod + def add_batch(self, inp: _torch.Tensor, out: _torch.Tensor) -> None: + """ + Perform computation on a batch of data to acquire statistics before + compression. + """ + raise NotImplementedError("Method not implemented in base class.") + + @_abstractmethod + def cleanup(self) -> None: + """ + Reset the state of the compression algorithm object and free GPU memory. + """ + raise NotImplementedError("Method not implemented in base class.") + + @_abstractmethod + def compress(self) -> None: + """ + Compress the weights of the layer. + """ + raise NotImplementedError("Method not implemented in base class.") + + +class OBSCompressionAlgorithm(LayerwiseCompressionAlgorithm): + """ + A compression algorithm which uses the Hessian of the reconstruction loss + to compress a weight matrix of a given layer. Based on the + optimal brain surgeon paradigm described in `Optimal Brain Compression: + A Framework for Accurate Post-Training Quantization and Pruning + `_. + """ + + def __init__(self, layer: _nn.Module, config: LayerwiseCompressionAlgorithmConfig): + self._layer = layer + self._device = self._layer.weight.device + self._nsamples = 0 + self._config = config + weight = self._layer.weight.data + if isinstance(self._layer, _nn.Conv2d): + weight = weight.flatten(1) + self._dim = weight.dim() + self._rows = weight.shape[0] + self._columns = weight.shape[1] + self._hessian = _torch.zeros((self._columns, self._columns), device=self._device) + + @_abstractmethod + def _init_parameters(self, config: LayerwiseCompressionAlgorithmConfig): + """ + Initialize parameters of the algorithm from config. + """ + raise NotImplementedError("Method not implemented in base class.") + + def add_batch(self, inp: _torch.Tensor, out: _torch.Tensor): + self._compute_hessian(inp, out) + + def _compute_hessian(self, inp: _torch.Tensor, out: _torch.Tensor): + """ + Compute Hessian of the L2 loss between the original output + of the layer and the output computed using compressed weights. + """ + self._inp1 = inp + self._out1 = out + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self._layer, _nn.Linear): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if isinstance(self._layer, _nn.Conv2d): + unfold = _nn.Unfold( + self._layer.kernel_size, + dilation=self._layer.dilation, + padding=self._layer.padding, + stride=self._layer.stride, + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + self._hessian *= self._nsamples / (self._nsamples + tmp) + self._nsamples += tmp + inp = _math.sqrt(2 / self._nsamples) * inp.float() + self._hessian += inp.matmul(inp.t()) + + @_abstractmethod + def _compress_impl(self): + """ + Implementation of the compression algorithm + """ + raise NotImplementedError("Method not implemented in base class.") + + def compress(self): + self._compress_impl() + # NOTE: Currently algorithm assumes weight parameter is available for all layers + # and the only parameter that gets updated + metadata = self._get_compression_metadata("weight", self._layer.weight) + metadata.register(self._layer) + + def cleanup(self): + self._inp1 = None + self._out1 = None + self._nsamples = 0 + _torch.cuda.empty_cache() + self._hessian = None + + @_abstractmethod + def _get_compression_metadata(self, param_name, param): + raise NotImplementedError("Method not implemented in base class.") + + def _store_quantization_params(self): + if self._quantizer is not None: + scale = self._quantizer.scale + scale_store = _torch.empty_like(scale, device=_torch.device("cpu")).copy_(scale) + self._scale.append(scale_store) + if not self._enable_normal_float: + zero_point = self._quantizer.zero_point + zero_point_store = _torch.empty_like(zero_point, device=_torch.device("cpu")).copy_( + zero_point + ) + self._zero_point.append(zero_point_store) + + +@LayerwiseCompressionAlgorithm.register("gptq") +class GPTQ(OBSCompressionAlgorithm): + """ + A post training compression algorithm based on the paper + `GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers + `_. + + Args: + layer (:obj:`torch.nn.Module`): Module to be compressed. + config (:py:class:`ModuleGPTQConfig`): Config specifying hyper-parameters + for the GPTQ algorithm. + """ + + def __init__(self, layer: _nn.Module, config: ModuleGPTQConfig): + super().__init__(layer, config) + self._init_parameters(config) + + def _init_parameters(self, config: ModuleGPTQConfig): + # Defaults to blocking along input channel axis + self._block_size = config.block_size + if self._block_size is not None and self._columns % self._block_size != 0: + raise ValueError( + f"Block size must completely divide the axis along which blocking is done: {self._columns} % {self._block_size} != 0" + ) + self._weight_n_bits = config.weight_n_bits + self._processing_group_size = config.processing_group_size + self._enable_normal_float = config.enable_normal_float + self._hessian_dampening = config.hessian_dampening + self._use_activation_order_heuristic = config.use_activation_order_heuristic + self._quantizer = None + if self._weight_n_bits < 16: + per_channel = config.granularity in [ + QuantizationGranularity.per_channel, + QuantizationGranularity.per_block, + ] + self._quantizer = _Quantizer( + n_bits=self._weight_n_bits, + per_channel=per_channel, + symmetric=config.quantization_scheme == _QuantizationScheme.symmetric, + enable_normal_float=config.enable_normal_float, + ) + self._scale = [] + self._zero_point = [] + + def _compress_impl(self): + weight = self._layer.weight.data.clone() + if isinstance(self._layer, _nn.Conv2d): + weight = weight.flatten(1) + weight = weight.float() + + tick = _time.time() + + if not self._quantizer.ready(): + self._quantizer.find_params(weight, weight=True) + if self._block_size == None: + self._store_quantization_params() + + hessian = self._hessian + del self._hessian + dead = _torch.diag(hessian) == 0 + hessian[dead, dead] = 1 + weight[:, dead] = 0 + + perm = None + if self._use_activation_order_heuristic: + perm = _torch.argsort(_torch.diag(hessian), descending=True) + weight = weight[:, perm] + hessian = hessian[perm][:, perm] + + losses = _torch.zeros_like(weight) + quant_weight = _torch.zeros_like(weight) + + damp = self._hessian_dampening * _torch.mean(_torch.diag(hessian)) + diag = _torch.arange(self._columns, device=self._device) + hessian[diag, diag] += damp + hessian = _torch.linalg.cholesky(hessian) + hessian = _torch.cholesky_inverse(hessian) + hessian = _torch.linalg.cholesky(hessian, upper=True) + hessian_inverse = hessian + + for i1 in range(0, self._columns, self._processing_group_size): + i2 = min(i1 + self._processing_group_size, self._columns) + count = i2 - i1 + + weight_block = weight[:, i1:i2].clone() + quant_weight_block = _torch.zeros_like(weight_block) + error_block = _torch.zeros_like(weight_block) + losses_block = _torch.zeros_like(weight_block) + hessian_inverse_block = hessian_inverse[i1:i2, i1:i2] + + for i in range(count): + w = weight_block[:, i] + d = hessian_inverse_block[i, i] + + if self._block_size is not None: + if (i1 + i) % self._block_size == 0: + self._quantizer.find_params( + weight[:, (i1 + i) : (i1 + i + self._block_size)], + weight=True, + ) + self._store_quantization_params() + + q = _quantize( + w.unsqueeze(1), + self._quantizer.scale, + self._quantizer.zero_point, + self._quantizer.max_q, + self._enable_normal_float, + ).flatten() + quant_weight_block[:, i] = q + losses_block[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + weight_block[:, i:] -= err1.unsqueeze(1).matmul( + hessian_inverse_block[i, i:].unsqueeze(0) + ) + error_block[:, i] = err1 + + quant_weight[:, i1:i2] = quant_weight_block + losses[:, i1:i2] = losses_block / 2 + + weight[:, i2:] -= error_block.matmul(hessian_inverse[i1:i2, i2:]) + + if _torch.cuda.is_available(): + _torch.cuda.synchronize() + + _logger.info( + "time %.2f, weight quantization error %.2f" + % (_time.time() - tick, _torch.sum(losses).item()) + ) + + if self._use_activation_order_heuristic: + inverse_perm = _torch.argsort(perm) + quant_weight = quant_weight[:, inverse_perm] + + self._layer.weight.data = quant_weight.reshape(self._layer.weight.shape).to( + self._layer.weight.data.dtype + ) + _logger.debug( + "quantization error in output activations = %.2f" + % (_torch.sum((self._layer(self._inp1) - self._out1) ** 2)) + ) + + def _get_compression_metadata(self, param_name, param): + metadata = _CompressionMetadata(param_name) + + scale = _torch.cat(self._scale, dim=1) + if self._enable_normal_float: + metadata.compression_type = ["palettization"] + metadata.lut = _normal_float_palette[self._weight_n_bits].unsqueeze(-1) + for _ in range(param.dim()): + metadata.lut = metadata.lut.unsqueeze(0) + metadata.palettization_scale = scale + else: + metadata.compression_type = ["quantization"] + metadata.quantization_n_bits = self._weight_n_bits + metadata.quantization_scale = scale + metadata.zero_point = _torch.cat(self._zero_point, dim=1) + + return metadata + + +@LayerwiseCompressionAlgorithm.register("sparse_gpt") +class SparseGPT(OBSCompressionAlgorithm): + """ + A post training compression algorithm based on the paper + `SparseGPT: Massive Language Models Can be Accurately Pruned in One-Shot + `_ + + Args: + layer (:obj:`torch.nn.Module`): Module to be compressed. + config (:py:class:`ModuleSparseGPTConfig`): Config specifying hyper-parameters + for the SparseGPT algorithm. + """ + + def __init__(self, layer: _nn.Module, config: ModuleSparseGPTConfig): + super().__init__(layer, config) + self._init_parameters(config) + + def _init_parameters(self, config: ModuleSparseGPTConfig): + self._target_sparsity = config.target_sparsity + self._weight_n_bits = config.weight_n_bits + self._n_m_ratio = config.n_m_ratio + self._processing_group_size = config.processing_group_size + self._enable_normal_float = config.enable_normal_float + self._hessian_dampening = config.hessian_dampening + self._quantizer = None + if self._weight_n_bits < 16: + per_channel = config.quantization_granularity in [ + QuantizationGranularity.per_channel, + QuantizationGranularity.per_block, + ] + self._quantizer = _Quantizer( + n_bits=self._weight_n_bits, + per_channel=per_channel, + symmetric=config.quantization_scheme == _QuantizationScheme.symmetric, + enable_normal_float=config.enable_normal_float, + ) + self._scale = [] + self._zero_point = [] + if self._n_m_ratio is not None: + self._prune_n, self._prune_m = self._n_m_ratio + else: + self._prune_n, self._prune_m = 0, 0 + + def _compress_impl(self): + weight = self._layer.weight.data.clone() + if isinstance(self._layer, _nn.Conv2d): + weight = weight.flatten(1) + weight = weight.float() + + if self._quantizer is not None and not self._quantizer.ready(): + self._quantizer.find_params(weight, weight=True) + self._store_quantization_params() + + tick = _time.time() + + hessian = self._hessian + del self._hessian + dead = _torch.diag(hessian) == 0 + hessian[dead, dead] = 1 + weight[:, dead] = 0 + + losses = _torch.zeros(self._rows, device=self._device) + + damp = self._hessian_dampening * _torch.mean(_torch.diag(hessian)) + diag = _torch.arange(self._columns, device=self._device) + hessian[diag, diag] += damp + hessian = _torch.linalg.cholesky(hessian) + hessian = _torch.cholesky_inverse(hessian) + hessian = _torch.linalg.cholesky(hessian, upper=True) + hessian_inverse = hessian + + mask = None + + for i1 in range(0, self._columns, self._processing_group_size): + i2 = min(i1 + self._processing_group_size, self._columns) + count = i2 - i1 + + weight_block = weight[:, i1:i2].clone() + quant_weight_block = _torch.zeros_like(weight_block) + error_block = _torch.zeros_like(weight_block) + losses_block = _torch.zeros_like(weight_block) + hessian_inverse_block = hessian_inverse[i1:i2, i1:i2] + + if self._prune_n == 0: + if mask is not None: + mask1 = mask[:, i1:i2] + else: + tmp = ( + weight_block**2 + / (_torch.diag(hessian_inverse_block).reshape((1, -1))) ** 2 + ) + thresh = _torch.sort(tmp.flatten())[0][int(tmp.numel() * self._target_sparsity)] + mask1 = tmp <= thresh + else: + mask1 = _torch.zeros_like(weight_block) == 1 + + for i in range(count): + w = weight_block[:, i] + d = hessian_inverse_block[i, i] + + if self._prune_n != 0 and i % self._prune_m == 0: + tmp = ( + weight_block[:, i : (i + self._prune_m)] ** 2 + / ( + _torch.diag(hessian_inverse_block)[i : (i + self._prune_m)].reshape( + (1, -1) + ) + ) + ** 2 + ) + mask1.scatter_( + 1, + i + _torch.topk(tmp, self._prune_n, dim=1, largest=False)[1], + True, + ) + + q = w.clone() + q[mask1[:, i]] = 0 + + if self._quantizer is not None: + q = _quantize( + q.unsqueeze(1), + self._quantizer.scale, + self._quantizer.zero_point, + self._quantizer.max_q, + self._enable_normal_float, + ).flatten() + + quant_weight_block[:, i] = q + losses_block[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + weight_block[:, i:] -= err1.unsqueeze(1).matmul( + hessian_inverse_block[i, i:].unsqueeze(0) + ) + error_block[:, i] = err1 + + weight[:, i1:i2] = quant_weight_block + losses += _torch.sum(losses_block, 1) / 2 + + weight[:, i2:] -= error_block.matmul(hessian_inverse[i1:i2, i2:]) + + if _torch.cuda.is_available(): + _torch.cuda.synchronize() + + _logger.info( + "time %.2f, weight quantization error %.2f" + % (_time.time() - tick, _torch.sum(losses).item()) + ) + + self._layer.weight.data = weight.reshape(self._layer.weight.shape).to( + self._layer.weight.data.dtype + ) + _logger.debug( + "quantization error in output activations = %.2f" + % (_torch.sum((self._layer(self._inp1) - self._out1) ** 2)) + ) + + def _get_compression_metadata(self, param_name, param): + metadata = _CompressionMetadata(param_name) + compression_type = ["pruning"] + + if not self._quantizer: + metadata.compression_type = compression_type + return metadata + + scale = _torch.cat(self._scale, dim=1) + if self._enable_normal_float: + compression_type.append("palettization") + metadata.lut = _normal_float_palette[self._weight_n_bits].unsqueeze(-1) + for _ in range(param.dim()): + metadata.lut = metadata.lut.unsqueeze(0) + metadata.palettization_scale = scale + else: + compression_type.append("quantization") + metadata.quantization_n_bits = self._weight_n_bits + metadata.quantization_scale = scale + metadata.zero_point = _torch.cat(self._zero_point, dim=1) + + metadata.compression_type = compression_type + return metadata diff --git a/coremltools/optimize/torch/layerwise_compression/input_cacher.py b/coremltools/optimize/torch/layerwise_compression/input_cacher.py new file mode 100644 index 000000000..2e498b1e5 --- /dev/null +++ b/coremltools/optimize/torch/layerwise_compression/input_cacher.py @@ -0,0 +1,184 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +# Original implementation from https://github.com/IST-DASLab/sparsegpt +# Copyright 2023 IST Austria Distributed Algorithms and Systems Lab. All Rights Reserved. + +import logging as _logging +import re as _re +from abc import ABC as _ABC +from abc import abstractmethod as _abstractmethod +from typing import Dict as _Dict +from typing import Iterable as _Iterable +from typing import List as _List +from typing import Tuple as _Tuple +from typing import Union as _Union + +import torch as _torch +import torch.nn as _nn + +from coremltools.optimize.torch._utils.python_utils import ClassRegistryMixin as _ClassRegistryMixin + +_logger = _logging.getLogger(__name__) + + +class StopExecution(ValueError): + pass + + +class FirstLayerInputCacher(_ABC, _ClassRegistryMixin): + """ + A template class for getting the inputs to feed to the first layer of the model + which is set up for compression. + """ + + def __init__(self, model: _nn.Module, layers: str): + self._model = model + self._layers = layers + + @_abstractmethod + def cache( + self, dataloader: _Iterable, nsamples: int, device: str + ) -> _Tuple[_List[_torch.Tensor], _Dict[str, _torch.Tensor]]: + """ + Cache inputs and keyword arguments to be fed to first layer of the model + which is set up for compression. + + Args: + dataloader (:py:class:`Iterable`): An iterable where each element + is an input to the model to be compressed. + nsamples (:obj:`int`): Number of samples to cache. + device (:obj:`str`): Device string for device to run compression on. + """ + raise NotImplementedError("Method not implemented in base class.") + + +@FirstLayerInputCacher.register("gpt") +class GPTFirstLayerInputCacher(FirstLayerInputCacher): + """ + An implementation of :py:class:`FirstLayerInputCacher` for GPT style models. + Computes inputs to feed to the first layer of the model which is set up for compression. + + Args: + model (:obj:`torch.nn.Module`): Module to be compressed. + layers (:obj:`str`): Regex string for the decoder layers of the model. + """ + + def __init__( + self, + model: _nn.Module, + layers: _Union[str, _List], + ): + super().__init__(model, layers) + self._pre_layers = [] + self._first_layer = None + for layer_name, layer in model.named_modules(remove_duplicate=True): + if self._first_layer_match(layer_name, layer): + self._pre_layers.append(layer) + self._first_layer = layer + # break the first time there's a match + break + elif len(list(layer.children())) == 0: + self._pre_layers.append(layer) + if self._first_layer is None: + _logger.warning( + "Could not find first decoder layer based on", + f"decoder layer path {layers} regex", + ) + + def _first_layer_match(self, layer_name: str, layer: _torch.nn.Module) -> bool: + if isinstance(self._layers, str): + return _re.fullmatch(self._layers, layer_name) + elif isinstance(self._layers, list): + if isinstance(self._layers[0], str): + return _re.fullmatch(self._layers[0], layer_name) + else: + return layer == self._layers[0] + + def _feed_data(self, dataloader: _Iterable, nsamples: int, device: str): + """ + Feed data to the model so that the inputs to the first layer can be cached. + """ + num_sampled = 0 + for batch in dataloader: + try: + self._model(batch.to(device)) + except StopExecution: + pass + num_sampled += 1 + if num_sampled >= nsamples: + break + + @staticmethod + def _get_input_cacher_pre_hook(inputs, kwarg_inputs): + """ + Returns forward_pre_hook for caching inputs and keyword arguments + to the first decoder layer of a GPT model. + """ + + def input_cacher_pre_hook(module, args, kwargs): + inputs.append(args) + for key, val in kwargs.items(): + kwarg_inputs[key] = val + raise StopExecution() + + return input_cacher_pre_hook + + def cache( + self, dataloader: _Iterable, nsamples: int, device: str + ) -> _Tuple[_List[_torch.Tensor], _Dict[str, _torch.Tensor]]: + """ + Cache inputs and keyword arguments to be fed to the first decoder layer + of a GPT style model. + + Args: + dataloader (:py:class:`Iterable`): An iterable where each element + is an input to the model to be compressed. + nsamples (:obj:`int`): Number of samples to cache. + device (:obj:`str`): Device string for device to run compression on. + """ + for layer in self._pre_layers: + layer.to(device) + + inputs, kwarg_inputs = [], {} + input_cacher_handle = self._first_layer.register_forward_pre_hook( + self._get_input_cacher_pre_hook(inputs, kwarg_inputs), with_kwargs=True + ) + self._feed_data(dataloader, nsamples, device) + input_cacher_handle.remove() + + for layer in self._pre_layers: + layer.cpu() + + for key, val in kwarg_inputs.items(): + if isinstance(val, _torch.Tensor): + kwarg_inputs[key] = val.to(device) + + return inputs, kwarg_inputs + + +@FirstLayerInputCacher.register("default") +class DefaultInputCacher(FirstLayerInputCacher): + def cache( + self, dataloader: _Iterable, nsamples: int, device: str + ) -> _Tuple[_List[_torch.Tensor], _Dict[str, _torch.Tensor]]: + """ + Cache inputs and keyword arguments to be fed to first layer of the model + which is set up for compression. + + Args: + dataloader (:py:class:`Iterable`): An iterable where each element + is an input to the model to be compressed. + nsamples (:obj:`int`): Number of samples to cache. + device (:obj:`str`): Device string for device to run compression on. + """ + inputs = [] + sampled = 0 + for batch in dataloader: + inputs.append(batch.to(device)) + sampled += 1 + if sampled == nsamples: + break + return inputs, {} diff --git a/coremltools/optimize/torch/layerwise_compression/layerwise_compressor.py b/coremltools/optimize/torch/layerwise_compression/layerwise_compressor.py new file mode 100644 index 000000000..10e443ec9 --- /dev/null +++ b/coremltools/optimize/torch/layerwise_compression/layerwise_compressor.py @@ -0,0 +1,424 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +# Original implementation from https://github.com/IST-DASLab/sparsegpt +# Copyright 2023 IST Austria Distributed Algorithms and Systems Lab. All Rights Reserved. + +import logging as _logging +import re as _re +from collections import OrderedDict as _OrderedDict +from contextlib import contextmanager as _contextmanager +from typing import Any as _Any +from typing import Callable as _Callable +from typing import Dict as _Dict +from typing import Iterable as _Iterable +from typing import List as _List +from typing import NewType as _NewType +from typing import Optional as _Optional +from typing import Tuple as _Tuple +from typing import Union as _Union + +import cattrs as _cattrs +import torch as _torch +import torch.nn as _nn +from attr import define as _define +from attr import field as _field +from attrs import validators as _validators + +from coremltools.optimize.torch._utils.metadata_utils import ( + register_metadata_version as _register_metadata_version, +) +from coremltools.optimize.torch._utils.report_utils import ( + compute_post_training_report as _compute_post_training_report, +) +from coremltools.optimize.torch._utils.torch_utils import get_atomic_layers as _get_atomic_layers +from coremltools.optimize.torch._utils.torch_utils import get_eval_model as _get_eval_model +from coremltools.optimize.torch.base_model_optimizer import ( + BaseDataCalibratedModelOptimizer as _BaseDataCalibratedModelOptimizer, +) +from coremltools.optimize.torch.base_model_optimizer import _Report +from coremltools.optimize.torch.layerwise_compression.algorithms import ( + LayerwiseCompressionAlgorithm as _LayerwiseCompressionAlgorithm, +) +from coremltools.optimize.torch.layerwise_compression.algorithms import ( + LayerwiseCompressionAlgorithmConfig as _LayerwiseCompressionAlgorithmConfig, +) +from coremltools.optimize.torch.layerwise_compression.input_cacher import ( + FirstLayerInputCacher as _FirstLayerInputCacher, +) +from coremltools.optimize.torch.optimization_config import OptimizationConfig as _OptimizationConfig + +_logger = _logging.getLogger(__name__) + + +_ModuleTypeConfigType = _NewType( + "ModuleTypeConfigType", + _Dict[_Union[_Callable, str], _Optional[_LayerwiseCompressionAlgorithmConfig]], +) + + +_SUPPORTED_MODULES = [_torch.nn.Conv2d, _torch.nn.Linear] + + +@_define +class LayerwiseCompressorConfig(_OptimizationConfig): + """ + Configuration class for specifying how different submodules of a model are + compressed by :py:class:`LayerwiseCompressor`. + + Only sequential models are supported. + + Args: + layers (:obj:`list` of :py:class:`torch.nn.Module` or :obj:`str`): List of layers + which should be compressed. The layer names can also be specified as a regex. + The layers listed should be immediate child modules of the parent container + :py:class:`torch.nn.Sequential` model and they should be contiguous, i.e., + output of layer ``n`` should be input of layer ``n+1``. + global_config (:py:class:`ModuleGPTQConfig` or :py:class:`ModuleSparseGPTConfig`): Config to be applied globally + to all supported modules. Missing values are chosen from the default config. + module_type_configs (:obj:`dict` of :obj:`str` to :py:class:`ModuleGPTQConfig` or :py:class:`ModuleSparseGPTConfig`): + Module type configs applied to a specific + module class, such as :py:class:`torch.nn.Linear`. The keys can be either strings + or module classes. + module_name_configs (:obj:`dict` of :obj:`str` to :py:class:`ModuleGPTQConfig` or :py:class:`ModuleSparseGPTConfig`): + Module level configs applied to specific modules. + The name of the module must either be a regex or a fully qualified name that can be used + to fetch it from the top level module using the ``module.get_submodule(target)`` method. + input_cacher (:obj:`str` or :py:class:`FirstLayerInputCacher`): Cacher object + which caches inputs which are fed to the first layer which is set up for compression + calibration_nsamples (:obj:`int`): Number of samples to be used for calibration. + """ + + layers: _Optional[_Union[_List[_Union[_nn.Module, str]], _nn.ModuleList]] = _field( + default=None, + validator=_validators.optional( + _validators.deep_iterable( + member_validator=_validators.instance_of((_nn.Module, str)), + iterable_validator=_validators.instance_of((list, _nn.ModuleList)), + ) + ), + ) + global_config: _Optional[_LayerwiseCompressionAlgorithmConfig] = _field( + default=None, + validator=_validators.optional( + _validators.instance_of(_LayerwiseCompressionAlgorithmConfig) + ), + ) + module_type_configs: _ModuleTypeConfigType = _field( + factory=_OrderedDict, + validator=_validators.deep_mapping( + key_validator=_validators.instance_of((str, _Callable)), + value_validator=_validators.optional( + _validators.instance_of(_LayerwiseCompressionAlgorithmConfig) + ), + mapping_validator=_validators.instance_of(dict), + ), + ) + module_name_configs: _Dict[str, _Optional[_LayerwiseCompressionAlgorithmConfig]] = _field( + factory=_OrderedDict, + validator=_validators.deep_mapping( + key_validator=_validators.instance_of(str), + value_validator=_validators.optional( + _validators.instance_of(_LayerwiseCompressionAlgorithmConfig) + ), + mapping_validator=_validators.instance_of(dict), + ), + ) + input_cacher: str = _field(default="default", converter=_FirstLayerInputCacher.get_class) + calibration_nsamples: int = _field(default=128, validator=_validators.instance_of(int)) + + @classmethod + def from_dict(cls, config_dict: _Dict[str, _Any]) -> "LayerwiseCompressorConfig": + super().from_dict(config_dict) + converter = _cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook( + _Optional[_Union[_List[_Union[_nn.Module, str]], _nn.ModuleList]], + lambda obj, type: obj, + ) + converter.register_structure_hook( + _LayerwiseCompressionAlgorithmConfig, + lambda obj, type: _LayerwiseCompressionAlgorithmConfig.get_class( + obj["algorithm"] + ).from_dict(obj), + ) + converter.register_structure_hook( + _ModuleTypeConfigType, + lambda module_type_config, type: { + key: _LayerwiseCompressionAlgorithmConfig.get_class(val["algorithm"]).from_dict(val) + if val is not None + else None + for key, val in module_type_config.items() + }, + ) + return converter.structure_attrs_fromdict(config_dict, cls) + + def get_layers(self, model: _nn.Module): + if self.layers is None: + for module_name, module in model.named_children(): + yield module_name, module + else: + yielded = set() + for module_name, module in model.named_modules(remove_duplicate=True): + for layer in self.layers: + if isinstance(layer, str) and _re.fullmatch(layer, module_name): + if module_name not in yielded: + yielded.add(module_name) + yield module_name, module + elif module == layer: + if module_name not in yielded: + yielded.add(module_name) + yield module_name, module + + +@_contextmanager +def _set_torch_flags(): + # TODO: Copied from original implementation; determine if this is necessary + cuda_matmul_tf32 = _torch.backends.cuda.matmul.allow_tf32 + cudnn_allow_tf32 = _torch.backends.cudnn.allow_tf32 + try: + _torch.backends.cuda.matmul.allow_tf32 = False + _torch.backends.cudnn.allow_tf32 = False + yield + finally: + _torch.backends.cuda.matmul.allow_tf32 = cuda_matmul_tf32 + _torch.backends.cudnn.allow_tf32 = cudnn_allow_tf32 + + +class LayerwiseCompressor(_BaseDataCalibratedModelOptimizer): + """ + A post training compression algorithm which compresses a sequential model layer by layer. + The implementation supports two variations of this algorithm: + + 1) `GPTQ `_ + 2) `SparseGPT `_ + + At a high level, it compresses weights of a model layer by layer, + by minimizing the L2 norm of the difference between the original activations and + activations obtained on compressing the weights of a layer. The activations + are computed using a few samples of training data. + + Only sequential models are supported, where output of one layer feeds into the + input of the next layer. + + For HuggingFace models, disable use_cache config. This is used to speed up decoding but, + to generalize forward pass for LayerwiseCompressor algorithms across all + model types we need to disable this behavior. + + Example: + + .. code-block:: python + + import torch.nn as nn + from coremltools.optimize.torch.layerwise_compression import ( + LayerwiseCompressor, + LayerwiseCompressorConfig, + ) + + model = nn.Sequential( + OrderedDict( + { + "conv": nn.Conv2d(1, 20, (3, 3)), + "relu1": nn.ReLU(), + "conv2": nn.Conv2d(20, 20, (3, 3)), + "relu2": nn.ReLU(), + } + ) + ) + + dataloder = load_calibration_data() + + # initialize the quantizer + config = LayerwiseCompressorConfig.from_dict( + { + "global_config": { + "algorithm": "gptq", + "weight_dtype": "int4", + }, + "input_cacher": "default", + "calibration_nsamples": 16, + } + ) + + compressor = LayerwiseCompressor(model, config) + + compressed_model = compressor.compress(dataloader) + + Args: + model (:obj:`torch.nn.Module`): Module to be compressed. + config (:py:class:`LayerwiseCompressorConfig`): Config that specifies how + different submodules in the model will be compressed. + """ + + _supported_modules: _Tuple = tuple(_SUPPORTED_MODULES) + + def __init__(self, model: _nn.Module, config: LayerwiseCompressorConfig): + super().__init__(model, config) + self._input_cacher = self._config.input_cacher( + self._model, + self._config.layers, + ) + + @staticmethod + def _forward_layer(layer, inputs, kwarg_inputs, outputs) -> _List: + """ + Perform forward pass on layer and store outputs. + """ + for j, inp in enumerate(inputs): + if isinstance(inp, _torch.Tensor): + inp = (inp,) + outputs[j] = layer(*inp, **kwarg_inputs) + return outputs + + def _get_cached_inputs( + self, dataloader: _Iterable, device: str + ) -> _Tuple[_List[_torch.Tensor], _Dict[str, _torch.Tensor]]: + """ + Cache the inputs and keyword arguments up till the first layer set up for compression + """ + inputs, kwarg_inputs = self._input_cacher.cache( + dataloader=dataloader, + nsamples=self._config.calibration_nsamples, + device=device, + ) + return inputs, kwarg_inputs + + def _get_layers_to_compress(self) -> _Dict[str, _nn.Module]: + """ + Returns a list of layers to be compressed + """ + return self._config.get_layers(self._model) + + def _init_and_config_layer( + self, atomic_layer_name, atomic_layer + ) -> _Optional[_LayerwiseCompressionAlgorithm]: + """ + Initializes and configures the compression algorithm for a given + atomic layer. Returns the initialized and configured compression + algorithm object + """ + layer_config = self._config.get_module_config(atomic_layer_name, atomic_layer) + if layer_config is not None: + algo_class = _LayerwiseCompressionAlgorithm.get_class(layer_config.algorithm) + return algo_class(atomic_layer, layer_config) + return None + + def _register_activation_processing_hook( + self, atomic_layer, compressor_obj + ) -> _torch.utils.hooks.RemovableHandle: + """ + Registers a forward hook on the layer for performing computation + using the inputs to acquire statistics. Returns the handle for + the forward hook + """ + + def activation_processing_hook(_, inp, out): + compressor_obj.add_batch(inp[0].data, out.data) + + return atomic_layer.register_forward_hook(activation_processing_hook) + + @_torch.no_grad() + def _compress_impl(self, dataloader: _Iterable, device: str) -> _nn.Module: + """ + Compresses a model layerwise using the following steps: + 1) Compute inputs to the first layer which is set up for compression using input cacher + 2) For each layer, find submodules which are supported for compression and install compression + hooks. + 3) Run forward pass through each layer, compute activation statistics and use them to + compress weights. + 4) Compute updated outputs using compressed weights to propagate quantization error + to the next layer and set them up as inputs to next layer. + """ + inputs, kwarg_inputs = self._get_cached_inputs(dataloader, device) + outputs = [None for _ in inputs] + + # compress the layers one by one + for layer_idx, (parent_layer_name, layer) in enumerate(self._get_layers_to_compress()): + layer.to(device) + atomic_layers_dict = _get_atomic_layers( + layer, + layer_types=self._supported_modules, + name_prefix=parent_layer_name, + ) + + # dict mapping layer_name -> compression algorithm object + compression_algo_objects_dict = dict() + + # dict mapping layer_name -> forward hook handle + layer_hooks = [] + + for atomic_layer_name, atomic_layer in atomic_layers_dict.items(): + obj = self._init_and_config_layer(atomic_layer_name, atomic_layer) + + if obj is not None: + compression_algo_objects_dict[atomic_layer_name] = obj + + layer_hooks.append(self._register_activation_processing_hook(atomic_layer, obj)) + + # Compute statistics on the activations using the activation processing hooks + outputs = self._forward_layer( + layer, + inputs, + kwarg_inputs, + outputs, + ) + + # Remove the activation processing hooks + for h in layer_hooks: + h.remove() + + # compress the layers + _logger.info(f"Layer {layer_idx}") + for ( + atomic_layer_name, + compressor_algo, + ) in compression_algo_objects_dict.items(): + _logger.info(f"Compressing {atomic_layer_name}") + compressor_algo.compress() + compressor_algo.cleanup() + + del compression_algo_objects_dict + + # feed the previous layer's outputs to this layer + outputs = self._forward_layer( + layer, + inputs, + kwarg_inputs, + outputs, + ) + + # free memory + layer.cpu() + del layer + _torch.cuda.empty_cache() + + # interchange inputs and outputs + inputs, outputs = outputs, inputs + + _register_metadata_version(self._model) + return self._model + + def compress(self, dataloader: _Iterable, device: str, inplace: bool = False) -> _nn.Module: + """ + Compresses model using samples from ``dataloader``. + + Args: + dataloader (:py:class:`Iterable`): An iterable where each element + is an input to the model to be compressed. + device (:obj:`str`): Device string for device to run compression on. + inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and + the original module is mutated, otherwise a copy of the model is mutated and returned. + Defaults to ``False``. + """ + self._model = super().compress(dataloader=dataloader, inplace=inplace) + with _get_eval_model(self._model): + with _set_torch_flags(): + return self._compress_impl(dataloader, device) + + def report(self) -> _Report: + return _compute_post_training_report( + self._uncompressed_model, + self._model, + supported_modules=self._supported_modules, + ) diff --git a/coremltools/optimize/torch/optimization_config.py b/coremltools/optimize/torch/optimization_config.py index 3b3124e05..081d95fc8 100644 --- a/coremltools/optimize/torch/optimization_config.py +++ b/coremltools/optimize/torch/optimization_config.py @@ -1,62 +1,50 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import re as _re from collections import OrderedDict as _OrderedDict -from typing import IO as _IO +from enum import Enum as _Enum from typing import Any as _Any from typing import Callable as _Callable from typing import Dict as _Dict from typing import List as _List from typing import Optional as _Optional -from typing import Type as _Type from typing import Union as _Union -import cattrs as _cattrs import torch as _torch -import yaml as _yaml from attr import Factory as _Factory -from attr import asdict as _asdict from attr import define as _define +from attrs import field as _field +from coremltools.optimize.torch._utils.python_utils import DictableDataClass as _DictableDataClass -@_define -class ModuleOptimizationConfig: - @classmethod - def from_dict(cls, config_dict: _Dict[str, _Any]) -> "ModuleOptimizationConfig": - """ - Create class from a dictionary of string keys and values. - Args: - config_dict (:obj:`dict` of :obj:`str` and values): A nested dictionary of strings - and values. - """ - # passing forbid_extra_keys=True doesn't prevent silent failure when keys are mis-spelled - _validate_config_dict(cls, config_dict) - converter = _cattrs.Converter(forbid_extra_keys=True) - return converter.structure_attrs_fromdict(config_dict, cls) +class QuantizationGranularity(_Enum): + """ + Enum to denote granularity at which different compression schemes are applied. + See specific algorithm for more details. + """ + per_tensor = "per_tensor" + per_channel = "per_channel" + per_block = "per_block" - @classmethod - def from_yaml(cls, yml: _Union[_IO, str]) -> "ModuleOptimizationConfig": - """ - Create class from a yaml stream. +class PalettizationGranularity(_Enum): + """ + Enum to denote granularity at which different compression schemes are applied. + See specific algorithm for more details. + """ - Args: - yml: An :py:class:`IO` stream containing yaml or a :obj:`str` - path to the yaml file. - """ - return _from_yaml(cls, yml) + per_tensor = "per_tensor" + per_grouped_channel = "per_grouped_channel" - def as_dict(self) -> _Dict[str, _Any]: - """ - Returns the config as a dictionary. - """ - return _asdict(self) +class ModuleOptimizationConfig(_DictableDataClass): + pass @_define -class OptimizationConfig: +class OptimizationConfig(_DictableDataClass): global_config: _Optional[ModuleOptimizationConfig] = None module_type_configs: _Dict[ _Union[_Callable, str], _Optional[ModuleOptimizationConfig] @@ -95,9 +83,10 @@ def set_module_name( def get_module_config( self, name: str, module: _torch.nn.Module ) -> _Optional[ModuleOptimizationConfig]: - if name in self.module_name_configs: - return self.module_name_configs[name] - elif type(module) in self.module_type_configs: + for mod_name in self.module_name_configs: + if _re.fullmatch(mod_name, name): + return self.module_name_configs[mod_name] + if type(module) in self.module_type_configs: return self.module_type_configs[type(module)] elif module.__class__.__name__ in self.module_type_configs: return self.module_type_configs[module.__class__.__name__] @@ -114,26 +103,9 @@ def from_dict(cls, config_dict: _Dict[str, _Any]) -> _Optional["OptimizationConf and values. """ # passing forbid_extra_keys=True doesn't prevent silent failure when keys are mis-spelled - _validate_config_dict(cls, config_dict) + cls._validate_dict(config_dict) return - @classmethod - def from_yaml(cls, yml: _Union[_IO, str]) -> "OptimizationConfig": - """ - Create class from a yaml stream. - - Args: - yml: An :py:class:`IO` stream containing yaml or a :obj:`str` - path to the yaml file. - """ - return _from_yaml(cls, yml) - - def as_dict(self) -> _Dict[str, _Any]: - """ - Returns the config as a dictionary. - """ - return _asdict(self) - def _validate_same_params(self, param_names: _List[str]): """ This method validates that all the parameters in param_names @@ -199,27 +171,6 @@ def _structure_from_dict_hook( return _structure_from_dict_hook -def _validate_config_dict(cls: _Type, config_dict: _Dict[str, _Any]): - for key, _ in config_dict.items(): - if not hasattr(cls, key): - raise ValueError(f"Found unrecognized key {key} in config_dict: {config_dict}.") - - -def _from_yaml( - cls: _Union[_Type[OptimizationConfig], _Type[ModuleOptimizationConfig]], yml: _Union[_IO, str] -): - if isinstance(yml, str): - with open(yml, "r") as file: - dict_from_yml = _yaml.safe_load(file) - else: - dict_from_yml = _yaml.safe_load(yml) - assert isinstance(dict_from_yml, dict), ( - "Invalid yaml received. yaml stream should return a dict " - f"on parsing. Received type: {type(dict_from_yml)}." - ) - return cls.from_dict(dict_from_yml) - - def _validate_module_type_keys_factory(supported_modules): supported_module_names = [cls.__name__ for cls in supported_modules] @@ -236,3 +187,11 @@ def validate_module_type_key(instance, attribute, value): ) return validate_module_type_key + + +def _deprecated_field(message="This field is deprecated"): + def validator(inst, attr, val): + if val is not None: + raise DeprecationWarning(message) + + return _field(default=None, validator=validator, on_setattr=validator) diff --git a/coremltools/optimize/torch/palettization/__init__.py b/coremltools/optimize/torch/palettization/__init__.py index 92a00a226..037fcebc1 100644 --- a/coremltools/optimize/torch/palettization/__init__.py +++ b/coremltools/optimize/torch/palettization/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -9,8 +9,8 @@ .. include:: palettization_desc.rst :end-line: 7 -_`Palettizer` -============= +_`DKMPalettizer` +================ Top level APIs -------------- @@ -31,8 +31,37 @@ .. autoclass:: coremltools.optimize.torch.palettization.FakePalettize :no-members: +_`SensitiveKMeans` +================== + +.. autoclass:: coremltools.optimize.torch.palettization.ModuleSKMPalettizerConfig + :members: from_dict, as_dict, from_yaml + +.. autoclass:: coremltools.optimize.torch.palettization.SKMPalettizerConfig + :members: set_global, set_module_type, set_module_name, from_dict, as_dict, from_yaml + +.. autoclass:: coremltools.optimize.torch.palettization.SKMPalettizer + :members: compute_sensitivity, compress + +_`PostTrainingPalettization` +============================ + +.. autoclass:: coremltools.optimize.torch.palettization.ModulePostTrainingPalettizerConfig + :members: from_dict, as_dict, from_yaml + +.. autoclass:: coremltools.optimize.torch.palettization.PostTrainingPalettizerConfig + :members: set_global, set_module_type, set_module_name, from_dict, as_dict, from_yaml + +.. autoclass:: coremltools.optimize.torch.palettization.PostTrainingPalettizer + :members: compress """ from .fake_palettize import FakePalettize from .palettization_config import DKMPalettizerConfig, ModuleDKMPalettizerConfig from .palettizer import DKMPalettizer +from .post_training_palettization import ( + ModulePostTrainingPalettizerConfig, + PostTrainingPalettizer, + PostTrainingPalettizerConfig, +) +from .sensitive_k_means import ModuleSKMPalettizerConfig, SKMPalettizer, SKMPalettizerConfig diff --git a/coremltools/optimize/torch/palettization/_custom_conversion.py b/coremltools/optimize/torch/palettization/_custom_conversion.py index 326796f56..426a69019 100644 --- a/coremltools/optimize/torch/palettization/_custom_conversion.py +++ b/coremltools/optimize/torch/palettization/_custom_conversion.py @@ -1,11 +1,16 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import torch as _torch import torch.nn as _nn import torch.nn.qat as _nnqat +from coremltools.optimize.torch._utils.metadata_utils import ( + CompressionMetadata as _CompressionMetadata, +) + from ._supported_modules import Conv1d, Embedding, LayerNorm, MultiheadAttention @@ -25,16 +30,38 @@ def do_attribute_assertions(cls, observed_module: _nn.Module): observed_module, "qconfig" ), f"Module {type(observed_module)} has no attribute qconfig" assert hasattr(observed_module, "activation_post_process"), ( - f"Module {type(observed_module)} has no " f"attribute activation_post_process " + f"Module {type(observed_module)} has no " f"attribute activation_post_process" ) assert hasattr(observed_module, "weight_fake_quant"), ( - f"Module {type(observed_module)} has no attribute " f"weight_fake_quant " + f"Module {type(observed_module)} has no attribute " f"weight_fake_quant" ) @classmethod def get_finalized_weights(cls, observed_module: _nn.Module): return observed_module.weight_fake_quant.forward(observed_module.weight.detach()) + @classmethod + def add_metadata(cls, observed_module: _nn.Module, return_module: _nn.Module): + for dir_key in dir(observed_module): + if "_fake_quant" in dir_key: + if not isinstance(getattr(observed_module, dir_key).centroids[0], _torch.Tensor): + break + param_name = dir_key.replace("_fake_quant", "") + compression_metadata = _CompressionMetadata(param_name) + compression_metadata.compression_type = ["palettization"] + lut = _torch.stack(getattr(observed_module, dir_key).centroids, dim=0) + for i in range(observed_module.weight.dim() + 2 - lut.dim()): + lut = lut.unsqueeze(-3) + compression_metadata.lut = lut + if getattr(observed_module, dir_key).enable_per_channel_scale: + per_channel_scaling_factor = getattr( + observed_module, dir_key + ).per_channel_scaling_factor + for _ in range(observed_module.weight.dim() - per_channel_scaling_factor.dim()): + per_channel_scaling_factor = per_channel_scaling_factor.unsqueeze(-1) + compression_metadata.palettization_scale = per_channel_scaling_factor + compression_metadata.register(return_module) + @classmethod def from_observed(cls, observed_module: _nn.Module): """ @@ -64,6 +91,7 @@ def from_observed(cls, observed_module: _nn.Module): dtype=observed_module.dtype if hasattr(observed_module, "dtype") else None, ) return_module.weight = _nn.Parameter(finalized_weights) + cls.add_metadata(observed_module, return_module) if observed_module.bias is not None: return_module.bias = _nn.Parameter(observed_module.bias.detach()) return_module.activation_post_process = observed_module.activation_post_process @@ -96,6 +124,7 @@ def from_observed(cls, observed_module: _nn.Module): dtype=observed_module.dtype if hasattr(observed_module, "dtype") else None, ) return_module.weight = _nn.Parameter(finalized_weights) + cls.add_metadata(observed_module, return_module) if observed_module.bias is not None: return_module.bias = _nn.Parameter(observed_module.bias.detach()) return_module.activation_post_process = observed_module.activation_post_process @@ -128,6 +157,7 @@ def from_observed(cls, observed_module: _nn.Module): dtype=observed_module.dtype if hasattr(observed_module, "dtype") else None, ) return_module.weight = _nn.Parameter(finalized_weights) + cls.add_metadata(observed_module, return_module) if observed_module.bias is not None: return_module.bias = _nn.Parameter(observed_module.bias.detach()) return_module.activation_post_process = observed_module.activation_post_process @@ -160,6 +190,7 @@ def from_observed(cls, observed_module: _nn.Module): dtype=observed_module.dtype if hasattr(observed_module, "dtype") else None, ) return_module.weight = _nn.Parameter(finalized_weights) + cls.add_metadata(observed_module, return_module) if observed_module.bias is not None: return_module.bias = _nn.Parameter(observed_module.bias.detach()) return_module.activation_post_process = observed_module.activation_post_process @@ -189,6 +220,7 @@ def from_observed(cls, observed_module: _nn.Module): return_module.weight = _nn.Parameter(finalized_weights) if observed_module.bias: return_module.bias = _nn.Parameter(observed_module.bias.detach()) + cls.add_metadata(observed_module, return_module) return_module.activation_post_process = observed_module.activation_post_process return return_module @@ -201,16 +233,46 @@ class MultiheadAttentionPalettizationConversion(PalettizationCustomConversionBas def __init__(self): super().__init__() + @classmethod + def do_attribute_assertions(cls, observed_module: _nn.Module): + assert hasattr( + observed_module, "qconfig" + ), f"Module {type(observed_module)} has no attribute qconfig" + assert hasattr(observed_module, "activation_post_process"), ( + f"Module {type(observed_module)} has no " f"attribute activation_post_process" + ) + + assert hasattr(observed_module.out_proj, "weight_fake_quant"), ( + f"Module {type(observed_module.out_proj)} has no attribute " f"q_proj_weight_fake_quant" + ) + if not observed_module._qkv_same_embed_dim: + assert hasattr(observed_module, "q_proj_weight_fake_quant"), ( + f"Module {type(observed_module)} has no attribute " f"q_proj_weight_fake_quant" + ) + assert hasattr(observed_module, "k_proj_weight_fake_quant"), ( + f"Module {type(observed_module)} has no attribute " f"k_proj_weight_fake_quant" + ) + assert hasattr(observed_module, "v_proj_weight_fake_quant"), ( + f"Module {type(observed_module)} has no attribute " f"v_proj_weight_fake_quant" + ) + else: + assert hasattr(observed_module, "in_proj_weight_fake_quant"), ( + f"Module {type(observed_module)} has no attribute " f"in_proj_weight_fake_quant" + ) + @classmethod def from_observed(cls, observed_module: _nn.Module): cls.do_attribute_assertions(observed_module) - finalized_weights = cls.get_finalized_weights(observed_module) + add_bias_kv = observed_module.bias_k is not None and observed_module.bias_v is not None + bias = ( + observed_module.out_proj.bias is not None and observed_module.in_proj_bias is not None + ) return_module = _nn.MultiheadAttention( embed_dim=observed_module.embed_dim, num_heads=observed_module.num_heads, dropout=observed_module.dropout, - bias=observed_module.bias is not None, - add_bias_kv=observed_module.add_bias_kv, + bias=bias, + add_bias_kv=add_bias_kv, add_zero_attn=observed_module.add_zero_attn, kdim=observed_module.kdim, vdim=observed_module.vdim, @@ -218,13 +280,40 @@ def from_observed(cls, observed_module: _nn.Module): device=observed_module.device if hasattr(observed_module, "device") else None, dtype=observed_module.dtype if hasattr(observed_module, "dtype") else None, ) - return_module.weight = _nn.Parameter(finalized_weights) - return_module.bias = _nn.Parameter(observed_module.bias.detach()) - if observed_module.add_bias_kv: + if not observed_module._qkv_same_embed_dim: + return_module.q_proj_weight = _nn.Parameter( + observed_module.q_proj_weight_fake_quant.forward( + observed_module.q_proj_weight.detach() + ) + ) + return_module.k_proj_weight = _nn.Parameter( + observed_module.k_proj_weight_fake_quant.forward( + observed_module.k_proj_weight.detach() + ) + ) + return_module.v_proj_weight = _nn.Parameter( + observed_module.v_proj_weight_fake_quant.forward( + observed_module.v_proj_weight.detach() + ) + ) + else: + return_module.in_proj_weight = _nn.Parameter( + observed_module.in_proj_weight_fake_quant.forward( + observed_module.in_proj_weight.detach() + ) + ) + return_module.out_proj.weight = _nn.Parameter( + observed_module.out_proj.weight_fake_quant.forward( + observed_module.out_proj.weight.detach() + ) + ) + if bias: + return_module.out_proj.bias = _nn.Parameter(observed_module.out_proj.bias.detach()) + return_module.in_proj_bias = _nn.Parameter(observed_module.in_proj_bias.detach()) + if add_bias_kv: return_module.bias_k = _nn.Parameter(observed_module.bias_k.detach()) return_module.bias_v = _nn.Parameter(observed_module.bias_v.detach()) - else: - return_module.bias_k = return_module.bias_v = None + cls.add_metadata(observed_module, return_module) return_module.activation_post_process = observed_module.activation_post_process return return_module @@ -254,6 +343,7 @@ def from_observed(cls, observed_module: _nn.Module): dtype=observed_module.dtype if hasattr(observed_module, "dtype") else None, ) return_module.weight = _nn.Parameter(finalized_weights) + cls.add_metadata(observed_module, return_module) return_module.activation_post_process = observed_module.activation_post_process return return_module diff --git a/coremltools/optimize/torch/palettization/_efficient_kmeans.py b/coremltools/optimize/torch/palettization/_efficient_kmeans.py index 21305b745..24518746c 100644 --- a/coremltools/optimize/torch/palettization/_efficient_kmeans.py +++ b/coremltools/optimize/torch/palettization/_efficient_kmeans.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -14,45 +14,20 @@ class _EfficientKMeans: implementation of k-means, called ``kmeans_pp`` which runs entirely on GPU and is ~10x faster than sklearn's API. """ - @staticmethod - def get_cluster_avg(n_clusters: int, indices, vals): - v_sum = ( - _torch.zeros([n_clusters] + list(vals[0].size()), dtype=vals.dtype) - .to(vals.device) - .index_add_(0, indices, vals) - ) - v_numel = ( - _torch.zeros(n_clusters, dtype=_torch.int) - .to(vals.device) - .index_add_(0, indices, _torch.ones(len(vals), dtype=_torch.int).to(vals.device)) - ) - v_avg = v_sum / v_numel.reshape(-1, 1) - - return v_avg - - @staticmethod - def x_c_dist(weights: _torch.Tensor, centroids: _torch.Tensor): - """ - Method to calculate distance between weights and centroids. - """ - return _torch.cdist(weights, centroids).square() - def __init__( self, - n_clusters: int, - init: str, - n_init: int = 0, + n_clusters, + init, + n_init=0, labels=None, - verbose: int = 0, - max_iter: int = 100, - tol: float = 0.0001, - error_bnd: int = 0, + max_iter=100, + tol=0.0001, + error_bnd=0.0, ): self.n_clusters = n_clusters self.n_init = n_init self.max_iter = max_iter self.tol = tol - self.verbose = verbose self.labels_ = labels self.inertia_ = None self.cluster_centers_ = init @@ -61,200 +36,156 @@ def __init__( assert self.max_iter > 0 assert self.n_clusters > 0 - def kmeans_pp(self, n_init: str, X: _torch.Tensor, random_state: int, offset: int = 0): - """ - In-house implementation of kmeans that runs entirely on GPU and is ~10x faster. - """ - assert ( - len(X) >= self.n_clusters - ), f"Weight fewer points than the number of clusters: {len(X)} vs. {self.n_clusters}" - - S = X[offset:] - - self.inertia_ = None - - width = (len(S) - 1) // (random_state + 1) - - for i in range(n_init): - idx = int(i / n_init * width) - C = S[idx].unsqueeze(0) - - for j in range(len(C), self.n_clusters): - min_error, labels = self.__class__.x_c_dist(S, C).min(dim=-1) - - while True: - max_dist_idx = _torch.argmax(min_error) - assert min_error[max_dist_idx] >= 0, "Cannot find a next candidate" - - candidate_C = S[max_dist_idx] - if candidate_C in set(C): - _dist[max_dist_idx] = -1 - else: - C = _torch.vstack((C, candidate_C)) - break + @staticmethod + def _get_cluster_avg(n_clusters, indices, vals) -> _torch.Tensor: + v_sum = ( + _torch.zeros([n_clusters] + list(vals[0].size())) + .to(vals.device) + .index_add_(0, indices, vals.float()) + ) + v_numel = ( + _torch.zeros(n_clusters, dtype=_torch.int) + .to(vals.device) + .index_add_(0, indices, _torch.ones(len(vals), dtype=_torch.int).to(vals.device)) + ) + v_numel[v_numel == 0] = 1 - if len(set(C)) != self.n_clusters: - return self.kmeans_pp(n_init, X, random_state, offset + 1) + v_avg = v_sum / v_numel.reshape(-1, 1) - min_error, labels = self.__class__.x_c_dist(X, C).min(dim=-1) - cur_cost = min_error.sum() + return v_avg.to(vals.dtype) - if self.inertia_ is None or self.inertia_ > cur_cost: - self.inertia_ = cur_cost - self.cluster_centers_ = C - self.labels_ = labels + @staticmethod + def x_c_dist(params, clusters) -> _torch.Tensor: + """ + Method to calculate the distance between weights and clusters. + """ + clusters = clusters.contiguous() - def cost(self, i: int, j: int, new_cluster_cost: float): - if i > j: - cur_cost = 0 + if _torch.finfo(params.dtype).bits > _torch.finfo(clusters.dtype).bits: + return _torch.cdist(params.to(clusters.dtype), clusters).square() else: - size = j - i + 1 - sum_i_j = self.prefix_x[j] - (self.prefix_x[i - 1] if i >= 1 else 0) - sum2_i_j = self.prefix_x2[j] - (self.prefix_x2[i - 1] if i >= 1 else 0) - mean_i_j = sum_i_j / size - cc_i_j = -mean_i_j * mean_i_j * size + sum2_i_j + return _torch.cdist(params, clusters.to(params.dtype)).square() + + def _kmeans_pp(self, parameters): + assert len(parameters) >= self.n_clusters + + self.inertia_ = int(1e9) + + for n in range(self.n_init): + centroids = _torch.zeros( + (self.n_clusters, parameters.size(-1)), + device=parameters.device, + dtype=parameters.dtype, + ) + for i in range(self.n_clusters): + if i == 0: + centroids[i] = parameters[_torch.randint(0, len(parameters), [1])] + d_ij_curr = _torch.cdist(centroids[:i], parameters) + else: + d_ij_prev = _torch.cdist(centroids[i - 1 : i], parameters) + d_ij_prev[d_ij_prev == 0] = -int(1e9) - if cc_i_j < 0: - cc_i_j = 0 + d_ij_curr = _torch.cat((d_ij_curr, d_ij_prev), dim=0) - cur_cost = cc_i_j * (1 - self.tol) + new_cluster_cost * self.tol + c_to_x = _torch.min(d_ij_curr, dim=0) + centroids[i] = parameters[c_to_x[0].argmax()] - return cur_cost + for i in range(self.max_iter): + min_error, labels = _torch.cdist(parameters, centroids).min(dim=-1) + + # if W is None: + centroids.zero_() + centroids.scatter_add_( + 0, + labels.view(-1, 1).expand([-1, parameters.size(-1)]), + parameters, + ) + n_centroids = _torch.bincount(labels, minlength=self.n_clusters).view(-1, 1) - def backtrace(self, P, T, i, m): - if m >= 0: - P = [m] + P + centroids /= n_centroids + cur_inertia = min_error.square().sum() - if m == 0: - return P + if cur_inertia < self.inertia_: + exit = self.inertia_ <= cur_inertia * (1 + self.tol) + self.inertia_ = cur_inertia + self.labels_ = labels + self.cluster_centers_ = centroids + if exit: + break - return self.backtrace(P, T, i - 1, T[i - 1][m - 1]) + return self - def fit(self, X: _torch.Tensor): + def fit(self, X): """ Method to run kmeans operation. """ N = len(X) + + assert N >= self.n_clusters, f"too many clusters {self.n_clusters} for {N} samples" + if isinstance(self.cluster_centers_, str): if "kmeans++" in self.cluster_centers_: - if _dist.is_available() and _dist.is_initialized(): - world_size = _dist.get_world_size() rank = _dist.get_rank() else: - world_size = 1 rank = 0 if "cpu" in self.cluster_centers_: - import sklearn.cluster + import sklearn + + if "minibatch" in self.cluster_centers_: + clustering_method = sklearn.cluster.MiniBatchKMeans + else: + clustering_method = sklearn.cluster.KMeans - kmeans = sklearn.cluster.KMeans( - n_init=max(10, self.n_init // world_size), + kmeans = clustering_method( + n_init=self.n_init, n_clusters=self.n_clusters, max_iter=self.max_iter, random_state=rank + 1, - verbose=0, tol=self.tol, - ).fit(X.cpu().numpy()) + ).fit(X.float().cpu().numpy()) self.inertia_ = _torch.Tensor([kmeans.inertia_]).to(X.device) - self.labels_ = _torch.from_numpy(kmeans.labels_).to(_torch.int).to(X.device) + self.labels_ = _torch.from_numpy(kmeans.labels_).int().to(X.device) self.cluster_centers_ = None else: - self.kmeans_pp(self.n_init, X, rank + 1) + self._kmeans_pp(X.float()) - self.fit(X) - - bcast_rank = self.get_best_rank(self.inertia_, _torch.argmin) - if bcast_rank is not None: - _dist.broadcast(self.cluster_centers_, bcast_rank) - _dist.broadcast(self.labels_, bcast_rank) - - return self + self.cluster_centers_ = _EfficientKMeans._get_cluster_avg( + self.n_clusters, self.labels_, X + ) elif self.cluster_centers_ == "opt1d": - nX, sort_order = _torch.sort(X, dim=0) - nX = nX.cpu().numpy() - rN = range(N) - - self.prefix_x = _np.cumsum(nX) - self.prefix_x2 = _np.cumsum(_np.square(nX)) - - new_cluster_cost = 0 # 2 * self.cost(0, N - 1, 0) - - num_D = self.n_clusters if self.verbose >= 2 else 2 - - D = _np.full((num_D, N), _np.inf) - D[0] = [self.cost(0, m, new_cluster_cost) for m in rN] - T = _np.full((self.n_clusters, N), -1, dtype=int) - T[0] = [0 for m in rN] - - opt_t_cost = D[0][-1] - opt_n_clusters = 0 - for c in range(1, self.n_clusters): - if True: + from coremltools._deps import _kmeans1d - def lookup(m, j): - return -( - D[(c - 1) % num_D][min(j - 1, m)] - + self.cost(j, m, new_cluster_cost) - ) - - R = self.smawk(rN, rN, lookup) - - for k, v in R.items(): - D[c % num_D][k] = -lookup(k, v) - T[c][k] = v - else: - for m in range(1, N): - for j in range(m): - cur_cost = D[(c - 1) % num_D][j] + self.cost( - j + 1, m, new_cluster_cost - ) - if cur_cost < D[c % num_D][m]: - D[c % num_D][m] = cur_cost - T[c][m] = j + 1 - - if opt_t_cost > D[c % num_D][-1]: - opt_t_cost = D[c % num_D][-1] - opt_n_clusters = c - - P = [] - P = self.backtrace(P, T, opt_n_clusters, T[opt_n_clusters][-1]) - P.append(N) - - self.labels_ = [] - self.cluster_centers_ = [] - for i in range(len(P) - 1): - v = nX[P[i] : P[i + 1]] - if len(v): - self.labels_ += [len(self.cluster_centers_)] * len(v) - self.cluster_centers_.append([_np.mean(v)]) + self.labels_, self.cluster_centers_ = _kmeans1d.cluster(X, self.n_clusters) self.n_clusters = len(self.cluster_centers_) - self.cluster_centers_ = _torch.from_numpy(_np.array(self.cluster_centers_)).to( - device=X.device, dtype=X.dtype - ) - min_error, self.labels_ = self.__class__.x_c_dist(X, self.cluster_centers_).min( - dim=-1 + self.cluster_centers_ = ( + _torch.Tensor(self.cluster_centers_) + .to(device=X.device, dtype=X.dtype) + .view(-1, 1) ) - self.inertia_ = min_error.sum() + self.labels_ = _torch.Tensor(self.labels_).int().to(X.device) + min_error, _ = _EfficientKMeans.x_c_dist(X, self.cluster_centers_).min(dim=-1) + self.inertia_ = min_error.sum() else: self.inertia_ = None for i in range(self.max_iter): - - self.cluster_centers_ = self.__class__.get_cluster_avg( + self.cluster_centers_ = _EfficientKMeans._get_cluster_avg( self.n_clusters, self.labels_, X ) + # remove empty clusters perhaps due to pruning nan_centers = self.cluster_centers_.isnan() if nan_centers.any(): - self.kmeans_pp(self.n_init, X, i) + self._kmeans_pp(X) continue - self.x_c_dist = self.__class__.x_c_dist(X, self.cluster_centers_) - min_error, self.labels_ = self.x_c_dist.min(dim=-1) + x_c_dist = _EfficientKMeans.x_c_dist(X, self.cluster_centers_) + min_error, self.labels_ = x_c_dist.min(dim=-1) cur_inertia = min_error.sum() if self.error_bnd and _torch.sqrt(cur_inertia / N) < self.error_bnd: @@ -267,13 +198,13 @@ def lookup(m, j): reduce_cluster_centers_ = reduce_cluster_centers_[ ~_torch.isnan(reduce_cluster_centers_) ].view(-1, 1) - reduce_min_error, reduce_labels_ = self.__class__.x_c_dist( + reduce_min_error, reduce_labels_ = _EfficientKMeans.x_c_dist( X, reduce_cluster_centers_ ).min(dim=-1) reduce_inertia = reduce_cluster_centers_.sum() - self.rmse_error = _torch.sqrt(reduce_inertia / N) + rmse_error = _torch.sqrt(reduce_inertia / N) - if self.rmse_error < self.error_bnd: + if rmse_error < self.error_bnd: self.cluster_centers_ = reduce_cluster_centers_ self.labels_ = reduce_labels_ self.n_clusters = len(self.cluster_centers_) @@ -286,64 +217,3 @@ def lookup(m, j): break return self - - def get_best_rank(self, metric, func=_torch.argmin): - if _dist.is_available() and _dist.is_initialized(): - world_size = _dist.get_world_size() - if world_size > 1: - tensor_list = [_torch.zeros_like(metric) for _ in range(world_size)] - _dist.all_gather(tensor_list, metric) - bcast_rank = func(_torch.Tensor(tensor_list)) - - return bcast_rank - - return None - - def rmse_error(self, a, b): - return _torch.sqrt(_torch.mean(_torch.square(a - b))) - - def smawk(self, rows, cols, lookup): - """Search for row-maxima in a 2d totally monotone matrix M[i,j]. - The input is specified by a list of row indices, a list of column - indices, and a function "lookup" satisfying lookup(i,j) = M[i,j]. - The matrix must satisfy the totally monotone ordering property: - if i occurs before i' in rows, j occurs before j' in cols, and - M[i,j] < M[i,j'], then also M[i',j] < M[i',j']. The result is - returned as a dictionary mapping row i to the column j containing - the largest value M[i,j]. Ties are broken in favor of earlier - columns. The number of calls to lookup is O(len(rows)+len(cols)).""" - - # base case of recursion - if not rows: - return {} - - # reduce phase: make number of columns at most equal to number of rows - stack = [] - for c in cols: - while len(stack) >= 1 and lookup(rows[len(stack) - 1], stack[-1]) < lookup( - rows[len(stack) - 1], c - ): - stack.pop() - if len(stack) != len(rows): - stack.append(c) - - cols = stack - - # recursive call to search for every odd row - result = self.smawk([rows[i] for i in range(1, len(rows), 2)], cols, lookup) - - # go back and fill in the even rows - c = 0 - for r in range(0, len(rows), 2): - row = rows[r] - if r == len(rows) - 1: - cc = len(cols) - 1 # if r is last row, search through last col - else: - cc = c # otherwise only until pos of max in row r+1 - target = result[rows[r + 1]] - while cols[cc] != target: - cc += 1 - result[row] = max([(lookup(row, cols[x]), -x, cols[x]) for x in range(c, cc + 1)])[2] - c = cc - - return result diff --git a/coremltools/optimize/torch/palettization/_fake_palettizer_tensor_hook.py b/coremltools/optimize/torch/palettization/_fake_palettizer_tensor_hook.py index 41b8272bd..4ffb9f3b4 100644 --- a/coremltools/optimize/torch/palettization/_fake_palettizer_tensor_hook.py +++ b/coremltools/optimize/torch/palettization/_fake_palettizer_tensor_hook.py @@ -1,114 +1,296 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import gc +from typing import Callable as _Callable +from typing import Tuple as _Tuple + import torch as _torch -import torch.nn.functional as _F +import torch.distributed as _dist + +from ._utils import get_shard_list as _get_shard_list +MAX_RECURSION_DEPTH = 10 -class _FakePalettizationTensorHook: + +class _FakePalettizerTensorHook: """ - _FakePalettizationTensorHook is the class to assist in using CPU when we only want to utilize a certain percentage - of the GPU memory. + _FakePalettizerTensorHook is the custom hook that implements many of the tensor packing and unpacking + techniques illustrated in the paper `eDKM: An Efficient and Accurate Train-time Weight Clustering for Large + Language Models `_ """ + + SOFTMAX_BACKWARD = "SoftmaxBackward" + CLAMP_BACKWARD = "ClampBackward" + DIST_BACKWARD = "EuclideanDistBackward" + TRANS_BACKWARD = "TransposeBackward" + STACK_BACKWARD = "StackBackward" + INDEX_BACKWARD = "IndexBackward" + DIV_BACKWARD = "DivBackward" + SLICE_BACKWARD = "SliceBackward" + VIEW_BACKWARD = "ViewBackward" + EXPAND_BACKWARD = "ExpandBackward" + RESHAPE_BACKWARD = "ReshapeAliasBackward" + TOCOPY_BACKWARD = "ToCopyBackward" + gc_trigger = None + last_report = {} def __init__( - self, size_list, use_cpu: bool = False, name: str = None, palett_tau: float = 0.0001 + self, + zero_threshold, + device, + min_size=0, + max_mem=1.0, + use_unique=False, + use_shard=False, ): - self.name = name - self.size_list = size_list - self.tensor_list = [None] * len(size_list) - self.device_list = [None] * len(size_list) - self.use_cpu = use_cpu - self.palett_tau = palett_tau - - def init_pack(self, x: _torch.Tensor): - """ - Method that initialises packing and saving values to CPU. - """ - if x.size() in self.size_list: - idx = self.size_list.index(x.size()) + self.min_size = max(min_size, 64) + self.max_mem = max_mem + self.tensor_dict = {} + self.tensor_counter = {} + self.total_requested = 0 + self.total_allocated = 0 + self.use_unique = use_unique + self.use_shard = use_shard + self.pack_counter = -1 + self.device = device + self.zero_threshold = zero_threshold - if self.tensor_list[idx] is None: - self.device_list[idx] = x.device + t = _torch.cuda.get_device_properties(device).total_memory + a = _torch.cuda.memory_allocated(device) - if self.use_cpu: - self.tensor_list[idx] = _torch.empty( - x.size(), dtype=x.dtype, layout=x.layout, pin_memory=True - ) - self.tensor_list[idx].copy_(x) - else: - self.tensor_list[idx] = x + self.use_cpu = (a / t) > abs(self.max_mem) and hasattr(_torch.autograd, "graph") + if self.use_cpu: + if self.__class__.gc_trigger is None: + self.__class__.gc_trigger = True - elif _torch.equal(self.tensor_list[idx][0].to(self.device_list[idx]), x[0]): - pass - else: - assert False + if self.__class__.gc_trigger: + gc.collect() - return idx + def _copy_to_device(self, x) -> _torch.Tensor: + if self.use_cpu: + packed = _torch.empty(x.size(), dtype=x.dtype, layout=x.layout, pin_memory=True) + packed.copy_(x, non_blocking=True) + return packed return x - def init_unpack(self, x: _torch.Tensor): - """ - Method that initialises un-packing and retrieving values from CPU. - """ - if isinstance(x, int): - idx = x + def _unique_tensor(self, x) -> _Tuple[_torch.Tensor, _torch.Tensor, _torch.Tensor]: + if x.size(1) <= 1 or x.size(0) <= 1024: + return x - assert self.tensor_list[idx] is not None - self.tensor_list[idx] = self.tensor_list[idx].to( - self.device_list[idx], non_blocking=True - ) - return self.tensor_list[idx] + y, y_i = x.float().unique(return_inverse=True, dim=0) + y_base = 0 - return x + y = y.to(x.dtype) + y = self._copy_to_device(y) + + max_y_size = y.size(0) - def reuse_pack(self, x: _torch.Tensor): + if max_y_size >= _torch.iinfo(_torch.int16).max: + y_base = max_y_size // 2 + y_i -= y_base + max_y_size = y_base + 1 + + y_i = _lower_int(y_i, 0, max_y_size) + y_i = self._copy_to_device(y_i) + + return y, y_i, y_base + + def _compress_tensor(self, x, dtype) -> list: + if x.numel() <= self.min_size: + return x + + if x.dim() > 1: + x = x.flatten(end_dim=-2) + + world_size = _dist.get_world_size() + rank = _dist.get_rank() + + if len(x) < world_size or not self.use_shard: + x = x.to(dtype) + if self.use_unique: + x = self._unique_tensor(x) + return x + + shard_list = _get_shard_list(len(x)) + + tensor_list = [None] * world_size + shard = x[shard_list[rank] : shard_list[rank + 1]].to(dtype) + + if self.use_unique: + tensor_list[rank] = self._unique_tensor(shard) + else: + tensor_list[rank] = self._copy_to_device(shard) + + for i in range(world_size): + shard = x[shard_list[i] : shard_list[i + 1]] + if i != rank: + tensor_list[i] = {"size": shard.size(), "dtype": dtype} + + return tensor_list + + def pack(self, x) -> _Tuple[str, _Callable, _torch.device, _torch.Tensor]: """ - Method to pack reused variables on to CPU. + Function that will be called every time an operation saves a tensor for backward. """ - if x.layout != _torch.sparse_coo and x.size() in self.size_list: - idx = self.size_list.index(x.size()) + key = None + op = lambda z: z.view(size) + if x.numel() <= self.min_size: + return x - assert self.size_list[idx] is not None + x_clone = x.clone() if self.max_mem <= 0 else None + device = x.device + size = x.size() - header = self.tensor_list[idx][0].to(self.device_list[idx]) + if x.dtype.is_floating_point: + grad_fn_list = [] + full_grad_fn_list = [] + c_grad_fn = x.grad_fn - if _torch.equal(x[0], -header * header / self.palett_tau): - return idx, "x_c_dist" - elif _torch.equal(x[0], _F.softmax(-header * header / self.palett_tau)): - return idx, "softmax" - else: - return x.to_sparse(), "sparse" + while len(grad_fn_list) < 2: + if c_grad_fn: + str_grad_fn = str(type(c_grad_fn)) - return x + full_grad_fn_list.append(str_grad_fn) + + if ( + self.__class__.RESHAPE_BACKWARD in str_grad_fn + or self.__class__.TOCOPY_BACKWARD in str_grad_fn + or self.__class__.EXPAND_BACKWARD in str_grad_fn + ): + pass + else: + grad_fn_list.append(str_grad_fn) + + c_grad_fn = c_grad_fn.next_functions[0][0] if c_grad_fn.next_functions else None + else: + break + + if key is None: + for _ in range(len(grad_fn_list), 2): + grad_fn_list.append("None") + + if ( + self.__class__.SOFTMAX_BACKWARD in grad_fn_list[0] + and self.__class__.DIV_BACKWARD in grad_fn_list[1] + ): + key = "softmax" + f".{self.pack_counter}" + elif ( + self.__class__.CLAMP_BACKWARD in grad_fn_list[0] + and self.__class__.SOFTMAX_BACKWARD in grad_fn_list[1] + ): + key = "softmax" + f".{self.pack_counter}" + op = lambda z: z.view(size).clamp(min=self.zero_threshold) + elif self.__class__.DIST_BACKWARD in grad_fn_list[0]: + self.pack_counter += 1 + key = "x_c_dist" + f".{self.pack_counter}" + elif ( + self.__class__.VIEW_BACKWARD in grad_fn_list[0] + and self.__class__.DIST_BACKWARD in grad_fn_list[1] + ): + key = "x_c_dist" + f".{self.pack_counter}" + elif ( + ( + self.__class__.VIEW_BACKWARD in grad_fn_list[0] + and self.__class__.STACK_BACKWARD in grad_fn_list[1] + ) + or ( + self.__class__.STACK_BACKWARD in grad_fn_list[0] + and self.__class__.INDEX_BACKWARD in grad_fn_list[1] + ) + or ( + self.__class__.STACK_BACKWARD in grad_fn_list[0] + and self.__class__.SLICE_BACKWARD in grad_fn_list[1] + ) + ): + key = "X.b" + f".{-1}" + elif ( + self.__class__.TRANS_BACKWARD in grad_fn_list[0] + and self.__class__.STACK_BACKWARD in grad_fn_list[1] + ): + key = "X.b" + f".{-1}" + if key in self.tensor_dict: + size = x.mT.size() + op = lambda z: z.reshape(size).mT + else: + key = None + + if key is None: + key = self._compress_tensor(x, x.dtype) + elif key not in self.tensor_dict: + w = self._compress_tensor(x, x.dtype) + self.tensor_dict[key] = w + else: + key = self._compress_tensor(x, _torch.uint8) + op = lambda z: z.to(device, _torch.int32) + + return key, op, device, x_clone - def reuse_unpack(self, x: _torch.Tensor): + def unpack(self, x) -> _torch.Tensor: """ - Method to unpack reused variables from CPU. + Function that will be called to return a + value to compute a new tensor, which is the one actually used during the backward pass. """ if isinstance(x, tuple): - obj, op = x - if isinstance(obj, int): - idx = obj - assert self.tensor_list[idx] is not None - self.tensor_list[idx] = self.tensor_list[idx].to(self.device_list[idx]) - - if op == "softmax": - val = self.tensor_list[idx] * self.tensor_list[idx] / self.palett_tau - return _F.softmax(-val, dim=1) - elif op == "x_c_dist": - return -self.tensor_list[idx] * self.tensor_list[idx] / self.palett_tau - elif op == "transpose": - return self.tensor_list[idx].T - else: - assert False - elif op == "sparse": - return obj.to_dense() + key, op, device, y = x + + look_up = isinstance(key, str) + if look_up: + v = self.tensor_dict[key] + else: + v = key + + v = _decompress_tensor(v, device) + + if look_up: + self.tensor_dict[key] = v + + x = op(v) + return x - def debug_hook(self, x: _torch.Tensor): + +def _lower_int(x, x_min=None, x_max=None) -> _torch.Tensor: + if x_min is None: + x_min, x_max = x.min(), x.max() + for t in [_torch.uint8, _torch.int8, _torch.int16, _torch.int32]: + if _torch.iinfo(t).bits >= _torch.iinfo(x.dtype).bits: + break + if _torch.iinfo(t).min <= x_min and x_max <= _torch.iinfo(t).max: + x = x.to(t) + break + return x + + +def _deunique_tensor(x, device) -> _torch.Tensor: + y, y_i, y_base = x + y = y.to(device, non_blocking=True) + y_i = y_i.to(_torch.int32) + if y_base > 0: + y_i += y_base + return y[y_i] + + +def _decompress_tensor(x, device) -> _torch.Tensor: + if not isinstance(x, list): + if isinstance(x, tuple): + x = _deunique_tensor(x, device=device) return x + + distributed_world_size = _dist.get_world_size() + distributed_rank = _dist.get_rank() + for i in range(distributed_world_size): + if isinstance(x[i], dict): + x[i] = _torch.empty(**x[i], device=device) + else: + if isinstance(x[i], tuple): + x[i] = _deunique_tensor(x[i], device=device) + else: + x[i] = x[i].to(device, non_blocking=True) + + _dist.all_gather(x[:distributed_world_size], x[distributed_rank]) + return _torch.concat(x, dim=0) diff --git a/coremltools/optimize/torch/palettization/_partitioner.py b/coremltools/optimize/torch/palettization/_partitioner.py index 599519904..1827478fc 100644 --- a/coremltools/optimize/torch/palettization/_partitioner.py +++ b/coremltools/optimize/torch/palettization/_partitioner.py @@ -1,15 +1,21 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import math as _math +from typing import Optional as _Optional from typing import Tuple as _Tuple import torch as _torch +import torch.distributed as _dist from ._efficient_kmeans import _EfficientKMeans +from ._utils import get_shard_list as _get_shard_list +from ._utils import vectorize as _vectorize - +# NF Cluster sizes for which partitioning has been verified. +NF_CLUSTER_SIZES = [8, 9, 16, 17] class _Partitioner: """ Internal class that manages partitioning. The ``FakePalettize`` class base classes the ``_Partitioner`` class @@ -20,169 +26,309 @@ def __init__( self, n_bits: int, enforce_zero: bool, - partition_size: int, + prune_threshold: float, cluster_dim: int, cluster_permute: _Tuple, + group_size: _Optional[int], palett_tau: float, kmeans_init: str, - prune_threshold: float, + percentage_palett_enable: float, kmeans_opt1d_threshold: int, - add_extra_centroid: bool, + kmeans_batch_threshold: int, + kmeans_n_init: int, + kmeans_error_bnd: float, ): - self.centroids_init = [kmeans_init] - if add_extra_centroid: - self.n_clusters = [2 ** int(n_bits) + 1] - else: - self.n_clusters = [2 ** int(n_bits)] - self.labels_init = [None] + self.centroids = [kmeans_init] + self.n_clusters = 2 ** int(n_bits) + self.labels = [None] self.enforce_zero = [enforce_zero] + self.enable_partition = [] + self.proj_factor = None self.partitions = [] - self.partition_size = partition_size + self.cum_inertia = [] self.cluster_dim = cluster_dim self.cluster_permute = cluster_permute - self.prune_threshold = float(prune_threshold) - - self.kmeans_init = kmeans_init + self.prune_threshold = prune_threshold + self.palett_tau = palett_tau # rename to palett_tau + self.group_size = group_size + self.percentage_palett_enable = percentage_palett_enable self.kmeans_opt1d_threshold = kmeans_opt1d_threshold - self.palett_tau = palett_tau + self.kmeans_batch_threshold = kmeans_batch_threshold + self.kmeans_n_init = kmeans_n_init + self.kmeans_error_bnd = kmeans_error_bnd - def create_partitions(self, weights: _torch.Tensor): + def create_partitions(self, weights) -> None: """ Method to create partitions in the weights. These partitions can be used to run channel level palettization. """ - num_channels = len(weights) - numel_per_channel = _torch.numel(weights[0]) - num_channels_per_partition = min( - num_channels, max(1, int(self.partition_size / numel_per_channel)) - ) - - self.partitions = [ - (n, min(n + num_channels_per_partition, num_channels)) - for n in range(0, num_channels, num_channels_per_partition) - ] - num_partitions = len(self.partitions) - - if self.centroids_init[0] == "auto": - # if auto then pick either init method - numel_per_partition = numel_per_channel * num_channels_per_partition - self.centroids_init[0] = ( - "opt1d" - if ( - numel_per_partition <= self.n_clusters[0] - or numel_per_partition <= self.kmeans_opt1d_threshold - ) - and self.cluster_dim == 1 - else "cpu.kmeans++" + with _torch.no_grad(): + num_channels = len(weights) + usr_num_channels_per_partition = ( + int(self.group_size) if self.group_size else num_channels ) - - self.centroids_init = self.centroids_init * num_partitions - self.n_clusters = self.n_clusters * num_partitions - self.labels_init = self.labels_init * num_partitions - self.enforce_zero = self.enforce_zero * num_partitions - - assert ( - num_channels_per_partition * numel_per_channel - >= min(self.n_clusters) * self.cluster_dim - ), f"The number of clusters ({self.n_clusters}) and/or the cluster dim ({self.cluster_dim}) is TOO big" + self.partitions = [ + list(range(i, min(num_channels, i + usr_num_channels_per_partition))) + for i in range(0, num_channels, usr_num_channels_per_partition) + ] + num_partitions = len(self.partitions) + self.centroids = self.centroids * num_partitions + self.labels = self.labels * num_partitions + self.enforce_zero = self.enforce_zero * num_partitions + self.cum_inertia = [1e9] * num_partitions + self.partition_numel = _torch.tensor( + [_torch.numel(weights[p]) for p in self.partitions] + ) + self.enable_partition = [True] * max( + 1, int(self.percentage_palett_enable * num_partitions) + ) + self.enable_partition += [False] * (num_partitions - len(self.enable_partition)) + numel_per_partition = max(self.partition_numel) + assert numel_per_partition + assert ( + numel_per_partition >= self.n_clusters * self.cluster_dim + ), f"The number of clusters ({self.n_clusters}) and/or the cluster dim ({self.cluster_dim}) is TOO big" def get_partition_kmeans( - self, weights: _torch.Tensor, partition_index: int, partition: int, max_iter: int, init: str - ): + self, + X, + partition, + n_clusters, + labels, + enforce_zero, + max_iter, + init, + n_init=10, + ) -> _EfficientKMeans: """ Method to get kmeans for a particular partition. """ - Y = weights[partition[0] : partition[1]].detach() - cY, pad = self.flatten(Y) + cY, pad = _vectorize(X[partition], self.cluster_dim) kmeans = _EfficientKMeans( - n_clusters=self.n_clusters[partition_index], + n_clusters=n_clusters, init=init, - labels=self.labels_init[partition_index], - n_init=10, + labels=labels, + n_init=n_init, max_iter=max_iter, + error_bnd=self.kmeans_error_bnd, ).fit(cY) - if self.enforce_zero[partition_index]: - zero_point = ( - _torch.zeros(kmeans.cluster_centers_[0].size()) - .to(kmeans.cluster_centers_.device) - .unsqueeze(0) + if enforce_zero: + # fix zero + zero_point = _torch.zeros_like(kmeans.cluster_centers_[0]).unsqueeze(0) + zero_idx = _torch.argmin( + _torch.cdist(kmeans.cluster_centers_.float(), zero_point.float()) ) - zero_idx = _torch.argmin(_torch.cdist(kmeans.cluster_centers_, zero_point)) - kmeans.cluster_centers_[zero_idx] = zero_point - weights[partition[0] : partition[1]] = self.deflatten( - kmeans.cluster_centers_[kmeans.labels_], Y.size(), pad - ) + # always put zero in the first + temp = kmeans.cluster_centers_[0] + kmeans.cluster_centers_[zero_idx] = temp + kmeans.cluster_centers_[0] = zero_point return kmeans - def init_partitions(self, weights: _torch.Tensor): + def init_partitions(self, parameters) -> None: """ Method to initialize the partitions and set the k-means. Called during first iteration of palettization in the forward method of ``FakePalettize``. """ + if isinstance(self.centroids[0], _torch.Tensor): + return with _torch.no_grad(): - self.create_partitions(weights) - for i, partition in enumerate(self.partitions): - kmeans = self.get_partition_kmeans( - weights.clone(), i, partition, max_iter=100, init=self.centroids_init[i] - ) + num_partitions = len(self.partitions) + numel_per_partition = max(self.partition_numel) + if "nf" in self.centroids[0]: + if self.n_clusters in NF_CLUSTER_SIZES and self.cluster_dim == 1: + nf_fit = "fit" in self.centroids[0] + for i, partition in enumerate(self.partitions): + bit = int(_math.log2(self.n_clusters)) + sparse = bool(_math.log2(self.n_clusters) - bit) - self.centroids_init[i] = kmeans.cluster_centers_ - self.labels_init[i] = kmeans.labels_ - self.n_clusters[i] = kmeans.n_clusters + self.centroids[i] = ( + _generate_natural_float(bit=bit, sparse=sparse) + .to(parameters.device) + .to(parameters.dtype) + .view(-1, 1) + ) - def flatten(self, weight_partition: _torch.Tensor): - """ - Method to flatten a particular weight partition. - """ - permute = self.cluster_permute - dim = self.cluster_dim + if nf_fit: + best_err = _torch.finfo(_torch.float).max + best_lambd = 1 + best_retry = 0 + best_thold = 10 + up_down_hill = 0 + lambd_list = [[1 + x / 100, 1 - x / 100] for x in range(99)] + lambd_list = [1] + [v for sublist in lambd_list for v in sublist] - if permute and len(permute) == len(weight_partition.size()): - weight_partition = weight_partition.permute(permute) + cur_X = parameters[self.partitions[i]].view(-1, 1) - num_misalignment = _torch.numel(weight_partition) % dim + for cur_lambd in lambd_list: + if up_down_hill > best_thold and cur_lambd < 1: + continue - pad = None - if num_misalignment: - weight_partition = weight_partition.flatten() - pad = weight_partition[-num_misalignment:] - weight_partition = weight_partition[:-num_misalignment] + if up_down_hill < -best_thold and cur_lambd > 1: + continue - return weight_partition.reshape(-1, dim), pad + cur_lut = _torch.stack( + [x.sign() * x.abs() ** (cur_lambd) for x in self.centroids[i]] + ) + x_c_dist = _torch.cdist(cur_X, cur_lut.to(cur_X.dtype)).square() + cur_err = x_c_dist.min(-1).values.float().sum() - def deflatten(self, weight_partition: _torch.Tensor, target_size: _Tuple, pad: _torch.Tensor): - """ - Method to deflatten a particular weight partition. - """ - permute = self.cluster_permute + if best_err > cur_err: + best_retry = 0 + best_err = cur_err + best_lambd = cur_lambd + if best_lambd > 1: + up_down_hill += 1 + else: + up_down_hill -= 1 - if pad is not None: - weight_partition = _torch.cat([weight_partition.flatten(), pad]) + elif best_retry > best_thold: + break + else: + best_retry += 1 - if permute and len(permute) == len(target_size): - cur_shape = [target_size[i] for i in permute] + self.centroids[i] = _torch.stack( + [x.sign() * x.abs() ** (best_lambd) for x in self.centroids[i]] + ) + return - weight_partition = weight_partition.reshape(cur_shape) - weight_partition = weight_partition.permute( - _torch.argsort(_torch.Tensor(permute)).tolist() - ) - assert weight_partition.size() == target_size + self.centroids = ["auto"] * num_partitions + + for i in range(num_partitions): + if self.centroids[i] == "auto": + # if auto then pick either init method + self.centroids[i] = ( + "opt1d" + if ( + numel_per_partition <= self.n_clusters + or numel_per_partition <= self.kmeans_opt1d_threshold + ) + and self.cluster_dim == 1 + else "kmeans++" + ) + + if _dist.is_available() and _dist.is_initialized(): + distributed_world_size = _dist.get_world_size() + else: + distributed_world_size = 1 + if max(num_partitions, distributed_world_size) < self.kmeans_batch_threshold: + for i, partition in enumerate(self.partitions): + kmeans = self.get_partition_kmeans( + parameters, + partition, + self.n_clusters, + self.labels[i], + self.enforce_zero[i], + max_iter=100, + init=self.centroids[i], + n_init=max(1, self.kmeans_n_init // distributed_world_size), + ) + bcast_rank = _get_best_rank(kmeans.inertia_, _torch.argmin) + if bcast_rank: + _dist.broadcast(kmeans.cluster_centers_, bcast_rank) + + self.centroids[i] = kmeans.cluster_centers_ + self.labels[i] = None + else: + shard_list = _get_shard_list(num_partitions) + centroids_list = [None] * distributed_world_size - return weight_partition.reshape(target_size) + for i in range(distributed_world_size): + begin, end = shard_list[i], shard_list[i + 1] + current_rank = ( + _dist.get_rank() if _dist.is_available() and _dist.is_initialized() else 0 + ) + if i == current_rank and begin < end: + for p in range(begin, end): + kmeans = self.get_partition_kmeans( + parameters, + self.partitions[p], + self.n_clusters, + self.labels[p], + self.enforce_zero[p], + max_iter=100, + init=self.centroids[p], + n_init=self.kmeans_n_init, + ) + self.centroids[p] = kmeans.cluster_centers_ + + centroids_list[i] = _torch.stack(self.centroids[begin:end]) + else: + centroids_list[i] = _torch.full( + [end - begin, self.n_clusters, self.cluster_dim], + float("nan"), + dtype=parameters.dtype, + device=parameters.device, + ) + + if _dist.is_available() and _dist.is_initialized(): + _dist.all_gather(centroids_list, centroids_list[_dist.get_rank()]) + centroids_list = [v for sublist in centroids_list for v in sublist] + + assert len(centroids_list) == num_partitions + for p in range(num_partitions): + self.labels[p] = None + self.centroids[p] = centroids_list[p] - # Do not use _load_from_state_dict as this class doesn't call super - # So it makes multiple inheritance easier to apprehend in child classes def _load_from_state_dict_( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, ): self.cluster_permute = state_dict.pop(prefix + "permute") self.partitions = state_dict.pop(prefix + "partitions") + self.centroids = state_dict.pop(prefix + "centroids") + self.labels = state_dict.pop(prefix + "labels") + self.proj_factor = state_dict.pop(prefix + "proj_factor") - # Do not use _save_to_state_dict as this class doesn't call super - # So it makes multiple inheritance easier to apprehend in child classes def _save_to_state_dict_(self, destination, prefix, keep_vars): + destination[prefix + "centroids"] = self.centroids + destination[prefix + "labels"] = self.labels destination[prefix + "permute"] = self.cluster_permute destination[prefix + "partitions"] = self.partitions + + +def _get_best_rank(metric, func=_torch.argmin) -> int: + """ + Get best rank of a particular metric according to a specified function. + """ + if _dist.is_available() and _dist.is_initialized(): + distributed_world_size = _dist.get_world_size() + if distributed_world_size > 1: + tensor_list = [_torch.zeros_like(metric) for _ in range(distributed_world_size)] + _dist.all_gather(tensor_list, metric) + bcast_rank = func(_torch.Tensor(tensor_list)) + + return bcast_rank + + return None + + +def _generate_natural_float(bit=4, sparse=False, offset=0.9677083) -> _torch.Tensor: + """ + Function to generate NF4 values. + """ + from scipy.stats import norm + + space = (2**bit) // 2 + # one more positive value, this is an asymmetric type + v1 = norm.ppf(_torch.linspace(offset, 0.5, space + 1)[:-1]).tolist() + + if sparse: + v3 = [-x for x in v1] + else: + v3 = (-norm.ppf(_torch.linspace(offset, 0.5, space)[:-1])).tolist() + + v = [0] + v3 + list(reversed(v1)) + + values = _torch.Tensor(v) + values /= values.max() + + return values diff --git a/coremltools/optimize/torch/palettization/_supported_modules.py b/coremltools/optimize/torch/palettization/_supported_modules.py index 66c6e82d6..921f6159e 100644 --- a/coremltools/optimize/torch/palettization/_supported_modules.py +++ b/coremltools/optimize/torch/palettization/_supported_modules.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -195,8 +195,19 @@ class MultiheadAttention(_nn.MultiheadAttention): _FLOAT_MODULE = _nn.MultiheadAttention def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): + is_batched = query.dim() == 3 + if self.batch_first and is_batched: + # Ensure that that the "is" property is maintained + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = (x.transpose(1, 0) for x in (query, key)) + value = key + else: + query, key, value = (x.transpose(1, 0) for x in (query, key, value)) if not self._qkv_same_embed_dim: - return _F.multi_head_attention_forward( + attn_output, attn_output_weights = _F.multi_head_attention_forward( query, key, value, @@ -220,7 +231,7 @@ def forward(self, query, key, value, key_padding_mask=None, need_weights=True, a v_proj_weight=self.v_proj_weight_fake_quant(self.v_proj_weight), ) else: - return _F.multi_head_attention_forward( + attn_output, attn_output_weights = _F.multi_head_attention_forward( query, key, value, @@ -239,6 +250,10 @@ def forward(self, query, key, value, key_padding_mask=None, need_weights=True, a need_weights=need_weights, attn_mask=attn_mask, ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights @classmethod def from_float(cls, mod): @@ -258,12 +273,12 @@ def from_float(cls, mod): mod.embed_dim, mod.num_heads, mod.dropout, + batch_first=mod.batch_first, bias=hasattr(mod, "in_proj_bias"), add_bias_kv=mod.bias_k is not None, add_zero_attn=mod.add_zero_attn, kdim=mod.kdim, vdim=mod.vdim, - qconfig=qconfig, ) qat.qconfig = qconfig if not qat._qkv_same_embed_dim: @@ -282,3 +297,20 @@ def from_float(cls, mod): setattr(qat.out_proj, name, param) return qat + + +def get_palettizable_parameters(module): + """ + Return a list of parameters of the module which can be palettized + """ + if isinstance(module, _nn.MultiheadAttention): + if not module._qkv_same_embed_dim: + return [ + module.out_proj.weight, + module.q_proj_weight, + module.k_proj_weight, + module.v_proj_weight, + ] + else: + return [module.in_proj_weight, module.out_proj.weight] + return [module.weight] diff --git a/coremltools/optimize/torch/palettization/_utils.py b/coremltools/optimize/torch/palettization/_utils.py new file mode 100644 index 000000000..3fd59e3d8 --- /dev/null +++ b/coremltools/optimize/torch/palettization/_utils.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from typing import Tuple as _Tuple + +import torch as _torch +import torch.distributed as _dist + +def vectorize(current_tensor, cluster_dim) -> _Tuple[_torch.Tensor, _torch.Tensor]: + """ + Function to vectorize a tensor till the point where its numel is divisible by cluster_dim. The remaining parameters + are returned as a pad. + """ + num_misalignment = _torch.numel(current_tensor) % cluster_dim + + pad = None + if num_misalignment: + current_tensor = current_tensor.flatten() + pad = current_tensor[-num_misalignment:] + current_tensor = current_tensor[:-num_misalignment] + + return current_tensor.reshape(-1, cluster_dim), pad + + +def devectorize(current_tensor, pad, target_size) -> _torch.Tensor: + """ + Function to devectorize by tracing back the vectorize operation in the method above. + """ + if pad is not None: + current_tensor = _torch.cat([current_tensor.flatten(), pad]) + + return current_tensor.reshape(target_size) + + +def get_shard_list(length) -> list: + """ + Function to generate shard_list for different partitions. + """ + + distributed_world_size = ( + _dist.get_world_size() if _dist.is_available() and _dist.is_initialized() else 1 + ) + shard_size = max(1, length // distributed_world_size) + shard_list = list(range(0, length, shard_size)) + if len(shard_list) > distributed_world_size: + shard_list = shard_list[:distributed_world_size] + [length] + else: + shard_list += [length] * (distributed_world_size + 1 - len(shard_list)) + + return shard_list diff --git a/coremltools/optimize/torch/palettization/fake_palettize.py b/coremltools/optimize/torch/palettization/fake_palettize.py index 3067edecb..dbb88df60 100644 --- a/coremltools/optimize/torch/palettization/fake_palettize.py +++ b/coremltools/optimize/torch/palettization/fake_palettize.py @@ -1,21 +1,38 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause import contextlib -import gc +import logging as _logging +from distutils.version import StrictVersion as _StrictVersion +from typing import Optional as _Optional +from typing import Tuple as _Tuple +from typing import Union as _Union import torch as _torch +import torch.distributed as _dist import torch.nn.functional as _F from torch.ao.quantization.observer import ObserverBase as _ObserverBase from torch.quantization import FakeQuantize as _FakeQuantize +from coremltools.optimize.torch._utils.torch_utils import get_torch_version as _get_torch_version + from ._efficient_kmeans import _EfficientKMeans -from ._fake_palettizer_tensor_hook import _FakePalettizationTensorHook +from ._fake_palettizer_tensor_hook import _FakePalettizerTensorHook from ._partitioner import _Partitioner +from ._utils import devectorize as _devectorize +from ._utils import get_shard_list as _get_shard_list +from ._utils import vectorize as _vectorize from .palettization_config import DEFAULT_PALETTIZATION_ADVANCED_OPTIONS +# This is the maximum torch version currently supported for supporting the +# FakePalettizerTensorHook as the backward graph tracing that the pack/unpack method +# does accepts certain names for functions which have been changed after this +# torch version +MAX_TORCH_VERSION_FOR_PALETT_MAX_MEM = "2.0.1" + +_logger = _logging.getLogger(__name__) class FakePalettize(_FakeQuantize, _Partitioner): """ @@ -55,6 +72,7 @@ class FakePalettize(_FakeQuantize, _Partitioner): ), n_bits=2, cluster_dim=1, + module_parameter_shape=torch.Size([5, 4]), ) model.linear2.qconfig = torch.quantization.QConfig( activation=fq_activation, weight=fq_weight @@ -71,6 +89,9 @@ class FakePalettize(_FakeQuantize, _Partitioner): observer (:obj:`torch.ao.quantization.observer.ObserverBase`): Observer for quantizing the ``LUT``. n_bits (:obj:`int`): Number of palettization bits. There would be :math:`2^{n\_bits}` unique weights in the ``LUT``. cluster_dim (:obj:`int`): Dimensionality of centroids to use for clustering. + enable_per_channel_scale (:obj:`bool`): When set to ``True``, per channel scaling is used along the channel dimension. + group_size (:obj:`int`): Each group of ``group_size`` number of channels are palettized using + different look up tables. quant_min (:obj:`int`): The minimum allowable quantized value. quant_max (:obj:`int`): The maximum allowable quantized value. cluster_dtype (:obj:`str`): String that decides whether to quantize the ``LUT`` or not. The following are the ``str`` @@ -90,21 +111,38 @@ def __init__( observer: _ObserverBase, n_bits: int, cluster_dim: int, + enable_per_channel_scale: bool = False, + group_size: _Optional[int] = None, quant_min: int = -128, quant_max: int = 127, cluster_dtype: str = "f32", advanced_options: dict = {}, **observer_kwargs, ): - partition_size = advanced_options.get( - "partition_size", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["partition_size"] - ) cluster_permute = advanced_options.get( "cluster_permute", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["cluster_permute"] ) palett_max_mem = advanced_options.get( "palett_max_mem", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_max_mem"] ) + if palett_max_mem < 1: + _CURRENT_TORCH_VERSION = _get_torch_version(_torch.__version__) + if _CURRENT_TORCH_VERSION > _StrictVersion(MAX_TORCH_VERSION_FOR_PALETT_MAX_MEM): + _logger.error( + f"palett_max_mem<1 is only supported till a max torch version " + f"of:{MAX_TORCH_VERSION_FOR_PALETT_MAX_MEM} " + ) + + palett_shard = advanced_options.get( + "palett_shard", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_shard"] + ) + palett_unique = advanced_options.get( + "palett_unique", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_unique"] + ) + palett_min_tsize = advanced_options.get( + "palett_min_tsize", + DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_min_tsize"], + ) kmeans_max_iter = advanced_options.get( "kmeans_max_iter", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["kmeans_max_iter"] ) @@ -125,7 +163,8 @@ def __init__( "palett_mode", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_mode"] ) palett_cluster_tol = advanced_options.get( - "palett_cluster_tol", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_cluster_tol"] + "palett_cluster_tol", + DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_cluster_tol"], ) palett_tau = advanced_options.get( "palett_tau", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_tau"] @@ -137,7 +176,38 @@ def __init__( "palett_lambda", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_lambda"] ) add_extra_centroid = advanced_options.get( - "add_extra_centroid", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["add_extra_centroid"] + "add_extra_centroid", + DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["add_extra_centroid"], + ) + per_channel_scaling_factor_scheme = advanced_options.get( + "per_channel_scaling_factor_scheme", + DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["per_channel_scaling_factor_scheme"], + ) + percentage_palett_enable = advanced_options.get( + "percentage_palett_enable", + DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["percentage_palett_enable"], + ) + kmeans_batch_threshold = advanced_options.get( + "kmeans_batch_threshold", + DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["kmeans_batch_threshold"], + ) + kmeans_n_init = advanced_options.get( + "kmeans_n_init", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["kmeans_n_init"] + ) + zero_threshold = advanced_options.get( + "zero_threshold", DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["zero_threshold"] + ) + palett_batch_mode = advanced_options.get( + "palett_batch_mode", + DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_batch_mode"], + ) + palett_dist = advanced_options.get( + "palett_dist", + DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_dist"], + ) + kmeans_error_bnd = advanced_options.get( + "kmeans_error_bnd", + DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["kmeans_error_bnd"], ) self._target_module_level_sparsity = 0.0 @@ -147,21 +217,37 @@ def __init__( self, n_bits, enforce_zero, - partition_size, + prune_threshold, cluster_dim, cluster_permute, + group_size, palett_tau, kmeans_init, - prune_threshold, + percentage_palett_enable, kmeans_opt1d_threshold, - add_extra_centroid, + kmeans_batch_threshold, + kmeans_n_init, + kmeans_error_bnd, ) + self.cluster_permute = cluster_permute + self.enable_per_channel_scale = enable_per_channel_scale + self.per_channel_scaling_factor_scheme = per_channel_scaling_factor_scheme + self.per_channel_scaling_factor = None + self.partitions = [] + self.group_size = group_size self.cluster_dtype = cluster_dtype self.add_extra_centroid = add_extra_centroid self.need_to_quantize = self.cluster_dtype in ["i8", "u8", "f16"] self.autograd_graph = hasattr(_torch.autograd, "graph") and palett_max_mem < 1.0 self.palett_max_mem = palett_max_mem + self.palett_min_tsize = palett_min_tsize + self.palett_unique = palett_unique + self.palett_shard = palett_shard + self.palett_dist = palett_dist and _dist.is_available() and _dist.is_initialized() + self.zero_threshold = zero_threshold + self.prune_threshold = prune_threshold + self.palett_batch_mode = palett_batch_mode self.palett_cluster_tol = palett_cluster_tol self.kmeans_max_iter = kmeans_max_iter self.palett_mode = palett_mode @@ -171,18 +257,9 @@ def __init__( self.n_bits = n_bits self.cluster_dim = cluster_dim self.kmeans_init = kmeans_init - # Temporary create placeholder buffers that will get replaced with proper centroids on the first forward, - # or when we reload a checkpoint. Having placeholder values is useful to maintain the structure of the state - # dict constant. - self.register_buffer("centroids", _torch.rand([1])) - self.register_buffer("labels", _torch.rand([1])) - # During init, we would want the fake_palett_enabled flag to be False, i.e. to be at a state of 0. Also, we - # would have set the fake_quant_enabled and observer_enabled to be 0 as well so that palettizer does nothing - # until the first milestone. self.register_buffer("fake_palett_enabled", _torch.tensor([0], dtype=_torch.uint8)) self.disable_fake_quant() self.disable_observer() - self.buffers_are_placeholders = True def enable_fake_palett(self, enabled: bool = True) -> None: self.fake_palett_enabled[0] = 1 if enabled else 0 @@ -190,210 +267,376 @@ def enable_fake_palett(self, enabled: bool = True) -> None: def disable_fake_palett(self): self.enable_fake_palett(False) - def diff_palettize(self, weights: _torch.Tensor): - """ - Method called to run the differentiable k-means operation. - """ - use_cpu_if_cuda_available = False - if _torch.cuda.is_available(): - t = _torch.cuda.get_device_properties(weights.device).total_memory - a = _torch.cuda.memory_allocated(weights.device) - use_cpu_if_cuda_available = (a / t) > self.palett_max_mem and self.autograd_graph - if use_cpu_if_cuda_available: - if _FakePalettizationTensorHook.gc_trigger is None: - _FakePalettizationTensorHook.gc_trigger = True - - if _FakePalettizationTensorHook.gc_trigger: - gc.collect() - - auto_grad_graph_on_cpu = ( - _torch.autograd.graph.save_on_cpu(pin_memory=True) - if use_cpu_if_cuda_available - else contextlib.nullcontext() + def diff_palettize(self, X) -> _torch.Tensor: + cX, pad = list( + zip( + *[ + _vectorize(X[partition], self.cluster_dim) + for i, partition in enumerate(self.partitions) + ] + ) ) - for i, partition in enumerate(self.partitions): + if self.training: + with _torch.no_grad(): + if self.palett_tau > 0: + new_centroid_list = [] + new_cur_n_clusters = self.n_clusters + for i, partition in enumerate(self.partitions): + if not self.enable_partition[i]: + continue + + cur_clusters, cur_inverse, cur_counts = _torch.unique( + self.centroids[i].float(), + dim=0, + return_inverse=True, + return_counts=True, + ) + cur_n_clusters = len(cur_clusters) + new_cur_n_clusters = min(new_cur_n_clusters, cur_n_clusters) + + if cur_n_clusters < self.n_clusters * (1 - self.palett_cluster_tol): + for j, count in enumerate(cur_counts): + if count > 1: + new_centroid = 0.5 * ( + cur_clusters[j] + cur_clusters[(j + 1) % cur_n_clusters] + ) + self.centroids[i][cur_inverse.tolist().index(j)] = new_centroid + new_centroid_list.append(new_centroid) + + batch_partitions = [] + seq_partitions = [] + disabled_partitions = [] + most_common_numel = None + + for i, numel in enumerate(self.partition_numel): + if self.enable_partition[i]: + if most_common_numel is None: + most_common_numel = self.partition_numel[self.enable_partition].mode()[0] + if numel == most_common_numel: + batch_partitions.append(i) + else: + seq_partitions.append(i) + elif isinstance(self.centroids[i], _torch.Tensor): + disabled_partitions.append(i) + + if len(batch_partitions) == 1 or not self.palett_batch_mode: + seq_partitions += batch_partitions + batch_partitions = [] + + if batch_partitions: + X, mean_inertia = self.diff_palettize_batch(X, cX, pad, batch_partitions) + + if seq_partitions: + X, mean_inertia = self.diff_palettize_seq(X, cX, pad, seq_partitions) + + if disabled_partitions: + X = self.palettize(X, cX, pad, disabled_partitions) + else: + X = self.palettize(X, cX, pad, partitions=range(len(self.partitions))) - current_partition_clone = weights[partition[0] : partition[1]].clone() - cX, pad = self.flatten(current_partition_clone) + return X - with _torch.no_grad(): - palett_table = _torch.unique(self.centroids[i], dim=0) - if len(palett_table) < self.n_clusters[i] * self.palett_cluster_tol: - # We use n_init as 3 so as to not spend a lot of time running this operation - kmeans = _EfficientKMeans( - n_clusters=self.n_clusters[i], - init="kmeans++", - labels=self.labels[i], - n_init=3, - max_iter=1, - ) - kmeans.kmeans_pp(3, cX, 0) - self.centroids[i] = kmeans.cluster_centers_ + def diff_palettize_seq( + self, X, cX, pad, partitions + ) -> _Tuple[_torch.Tensor, _Union[_torch.Tensor, int]]: + cur_inertia = [] + for p in partitions: + partition = self.partitions[p] + centroids = self.centroids[p].clone() + if _torch.is_grad_enabled(): + assert not centroids.requires_grad - centroids = self.centroids[i].clone() + cX_p = cX[p] + cX_pt = cX_p.T - assert not centroids.requires_grad last_inertia = None + keep_sparsity = self.prune_threshold == 0 and self.enforce_zero[p] for j in range(self.kmeans_max_iter): - if self.autograd_graph: - tensor_hook = _FakePalettizationTensorHook( - [_torch.Size([cX.size()[0], centroids.size()[0]])], - use_cpu_if_cuda_available, - f"FakePalettizationTensorHook.{i}.{j}", - self.palett_tau, - ) - auto_grad_graph_hook_init = _torch.autograd.graph.saved_tensors_hooks( - tensor_hook.init_pack, tensor_hook.init_unpack + x_c_dist = _EfficientKMeans.x_c_dist(cX_p, centroids) + + if keep_sparsity: + # need to be keep pruning exact, no additional weight to be pruned by being assigned to the zero + # centroid. the zero centroid is always centroids[0] + if _torch.is_nonzero(centroids[0]): + centroids[0] = _torch.zeros_like(centroids[0]).unsqueeze(0) + + cX_nonzero_indices = cX_p.nonzero(as_tuple=True)[0] + x_c_dist[cX_nonzero_indices, :1] = 1 / self.zero_threshold + + if self.prune_threshold > 0: + x_c_dist[:, :1] -= self.prune_threshold + + if "dkm" in self.palett_mode: + attention = _F.softmax(-x_c_dist / self.palett_tau, dim=-1).clamp( + min=self.zero_threshold ) - auto_grad_graph_hook_reuse = _torch.autograd.graph.saved_tensors_hooks( - tensor_hook.reuse_pack, tensor_hook.reuse_unpack + elif "topk" in self.palett_mode: + values, indices = _torch.topk(x_c_dist, k=2, dim=-1, largest=False) + attention_topk = _F.softmax(-values / self.palett_tau, dim=-1) + attention = _torch.zeros_like(x_c_dist) + attention[:, indices] = attention_topk + elif "hard" in self.palett_mode: + col_idx = x_c_dist.min(dim=-1).indices + row_idx = _torch.arange(start=0, end=len(col_idx), dtype=_torch.int32).to( + cX_p.device ) + attention = _torch.sparse_coo_tensor( + _torch.vstack([row_idx, col_idx]), + _torch.ones_like(row_idx).to(cX_p.device), + x_c_dist.size(), + dtype=x_c_dist.dtype, + requires_grad=True, + ).to_dense() + elif "gsm" in self.palett_mode: + attention = _F.gumbel_softmax(-x_c_dist / self.palett_tau, dim=-1) else: - auto_grad_graph_hook_init = contextlib.nullcontext() - auto_grad_graph_hook_reuse = contextlib.nullcontext() - - with auto_grad_graph_hook_init: - x_c_dist = _EfficientKMeans.x_c_dist(cX, centroids) - min_error, _ = x_c_dist.min(dim=-1) - - with auto_grad_graph_hook_reuse: - if "dkm" in self.palett_mode: - attention = _F.softmax(-x_c_dist / self.palett_tau, dim=1) - elif "gsm" in self.palett_mode: - attention = _F.gumbel_softmax(-x_c_dist / self.palett_tau, dim=1) - elif "hard" in self.palett_mode: - col_idx = x_c_dist.min(dim=1).indices - row_idx = _torch.arange(start=0, end=len(col_idx), dtype=_torch.int32).to( - cX.device - ) - attention = _torch.sparse_coo_tensor( - _torch.vstack([row_idx, col_idx]), - _torch.ones_like(row_idx).to(cX.device), - x_c_dist.size(), - dtype=x_c_dist.dtype, - requires_grad=True, - ).to_dense() + raise ValueError(f"palett_mode: {self.palett_mode} currently not supported.") - assert attention.requires_grad + # attention_sum can overflow with fp16 attention_sum = attention.sum(dim=0).view(-1, 1) - attention_sum[attention_sum == 0] = 1e-6 + assert not (attention_sum == 0).any() - with auto_grad_graph_hook_reuse: - centroids = _torch.matmul(cX.T, attention).T / attention_sum + # matmul can overflow with fp16 + centroids = _torch.matmul(cX_pt, attention).T / attention_sum - with auto_grad_graph_on_cpu: - if self.need_to_quantize: - centroids = super().forward(centroids) + if self.need_to_quantize: + centroids = super().forward(centroids) + if _torch.is_grad_enabled(): assert centroids.requires_grad - if self.prune_threshold > 0: - centroids = _torch.nn.Hardshrink(self.prune_threshold.item())(centroids) + if self.enforce_zero[p]: + # fix zero + zero_point = _torch.zeros_like(centroids[0]).unsqueeze(0) + centroids[0] = zero_point - if self.enforce_zero[i]: - zero_point = ( - _torch.zeros(centroids[0].size()).to(centroids.device).unsqueeze(0) - ) - zero_idx = _torch.argmin(_torch.cdist(centroids, zero_point)) - centroids[zero_idx] = zero_point + min_error, _ = x_c_dist.min(dim=-1) + cur_inertia.append(min_error.sum()) + + if last_inertia and abs(last_inertia - cur_inertia[-1]) <= self.palett_epsilon: + break + + last_inertia = cur_inertia[-1] - cur_inertia = min_error.sum() + X[partition] = _devectorize( + _torch.matmul(attention, centroids), pad[p], X[partition].size() + ).to(X.dtype) - if last_inertia and abs(last_inertia - cur_inertia) <= self.palett_epsilon: + self.labels[p] = None + self.centroids[p] = centroids.detach().to(X.dtype) + self.cum_inertia[p] += float(cur_inertia[-1].detach()) + + return X, (_torch.stack(cur_inertia).mean() if cur_inertia else -1) + + def diff_palettize_batch(self, X, cX, pad, partitions) -> _Tuple[_torch.Tensor, _torch.Tensor]: + num_partitions = len(partitions) + centroids = _torch.stack([self.centroids[i] for i in partitions]).clone() + cX = _torch.stack([cX[i] for i in partitions]) + cXt = cX.mT + last_inertia = None + + for j in range(self.kmeans_max_iter): + if self.palett_dist: + x_c_dist = dist_batch_cdist_square.apply(cX, centroids) + else: + x_c_dist = _EfficientKMeans.x_c_dist(cX, centroids) + + attention = _F.softmax(-x_c_dist / self.palett_tau, -1).clamp(min=self.zero_threshold) + + # attention_sum can overflow with fp16 + if _torch.is_grad_enabled(): + assert attention.requires_grad + attention_sum = attention.sum(dim=1).view(num_partitions, -1, 1) + + centroids = _torch.matmul(cXt, attention).mT / attention_sum + + if self.need_to_quantize: + centroids = super().forward(centroids) + + if _torch.is_grad_enabled(): + assert centroids.requires_grad + if self.enforce_zero[0]: + zero_point = _torch.zeros_like(centroids[0][0]).unsqueeze(0) + + for k in range(centroids.size(0)): + centroids[k][0] = zero_point + + if self.kmeans_max_iter <= 1 and self.percentage_palett_enable >= 1: + cur_inertia = _torch.zeros([num_partitions], device=X.device, dtype=X.dtype) + break + else: + min_error, _ = x_c_dist.min(dim=-1) + cur_inertia = min_error.sum(dim=1) + avg_inertia = cur_inertia.mean() + + if last_inertia and abs(last_inertia - avg_inertia) <= self.palett_epsilon: break - last_inertia = cur_inertia + last_inertia = avg_inertia - with auto_grad_graph_hook_reuse: - weights[partition[0] : partition[1]] = self.deflatten( - _torch.matmul(attention, centroids), current_partition_clone.size(), pad - ) + tX = _torch.matmul(attention, centroids) - self.centroids[i] = ( - self.palett_lambda * self.centroids[i] + (1 - self.palett_lambda) * centroids - ).detach() - self.labels[i] = attention.detach().max(dim=1)[1].data + for i, p in enumerate(partitions): + partition = self.partitions[p] + X[partition] = _devectorize(tX[i], pad[p], X[partition].size()).to(X.dtype) + self.labels[p] = None + self.centroids[p] = centroids[i].detach().to(X.dtype) + self.cum_inertia[p] += float(cur_inertia[i].detach()) - return weights + return X, cur_inertia - def palettize(self, weights: _torch.Tensor): + def palettize(self, X, cX, pad, partitions) -> _torch.Tensor: """ This method is run during inference time by the forward method of the ``FakePalettize`` class. It calculates the weight from the ``LUT`` and ``indices`` across all partitions and returns them. """ - for i, partition in enumerate(self.partitions): - labels = self.labels[i] - if labels is not None: - current_weight_partition = weights[partition[0] : partition[1]].detach() - _, pad = self.flatten(current_weight_partition) - - weights[partition[0] : partition[1]] = self.deflatten( - self.centroids[i][labels.long()], current_weight_partition.size(), pad - ) + batch_partitions = [] + seq_partitions = [] + most_common_numel = self.partition_numel[partitions].mode()[0] + + for p in partitions: + if self.partition_numel[p] == most_common_numel and self.labels[p] is None: + batch_partitions.append(p) + else: + seq_partitions.append(p) + + if len(batch_partitions) == 1 or not self.palett_batch_mode: + seq_partitions += batch_partitions + batch_partitions = [] + + if seq_partitions: + X = self.palettize_seq(X, cX, pad, seq_partitions) + + if batch_partitions: + X = self.palettize_batch(X, cX, pad, batch_partitions) - return weights + return X - def forward(self, weights: _torch.Tensor): - if self.partition_size == 0: - forwarded_weights = super().forward(weights) - if self.fake_palett_enabled[0] == 1: - with _torch.no_grad(): - quant_centroids, quant_labels = forwarded_weights.unique(return_inverse=True) - self.centroids = _torch.stack([quant_centroids.view(-1, self.cluster_dim)]) - self.labels = _torch.stack([quant_labels]) + def palettize_seq(self, X, cX, pad, partitions) -> _torch.Tensor: + for p in partitions: + partition = self.partitions[p] + labels = self.labels[p] + centroids = self.centroids[p] + if labels is None: + cX_p = cX[p] + + x_c_dist = _EfficientKMeans.x_c_dist(cX_p, centroids) + + if self.prune_threshold > 0: + x_c_dist[:, :1] -= self.prune_threshold + + min_error, labels = x_c_dist.min(dim=-1) + self.labels[p] = labels.to(_torch.int).cpu() + + if X is not None: + X[partition] = _devectorize( + centroids[self.labels[p]], pad[p], X[partition].size() + ).to(X.dtype) + + return X + + def palettize_batch(self, X, cX, pad, partitions) -> _torch.Tensor: + # intentionally use cat instead of stack to make the backward graph distinguishable from diff_palettize_batch + cX = _torch.cat([cX[i] for i in partitions]).view(len(partitions), -1, self.cluster_dim) + centroids = _torch.stack([self.centroids[i] for i in partitions]) + x_c_dist = _EfficientKMeans.x_c_dist(cX, centroids) + + if self.prune_threshold > 0: + x_c_dist[:, :, :1] -= self.prune_threshold + + min_error, labels = x_c_dist.min(dim=-1) + + for i, p in enumerate(partitions): + partition = self.partitions[p] + centroids = self.centroids[p] + self.labels[p] = labels[i].to(_torch.int).cpu() + + X[partition] = _devectorize(centroids[self.labels[p]], pad[p], X[partition].size()).to( + X.dtype + ) + + return X + + def forward(self, weights: _torch.Tensor) -> _torch.Tensor: + if self.cluster_permute and len(self.cluster_permute) == len(weights.size()): + weights = weights.permute(self.cluster_permute) + if self.enable_per_channel_scale: + if not isinstance(self.per_channel_scaling_factor, _torch.Tensor): + self.per_channel_scaling_factor = _torch.zeros((weights.flatten(1).shape[0], 1)) + with _torch.no_grad(): + if not self.per_channel_scaling_factor[0][0]: + permuted_weights_proj = weights.flatten(1) + if self.per_channel_scaling_factor_scheme == "min_max": + self.per_channel_scaling_factor = 0.5 * ( + permuted_weights_proj.max(1)[0].view(-1, 1) + - permuted_weights_proj.min(1)[0].view(-1, 1) + ) + elif self.per_channel_scaling_factor_scheme == "abs": + self.per_channel_scaling_factor = ( + permuted_weights_proj.abs().max(1)[0].view(-1, 1) + ) + else: + raise ValueError( + f"Unsupported per_channel_scaling_factor_scheme:{self.per_channel_scaling_factor_scheme}" + ) + + weights = (weights.flatten(1) / self.per_channel_scaling_factor).view( + weights.size() + ) # scale the weights using projection factors + + if self.fake_palett_enabled[0] == 1: + if not self.partitions: + self.create_partitions(weights.detach()) + tensor_hook = None + if self.training and self.palett_max_mem < 1.0: + tensor_hook = _FakePalettizerTensorHook( + zero_threshold=self.zero_threshold, + device=weights.device, + min_size=self.palett_min_tsize, + max_mem=self.palett_max_mem, + use_unique=self.palett_unique + and self.cluster_dim == 1 + and weights.dtype in [_torch.bfloat16, _torch.float16], + use_shard=self.palett_shard, + ) + + with _torch.autograd.graph.saved_tensors_hooks( + tensor_hook.pack, tensor_hook.unpack + ) if tensor_hook else contextlib.nullcontext(): + cloned_weights = weights.clone() + self.init_partitions(cloned_weights.detach()) + palettized_weights = self.diff_palettize(cloned_weights) else: - forwarded_weights = weights.clone() + palettized_weights = super().forward(weights) - if self.fake_palett_enabled[0] == 1: - if not self.partitions: - self.init_partitions(weights.detach()) - self.centroids = _torch.stack(self.centroids_init) - self.labels = _torch.stack(self.labels_init) - self.buffers_are_placeholders = False + if self.enable_per_channel_scale: + palettized_weights = ( + palettized_weights.flatten(1) * self.per_channel_scaling_factor + ).view(palettized_weights.size()) - if self.training: - forwarded_weights = self.diff_palettize(forwarded_weights) - else: - forwarded_weights = self.palettize(forwarded_weights) - else: - forwarded_weights = super().forward(weights) + if self.cluster_permute: + palettized_weights = palettized_weights.permute( + _torch.argsort(_torch.Tensor(self.cluster_permute)).tolist() + ) if self.cluster_dtype == "f16": - forwarded_weights = forwarded_weights.to(_torch.float16).to(weights.dtype) + palettized_weights = palettized_weights.to(_torch.float16).to(weights.dtype) elif self.cluster_dtype == "b16": - forwarded_weights = forwarded_weights.to(_torch.bfloat16).to(weights.dtype) + palettized_weights = palettized_weights.to(_torch.bfloat16).to(weights.dtype) - return forwarded_weights + return palettized_weights def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): self.cluster_dtype = local_metadata["cluster_dtype"] - state_dict_buffers_are_placeholders = local_metadata["buffers_are_placeholders"] - - if not self.buffers_are_placeholders and state_dict_buffers_are_placeholders: - raise ValueError( - f"Trying to reload an uninitialized state dict onto an initialized module: {prefix[:-1]}" - ) - - if self.buffers_are_placeholders and not state_dict_buffers_are_placeholders: - # We only change the size of the placeholders if we intend to reload a proper checkpoint - # onto an uninitialized module. In the other cases, we expect the state dict and the module to be compatible. - self.centroids = _torch.empty( - state_dict[prefix + "centroids"].size(), device=self.centroids.device - ) - self.labels = _torch.empty( - state_dict[prefix + "labels"].size(), device=self.labels.device - ) - self.fake_palett_enabled = _torch.empty( - state_dict[prefix + "fake_palett_enabled"].size(), device=self.labels.device - ) - - self.buffers_are_placeholders = state_dict_buffers_are_placeholders - + self.fake_palett_enabled = _torch.empty( + state_dict[prefix + "fake_palett_enabled"].size(), + device=self.centroids.device, + ) _Partitioner._load_from_state_dict_( self, state_dict, @@ -438,24 +681,95 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): # Skip FakeQuantize._save_to_state_dict and go directly to nn.Module._save_to_state_dict super(_FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars) - # State dicts can only contain tensors (for DDP), so store infos in the metadata dict (in particular str) + # State dicts can only contain tensors (for DDP), so store infos in the metatadata dict (in particular str) destination._metadata[prefix[:-1]]["cluster_dtype"] = self.cluster_dtype - destination._metadata[prefix[:-1]][ - "buffers_are_placeholders" - ] = self.buffers_are_placeholders + destination[prefix + "per_channel_scaling_factor"] = self.per_channel_scaling_factor _Partitioner._save_to_state_dict_(self, destination, prefix + "palett.", keep_vars) def __repr__(self): rep = super().__repr__() - if self.centroids.shape[0] != self.n_clusters: - rep += " ===> centroids: uninitialised buffer, " - rep += "labels: uninitialised buffer, " - else: - rep += f" ===> centroids: {self.centroids}, " - rep += f"labels: {self.labels}, " rep += f"cluster_dtype: {self.cluster_dtype}, " rep += f"n_bits: {self.n_bits}, " rep += f"cluster_dim: {self.cluster_dim}, " rep += f"palett_tau: {self.palett_tau}, " rep += f"palett_mode: {self.palett_mode}" return rep + + +class dist_batch_cdist_square(_torch.autograd.Function): + def forward_2d(X, C): + _C = C.reshape(-1) + _X = X.repeat(1, C.size(0)) + _T = _X - _C + _T = _T.square() + T = _T.view(X.size(0), C.size(0), C.size(1)).sum(dim=-1) + return T + + def forward_3d(X, C): + T = [None] * X.size(0) + + for i in range(X.size(0)): + T[i] = dist_batch_cdist_square.forward_2d(X[i], C[i]) + + return _torch.stack(T) + + def backward_2d(X, C, grad_output): + _C = C.reshape(-1) + _X = X.repeat(1, C.size(0)) + _T = _X - _C + _T = _T.view(-1, C.size(0), C.size(1)) + _T = _T * grad_output.unsqueeze(-1).expand( + grad_output.size(0), grad_output.size(1), C.size(1) + ) + + grad_X = _T.sum(dim=1) + grad_C = _T.sum(dim=0) + + return 2 * grad_X, -2 * grad_C + + def backward_3d(X, C, grad_output): + grad_X = [None] * X.size(0) + grad_C = [None] * X.size(0) + + for i in range(X.size(0)): + grad_X[i], grad_C[i] = dist_batch_cdist_square.backward_2d(X[i], C[i], grad_output[i]) + + return _torch.stack(grad_X), _torch.stack(grad_C) + + @staticmethod + def forward(ctx, X, C): + shard_list = _get_shard_list(X.size(0)) + T = [None] * _dist.world_size + + for i in range(_dist.world_size): + cur_X = X[shard_list[i] : shard_list[i + 1]] + cur_C = C[shard_list[i] : shard_list[i + 1]] + + if i == _dist.rank: + T[i] = _torch.cdist(cur_X, cur_C).square() + else: + T[i] = _torch.zeros( + [cur_X.size(0), cur_X.size(1), cur_C.size(1)], + device=X.device, + dtype=X.dtype, + ) + + _dist.all_gather(T, T[_dist.rank]) + T = _torch.cat(T) + + M = _torch.Tensor([]) + ctx.save_for_backward(X, C, M) + + return T + + @staticmethod + def backward(ctx, grad_output): + X, C, _ = ctx.saved_tensors + + # gradient is data-dependent, so it CANNOT be sharded + if X.dim() == 3: + grad_X, grad_C = dist_batch_cdist_square.backward_3d(X, C, grad_output) + else: + grad_X, grad_C = dist_batch_cdist_square.backward_2d(X, C, grad_output) + + return grad_X, grad_C diff --git a/coremltools/optimize/torch/palettization/palettization_config.py b/coremltools/optimize/torch/palettization/palettization_config.py index 06aba9fb7..775c2dc59 100644 --- a/coremltools/optimize/torch/palettization/palettization_config.py +++ b/coremltools/optimize/torch/palettization/palettization_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -26,16 +26,19 @@ ModuleOptimizationConfig as _ModuleOptimizationConfig, ) from coremltools.optimize.torch.optimization_config import OptimizationConfig as _OptimizationConfig -from coremltools.optimize.torch.optimization_config import _validate_module_type_keys_factory +from coremltools.optimize.torch.optimization_config import ( + PalettizationGranularity, + _deprecated_field, + _validate_module_type_keys_factory, +) # Default advanced options for palettization DEFAULT_PALETTIZATION_ADVANCED_OPTIONS = { - "partition_size": 2000000000, "cluster_permute": None, "palett_max_mem": 1.0, "kmeans_max_iter": 3, - "prune_threshold": 0.0, - "kmeans_init": "cpu.kmeans++", + "prune_threshold": 1e-7, + "kmeans_init": "auto", "kmeans_opt1d_threshold": 1024, "enforce_zero": False, "palett_mode": "dkm", @@ -43,7 +46,18 @@ "palett_epsilon": 0.0001, "palett_lambda": 0.0, "add_extra_centroid": False, - "palett_cluster_tol": 0.05, + "palett_cluster_tol": 0.0, + "palett_min_tsize": 64 * 1024, + "palett_unique": False, + "palett_shard": False, + "palett_batch_mode": False, + "palett_dist": False, + "per_channel_scaling_factor_scheme": "min_max", + "percentage_palett_enable": 1.0, + "kmeans_batch_threshold": 4, + "kmeans_n_init": 100, + "zero_threshold": 1e-7, + "kmeans_error_bnd": 0.0, } @@ -55,6 +69,9 @@ "weight_threshold": 2048, "milestone": 0, "quantize_activations": False, + "enable_per_channel_scale": False, + "granularity": "per_tensor", + "group_size": None, } @@ -106,7 +123,16 @@ class ModuleDKMPalettizerConfig(_ModuleOptimizationConfig): n_bits (:obj:`int`): Number of clusters. The number of clusters used is :math:`2^{n\_bits}`. Defaults to ``4`` for linear layers and ``2`` for all other layers. weight_threshold (:obj:`int`): A module is only palettized if the number of elements in - its weight matrix exceeds ``weight_threshold``. Defaults to ``2048``. + its weight matrix exceeds ``weight_threshold``. If there are multiple weights in a + module (like :py:class:`torch.nn.MultiheadAttention`), all of them must have + more elements than the ``weight_threshold`` for the module to be palettized. + Defaults to ``2048``. + granularity (:py:class:`PalettizationGranularity`) – Granularity for palettization. + One of ``per_tensor`` or ``per_grouped_channel``. Defaults to ``per_tensor``. + group_size (:obj:`int`): Specify the number of channels in a group. + Only effective when granularity is ``per_grouped_channel``. + enable_per_channel_scale (:obj:`bool`): When set to ``True``, per channel scaling is used along the channel + dimension. milestone (:obj:`int`): Step or epoch at which palettization begins. Defaults to ``0``. cluster_dim (:obj:`int`, ``optional``): The dimension of each cluster. Defaults to ``1``. quant_min: (:obj:`int`, ``optional``): The minimum value for each element in the weight clusters if they are @@ -120,8 +146,6 @@ class ModuleDKMPalettizerConfig(_ModuleOptimizationConfig): by default, the clusters aren't quantized. quantize_activations (:obj:`bool`, ``optional``): When ``True``, the activation are quantized. Defaults to ``False``. - partition_size (:obj:`int`, ``optional``): partition_size helps in per channel palettization. - Defaults to ``2000000000``. cluster_permute (:obj:`tuple`, ``optional``): Permutation order to apply to weight partitions. Defaults to ``None``. palett_max_mem (:obj:`float`, ``optional``): Proportion of available GPU memory that should be used for @@ -129,9 +153,9 @@ class ModuleDKMPalettizerConfig(_ModuleOptimizationConfig): kmeans_max_iter (:obj:`int`, ``optional``): Maximum number of differentiable ``k-means`` iterations. Defaults to ``3``. prune_threshold (:obj:`float`, ``optional``): Hard-shrinks weights between [``-prune_threshold``, - ``prune_threshold``] to zero. Useful for joint pruning and palettization. Defaults to ``0.0``. + ``prune_threshold``] to zero. Useful for joint pruning and palettization. Defaults to ``1e-7``. kmeans_init (:obj:`str`, ``optional``): ``k-means`` algorithm to use. Allowed options are - ``efficient_kmeans``, ``cpu.kmeans++`` and ``kmeans_pp``. Defaults to ``cpu.kmeans++``. + ``opt1d``, ``cpu.kmeans++`` and ``kmeans++``. Defaults to ``auto``. kmeans_opt1d_threshold (:obj:`int`, ``optional``): Channel threshold to decide if ``opt1d kmeans`` should be used. Defaults to ``1024``. enforce_zero (:obj:`bool`, ``optional``): If ``True``, enforces the LUT centroid which is closest to @@ -147,7 +171,46 @@ class ModuleDKMPalettizerConfig(_ModuleOptimizationConfig): add_extra_centroid (:obj:`bool`, ``optional``): If ``True``, adds an extra centroid to the LUT. Defaults to ``False``. palett_cluster_tol (:obj:`float`, ``optional``): Tolerance for non-unique centroids in the LUT. - The higher the number, the more tolerance for non-unique centroids. Defaults to ``0.05``. + The higher the number, the more tolerance for non-unique centroids. Defaults to ``0.0``. + palett_min_tsize (:obj:`int`, ``optional``): Weight threshold beyond which to use custom packing and unpacking + hook for autograd. Defaults to ``64*1024``. + palett_unique (:obj:`bool`, ``optional``): If ``True``, reduces the attention map by leveraging the fact that + FP16 only has ``2^16`` unique values. Useful for Large Models like LLMs where attention maps can be huge. + Defaults to ``False``. More details can be found `eDKM: An Efficient and Accurate Train-time Weight + Clustering for Large Language Models `_ . + palett_shard (:obj:`bool`, ``optional``): If ``True``, the index list is sharded across GPUs. + Defaults to ``False``. More details can be found `eDKM: An Efficient and Accurate Train-time Weight + Clustering for Large Language Models `_ . + palett_batch_mode (:obj:`bool`, ``optional``): If ``True``, performs batch DKM across different partitions + created for different blocks. Defaults to ``False``. More details can be found `eDKM: An Efficient and Accurate Train-time Weight + Clustering for Large Language Models `_ . + palett_dist (:obj:`bool`, ``optional``): If ``True``, performs distributed distance calculation in batch_mode if + distributed torch is available. Defaults to ``False``. + per_channel_scaling_factor_scheme (:obj:`str`, ``optional``): Criteria to calculate the + ``per_channel_scaling_factor``. Allowed options are ``min_max`` and ``abs``. Defaults to ``min_max``. + percentage_palett_enable (:obj:`float`, ``optional``): Percentage partitions to enable for DKM. + Defaults to ``1.0``. + kmeans_batch_threshold (:obj:`int`, ``optional``): Threshold to decide at what num_partitions to go through with + sharded centroids list. num_partitions is calculated by dividing the channel size with the group_size + provided. If the kmeans_batch_threshold, the algorithm resorts to performing distirbuted kmeans for lower + partition numbers, given that num_partition number of GPUs are available. Defaults to ``4``. + kmeans_n_init (:obj:`int`, ``optional``): Number of time the k-means algorithm will be run with different + centroid seeds. The final results will be the best output of kmeans_n_init consecutive runs in terms of inertia. + zero_threshold (:obj:`int`, ``optional``): Zero threshold to be used to decide min value of clamp for softmax + . Defaults to ``1e-7``. + kmeans_error_bnd (:obj:`float`, ``optional``): Distance threshold to decide at what distance between parameters + and clusters to stop the kmeans operation. Defaults to ``0.0``. + + This class supports few different configurations to structure the palettization: + + 1. **Per-tensor palettization**: This is the default configuration where the whole tensor shares a single look-up + table. The ``granularity`` is set to ``per_tensor`` and ``group_size`` is ``None``. + + 2. **Per-grouped-channel palettization**: In this configuration, ``group_size`` number of channels along + ``channel_axis`` share the same look-up table. For example, for a weight matrix of shape ``(16, 25)``, if we provide + ``group_size = 8``, the shape of the look-up table would be ``(2, 2^n_bits)``. + + NOTE: Currently grouping is only supported along output channel axis. """ n_bits: _Optional[int] = _field( default=None, validator=_validators.optional(_validators.instance_of(int)) @@ -156,6 +219,19 @@ class ModuleDKMPalettizerConfig(_ModuleOptimizationConfig): default=DEFAULT_PALETTIZATION_OPTIONS["weight_threshold"], validator=_validators.instance_of(int), ) + granularity: PalettizationGranularity = _field( + default=DEFAULT_PALETTIZATION_OPTIONS["granularity"], + converter=PalettizationGranularity, + validator=_validators.in_(PalettizationGranularity), + ) + group_size: _Optional[int] = _field( + default=DEFAULT_PALETTIZATION_OPTIONS["group_size"], + validator=_validators.optional(_validators.instance_of(int)), + ) + enable_per_channel_scale: bool = _field( + default=DEFAULT_PALETTIZATION_OPTIONS["enable_per_channel_scale"], + validator=_validators.instance_of(bool), + ) milestone: int = _field( default=DEFAULT_PALETTIZATION_OPTIONS["milestone"], validator=_validators.instance_of(int), @@ -174,7 +250,10 @@ class ModuleDKMPalettizerConfig(_ModuleOptimizationConfig): dtype: _torch.dtype = _field( default=DEFAULT_PALETTIZATION_OPTIONS["dtype"], converter=_maybe_convert_str_to_dtype, - validator=_validators.instance_of(_torch.dtype), + validator=[ + _validators.instance_of(_torch.dtype), + _validators.in_([_torch.qint8, _torch.quint8, _torch.float32]), + ], ) cluster_dtype: str = _field( default=DEFAULT_PALETTIZATION_OPTIONS["cluster_dtype"], @@ -184,10 +263,6 @@ class ModuleDKMPalettizerConfig(_ModuleOptimizationConfig): default=DEFAULT_PALETTIZATION_OPTIONS["quantize_activations"], validator=_validators.instance_of(bool), ) - partition_size: int = _field( - default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["partition_size"], - validator=_validators.instance_of(int), - ) cluster_permute: _Optional[tuple] = _field( default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["cluster_permute"], validator=_validators.optional(_validators.instance_of(tuple)), @@ -240,6 +315,68 @@ class ModuleDKMPalettizerConfig(_ModuleOptimizationConfig): default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_cluster_tol"], validator=_validators.instance_of(float), ) + palett_min_tsize: int = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_min_tsize"], + validator=_validators.instance_of(int), + ) + palett_unique: bool = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_unique"], + validator=_validators.instance_of(bool), + ) + palett_shard: bool = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_shard"], + validator=_validators.instance_of(bool), + ) + palett_batch_mode: bool = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_batch_mode"], + validator=_validators.instance_of(bool), + ) + palett_dist: bool = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["palett_dist"], + validator=_validators.instance_of(bool), + ) + per_channel_scaling_factor_scheme: str = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["per_channel_scaling_factor_scheme"], + validator=_validators.and_( + _validators.instance_of(str), _validators.in_(["min_max", "abs"]) + ), + ) + percentage_palett_enable: float = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["percentage_palett_enable"], + validator=_validators.instance_of(float), + ) + kmeans_batch_threshold: int = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["kmeans_batch_threshold"], + validator=_validators.instance_of(int), + ) + kmeans_n_init: int = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["kmeans_n_init"], + validator=_validators.instance_of(int), + ) + zero_threshold: float = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["zero_threshold"], + validator=_validators.instance_of(float), + ) + kmeans_error_bnd: float = _field( + default=DEFAULT_PALETTIZATION_ADVANCED_OPTIONS["kmeans_error_bnd"], + validator=_validators.instance_of(float), + ) + partition_size: int = _deprecated_field( + message=( + "partition_size is being deprecated and will be removed in " + "future versions. Please use group_size parameter instead." + ) + ) + + @group_size.validator + def per_grouped_channel_granularity(self, attribute, value): + if self.granularity == PalettizationGranularity.per_grouped_channel: + assert ( + value is not None + ), "group_size has to be specified along with per_grouped_channel granularity." + assert value > 0, "group_size should be greater than zero" + else: + assert value is None, "group_size can't be specified along with per_tensor granularity." _default_module_type_configs = _OrderedDict( @@ -335,9 +472,7 @@ class DKMPalettizerConfig(_OptimizationConfig): factory=_OrderedDict, validator=_validators.deep_mapping( key_validator=_validators.instance_of(str), - value_validator=_validators.optional( - _validators.instance_of(ModuleDKMPalettizerConfig) - ), + value_validator=_validate_dkm_config_type, mapping_validator=_validators.instance_of(dict), ), ) diff --git a/coremltools/optimize/torch/palettization/palettizer.py b/coremltools/optimize/torch/palettization/palettizer.py index d391a0b84..c0f744d5c 100644 --- a/coremltools/optimize/torch/palettization/palettizer.py +++ b/coremltools/optimize/torch/palettization/palettizer.py @@ -1,9 +1,8 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -import copy as _copy import logging as _logging from typing import Dict as _Dict from typing import Optional as _Optional @@ -14,9 +13,12 @@ from coremltools.optimize.torch._typing import ParamsDict as _ParamsDict from coremltools.optimize.torch._utils.math_utils import rmse_error as _rmse_error +from coremltools.optimize.torch._utils.metadata_utils import ( + register_metadata_version as _register_metadata_version, +) from coremltools.optimize.torch._utils.torch_utils import get_eval_model as _get_eval_model from coremltools.optimize.torch.base_model_optimizer import ( - BaseModelOptimizer as _BaseModelOptimizer, + BaseTrainingTimeModelOptimizer as _BaseTrainingTimeModelOptimizer, ) from coremltools.optimize.torch.base_model_optimizer import _Report from coremltools.optimize.torch.palettization._custom_conversion import ( @@ -25,6 +27,9 @@ from coremltools.optimize.torch.palettization._supported_modules import ( _get_palettization_qat_mappings, ) +from coremltools.optimize.torch.palettization._supported_modules import ( + get_palettizable_parameters as _get_palettizable_parameters, +) from coremltools.optimize.torch.palettization.fake_palettize import FakePalettize as _FakePalettize from coremltools.optimize.torch.palettization.palettization_config import ( DEFAULT_PALETTIZATION_ADVANCED_OPTIONS as _DEFAULT_PALETTIZATION_ADVANCED_OPTIONS, @@ -42,7 +47,7 @@ _logger = _logging.getLogger(__name__) -class Palettizer(_BaseModelOptimizer): +class Palettizer(_BaseTrainingTimeModelOptimizer): pass @@ -108,7 +113,10 @@ def _palettize_supported_modules(self): if config is not None: submod_configs = config if isinstance(config, list) else [config] for submod_config in submod_configs: - if submodule.weight.numel() > submod_config.weight_threshold: + if all( + param.numel() > submod_config.weight_threshold + for param in _get_palettizable_parameters(submodule) + ): module_level_advanced_options = self._get_module_level_advanced_options( submodule, submod_config ) @@ -122,10 +130,19 @@ def _palettize_supported_modules(self): if submod_config.cluster_dim is not None else _DEFAULT_PALETTIZATION_SCHEME[type(submodule)]["cluster_dim"] ) + enable_per_channel_scale = ( + submod_config.enable_per_channel_scale + if submod_config.enable_per_channel_scale is not None + else _DEFAULT_PALETTIZATION_SCHEME[type(submodule)][ + "enable_per_channel_scale" + ] + ) self._palettize_module( submodule, n_bits, cluster_dim, + enable_per_channel_scale, + submod_config.group_size, submod_config.quant_min, submod_config.quant_max, submod_config.cluster_dtype, @@ -140,6 +157,8 @@ def _palettize_module( module: _nn.Module, n_bits: int, cluster_dim: int, + enable_per_channel_scale: bool, + group_size: _Optional[int], quant_min: int, quant_max: int, cluster_dtype: str, @@ -157,6 +176,8 @@ def _palettize_module( ), n_bits=n_bits, cluster_dim=cluster_dim, + enable_per_channel_scale=enable_per_channel_scale, + group_size=group_size, quant_min=quant_min, quant_max=quant_max, cluster_dtype=cluster_dtype, @@ -165,7 +186,7 @@ def _palettize_module( if quantize_activations: fq_activation = _FakeQuantize.with_args( observer=_torch.quantization.MovingAveragePerChannelMinMaxObserver.with_args( - quant_min=quant_min, quant_max=quant_max + quant_min=quant_min, quant_max=quant_max, dtype=dtype ), quant_min=quant_min, quant_max=quant_max, @@ -201,9 +222,7 @@ def prepare(self, inplace: bool = False) -> _nn.Module: inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned. """ - if not inplace: - self._model = _copy.deepcopy(self._model) - + self._model = self._get_model_for_compression(inplace) self._model.train() self._palettize_supported_modules() qat_mappings = _get_palettization_qat_mappings() @@ -234,6 +253,7 @@ def finalize(self, model: _Optional[_nn.Module] = None, inplace: bool = False) - if model is None: self._model = finalized_model + _register_metadata_version(finalized_model) return finalized_model def step(self): @@ -291,10 +311,11 @@ def enable_fake_palett(self, flag: bool): @staticmethod def _enable_fake_palett_impl(module: _torch.nn.Module, flag: bool): - if hasattr(module, "weight_fake_quant") and isinstance( - module.weight_fake_quant, _FakePalettize - ): - module.weight_fake_quant.enable_fake_palett(flag) + def enable_fn(mod): + if hasattr(mod, "enable_fake_palett"): + mod.enable_fake_palett(flag) + + module.apply(enable_fn) def report(self) -> _Report: """ @@ -316,7 +337,7 @@ def report(self) -> _Report: module_summary["error"] = _rmse_error( module.weight.detach(), qweight ).item() - n_clusters = module.weight_fake_quant.n_clusters[0] + n_clusters = module.weight_fake_quant.n_clusters module_summary["#params"] = int(_torch.numel(qweight)) cluster_dim = module.weight_fake_quant.cluster_dim module_summary["#dtype"] = ( diff --git a/coremltools/optimize/torch/palettization/post_training_palettization.py b/coremltools/optimize/torch/palettization/post_training_palettization.py new file mode 100644 index 000000000..4b4e1ae26 --- /dev/null +++ b/coremltools/optimize/torch/palettization/post_training_palettization.py @@ -0,0 +1,327 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import logging as _logging +from collections import OrderedDict as _OrderedDict +from typing import Any as _Any +from typing import Callable as _Callable +from typing import Dict as _Dict +from typing import NewType as _NewType +from typing import Optional as _Optional +from typing import Tuple as _Tuple +from typing import Union as _Union + +import cattrs as _cattrs +import torch as _torch +from attr import define as _define +from attr import field as _field +from attrs import validators as _validators + +from coremltools.optimize.torch._utils.k_means import KMeansConfig as _KMeansConfig +from coremltools.optimize.torch._utils.k_means import ( + KMeansSupportedModulesRegistry as _KMeansSupportedModulesRegistry, +) +from coremltools.optimize.torch._utils.k_means import ParallelKMeans as _ParallelKMeans +from coremltools.optimize.torch._utils.k_means import SequentialKMeans as _SequentialKMeans +from coremltools.optimize.torch._utils.report_utils import ( + compute_post_training_report as _compute_post_training_report, +) +from coremltools.optimize.torch._utils.torch_utils import ( + maybe_convert_str_to_dtype as _maybe_convert_str_to_dtype, +) +from coremltools.optimize.torch._utils.torch_utils import ( + maybe_convert_str_to_mod_type as _maybe_convert_str_to_mod_type, +) +from coremltools.optimize.torch._utils.validation_utils import ( + validate_param_config as _validate_param_config, +) +from coremltools.optimize.torch.base_model_optimizer import ( + BasePostTrainingModelOptimizer as _BasePostTrainingModelOptimizer, +) +from coremltools.optimize.torch.base_model_optimizer import _Report +from coremltools.optimize.torch.optimization_config import ( + ModuleOptimizationConfig as _ModuleOptimizationConfig, +) +from coremltools.optimize.torch.optimization_config import OptimizationConfig as _OptimizationConfig +from coremltools.optimize.torch.optimization_config import ( + PalettizationGranularity, + _structure_from_dict_hook_factory, +) + +_logger = _logging.getLogger(__name__) + + +@_define +class ModulePostTrainingPalettizerConfig(_ModuleOptimizationConfig): + """ + Configuration class for specifying global and module level palettization options for + :py:class:`PostTrainingPalettizerConfig` algorithm. + + Args: + n_bits (:obj:`int`): Number of bits to use for palettizing the weights. Defaults to ``4``. + lut_dtype (:py:class:`torch.dtype`): The dtype to use for representing each element in look up tables. + When value is None, no quantization is performed. Supported values are :py:class:`torch.int8` and + :py:class:`torch.uint8`. Defaults to None. + granularity (:py:class:`PalettizationGranularity`) – Granularity for palettization. + One of ``per_tensor`` or ``per_grouped_channel``. Defaults to ``per_tensor``. + group_size (:obj:`int`): Specify the number of channels in a group. + Only effective when granularity is ``per_grouped_channel``. + channel_axis (:obj:`int`): Specify the channel axis to form a group of channels. + Only effective when granularity is ``per_grouped_channel``. Defaults to output channel axis. + cluster_dim (:obj:`int`): The dimension of centroids for each look up table. Defaults to ``1``. + The centroid is a scalar by default. When ``cluster_dim > 1``, it indicates 2-D clustering + and each ``cluster_dim`` length of weight vectors along the output channel are palettized + using the same 2-D centroid. The length of each entry in the look-up tables is equal to ``cluster_dim``. + enable_per_channel_scale (:obj:`bool`): When set to ``True``, weights are normalized along the output channels + using per channel scales before being palettized. This is not supported with ``cluster_dim > 1``. + + This class supports few different configurations to structure the palettization: + + 1. **Per-tensor palettization**: This is the default configuration where the whole tensor shares a single look-up + table. The ``granularity`` is set to ``per_tensor`` and ``group_size`` is ``None``. + + 2. **Per-grouped-channel palettization**: In this configuration, ``group_size`` number of channels along + ``channel_axis`` share the same look-up table. For example, for a weight matrix of shape ``(16, 25)``, if we provide + ``group_size = 8``, the shape of the look-up table would be ``(2, 2^n_bits)``. + + NOTE: Currently grouping is only supported along either input or output channel axis. + """ + + n_bits: _Optional[int] = _field( + default=4, validator=_validators.optional(_validators.instance_of(int)) + ) + lut_dtype: _torch.dtype = _field( + default=None, + converter=lambda val: _maybe_convert_str_to_dtype(val) if val else val, + validator=_validators.optional( + [ + _validators.instance_of(_torch.dtype), + _validators.in_([_torch.int8, _torch.uint8]), + ] + ), + ) + granularity: PalettizationGranularity = _field( + default="per_tensor", + converter=PalettizationGranularity, + validator=_validators.in_(PalettizationGranularity), + ) + group_size: _Optional[int] = _field( + default=None, validator=_validators.optional(_validators.instance_of(int)) + ) + channel_axis: int = _field( + default=0, + validator=_validators.optional([_validators.instance_of(int), _validators.in_([0, 1])]), + ) + cluster_dim: _Optional[int] = _field( + default=None, validator=_validators.optional(_validators.instance_of(int)) + ) + enable_per_channel_scale: _Optional[bool] = _field( + default=False, validator=_validators.optional(_validators.instance_of(bool)) + ) + + @group_size.validator + def per_grouped_channel_granularity(self, attribute, value): + if self.granularity == PalettizationGranularity.per_grouped_channel: + assert ( + value is not None + ), "group_size has to be specified along with per_grouped_channel granularity." + assert value > 0, "group_size should be greater than zero" + else: + assert value is None, "group_size can't be specified along with per_tensor granularity." + + @cluster_dim.validator + def per_tensor_granularity(self, attribute, value): + if value and value > 1: + assert ( + self.granularity == PalettizationGranularity.per_tensor + ), "cluster_dim larger than 1 is only supported with per tensor palettization" + + @cluster_dim.validator + def no_per_channel_scale(self, attribute, value): + if value and value > 1: + assert ( + self.enable_per_channel_scale == False + ), f"Enabling per_channel_scale is not supported for cluster_dim={value} larger than 1" + + +_ModuleTypeConfigType = _NewType( + "ModuleTypeConfigType", + _Dict[_Union[_Callable, str], _Optional[ModulePostTrainingPalettizerConfig]], +) + + +@_define +class PostTrainingPalettizerConfig(_OptimizationConfig): + """ + Configuration class for specifying how different submodules of a model + should be post-training palettized by :py:class:`PostTrainingPalettizer`. + + Args: + global_config (:py:class:`ModulePostTrainingPalettizerConfig`): Config to be applied globally + to all supported modules. + module_type_configs (:obj:`dict` of :obj:`str` to :py:class:`ModulePostTrainingPalettizerConfig`): + Module type configs applied to a specific module class, such as :py:class:`torch.nn.Linear`. + The keys can be either strings or module classes. + module_name_configs (:obj:`dict` of :obj:`str` to :py:class:`ModulePostTrainingPalettizerConfig`): + Module name configs applied to specific modules. This can be a dictionary with module names pointing to their + corresponding :py:class:`ModulePostTrainingPalettizerConfig`s + """ + + global_config: _Optional[ModulePostTrainingPalettizerConfig] = _field( + default=None, + validator=_validators.optional(_validators.instance_of(ModulePostTrainingPalettizerConfig)), + ) + module_type_configs: _ModuleTypeConfigType = _field( + factory=_OrderedDict, + validator=_validators.deep_mapping( + key_validator=_validators.instance_of((str, _Callable)), + value_validator=_validators.optional( + _validators.instance_of(ModulePostTrainingPalettizerConfig) + ), + mapping_validator=_validators.instance_of(dict), + ), + ) + module_name_configs: _Dict[str, _Optional[ModulePostTrainingPalettizerConfig]] = _field( + factory=_OrderedDict, + validator=_validators.deep_mapping( + key_validator=_validators.instance_of(str), + value_validator=_validators.optional( + _validators.instance_of(ModulePostTrainingPalettizerConfig) + ), + mapping_validator=_validators.instance_of(dict), + ), + ) + + def __attrs_post_init__(self): + if ( + self.global_config is None + and len(self.module_type_configs) == 0 + and len(self.module_name_configs) == 0 + ): + self.global_config = ModulePostTrainingPalettizerConfig() + self.module_type_configs = { + _maybe_convert_str_to_mod_type(key): val + for key, val in self.module_type_configs.items() + } + + @classmethod + def from_dict(cls, config_dict: _Dict[str, _Any]) -> "PostTrainingPalettizerConfig": + super().from_dict(config_dict) + converter = _cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook( + _ModuleTypeConfigType, + _structure_from_dict_hook_factory(ModulePostTrainingPalettizerConfig), + ) + return converter.structure_attrs_fromdict(config_dict, cls) + + +class PostTrainingPalettizer(_BasePostTrainingModelOptimizer): + """ + Perform post-training palettization on a torch model. Post palettization, all the weights in supported + layers point to elements in a look-up table after performing a kmeans operation. + + Example: + + .. code-block:: python + + import torch.nn as nn + from coremltools.optimize.torch.palettization import ( + PostTrainingPalettizerConfig, + PostTrainingPalettizer, + ) + + model = nn.Sequential( + OrderedDict( + { + "conv": nn.Conv2d(1, 20, (3, 3)), + "relu1": nn.ReLU(), + "conv2": nn.Conv2d(20, 20, (3, 3)), + "relu2": nn.ReLU(), + } + ) + ) + + # initialize the palettizer + config = PostTrainingPalettizerConfig.from_dict( + { + "global_config": { + "n_bits": 4, + }, + } + ) + + ptpalettizer = PostTrainingPalettizer(model, config) + palettized_model = ptpalettizer.compress() + + Args: + model (:obj:`torch.nn.Module`): Module to be compressed. + config (:py:class:`PostTrainingPalettizerConfig`): Config that specifies how + different submodules in the model will be palettized. + """ + + _supported_modules: _Tuple = _KMeansSupportedModulesRegistry.get_supported_modules() + + def __init__(self, model: _torch.nn.Module, config: PostTrainingPalettizerConfig = None): + config = PostTrainingPalettizerConfig() if config is None else config + super().__init__(model, config) + + def compress(self, num_kmeans_workers: int = 1, inplace: bool = False) -> _torch.nn.Module: + """ + The compress method performs a `kmeans` operation on all supported modules. + Args: + num_kmeans_workers (:obj:`int`): Number of worker processes used for + performing post-training palettization. Defaults to ``1``. + inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and + the original module is mutated, otherwise a copy of the model is mutated and returned. + Defaults to ``False``. + """ + self._model = super().compress(inplace=inplace) + kmeans_config_dict = dict() + for name, submodule in self._model.named_modules(): + submod_config = self._config.get_module_config(name, submodule) + if submod_config is None: + continue + + k_means_module_cls = _KMeansSupportedModulesRegistry.get_kmeans_module(submodule) + if k_means_module_cls is None: + continue + + for param_name in k_means_module_cls.parameter_names: + # Validate configuration for parameter + param = submodule.get_parameter(param_name) + updated_config = _validate_param_config( + name + "." + param_name, + param, + submod_config, + ["palettization_group_size", "palettization_cluster_dim"], + ) + if not updated_config: + continue + + if name not in kmeans_config_dict: + kmeans_config_dict[name] = {} + + kmeans_config_dict[name][param_name] = _KMeansConfig( + n_bits=updated_config.n_bits, + axis=updated_config.channel_axis, + lut_dtype=updated_config.lut_dtype, + block_size=updated_config.group_size, + cluster_dim=updated_config.cluster_dim, + enable_per_channel_scale=updated_config.enable_per_channel_scale, + ) + + if num_kmeans_workers > 1: + return _ParallelKMeans.cluster_weights( + self._model, kmeans_config_dict, num_workers=num_kmeans_workers + ) + else: + return _SequentialKMeans.cluster_weights(self._model, kmeans_config_dict) + + def report(self) -> _Report: + return _compute_post_training_report( + self._uncompressed_model, + self._model, + supported_modules=self._supported_modules, + ) diff --git a/coremltools/optimize/torch/palettization/sensitive_k_means.py b/coremltools/optimize/torch/palettization/sensitive_k_means.py new file mode 100644 index 000000000..ca4027d8f --- /dev/null +++ b/coremltools/optimize/torch/palettization/sensitive_k_means.py @@ -0,0 +1,680 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import copy as _copy +import logging as _logging +import tempfile as _tempfile +from collections import OrderedDict as _OrderedDict +from contextlib import contextmanager as _contextmanager +from typing import Any as _Any +from typing import Callable as _Callable +from typing import Dict as _Dict +from typing import Iterable as _Iterable +from typing import List as _List +from typing import NewType as _NewType +from typing import Optional as _Optional +from typing import Tuple as _Tuple +from typing import Union as _Union + +import cattrs as _cattrs +import torch as _torch +import torch.multiprocessing as _mp +from attr import define as _define +from attr import field as _field +from attrs import validators as _validators +from torch.distributed.fsdp import FullStateDictConfig as _FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as _FSDP +from torch.distributed.fsdp import ShardingStrategy as _ShardingStrategy +from torch.distributed.fsdp import StateDictType as _StateDictType + +from coremltools.optimize.torch._utils.dist_utils import ddp_setup as _ddp_setup +from coremltools.optimize.torch._utils.dist_utils import is_leader as _is_leader +from coremltools.optimize.torch._utils.fsdp_utils import FSDPAutoWrapPolicy as _FSDPAutoWrapPolicy +from coremltools.optimize.torch._utils.k_means import KMeansConfig as _KMeansConfig +from coremltools.optimize.torch._utils.k_means import ( + KMeansSupportedModulesRegistry as _KMeansSupportedModulesRegistry, +) +from coremltools.optimize.torch._utils.k_means import ParallelKMeans as _ParallelKMeans +from coremltools.optimize.torch._utils.k_means import SequentialKMeans as _SequentialKMeans +from coremltools.optimize.torch._utils.report_utils import ( + compute_post_training_report as _compute_post_training_report, +) +from coremltools.optimize.torch._utils.torch_utils import ( + maybe_convert_str_to_dtype as _maybe_convert_str_to_dtype, +) +from coremltools.optimize.torch._utils.torch_utils import ( + maybe_convert_str_to_mod_type as _maybe_convert_str_to_mod_type, +) +from coremltools.optimize.torch._utils.validation_utils import ( + validate_param_config as _validate_param_config, +) +from coremltools.optimize.torch.base_model_optimizer import ( + BaseDataCalibratedModelOptimizer as _BaseDataCalibratedModelOptimizer, +) +from coremltools.optimize.torch.base_model_optimizer import _Report +from coremltools.optimize.torch.optimization_config import ( + ModuleOptimizationConfig as _ModuleOptimizationConfig, +) +from coremltools.optimize.torch.optimization_config import OptimizationConfig as _OptimizationConfig +from coremltools.optimize.torch.optimization_config import ( + PalettizationGranularity, + _structure_from_dict_hook_factory, +) + +_logger = _logging.getLogger(__name__) + + +@_define +class ModuleSKMPalettizerConfig(_ModuleOptimizationConfig): + """ + Configuration class for specifying global and module level compression options for + :py:class:`SKMPalettizer` algorithm. + + Args: + n_bits (:obj:`int`): Number of bits to use for palettizing the weights. Defaults to ``4``. + lut_dtype (:py:class:`torch.dtype`): The dtype to use for representing each element in look up tables. + When value is None, no quantization is performed. Supported values are :py:class:`torch.int8` and + :py:class:`torch.uint8`. Defaults to None. + granularity (:py:class:`PalettizationGranularity`) – Granularity for palettization. + One of ``per_tensor`` or ``per_grouped_channel``. Defaults to ``per_tensor``. + group_size (:obj:`int`): Specify the number of channels in a group. + Only effective when granularity is ``per_grouped_channel``. + channel_axis (:obj:`int`): Specify the channel axis to form a group of channels. + Only effective when granularity is ``per_grouped_channel``. Defaults to output channel axis. + enable_per_channel_scale (:obj:`bool`): When set to ``True``, weights are normalized along the output channels + using per channel scales before being palettized. This is not supported with ``cluster_dim > 1``. + + This class supports few different configurations to structure the palettization: + + 1. **Per-tensor palettization**: This is the default configuration where the whole tensor shares a single look-up + table. The ``granularity`` is set to ``per_tensor``. + + 2. **Per-grouped-channel palettization**: In this configuration, ``group_size`` number of channels along + ``channel_axis`` share the same look-up table. For example, for a weight matrix of shape ``(16, 25)``, if we provide + ``group_size = 8``, the shape of the look-up table would be ``(2, 2^n_bits)``. + + NOTE: Currently grouping is only supported along either input or output channel axis. + """ + + n_bits: int = _field(default=4, validator=_validators.instance_of(int)) + lut_dtype: _torch.dtype = _field( + default=None, + converter=lambda val: _maybe_convert_str_to_dtype(val) if val else val, + validator=_validators.optional( + [ + _validators.instance_of(_torch.dtype), + _validators.in_([_torch.int8, _torch.uint8]), + ] + ), + ) + granularity: PalettizationGranularity = _field( + default="per_tensor", + converter=PalettizationGranularity, + validator=_validators.in_(PalettizationGranularity), + ) + group_size: _Optional[int] = _field( + default=None, validator=_validators.optional(_validators.instance_of(int)) + ) + channel_axis: int = _field( + default=0, + validator=_validators.optional([_validators.instance_of(int), _validators.in_([0, 1])]), + ) + enable_per_channel_scale: bool = _field( + default=False, validator=_validators.optional(_validators.instance_of(bool)) + ) + + @group_size.validator + def per_grouped_channel_granularity(self, attribute, value): + if self.granularity == PalettizationGranularity.per_grouped_channel: + assert ( + value is not None + ), "group_size has to be specified along with per_grouped_channel granularity." + assert value > 0, "group_size should be greater than zero" + else: + assert value is None, "group_size can't be specified along with per_tensor granularity." + + +_ModuleTypeConfigType = _NewType( + "ModuleTypeConfigType", + _Dict[_Union[_Callable, str], _Optional[ModuleSKMPalettizerConfig]], +) + + +@_define +class SKMPalettizerConfig(_OptimizationConfig): + """ + Configuration class for specifying how different submodules of a model are + palettized by :py:class:`SKMPalettizer`. + + Args: + global_config (:py:class:`ModuleSKMPalettizerConfig`): Config to be applied globally + to all supported modules. Missing values are chosen from the default config. + module_type_configs (:obj:`dict` of :obj:`str` to :py:class:`ModuleSKMPalettizerConfig`): + Module type configs applied to a specific module class, such as :py:class:`torch.nn.Linear`. + The keys can be either strings or module classes. + module_name_configs (:obj:`dict` of :obj:`str` to :py:class:`ModuleSKMPalettizerConfig`): + Module level configs applied to specific modules. The name of the module must either be + a regex or a fully qualified name that can be used to fetch it from the top level module + using the ``module.get_submodule(target)`` method. + calibration_nsamples (:obj:`int`): Number of samples to be used for calibration. + """ + + global_config: _Optional[ModuleSKMPalettizerConfig] = _field( + default=None, + validator=_validators.optional(_validators.instance_of(ModuleSKMPalettizerConfig)), + ) + module_type_configs: _ModuleTypeConfigType = _field( + factory=_OrderedDict, + validator=_validators.deep_mapping( + key_validator=_validators.instance_of((str, _Callable)), + value_validator=_validators.optional( + _validators.instance_of(ModuleSKMPalettizerConfig) + ), + mapping_validator=_validators.instance_of(dict), + ), + ) + module_name_configs: _Dict[str, _Optional[ModuleSKMPalettizerConfig]] = _field( + factory=_OrderedDict, + validator=_validators.deep_mapping( + key_validator=_validators.instance_of(str), + value_validator=_validators.optional( + _validators.instance_of(ModuleSKMPalettizerConfig) + ), + mapping_validator=_validators.instance_of(dict), + ), + ) + calibration_nsamples: int = _field(default=128, validator=_validators.instance_of(int)) + + def __attrs_post_init__(self): + if ( + self.global_config is None + and len(self.module_type_configs) == 0 + and len(self.module_name_configs) == 0 + ): + self.global_config = ModuleSKMPalettizerConfig() + self.module_type_configs = { + _maybe_convert_str_to_mod_type(key): val + for key, val in self.module_type_configs.items() + } + + @classmethod + def from_dict(cls, config_dict: _Dict[str, _Any]) -> "SKMPalettizerConfig": + super().from_dict(config_dict) + converter = _cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook( + _ModuleTypeConfigType, + _structure_from_dict_hook_factory(ModuleSKMPalettizerConfig), + ) + return converter.structure_attrs_fromdict(config_dict, cls) + + +class SKMPalettizer(_BaseDataCalibratedModelOptimizer): + """ + Perform post-training palettization of weights by running a weighted k-means + on the model weights. The weight values used for weighing different elements of + a model's weight matrix are computed using the Fisher information matrix, which + is an approximation of the Hessian. These weight values indicate how sensitive + a given weight element is; the more sensitive an element, the larger impact perturbing + it (or palettizing it) has on the model's loss function. Thus, weighted k-means + moves the clusters closer to the sensitive weight values, allowing them to be + represented more exactly and thus leading to a lower degradation in model performance + after palettization. The Fisher information matrix is computed using a few + samples of calibration data. + + This algorithm implements `SqueezeLLM: Dense-and-Sparse Quantization `_. + + Example: + + .. code-block:: python + + import torch.nn as nn + from coremltools.optimize.torch.palettization import ( + SKMPalettizer, + SKMPalettizerConfig, + ) + + model = nn.Sequential( + OrderedDict( + { + "conv": nn.Conv2d(1, 20, (3, 3)), + "relu1": nn.ReLU(), + "conv2": nn.Conv2d(20, 20, (3, 3)), + "relu2": nn.ReLU(), + } + ) + ) + + dataloder = load_calibration_data() + + # define callable for loss function + def loss_fn(model, data): + inp, target = data + out = model(inp) + return nn.functional.mse_loss(out, target) + + + # initialize the palettizer + config = SKMPalettizerConfig.from_dict( + { + "global_config": { + "n_bits": 4, + }, + "calibration_nsamples": 16, + } + ) + + compressor = SKMPalettizer(model, config) + compressed_model = compressor.compress(dataloader=dataloader, loss_fn=loss_fn) + + Args: + model (:obj:`torch.nn.Module`): Module to be compressed. + config (:py:class:`LayerwiseCompressorConfig`): Config that specifies how + different submodules in the model will be compressed. + """ + + _supported_modules: _Tuple = _KMeansSupportedModulesRegistry.get_supported_modules() + _SENSITIVITY_CLIP_THR: int = 1e-12 + + def __init__(self, model: _torch.nn.Module, config: _Optional[SKMPalettizerConfig] = None): + config = SKMPalettizerConfig() if config is None else config + super().__init__(model, config) + self._tempdir = _tempfile.TemporaryDirectory() + self._sensitivity_path = self._tempdir.name + "/sensitivity.pt" + self._model_checkpoint_path = self._tempdir.name + "/model.pt" + + def _compute_sensitivity_impl_single_worker( + self, dataset: _List, loss_fn: _Callable, sensitivity_path: _Optional[str] + ): + """ + Computes sensitivity for the model weights using a single process. + """ + if _torch.cuda.is_available(): + self._model.cuda() + + self._model.zero_grad() + + with self._register_grad_square_hooks(self._model): + for didx, data in enumerate(dataset): + _logger.info(f"Computing sensitivity using sample {didx}") + loss = loss_fn(self._model, data) + loss.backward() + + sensitivity_dict = dict() + for name, param in self._model.named_parameters(remove_duplicate=True): + if param.requires_grad: + sensitivity_dict[name] = -param.grad.cpu() + + _torch.save(sensitivity_dict, self._get_sensitivity_path(sensitivity_path)) + + def _compute_sensitivity_impl_multiple_workers( + self, + rank: int, + num_workers: int, + dataset: _List, + loss_fn: _Callable, + sensitivity_path: _Optional[str] = None, + fsdp_auto_wrap_policy: _Optional[_FSDPAutoWrapPolicy] = None, + ): + """ + Computes sensitivity for the model weights using multiple processes. + This mode is useful for large models for which computing gradients on a single + process is infeasible because the model does not fit on a single GPU. The model is + sharded on multiple GPUs using :py:class:`FullyShardedDataParallel`, which enables + distributed computation of gradients. + + If ``sensitivity_path`` is passed as ``None``, the sensitivity matrices are + stored temporarily and deleted after model compression. Otherwise, they are + saved at the location specified by ``sensitivity_path``. + + Args: + rank (:obj:`int`): Rank of the worker process on which this function is executed + num_workers (:obj:`int`): Number of workers used for computing sensitivity + dataset (:py:class:`Iterable`): An iterable where each element + is an input to the model to be compressed. Used for computing gradients of model weights. + loss_fn (:obj:`Callable`): A callable which takes the model and data as input and performs + a forward pass on the model and computes the training loss + sensitivity_path (:obj:`str` or ``None``): An optional path for saving the sensitivity + of weights. Defaults to ``None``. + fsdp_auto_wrap_policy (:py:class:`_FSDPAutoWrapPolicy` or ``None``): Policy to apply + :py:class:`FullyShardedDataParallel` to submodules of ``model``. Defaults to ``None``. + """ + _ddp_setup(rank, num_workers) + auto_wrap_policy = ( + fsdp_auto_wrap_policy.get_policy() if fsdp_auto_wrap_policy is not None else None + ) + model = _FSDP( + module=self._model, + auto_wrap_policy=auto_wrap_policy, + sharding_strategy=_ShardingStrategy.FULL_SHARD, + use_orig_params=False, + device_id=_torch.cuda.current_device(), + sync_module_states=True, + ) + + # We want to compute squares of gradients of the un-sharded parameters + # to use later for k-means. However, parameters are sharded and gradients + # are also computed in the sharded state. And there is no efficient way + # to un-shard them, hence we use an optimizer to add the sharded gradients + # to the parameters, which can later be un-sharded when we save the state dict. + optim = _torch.optim.SGD( + [param for param in model.parameters() if param.requires_grad], lr=1.0 + ) + optim.zero_grad() + + with self._register_grad_square_hooks(model): + for didx, data in enumerate(dataset): + if _is_leader(): + _logger.info(f"Computing sensitivity using sample {didx}") + loss = loss_fn(model, data) + loss.backward() + + # we set the parameters to zero so that when we call optim.step, + # the parameter values are equal to the square of the gradient + with _torch.no_grad(): + for param in model.parameters(): + param.data.zero_() + + optim.step() + + cfg = _FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with _FSDP.state_dict_type(model, _StateDictType.FULL_STATE_DICT, cfg): + sensitivity_dict = model.state_dict() + + if _is_leader(): + _torch.save(sensitivity_dict, self._get_sensitivity_path(sensitivity_path)) + + def _get_dataset(self, rank: int, num_workers: int, dataloader: _Iterable) -> _List[_Any]: + """ + Create a subset of dataloader for worker with given rank. + """ + dataset = [] + num_samples = self._config.calibration_nsamples // num_workers + sampled = 0 + for idx, data in enumerate(dataloader): + if idx % num_workers == rank: + dataset.append(_copy.deepcopy(data)) + sampled += 1 + if sampled == num_samples: + break + return dataset + + @staticmethod + @_contextmanager + def _register_grad_square_hooks(model: _torch.nn.Module): + """ + Context manager for registering gradient squaring hooks within the context + and unregistering them on exit. + """ + hook_handles = [] + for param in model.parameters(): + if param.requires_grad: + hook_handles.append(param.register_hook(lambda grad: _torch.square(grad))) + try: + yield model + finally: + for handle in hook_handles: + handle.remove() + + def _get_sensitivity_path(self, sensitivity_path: _Optional[str]) -> str: + """ + Return sensitivity_path if it's not None else a temporary path + """ + return sensitivity_path if sensitivity_path is not None else self._sensitivity_path + + def compute_sensitivity( + self, + dataloader: _Iterable, + loss_fn: _Callable, + sensitivity_path: _Optional[str] = None, + num_sensitivity_workers: int = 1, + fsdp_auto_wrap_policy: _Optional[_FSDPAutoWrapPolicy] = None, + ) -> _Dict[str, _Any]: + """ + Compute sensitivities of model's weights. A weight element's sensitivity indicates + how much effect perturbing it has on the model's loss function. The sensitivities + are computed as Fisher information of the model's weights. + + If ``sensitivity_path`` is passed as a non ``None`` value, the sensitivity matrices + saved at the location specified by ``sensitivity_path``. + + When computing sensitivity of large models, it is recommended to use ``num_sensitivity_workers`` + equal to the number of GPUs available. This is because computing gradients using a single + process maybe infeasible for a large model as it may not fit on a single GPU. + When ``num_sensitivity_workers > 1``, the model is sharded on multiple GPUs using + :py:class:`FullyShardedDataParallel`, which enables distributed computation of gradients. + + Args: + dataloader (:py:class:`Iterable`): An iterable where each element + is an input to the model to be compressed. Used for computing gradients of model weights. + loss_fn (:obj:`Callable`): A callable which takes the model and data as input and performs + a forward pass on the model and computes the training loss + sensitivity_path (:obj:`str` or ``None``): An optional path for saving the sensitivity + of weights. Defaults to ``None``. + num_sensitivity_workers (:obj:`int`): Number of worker processes used for computing sensitivity. + Defaults to ``1``. + fsdp_auto_wrap_policy (:py:class:`_FSDPAutoWrapPolicy` or ``None``): Policy which specifies + how different submodules of ``model`` are wrapped with individual + :py:class:`FullyShardedDataParallel` wrappers. This argument is only used when + ``num_sensitivity_workers > 1`` and it is only necessary when the model cannot be fit on a single GPU. + Please refer to documentation of :py:class:`_FSDPAutoWrapPolicy` for more details. + Defaults to ``None`. + """ + if num_sensitivity_workers > 1 and not _torch.cuda.is_available(): + _logger.warning( + "num_sensitivity_workers > 1 is only supported on GPUs with CUDA. Setting " + "num_sensitivity_workers to 1, since a CUDA compatible PyTorch installation" + "couldn't be found." + ) + num_sensitivity_workers = 1 + + # We save the model's state dict so that we can restore it later + # We need to do this because _compute_sensitivity_impl_multiple_workers + # sets the parameters' value to squares of their gradients and + # _compute_sensitivity_impl_single_worker can modify layers such as batch norm + # during forward pass + _torch.save(self._model.state_dict(), self._model_checkpoint_path) + if num_sensitivity_workers == 1: + self._compute_sensitivity_impl_single_worker( + self._get_dataset(0, 1, dataloader), + loss_fn, + sensitivity_path, + ) + else: + if fsdp_auto_wrap_policy is None: + _logger.warning( + "num_sensitivity_workers > 1 and fsdp_auto_wrap_policy is None. For a large model, this might " + "lead to OOM issue on GPUs. Consider setting fsdp_auto_wrap_policy to indicate how different " + "submodules of the model should be wrapped with FSDP wrappers to prevent all gather for all " + "parameters on all GPUs." + ) + + ctx = _mp.get_context("spawn") + + worker_processes = [ + ctx.Process( + target=self._compute_sensitivity_impl_multiple_workers, + args=( + rank, + num_sensitivity_workers, + self._get_dataset(rank, num_sensitivity_workers, dataloader), + loss_fn, + sensitivity_path, + fsdp_auto_wrap_policy, + ), + name=f"Process-{rank}", + ) + for rank in range(num_sensitivity_workers) + ] + for worker_process in worker_processes: + worker_process.start() + _logger.info(f"Started {worker_process.name} for computing sensitivity.") + + for worker_process in worker_processes: + worker_process.join() + _logger.info(f"Finished {worker_process.name}.") + + # restore the original state of the model + self._model.cpu() + old_state_dict = _torch.load(self._model_checkpoint_path) + self._model.load_state_dict(old_state_dict) + + return self._process_sensitivity(sensitivity_path) + + def _process_sensitivity(self, sensitivity_path: _Optional[str] = None) -> _Dict[str, _Any]: + """ + Post process the sensitivity values to normalize them. + """ + raw_sensitivity_dict = _torch.load(self._get_sensitivity_path(sensitivity_path)) + sensitivity_dict = dict() + for key, val in raw_sensitivity_dict.items(): + # Since optimizer sets param value as: p <= p - learning_rate * (grad**2), + # we need to negate the values to get grad**2 + val = 100 * -val + if len(val.nonzero()) == 0: + val[val == 0] = 1.0 + + # normalize sensitivity between 0 and 1 + val = val / _torch.max(val) + + # Clipping very small or zero sensitivity values stabilizes k-means, + # they can lead to divergence otherwise + val[val == 0] = _torch.min(val[val != 0]) + val[val < self._SENSITIVITY_CLIP_THR] = self._SENSITIVITY_CLIP_THR + + sensitivity_dict[key] = val + + # If user wants to save sensitivity values at the specified path + # we save them in the processed state + if sensitivity_path is not None: + _torch.save(sensitivity_dict, sensitivity_path) + return sensitivity_dict + + def _compute_outlier_mask(self, sensitivity: _torch.Tensor, outliers: float) -> _torch.Tensor: + """ + Compute outlier masks using the sensitivity values. + """ + sensitivity_flat = sensitivity.flatten() + numel = sensitivity_flat.numel() + num_outliers = int(numel * (outliers / 100.0)) + mask = _torch.ones_like(sensitivity_flat, dtype=_torch.bool) + mask[_torch.argsort(sensitivity_flat, descending=True)[:num_outliers]] = False + return mask.reshape(sensitivity.shape) + + def _get_submodules_to_compress(self) -> _Iterable[_Tuple[str, _torch.nn.Module]]: + """ + Return an iterator over the names and submodules to be compressed. + """ + for name, submodule in self._model.named_modules(): + yield name, submodule + + def compress( + self, + dataloader: _Optional[_Iterable] = None, + loss_fn: _Optional[_Callable] = None, + sensitivity_path: _Optional[str] = None, + num_kmeans_workers: int = 1, + num_sensitivity_workers: int = 1, + inplace: bool = False, + fsdp_auto_wrap_policy: _Optional[_FSDPAutoWrapPolicy] = None, + ) -> _torch.nn.Module: + """ + Compresses a model's weights using Fisher information sensitivity based weighted k-means + palettization. + + Args: + dataloader (:py:class:`Iterable`): An iterable where each element + is an input to the model to be compressed. Used for computing gradients of model weights. + This argument is not needed if ``sensitivity_path`` is specified and will be ignored. + It is required then ``sensitivity_path`` is ``None``. Defaults to ``None``. + loss_fn (:obj:`Callable`): A callable which takes the model and data as input and performs + a forward pass on the model and computes the training loss. This argument is not needed if + ``sensitivity_path`` is specified and will be ignored. It is required when ``sensitivity_path`` + is ``None``. Defaults to ``None``. + sensitivity_path (:obj:`str` or ``None``): An optional path from which the sensitivity values + are loaded. If ``sensitivity_path`` is not ``None``, sensitivity values are loaded from the + path specified, otherwise, sensitivity values are computed using the ``dataloader`` and + ``loss_fn``. The sensitivity values stored at ``sensitivity_path`` should be a dictionary + from strings indicating fully qualified parameter names to tensors with the same shape as the + parameters, with each element of the tensor indicating how important that element is. This is + usally the output of the :py:meth:`compute_sensitivity` method. Defaults to ``None``. + num_kmeans_workers (:obj:`int`): Number of worker processes to use for performing k-means. + It is recommended to use more than one worker process to parallelize the clustering, + especially when multiple CPUs are available. Defaults to ``1``. + num_sensitivity_workers (:obj:`int`): Number of worker processes to use for computing + sensitivity. For large models, it is recommended to set this value to the number + of GPUs available. Defaults to ``1``. + inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and + the original module is mutated, otherwise a copy of the model is mutated and returned. + Defaults to ``False``. + fsdp_auto_wrap_policy (:py:class:`_FSDPAutoWrapPolicy` or ``None``): Policy which specifies + how different submodules of ``model`` are wrapped with individual + :py:class:`FullyShardedDataParallel` wrappers. This argument is only used when + ``num_sensitivity_workers > 1`` and it is only necessary when the model cannot be fit on a single GPU. + Please refer to documentation of :py:class:`_FSDPAutoWrapPolicy` for more details. + Defaults to ``None`. + """ + self._model = super().compress(dataloader=dataloader, inplace=inplace) + if sensitivity_path is None: + sensitivity_dict = self.compute_sensitivity( + dataloader, + loss_fn, + sensitivity_path, + num_sensitivity_workers, + fsdp_auto_wrap_policy=fsdp_auto_wrap_policy, + ) + else: + _logger.info(f"Loading sensitivity values from {sensitivity_path}.") + sensitivity_dict = _torch.load(sensitivity_path) + + kmeans_config_dict = dict() + for name, submodule in self._get_submodules_to_compress(): + submod_config = self._config.get_module_config(name, submodule) + if submod_config is None: + continue + + k_means_module_cls = _KMeansSupportedModulesRegistry.get_kmeans_module(submodule) + if k_means_module_cls is None: + continue + + for param_name in k_means_module_cls.parameter_names: + # Validate configuration for parameter + param = submodule.get_parameter(param_name) + updated_config = _validate_param_config( + name + "." + param_name, + param, + submod_config, + ["palettization_group_size"], + ) + if not updated_config: + continue + + sensitivity_key = f"{name}.{param_name}" if len(name) > 0 else param_name + sensitivity = sensitivity_dict[sensitivity_key] + + if name not in kmeans_config_dict: + kmeans_config_dict[name] = {} + + kmeans_config_dict[name][param_name] = _KMeansConfig( + n_bits=updated_config.n_bits, + axis=updated_config.channel_axis, + lut_dtype=updated_config.lut_dtype, + block_size=updated_config.group_size, + importance=sensitivity, + enable_per_channel_scale=updated_config.enable_per_channel_scale, + ) + + if num_kmeans_workers > 1: + return _ParallelKMeans.cluster_weights( + self._model, kmeans_config_dict, num_workers=num_kmeans_workers + ) + else: + return _SequentialKMeans.cluster_weights(self._model, kmeans_config_dict) + + def report(self) -> _Report: + return _compute_post_training_report( + self._uncompressed_model, + self._model, + supported_modules=self._supported_modules, + ) diff --git a/coremltools/optimize/torch/pruning/__init__.py b/coremltools/optimize/torch/pruning/__init__.py index d78884380..a954d3d19 100644 --- a/coremltools/optimize/torch/pruning/__init__.py +++ b/coremltools/optimize/torch/pruning/__init__.py @@ -1,9 +1,12 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause """ +.. _coremltools_optimize_torch_pruning: + +.. include:: pruning_desc.rst _`MagnitudePruner` ================== diff --git a/coremltools/optimize/torch/pruning/_base_pruner.py b/coremltools/optimize/torch/pruning/_base_pruner.py index 7b30f95fe..556d84368 100644 --- a/coremltools/optimize/torch/pruning/_base_pruner.py +++ b/coremltools/optimize/torch/pruning/_base_pruner.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -10,20 +10,26 @@ import torch as _torch +from coremltools.optimize.torch._utils.metadata_utils import ( + register_metadata_version as _register_metadata_version, +) from coremltools.optimize.torch._utils.torch_utils import get_eval_model as _get_eval_model from coremltools.optimize.torch.base_model_optimizer import ( - BaseModelOptimizer as _BaseModelOptimizer, + BaseTrainingTimeModelOptimizer as _BaseTrainingTimeModelOptimizer, ) from coremltools.optimize.torch.base_model_optimizer import _Report from coremltools.optimize.torch.optimization_config import OptimizationConfig as _OptimizationConfig from coremltools.optimize.torch.pruning._utils import ( get_global_sparsity_summaries as _get_global_sparsity_summaries, ) +from coremltools.optimize.torch.pruning._utils import ( + register_compression_metadata as _register_compression_metadata, +) _logger = _logging.getLogger(__name__) -class BasePruner(_BaseModelOptimizer): +class BasePruner(_BaseTrainingTimeModelOptimizer): pass @@ -51,7 +57,7 @@ def prepare(self, inplace: bool = False) -> _torch.nn.Module: inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned. """ - return _copy.deepcopy(self._model) if not inplace else self._model + return self._get_model_for_compression(inplace=inplace) def step(self): """ @@ -76,9 +82,11 @@ def finalize( if model is None: model = self._model finalized_model = model if inplace else _copy.deepcopy(model) - for _, submodule in finalized_model.named_modules(remove_duplicate=True): + _register_metadata_version(finalized_model) + for name, submodule in finalized_model.named_modules(remove_duplicate=True): if hasattr(submodule, "pruning_method"): submodule.pruning_method.remove(submodule) + _register_compression_metadata(submodule, self._pruner_info[name].config) if model is None: self._model = finalized_model return finalized_model diff --git a/coremltools/optimize/torch/pruning/_base_pruning_method.py b/coremltools/optimize/torch/pruning/_base_pruning_method.py index bafdd4dcc..2b19f81cf 100644 --- a/coremltools/optimize/torch/pruning/_base_pruning_method.py +++ b/coremltools/optimize/torch/pruning/_base_pruning_method.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/optimize/torch/pruning/_utils.py b/coremltools/optimize/torch/pruning/_utils.py index f53465568..af3b444ba 100644 --- a/coremltools/optimize/torch/pruning/_utils.py +++ b/coremltools/optimize/torch/pruning/_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -9,9 +9,12 @@ from typing import Tuple as _Tuple from typing import cast as _cast -import numpy as _np import torch as _torch +from coremltools.optimize.torch._utils.metadata_utils import ( + CompressionMetadata as _CompressionMetadata, +) + logger = _logging.getLogger(__name__) @@ -71,11 +74,6 @@ def magnitude_ranked_mask( ch_shape = magnitude_map.shape[0] assert ch_shape % block_size == 0 - if block_size > ch_shape / 2: - raise ValueError( - f"Pruning block size ({block_size}) can be at most half the number of output channels ({ch_shape}/2={ch_shape/2})" - ) - # Reshape to expose the "block" sub-axis s = list(magnitude_map.shape) # block exposed shape s.insert(1, block_size) @@ -188,63 +186,6 @@ def unstructured_sparsity(weight: _torch.Tensor) -> _torch.Tensor: return weight.eq(0.0).float().mean().item() -def unstructured_sparsity_matrix( - name: str, weight: _torch.Tensor, block_size: int -) -> _torch.Tensor: - import matplotlib - - matplotlib.use("agg") - import matplotlib.pyplot as plt - - rank = len(weight.shape) - - weight = weight.clone().detach() - - if block_size is not None and block_size > 1: - C_out, C_in = weight.shape[:2] - assert C_out % block_size == 0 - if rank > 2: - weight = weight.flatten(2).view(C_out // block_size, block_size, C_in, -1) - else: - weight = weight.view(C_out // 2, 2, C_in) - - sparsity_matrix = weight.sum(1).eq(0.0).float() - else: - sparsity_matrix = weight.eq(0.0).float() - - if rank > 2: - max_kernel_support = _np.prod(sparsity_matrix.shape[2:]) - sparsity_matrix = sparsity_matrix.sum(dim=tuple(range(2, len(sparsity_matrix.shape)))) - else: - max_kernel_support = 1 - - f = plt.figure() - ax = f.gca() - ax.imshow( - max_kernel_support - sparsity_matrix.cpu().numpy(), - cmap="jet", - interpolation="nearest", - vmin=0, - vmax=max_kernel_support, - ) - ax.set_xlabel("Input channels index") - ax.set_ylabel("Output channels index") - sparsity_type = ( - f"Block-{block_size}" if block_size is not None and block_size > 1 else "Unstructured" - ) - ax.set_title(f"{sparsity_type} Sparsity Matrix for Layer {name}") - ax.set_xticks([]) - ax.set_yticks([]) - f.canvas.draw() - - im = _np.frombuffer(f.canvas.tostring_rgb(), dtype=_np.uint8).copy() - im = im.reshape((1,) + f.canvas.get_width_height()[::-1] + (3,)) - - f.clear() - plt.close(f) - return _torch.from_numpy(im) - - def get_global_sparsity_summaries( layer_sparsities: _List[_torch.Tensor], layer_numel: _List[int] ) -> float: @@ -273,3 +214,10 @@ def validate_allowed_granularity_values(instance, attribute, value): f"Allowed values for granularity are: {', '.join(allowed_values)}. " f"Received: {value}" ) + + +def register_compression_metadata(submodule, config): + param_name = config.param_name + metadata = _CompressionMetadata(param_name) + metadata.compression_type = ["pruning"] + metadata.register(submodule) diff --git a/coremltools/optimize/torch/pruning/magnitude_pruner.py b/coremltools/optimize/torch/pruning/magnitude_pruner.py index 6ef8af0be..3e050ea49 100644 --- a/coremltools/optimize/torch/pruning/magnitude_pruner.py +++ b/coremltools/optimize/torch/pruning/magnitude_pruner.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -301,12 +301,6 @@ def __attrs_post_init__(self): f"When n_m_ratio != None, the only allowed value of granularity is " f"per_scalar." ) - if self.initial_sparsity is not None and self.initial_sparsity > 0.0: - raise ValueError( - f"Received initial_sparsity = {self.initial_sparsity} and " - f"n_m_ratio = {self.nm_ratio}. When n_m_ratio != None, the only allowed " - f"value of initial_sparsity is 0." - ) _ModuleTypeConfigType = _NewType( @@ -456,7 +450,7 @@ class MagnitudePruner(_BasePrunerWithPruningMethod): loss_fn = define_loss() # define the loss function # initialize pruner and configure it - # we only prune the first conv layer + # we only prune the fisrt conv layer config = MagnitudePrunerConfig.from_dict( { "module_name_configs": { diff --git a/coremltools/optimize/torch/pruning/pruning_scheduler.py b/coremltools/optimize/torch/pruning/pruning_scheduler.py index 591f9e587..7177d73f1 100644 --- a/coremltools/optimize/torch/pruning/pruning_scheduler.py +++ b/coremltools/optimize/torch/pruning/pruning_scheduler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/optimize/torch/quantization/__init__.py b/coremltools/optimize/torch/quantization/__init__.py index 3b58bf642..5e7a7c78e 100644 --- a/coremltools/optimize/torch/quantization/__init__.py +++ b/coremltools/optimize/torch/quantization/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -27,6 +27,18 @@ .. autoclass:: coremltools.optimize.torch.quantization.QuantizationScheme +_`PostTrainingQuantization` +============================ + +.. autoclass:: coremltools.optimize.torch.quantization.ModulePostTrainingQuantizerConfig + :members: from_dict, as_dict, from_yaml + +.. autoclass:: coremltools.optimize.torch.quantization.PostTrainingQuantizerConfig + :members: set_global, set_module_type, set_module_name, from_dict, as_dict, from_yaml + +.. autoclass:: coremltools.optimize.torch.quantization.PostTrainingQuantizer + :members: compress + """ from .quantization_config import ( @@ -36,3 +48,8 @@ QuantizationScheme, ) from .quantizer import LinearQuantizer +from .post_training_quantization import ( + ModulePostTrainingQuantizerConfig, + PostTrainingQuantizer, + PostTrainingQuantizerConfig, +) diff --git a/coremltools/optimize/torch/quantization/_annotation_handler_utils.py b/coremltools/optimize/torch/quantization/_annotation_handler_utils.py new file mode 100644 index 000000000..d499e98fe --- /dev/null +++ b/coremltools/optimize/torch/quantization/_annotation_handler_utils.py @@ -0,0 +1,726 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools as _itertools +from typing import Callable as _Callable +from typing import List as _List +from typing import Optional as _Optional +from typing import Tuple as _Tuple + +import torch as _torch +import torch.nn.functional as _F +from torch.ao.quantization.pt2e.utils import get_aten_graph_module as _get_aten_graph_module +from torch.ao.quantization.quantizer.quantizer import ( + FixedQParamsQuantizationSpec as _FixedQParamsQuantizationSpec, +) +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation as _QuantizationAnnotation, +) +from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as _TorchQuantizationSpec +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationSpecBase as _TorchQuantizationSpecBase, +) +from torch.ao.quantization.quantizer.quantizer import ( + SharedQuantizationSpec as _SharedQuantizationSpec, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer import _get_module_name_filter +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + _is_annotated, + _mark_nodes_as_annotated, +) +from torch.fx import Node as _Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap as _SubgraphMatcherWithNameNodeMap, +) +from torch.fx.passes.utils.source_matcher_utils import ( + get_source_partitions as _get_source_partitions, +) + +from coremltools.optimize.torch.quantization._annotation_config import ( + AnnotationConfig as _AnnotationConfig, +) + +# All activations for which fusion is supported +_supported_activations = ( + _F.relu, + _F.relu6, + _F.leaky_relu, + _F.silu, + _F.elu, + _F.celu, + _F.selu, + _F.mish, + _F.hardtanh, + _F.hardswish, + _F.hardsigmoid, +) + + +# These activation functions don't have an inplace argument +_supported_activations_no_inplace = (_F.gelu, _F.sigmoid, _F.logsigmoid, _F.tanh) + + +# Map of dimension to convolution function +_conv_fn_map = {1: _F.conv1d, 2: _F.conv2d, 3: _F.conv3d} + + +def _adjust_activation_qspec( + node: _torch.fx.Node, qspec: _Optional[_TorchQuantizationSpecBase] +) -> _Optional[_TorchQuantizationSpecBase]: + """ + Adjust quantization spec for ops which can use fixed qparams + or ops for which we can use affine quantization mode during + symmetric quantization because their output is always positive. + """ + if qspec is None: + return qspec + + tanh_qspec = _FixedQParamsQuantizationSpec( + dtype=_torch.uint8, + scale=2.0 / 256.0, + zero_point=128, + quant_min=0, + quant_max=255, + qscheme=_torch.per_tensor_symmetric, + ) + + sigmoid_qspec = _FixedQParamsQuantizationSpec( + dtype=_torch.uint8, + scale=1.0 / 256.0, + zero_point=0, + quant_min=0, + quant_max=255, + qscheme=_torch.per_tensor_affine, + ) + + fixed_q_params_ops = { + _torch.ops.aten.tanh.default: tanh_qspec, + _torch.ops.aten.tanh_.default: tanh_qspec, + _torch.ops.aten.sigmoid.default: sigmoid_qspec, + _torch.ops.aten.sigmoid_.default: sigmoid_qspec, + _torch.ops.aten.hardsigmoid.default: sigmoid_qspec, + _torch.ops.aten.hardsigmoid_.default: sigmoid_qspec, + } + + always_affine_ops = ( + _torch.ops.aten.relu.default, + _torch.ops.aten.relu_.default, + _torch.ops.aten.relu6.default, + _torch.ops.aten.relu6_.default, + ) + + # ReLU6 activation maps to _torch.ops.aten.hardtanh.default with + # min_val = 0 and max_val = 6 + is_always_affine_op = node.target in always_affine_ops or ( + node.target in [_torch.ops.aten.hardtanh.default, _torch.ops.aten.hardtanh_.default] + and node.args[1] == 0 # min_val, corresponding to ReLU6 + and node.args[2] == 6 # max_val, corresponding to ReLU6 + ) + + if node.target in fixed_q_params_ops: + return _TorchQuantizationSpec( + observer_or_fake_quant_ctr=qspec.observer_or_fake_quant_ctr, + dtype=qspec.dtype, + qscheme=fixed_q_params_ops[node.target].qscheme, + ) + # FIXME: Because of a bug in PyTorch in function _create_obs_or_fq_from_qspec + # in module torch/ao/quantization/fx/prepare.py which creates a + # FixedQParamsFakeQuantize partial, instead of an instance, we cannot + # actually create FixedQParamsQuantizationSpec + if is_always_affine_op: + return _TorchQuantizationSpec( + observer_or_fake_quant_ctr=qspec.observer_or_fake_quant_ctr, + dtype=qspec.dtype, + qscheme=_torch.per_tensor_affine, + ) + return qspec + + +def get_object_type_filter(tp: _Callable): + """ + Returns a filter which returns True if a node in the final exported graph + was created because of an object of type ``tp`` + """ + + def object_type_filter(n: _Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # 'add_10': ('add', ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [t for _, t in nn_module_stack.values()] + source_fn_stack = n.meta.get("source_fn_stack", {}) + types.extend([t for _, t in source_fn_stack]) + return tp in types + + return object_type_filter + + +def get_not_object_type_or_name_filter( + tp_list: _List[_Callable], module_name_list: _List[str] +) -> _Callable[[_Node], bool]: + """ + Returns a filter which returns True if a node in the final exported graph + was not created using any modules with names in ``module_name_list`` or objects with + type in ``tp_list``. + """ + object_type_filters = [get_object_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_object_type_or_name_filter(n: _Node) -> bool: + return not any(f(n) for f in object_type_filters + module_name_list_filters) + + return not_object_type_or_name_filter + + +def _get_weighted_mod_pattern( + mod_fn: _Callable, + example_inputs: _Tuple[_torch.Tensor, ...], + act_fn: _Optional[_Callable] = None, + act_in_place: bool = False, +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> weighted_mod -> activation -> output + + A weighted mod is a module which has a weight and bias, such as a + convolution module or a linear module. + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + + def pattern(input, weight, bias): + mod_out = mod_fn(input, weight, bias) + output = mod_out + node_dict = { + "input": input, + "mod": mod_out, + "weight": weight, + "bias": bias, + } + if act_fn is not None: + # Only add output if activation function is applied to model output + output = act_fn(output, inplace=True) if act_in_place else act_fn(output) + node_dict["output"] = output + return output, node_dict + + return _get_aten_graph_module(pattern, example_inputs, is_cuda=False) + + +def _get_weighted_mod_bn_pattern( + mod_fn: _Callable, + example_inputs: _Tuple[_torch.Tensor, ...], + act_fn: _Optional[_Callable] = None, + act_in_place: bool = False, +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> weighted_mod -> batch_norm -> activation -> output + + A weighted mod is a module which has a weight and bias, such as a + convolution module or a linear module. + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + + def pattern(input, weight, bias, bn_weight, bn_bias, bn_run_mean, bn_run_var): + mod_out = mod_fn(input, weight, bias) + output = _F.batch_norm(mod_out, bn_run_mean, bn_run_var, bn_weight, bn_bias, training=True) + if act_fn is not None: + output = act_fn(output, inplace=True) if act_in_place else act_fn(output) + return output, { + "input": input, + "mod": mod_out, + "weight": weight, + "bias": bias, + "output": output, + } + + return _get_aten_graph_module(pattern, example_inputs, is_cuda=False) + + +def get_binary_op_act_pattern( + binary_op: _Callable, + act_fn: _Optional[_Callable] = None, + act_in_place: bool = False, +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input_1 --- + \ + --> binary_op -> activation -> output + / + input_2 --- + + A binary op is any operation which consumes two inputs to create one output, + such as addition or multiplication. + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + + def pattern(input_1, input_2): + binary_op_out = binary_op(input_1, input_2) + node_dict = { + "binary_op": binary_op_out, + } + output = binary_op_out + if act_fn is not None: + output = act_fn(output, inplace=True) if act_in_place else act_fn(output) + node_dict["output"] = output + return output, node_dict + + example_inputs = (_torch.randn(1), _torch.randn(1)) + return _get_aten_graph_module(pattern, example_inputs, is_cuda=False) + + +def get_conv_pattern( + conv_dim: int, act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> conv -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + assert ( + conv_dim in _conv_fn_map + ), f"Dimension {conv_dim} is not supported for Convolution layers." + + example_inputs = ( + _torch.randn(1, 1, *[3] * conv_dim), # input + _torch.randn(1, 1, *[1] * conv_dim), # conv weight + _torch.randn(1), # conv bias + ) + return _get_weighted_mod_pattern(_conv_fn_map[conv_dim], example_inputs, act_fn, act_in_place) + + +def get_conv_bn_pattern( + conv_dim: int, act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> conv -> batch_norm -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + assert ( + conv_dim in _conv_fn_map + ), f"Dimension {conv_dim} is not supported for Convolution layers." + + example_inputs = ( + _torch.randn(1, 1, *[3] * conv_dim), # input + _torch.randn(1, 1, *[1] * conv_dim), # conv weight + _torch.randn(1), # conv bias + _torch.randn(1), # bn_weight + _torch.randn(1), # bn_bias + _torch.randn(1), # bn_run_mean + _torch.randn(1), # bn_run_var + ) + return _get_weighted_mod_bn_pattern( + _conv_fn_map[conv_dim], example_inputs, act_fn, act_in_place + ) + + +def get_linear_pattern( + act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> linear -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + example_inputs = ( + _torch.randn(1, 1), # input + _torch.randn(3, 1), # linear weight + _torch.randn(3), # linear bias + ) + return _get_weighted_mod_pattern(_F.linear, example_inputs, act_fn, act_in_place) + + +def get_linear_bn_pattern( + act_fn: _Optional[_Callable] = None, act_in_place: bool = False +) -> _torch.nn.Module: + """ + Returns an aten graph corresponding to a sequence of these ops: + input -> linear -> batch_norm -> activation -> output + + No activation is used if ``act_fn`` is ``None``. + ``act_fn`` is an activation function from _supported_activations or + _supported_activations_no_inplace + """ + example_inputs = ( + _torch.randn(2, 1), # input + _torch.randn(3, 1), # linear weight + _torch.randn(3), # linear bias + _torch.randn(3), # bn_weight + _torch.randn(3), # bn_bias + _torch.randn(3), # bn_run_mean + _torch.randn(3), # bn_run_var + ) + return _get_weighted_mod_bn_pattern(_F.linear, example_inputs, act_fn, act_in_place) + + +def annotate_weighted_mod_pattern( + model: _torch.fx.GraphModule, + pattern_gm: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]], +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates all nodes in ``model``, which match the pattern specified by ``pattern_gm`` + using ``quantization_config``. + + ``pattern_gm`` captures patterns of the following type: + + input -> weighted_mod -> batch_norm -> activation -> output + + batch_norm and activation may or may not be applied in the pattern. + + Only annotates those patterns in which all nodes return True when ``filter_fn`` is applied + to them. + """ + model.graph.eliminate_dead_code() + model.recompile() + + matcher = _SubgraphMatcherWithNameNodeMap(pattern_gm, ignore_literals=True) + matches = matcher.match(model.graph) + + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + input_node = name_node_map["input"] + mod_node = name_node_map["mod"] + weight_node = name_node_map["weight"] + bias_node = name_node_map["bias"] + if "output" in name_node_map: + # In this case, an activation is applied to the weighted module output + output_node = name_node_map["output"] + # If the output is same as mod_node, it means we have an inplace activation, + # so we need to correct the mod_node + if mod_node == output_node: + mod_node = mod_node.args[0] + else: + output_node = None + + # Validate mod args + if mod_node.args[0] is not input_node: + raise ValueError(f"Weighted module arg did not contain input node {input_node}") + if mod_node.args[1] is not weight_node: + raise ValueError(f"Weighted module arg did not contain weight node {weight_node}") + if len(mod_node.args) > 2 and mod_node.args[2] is not bias_node: + raise ValueError(f"Weighted module arg did not contain bias node {bias_node}") + + # Skip if the partition is already annotated or is filtered out by the user + partition = [mod_node, weight_node] + if bias_node is not None: + partition.append(bias_node) + if _is_annotated(partition): + continue + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + # Annotate conv inputs and pattern output + input_qspec_map = dict() + if not _is_annotated([input_node]): + input_qspec_map[input_node] = ( + quantization_config.input_activation if quantization_config else None + ) + else: + input_qspec_map[input_node] = input_node.meta["quantization_annotation"].output_qspec + + input_qspec_map[weight_node] = quantization_config.weight if quantization_config else None + output_qspec = quantization_config.output_activation if quantization_config else None + if bias_node is not None: + input_qspec_map[bias_node] = None + + if output_node is None: + mod_node.meta["quantization_annotation"] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) + else: + mod_node.meta["quantization_annotation"] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + if not _is_annotated([output_node]): + output_qspec = _adjust_activation_qspec(node=output_node, qspec=output_qspec) + output_node.meta["quantization_annotation"] = _QuantizationAnnotation( + output_qspec=output_qspec, + _annotated=True, + ) + + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +def annotate_binary_op_act_pattern( + model: _torch.fx.GraphModule, + pattern_gm: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates all nodes in ``model``, which match the pattern specified by ``pattern_gm`` + using ``quantization_config``. + + ``pattern_gm`` captures patterns of the following type: + + input_1 --- + \ + --> binary_op -> activation -> output + / + input_2 --- + + activation may or may not be applied in the pattern. + + Only annotates those patterns in which all nodes return True when ``filter_fn`` is applied + to them. + """ + model.graph.eliminate_dead_code() + model.recompile() + + matcher = _SubgraphMatcherWithNameNodeMap(pattern_gm, ignore_literals=True) + matches = matcher.match(model.graph) + + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + binary_op_node: _Node = name_node_map["binary_op"] + if "output" in name_node_map: + output_node = name_node_map["output"] + # In this case, an activation is applied to the weighted module output + output_node = name_node_map["output"] + # If the output is same as binary_op_node, it means we have an inplace activation, + # so we need to correct the binary_op_node + if binary_op_node == output_node: + binary_op_node = binary_op_node.args[0] + partition = [output_node, binary_op_node] + else: + output_node = None + partition = [binary_op_node] + + if output_node is not None and len(binary_op_node.users) > 1: + raise ValueError("Binary op with activation has more than one users.") + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + input_act_qspec = quantization_config.input_activation if quantization_config else None + output_act_qspec = quantization_config.output_activation if quantization_config else None + + input_qspec_map = {} + input_act0 = binary_op_node.args[0] + if isinstance(input_act0, _Node): + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = binary_op_node.args[1] + if isinstance(input_act1, _Node): + input_qspec_map[input_act1] = input_act_qspec + + if output_node is None: + binary_op_node.meta["quantization_annotation"] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + else: + binary_op_node.meta["quantization_annotation"] = _QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + output_act_qspec = _adjust_activation_qspec(node=output_node, qspec=output_act_qspec) + output_node.meta["quantization_annotation"] = _QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +def annotate_unary_shared_observer_ops( + model: _torch.fx.GraphModule, + ops: _List[_Callable], + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + Annotates all nodes in ``model``, which correspond to unary ops specified in ``ops``. + + input --> op --> output + + input and output nodes share the same quantization parameters. + """ + partitions = _get_source_partitions(model.graph, ops, filter_fn) + annotated_partitions = [] + for _, op_partitions in partitions.items(): + for partition in op_partitions: + output_node = partition.output_nodes[0] + op_node = partition.nodes[0] + if _is_annotated([output_node, op_node]): + continue + + input_node = op_node.args[0] + + input_act_qspec = quantization_config.input_activation if quantization_config else None + output_act_qspec = ( + quantization_config.output_activation if quantization_config else None + ) + + if ( + "quantization_annotation" not in input_node.meta + or not input_node.meta["quantization_annotation"]._annotated + or input_node.meta["quantization_annotation"].output_qspec is None + or input_act_qspec is None + or output_act_qspec is None + ): + continue + + # input and output of op will share quantization parameter with input of op + act_qspec = _SharedQuantizationSpec(input_node) + op_node.meta["quantization_annotation"] = _QuantizationAnnotation( + input_qspec_map={ + input_node: act_qspec, + }, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = _QuantizationAnnotation( + output_qspec=act_qspec, + _annotated=True, + ) + annotated_partitions.append(partition.nodes) + return annotated_partitions + + +def annotate_conv_bn_act_helper( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, + use_bn: bool = False, +) -> _Optional[_List[_List[_Node]]]: + """ + A helper function for annotating all patterns involving convolution operations, i.e., + + input -> conv -> batch_norm -> activation -> output + + conv can be either 1D, 2D or 3D + batch_norm and activation may or may not be applied. + """ + annotated_partitions = [] + + pattern_map = { + True: get_conv_bn_pattern, + False: get_conv_pattern, + } + + conv_dims = [1, 2, 3] + combinations = _itertools.product(conv_dims, _supported_activations, [True, False]) + for conv_dim, act_fn, act_in_place in combinations: + pattern_gm = pattern_map[use_bn](conv_dim, act_fn, act_in_place) + annotated_partitions.extend( + annotate_weighted_mod_pattern(model, pattern_gm, quantization_config, filter_fn) + ) + + combinations = _itertools.product(conv_dims, _supported_activations_no_inplace) + for conv_dim, act_fn in combinations: + pattern_gm = pattern_map[use_bn](conv_dim, act_fn, act_in_place=False) + annotated_partitions.extend( + annotate_weighted_mod_pattern(model, pattern_gm, quantization_config, filter_fn) + ) + + return annotated_partitions + + +def annotate_linear_bn_act_helper( + model: _torch.fx.GraphModule, + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, + use_bn: bool = False, +) -> _Optional[_List[_List[_Node]]]: + """ + A helper function for annotating all patterns involving linear operations, i.e., + + input -> linear -> batch_norm -> activation -> output + + batch_norm and activation may or may not be applied. + """ + annotated_partitions = [] + + pattern_map = { + True: get_linear_bn_pattern, + False: get_linear_pattern, + } + + combinations = _itertools.product(_supported_activations, [True, False]) + for act_fn, act_in_place in combinations: + pattern_gm = pattern_map[use_bn](act_fn, act_in_place) + annotated_partitions.extend( + annotate_weighted_mod_pattern(model, pattern_gm, quantization_config, filter_fn) + ) + + for act_fn in _supported_activations_no_inplace: + pattern_gm = pattern_map[use_bn](act_fn, act_in_place=False) + annotated_partitions.extend( + annotate_weighted_mod_pattern(model, pattern_gm, quantization_config, filter_fn) + ) + + return annotated_partitions + + +def annotate_binary_op_helper( + model: _torch.fx.GraphModule, + binary_ops: _List[_Callable], + quantization_config: _Optional[_AnnotationConfig], + filter_fn: _Optional[_Callable[[_Node], bool]] = None, +) -> _Optional[_List[_List[_Node]]]: + """ + A helper function for annotating all patterns involving binary operations, i.e., + using ``quantization_config``. + + input_1 --- + \ + --> binary_op -> activation -> output + / + input_2 --- + + activation may or may not be applied in the pattern. + """ + annotated_partitions = [] + + combinations = _itertools.product(binary_ops, _supported_activations, [True, False]) + for binary_op, act_fn, act_in_place in combinations: + pattern_gm = get_binary_op_act_pattern(binary_op, act_fn, act_in_place) + annotated_partitions.extend( + annotate_binary_op_act_pattern(model, pattern_gm, quantization_config, filter_fn) + ) + + combinations = _itertools.product(binary_ops, _supported_activations_no_inplace) + for binary_op, act_fn in combinations: + pattern_gm = get_binary_op_act_pattern(binary_op, act_fn, act_in_place=False) + annotated_partitions.extend( + annotate_binary_op_act_pattern(model, pattern_gm, quantization_config, filter_fn) + ) + + return annotated_partitions diff --git a/coremltools/optimize/torch/quantization/_backend_config.py b/coremltools/optimize/torch/quantization/_backend_config.py index b0e67930a..e32b359e0 100644 --- a/coremltools/optimize/torch/quantization/_backend_config.py +++ b/coremltools/optimize/torch/quantization/_backend_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -863,7 +863,7 @@ def get_backend_config() -> _BackendConfig: def get_supported_modules() -> _List[_Any]: """ - Returns a list of modules which are supported for quantization + Returns a tuple of modules which are supported for quantization aware training. """ return tuple(_BackendConfigRegistry.supported_modules) diff --git a/coremltools/optimize/torch/quantization/_backend_config_utils.py b/coremltools/optimize/torch/quantization/_backend_config_utils.py index 2a0d6d29b..a918ca0da 100644 --- a/coremltools/optimize/torch/quantization/_backend_config_utils.py +++ b/coremltools/optimize/torch/quantization/_backend_config_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -41,20 +41,34 @@ weighted_dtype_configs = [ - # weight int, act float + # weight int, act float, weight dtype signed _DTypeConfig( input_dtype=_torch.float, output_dtype=_torch.float, weight_dtype=_torch.qint8, bias_dtype=_torch.float, ), - # weight int, act int + # weight int, act float, weight dtype unsigned + _DTypeConfig( + input_dtype=_torch.float, + output_dtype=_torch.float, + weight_dtype=_torch.quint8, + bias_dtype=_torch.float, + ), + # weight int, act int, weight dtype signed _DTypeConfig( input_dtype=_torch.quint8, output_dtype=_torch.quint8, weight_dtype=_torch.qint8, bias_dtype=_torch.float, ), + # weight int, act int, weight dtype unsigned + _DTypeConfig( + input_dtype=_torch.quint8, + output_dtype=_torch.quint8, + weight_dtype=_torch.quint8, + bias_dtype=_torch.float, + ), ] diff --git a/coremltools/optimize/torch/quantization/_configure.py b/coremltools/optimize/torch/quantization/_configure.py index 350f2c479..eaf3d0f50 100644 --- a/coremltools/optimize/torch/quantization/_configure.py +++ b/coremltools/optimize/torch/quantization/_configure.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -6,6 +6,7 @@ from collections import defaultdict as _defaultdict from typing import Any as _Any from typing import Optional as _Optional +from typing import Tuple as _Tuple import torch as _torch import torch.ao.quantization as _aoquant @@ -155,7 +156,7 @@ def __init__( self._modules_to_replace = _defaultdict(list) self._new_act_post_process = dict() - def prepare(self, model: _nn.Module, example_inputs: _Any): + def prepare(self, model: _nn.Module, example_inputs: _Tuple[_Any, ...]): """ Performs graph passes on model to configure activation and weight quantization layers. """ diff --git a/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py b/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py index 7821c2243..868338cdc 100644 --- a/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py +++ b/coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py @@ -11,20 +11,21 @@ import torch as _torch import torch.nn.functional as _F + +_IS_TORCH_OLDER_THAN_2_3 = tuple(map(int, _torch.__version__.split(".")[:2])) < (2, 3) _IS_TORCH_OLDER_THAN_2_4 = tuple(map(int, _torch.__version__.split(".")[:2])) < (2, 4) if _IS_TORCH_OLDER_THAN_2_4: from torch.ao.quantization.pt2e.utils import get_aten_graph_module else: from torch.ao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern + from torch.ao.quantization.quantizer.quantizer import ( FixedQParamsQuantizationSpec as _FixedQParamsQuantizationSpec, ) from torch.ao.quantization.quantizer.quantizer import ( QuantizationAnnotation as _QuantizationAnnotation, ) -from torch.ao.quantization.quantizer.quantizer import ( - QuantizationSpec as _TorchQuantizationSpec, -) +from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as _TorchQuantizationSpec from torch.ao.quantization.quantizer.quantizer import ( QuantizationSpecBase as _TorchQuantizationSpecBase, ) @@ -75,8 +76,10 @@ def _get_aten_graph_module( pattern: _torch.nn.Module, example_inputs: _Tuple[_torch.Tensor], is_cuda: bool = False ): - if _IS_TORCH_OLDER_THAN_2_4: + if _IS_TORCH_OLDER_THAN_2_3: return get_aten_graph_module(pattern.forward, example_inputs, is_cuda) + elif _IS_TORCH_OLDER_THAN_2_4: + return get_aten_graph_module(pattern, example_inputs, is_cuda) else: return _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) diff --git a/coremltools/optimize/torch/quantization/_qconfig_mapping.py b/coremltools/optimize/torch/quantization/_qconfig_mapping.py index 506a54b19..702df079c 100644 --- a/coremltools/optimize/torch/quantization/_qconfig_mapping.py +++ b/coremltools/optimize/torch/quantization/_qconfig_mapping.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -14,6 +14,7 @@ from coremltools.optimize.torch.quantization._backend_config import ( get_supported_modules as _get_supported_modules, ) +from coremltools.optimize.torch.quantization._utils import get_quant_range as _get_quant_range from coremltools.optimize.torch.quantization.quantization_config import ( LinearQuantizerConfig as _LinearQuantizerConfig, ) @@ -125,6 +126,16 @@ def _create_qconfig_from_quantization_config( is_per_channel=False, ), ) + + quant_min, quant_max = ( + _get_quant_range( + n_bits=quantization_config.weight_n_bits, + dtype=quantization_config.weight_dtype, + ) + if quantization_config.weight_n_bits < 8 + else (None, None) + ) + weight_qconfig = _aoquant.FakeQuantize.with_args( observer=_ObserverType.get_observer( quantization_config.weight_observer, @@ -135,6 +146,8 @@ def _create_qconfig_from_quantization_config( quantization_config.quantization_scheme, is_per_channel=quantization_config.weight_per_channel, ), + quant_min=quant_min, + quant_max=quant_max, ) return _aoquant.QConfig(activation=activation_qconfig, weight=weight_qconfig) diff --git a/coremltools/optimize/torch/quantization/_utils.py b/coremltools/optimize/torch/quantization/_utils.py index 5ddfa1c7b..d85f90bf5 100644 --- a/coremltools/optimize/torch/quantization/_utils.py +++ b/coremltools/optimize/torch/quantization/_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -9,6 +9,7 @@ from typing import Dict as _Dict from typing import List as _List from typing import Optional as _Optional +from typing import Tuple as _Tuple import torch as _torch import torch.ao.quantization as _aoquant @@ -16,6 +17,9 @@ from torch.ao.quantization.backend_config import BackendConfig as _BackendConfig from torch.ao.quantization.backend_config import ObservationType as _ObservationType +from coremltools.optimize.torch._utils.metadata_utils import ( + CompressionMetadata as _CompressionMetadata, +) from coremltools.optimize.torch._utils.version_utils import is_torch_2 as _is_torch_2 @@ -153,3 +157,26 @@ def get_share_qparams_ops(backend_config: _BackendConfig): for op in configs if configs[op].observation_type == _ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT ] + + +def get_quant_range(n_bits: int, dtype: _torch.dtype) -> _Tuple[int, int]: + """ + Returns quant_max and quant_min values for a given quantization n_bits. + """ + max_q = 2**n_bits + if dtype in [_torch.quint8, _torch.uint8]: + quant_min = 0 + quant_max = max_q - 1 + else: + quant_min = -max_q / 2 + quant_max = max_q / 2 - 1 + return int(quant_min), int(quant_max) + + +def register_compression_metadata(submodule, config): + metadata = _CompressionMetadata("weight") + metadata.compression_type = ["quantization"] + metadata.quantization_n_bits = config.weight_n_bits + metadata.quantization_scale = submodule.weight_scale.detach().clone().unsqueeze(-1) + metadata.zero_point = submodule.weight_zero_point.detach().clone().unsqueeze(-1) + metadata.register(submodule) diff --git a/coremltools/optimize/torch/quantization/modules/__init__.py b/coremltools/optimize/torch/quantization/modules/__init__.py index 25c7d28c5..5dc5e6747 100644 --- a/coremltools/optimize/torch/quantization/modules/__init__.py +++ b/coremltools/optimize/torch/quantization/modules/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/optimize/torch/quantization/modules/fused_modules.py b/coremltools/optimize/torch/quantization/modules/fused_modules.py index ee5eeebcb..17b78080c 100644 --- a/coremltools/optimize/torch/quantization/modules/fused_modules.py +++ b/coremltools/optimize/torch/quantization/modules/fused_modules.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -34,7 +34,6 @@ def weight(self): class ConvAct1d(_ConvAct): pass - class ConvAct2d(_ConvAct): pass diff --git a/coremltools/optimize/torch/quantization/modules/qat_modules.py b/coremltools/optimize/torch/quantization/modules/qat_modules.py index 452941b61..d6d8046b6 100644 --- a/coremltools/optimize/torch/quantization/modules/qat_modules.py +++ b/coremltools/optimize/torch/quantization/modules/qat_modules.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -84,55 +84,40 @@ class ConvAct1d(_ConvAct): root_mod = _nn.Conv1d qat_mod = _nnqat.Conv1d fused_mod = _fuse.ConvAct1d - - def __init__(self, conv: _nnqat.Conv1d, act: _nn.Module, qconfig: _aoquant.QConfig): - super().__init__(conv, act, qconfig) + pass class ConvAct2d(_ConvAct): root_mod = _nn.Conv2d qat_mod = _nnqat.Conv2d fused_mod = _fuse.ConvAct2d - - def __init__(self, conv: _nnqat.Conv2d, act: _nn.Module, qconfig: _aoquant.QConfig): - super().__init__(conv, act, qconfig) + pass class ConvAct3d(_ConvAct): root_mod = _nn.Conv3d qat_mod = _nnqat.Conv3d fused_mod = _fuse.ConvAct3d - - def __init__(self, conv: _nnqat.Conv3d, act: _nn.Module, qconfig: _aoquant.QConfig): - super().__init__(conv, act, qconfig) + pass class ConvBnAct1d(_ConvBnAct): intr_mod = _nni.ConvBn1d qat_mod = _nniqat.ConvBn1d fused_mod = _fuse.ConvAct1d - - def __init__(self, conv: _nniqat.ConvBn1d, act: _nn.Module, qconfig: _aoquant.QConfig): - super().__init__(conv, act, qconfig) - + pass class ConvBnAct2d(_ConvBnAct): intr_mod = _nni.ConvBn2d qat_mod = _nniqat.ConvBn2d fused_mod = _fuse.ConvAct2d - - def __init__(self, conv: _nniqat.ConvBn2d, act: _nn.Module, qconfig: _aoquant.QConfig): - super().__init__(conv, act, qconfig) - + pass class ConvBnAct3d(_ConvBnAct): intr_mod = _nni.ConvBn3d qat_mod = _nniqat.ConvBn3d fused_mod = _fuse.ConvAct3d - - def __init__(self, conv: _nniqat.ConvBn3d, act: _nn.Module, qconfig: _aoquant.QConfig): - super().__init__(conv, act, qconfig) - + pass class LinearAct(_torch.nn.Sequential): def __init__(self, linear: _nnqat.Linear, act: _nn.Module, qconfig: _aoquant.QConfig): diff --git a/coremltools/optimize/torch/quantization/modules/quantized_modules.py b/coremltools/optimize/torch/quantization/modules/quantized_modules.py index f1810c045..61e7ce21f 100644 --- a/coremltools/optimize/torch/quantization/modules/quantized_modules.py +++ b/coremltools/optimize/torch/quantization/modules/quantized_modules.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -24,23 +24,17 @@ def from_float(cls, float_conv_act, weight_qparams): class QuantizedConvAct1d(_QuantizedConvAct): ref_quant_mod = _reference.Conv1d - - def __init__(self, conv: _reference.Conv1d, act: _nn.Module): - super().__init__(conv, act) + pass class QuantizedConvAct2d(_QuantizedConvAct): ref_quant_mod = _reference.Conv2d - - def __init__(self, conv: _reference.Conv2d, act: _nn.Module): - super().__init__(conv, act) + pass class QuantizedConvAct3d(_QuantizedConvAct): ref_quant_mod = _reference.Conv3d - - def __init__(self, conv: _reference.Conv3d, act: _nn.Module): - super().__init__(conv, act) + pass class QuantizedLinearAct(_nn.Sequential): diff --git a/coremltools/optimize/torch/quantization/post_training_quantization.py b/coremltools/optimize/torch/quantization/post_training_quantization.py new file mode 100644 index 000000000..9da2c308a --- /dev/null +++ b/coremltools/optimize/torch/quantization/post_training_quantization.py @@ -0,0 +1,460 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import logging as _logging +from collections import OrderedDict as _OrderedDict +from typing import Any as _Any +from typing import Callable as _Callable +from typing import Dict as _Dict +from typing import NewType as _NewType +from typing import Optional as _Optional +from typing import Tuple as _Tuple +from typing import Type as _Type +from typing import Union as _Union + +import cattrs as _cattrs +import torch as _torch +import torch.nn as _nn +from attr import define as _define +from attr import field as _field +from attrs import validators as _validators + +from coremltools.converters.mil.mil.ops.defs.iOS18 import constexpr_blockwise_shift_scale +from coremltools.optimize.coreml._utils import compute_qparams as _cti_compute_qparams +from coremltools.optimize.torch._utils.metadata_utils import ( + CompressionMetadata as _CompressionMetadata, +) +from coremltools.optimize.torch._utils.report_utils import ( + compute_post_training_report as _compute_post_training_report, +) +from coremltools.optimize.torch._utils.torch_utils import get_atomic_layers as _get_atomic_layers +from coremltools.optimize.torch._utils.torch_utils import ( + get_n_bits_from_dtype as _get_n_bits_from_dtype, +) +from coremltools.optimize.torch._utils.torch_utils import ( + maybe_convert_str_to_dtype as _maybe_convert_str_to_dtype, +) +from coremltools.optimize.torch._utils.torch_utils import ( + maybe_convert_str_to_mod_type as _maybe_convert_str_to_mod_type, +) +from coremltools.optimize.torch._utils.validation_utils import ( + validate_param_config as _validate_param_config, +) +from coremltools.optimize.torch.base_model_optimizer import ( + BasePostTrainingModelOptimizer as _BasePostTrainingModelOptimizer, +) +from coremltools.optimize.torch.base_model_optimizer import _Report +from coremltools.optimize.torch.optimization_config import ( + ModuleOptimizationConfig as _ModuleOptimizationConfig, +) +from coremltools.optimize.torch.optimization_config import OptimizationConfig as _OptimizationConfig +from coremltools.optimize.torch.optimization_config import ( + QuantizationGranularity, + _structure_from_dict_hook_factory, +) +from coremltools.optimize.torch.quantization import QuantizationScheme as _QuantizationScheme + +_default_ptq_options = { + "weight_dtype": "int8", + "granularity": "per_channel", + "quantization_scheme": _QuantizationScheme.symmetric, + "block_size": None, +} + +_logger = _logging.getLogger(__name__) + + +@_define +class ModulePostTrainingQuantizerConfig(_ModuleOptimizationConfig): + """ + Configuration class for specifying global and module level quantizer options for + :py:class:`PostTrainingQuantizer` algorithm. + + Args: + weight_dtype (:py:class:`torch.dtype`): The dtype to use for quantizing the weights. The number of bits used + for quantization is inferred from the dtype. When dtype is set to :py:class:`torch.float32`, the weights + corresponding to that layer are not quantized. Defaults to :py:class:`torch.int8` which corresponds to + 8-bit quantization. + granularity (:py:class:`QuantizationGranularity`): Specifies the granularity at which quantization parameters + will be computed. Can be one of ``per_channel``, ``per_tensor`` or ``per_block``. When using ``per_block``, + ``block_size`` argument must be specified. Defaults to ``per_channel``. + quantization_scheme (:py:class:`~.coremltools.optimize.torch.quantization.quantization_config.QuantizationScheme`): Type of + quantization configuration to use. When this parameter is set to ``QuantizationScheme.symmetric``, + all weights are quantized with zero point as zero. When it is set to ``QuantizationScheme.affine``, + zero point can be set anywhere in the range of values allowed for the quantized weight. + Defaults to ``QuantizationScheme.symmetric``. + block_size (:obj:`tuple` of :obj:`int` or :obj:`int`): When ``block_size`` is specified, ``block_size`` + number of values will share the same quantization parameters of scale (and zero point if applicable) across + the input-channel axis. A tuple of integers can be provided for arbitrary sized blockwise quantization. + See below for more details on different possible configurations. Defaults to ``None``. + + This class supports few different configurations to structure the quantization: + + 1. **Per-channel quantization**: This is the default configuration where ``granularity`` is ``per_channel`` and + ``block_size`` is ``None``. In this configuration, quantization parameters are computed for each output channel. + + 2. **Per-tensor quantization**: In this configuration, quantization paramaters are computed for the tensor as a whole. That is, + all values in the tensor will share a single scale (and a single zero point if applicable). The ``granularity`` argument is set + to ``per_tensor``. + + 3. **Per-block quantization**: This is used to structure the tensor for block-wise quantization. For this configuration, + the ``granularity`` is set to ``per_block`` and the ``block_size`` argument has to be specified. + The ``block_size`` argument can either be: + * int: In this case, each row along the output-channel axis will have its own quantization parameters (similar to ``per_channel``). + Additionally, ``block_size`` number of values will share the same quantization parameters, along the input-channel axis. + For example, for a weight matrix of shape ``(10, 10)``, if we provide ``block_size = 2``, the shape of the quantization + parameters would be ``(10, 5)``. + * tuple: For more advanced configuration, users can provide an arbitrary N-D shaped block to share the quantization parameters. + This is specified in the form of a tuple where each value corresponds to the block size for the respective axis of the + weight matrix. The length of the provided tuple should be at most the number of dimensions of the weight matrix. + + .. note: + When performing 4-bit quantization, ``weight_dtype`` is set to :py:class:`torch.int8` for ``int4`` or + :py:class:`torch.uint8` for ``uint4``. This is because PyTorch currently doesn't provide support for 4-bit + data types. However, the quantization range is set according to 4-bit quantization and based on + whether the ``weight_dtype`` is signed or unsigned. + """ + + weight_dtype: _Union[str, _torch.dtype] = _field( + default=_default_ptq_options["weight_dtype"], + ) + granularity: QuantizationGranularity = _field( + default=_default_ptq_options["granularity"], + converter=QuantizationGranularity, + validator=_validators.in_(QuantizationGranularity), + ) + quantization_scheme: _QuantizationScheme = _field( + default=_default_ptq_options["quantization_scheme"], + converter=_QuantizationScheme, + validator=_validators.in_(_QuantizationScheme), + ) + block_size: _Optional[_Union[int, _Tuple[int]]] = _field( + default=_default_ptq_options["block_size"], + converter=lambda val: (val,) if type(val) is int else val, + validator=_validators.optional( + _validators.deep_iterable( + member_validator=_validators.instance_of(int), + iterable_validator=_validators.instance_of(tuple), + ) + ), + ) + + def __attrs_post_init__(self): + self.weight_n_bits = _get_n_bits_from_dtype(self.weight_dtype) + self.weight_dtype = _maybe_convert_str_to_dtype(self.weight_dtype) + if self.weight_dtype not in [_torch.int8, _torch.uint8, _torch.float32]: + raise ValueError( + f"weight_dtype must be one of (torch.uint8, torch.float32) not {self.weight_dtype}" + ) + + @block_size.validator + def per_block_granularity(self, attribute, value): + if self.granularity == QuantizationGranularity.per_block: + assert ( + value is not None + ), "block_size has to be specified along with per_block granularity." + else: + assert ( + value is None + ), "block_size can't be specified along with per_tensor or per_channel granularity." + + @classmethod + def from_dict(cls, config_dict): + converter = _cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook( + _Union[str, _torch.dtype], + lambda obj, type: obj, + ) + return converter.structure_attrs_fromdict(config_dict, cls) + + +_ModuleTypeConfigType = _NewType( + "ModuleTypeConfigType", + _Dict[_Union[_Callable, str], _Optional[ModulePostTrainingQuantizerConfig]], +) + + +@_define +class PostTrainingQuantizerConfig(_OptimizationConfig): + """ + Configuration class for specifying how different submodules of a model + should be post-training quantized by :py:class:`PostTrainingQuantizer`. + + Args: + global_config (:py:class:`ModulePostTrainingQuantizerConfig`): Config to be applied globally + to all supported modules. + module_type_configs (:obj:`dict` of :obj:`str` to :py:class:`ModulePostTrainingQuantizerConfig`): + Module type configs applied to a specific module class, such as :py:class:`torch.nn.Linear`. + The keys can be either strings or module classes. + module_name_configs (:obj:`dict` of :obj:`str` to :py:class:`ModulePostTrainingQuantizerConfig`): + Module name configs applied to specific modules. This can be a dictionary with module names pointing to their + corresponding :py:class:`ModulePostTrainingQuantizerConfig`s + """ + + global_config: _Optional[ModulePostTrainingQuantizerConfig] = _field( + default=None, + validator=_validators.optional(_validators.instance_of(ModulePostTrainingQuantizerConfig)), + ) + module_type_configs: _ModuleTypeConfigType = _field( + factory=_OrderedDict, + validator=_validators.deep_mapping( + key_validator=_validators.instance_of((str, _Callable)), + value_validator=_validators.optional( + _validators.instance_of(ModulePostTrainingQuantizerConfig) + ), + mapping_validator=_validators.instance_of(dict), + ), + ) + module_name_configs: _Dict[str, _Optional[ModulePostTrainingQuantizerConfig]] = _field( + factory=_OrderedDict, + validator=_validators.deep_mapping( + key_validator=_validators.instance_of(str), + value_validator=_validators.optional( + _validators.instance_of(ModulePostTrainingQuantizerConfig) + ), + mapping_validator=_validators.instance_of(dict), + ), + ) + + def __attrs_post_init__(self): + if ( + self.global_config is None + and len(self.module_type_configs) == 0 + and len(self.module_name_configs) == 0 + ): + self.global_config = ModulePostTrainingQuantizerConfig() + self.module_type_configs = { + _maybe_convert_str_to_mod_type(key): val + for key, val in self.module_type_configs.items() + } + self._validate_same_params(["quantization_scheme"]) + + @classmethod + def from_dict(cls, config_dict: _Dict[str, _Any]) -> "PostTrainingQuantizerConfig": + super().from_dict(config_dict) + converter = _cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook( + _Union[str, _torch.dtype], + lambda obj, type: obj, + ) + converter.register_structure_hook( + _ModuleTypeConfigType, + _structure_from_dict_hook_factory(ModulePostTrainingQuantizerConfig), + ) + return converter.structure_attrs_fromdict(config_dict, cls) + + +class PostTrainingQuantizer(_BasePostTrainingModelOptimizer): + """ + Perform post-training quantization on a torch model. After quantization, weights of all + submodules selected for quantization contain full precision values obtained by quantizing + and dequantizing the original weights which captures the error induced by quantization. + + .. note:: + After quantization, the weight values stored will still remain in full precision, therefore + the PyTorch model size will not be reduced. To see the reduction in model size, please convert + the model using ``coremltools.convert(...)``, which will produce a MIL model containing the + compressed weights. + + Example: + + .. code-block:: python + + import torch.nn as nn + from coremltools.optimize.torch.quantization import ( + PostTrainingQuantizerConfig, + PostTrainingQuantizer, + ) + + model = nn.Sequential( + OrderedDict( + { + "conv": nn.Conv2d(1, 20, (3, 3)), + "relu1": nn.ReLU(), + "conv2": nn.Conv2d(20, 20, (3, 3)), + "relu2": nn.ReLU(), + } + ) + ) + + # initialize the quantizer + config = PostTrainingquantizerConfig.from_dict( + { + "global_config": { + "weight_dtype": "int8", + }, + } + ) + + ptq = PostTrainingQuantizer(model, config) + quantized_model = ptq.compress() + + Args: + model (:obj:`torch.nn.Module`): Module to be compressed. + config (:py:class:`PostTrainingQuantizerConfig`): Config that specifies how + different submodules in the model will be quantized. + """ + + _supported_modules: _Tuple[_Type[_torch.nn.Module]] = ( + _nn.Conv2d, + _nn.Linear, + _nn.MultiheadAttention, + ) + + def __init__(self, model: _torch.nn.Module, config: PostTrainingQuantizerConfig = None): + config = PostTrainingQuantizerConfig() if config is None else config + super().__init__(model, config) + + @_torch.no_grad() + def _quantize_weight( + self, + submod_name: str, + submodule: _torch.nn.Module, + submod_config: ModulePostTrainingQuantizerConfig, + param_name: str, + ) -> _Optional[_Tuple[_torch.Tensor, _torch.Tensor, _Optional[_torch.Tensor]]]: + """ + Helper function to perform the quantization on a PyTorch submodule's parameter + + Args: + submod_name (:obj:`str`): Name of the submodule + submodule (:obj:`torch.nn.Module`) Submodule which is being quantized + submod_config (:py:class:`ModulePostTrainingQuantizerConfig`): Config for the submodule + param_name (:obj:`str`): Name of the parameter within the submodule to quantize + + .. note:: + This function extracts the numpy array out of the torch weight value and + uses that for performing the quantization + """ + + torch_weight = submodule.get_parameter(param_name) + weight = torch_weight.numpy() + + block_sizes = [0] * weight.ndim + assert len(block_sizes) >= 2, "Weight matrix has to be at least 2D or greater" + + if submod_config.granularity == QuantizationGranularity.per_channel: + block_sizes[0] = 1 + + elif submod_config.granularity == QuantizationGranularity.per_block: + updated_config = _validate_param_config( + submod_name + "." + param_name, + torch_weight, + submod_config, + ["quantization_block_size"], + ) + if not updated_config: + _logger.warning(f"Unable to quantize layer {submod_name} - skipping it.") + return + block_size_config = list(updated_config.block_size) + block_sizes[: len(block_size_config)] = block_size_config + + quantization_mode = ( + "LINEAR_SYMMETRIC" + if submod_config.quantization_scheme == _QuantizationScheme.symmetric + else "LINEAR" + ) + + ret = _cti_compute_qparams( + weight=weight, + nbits=submod_config.weight_n_bits, + quantization_mode=quantization_mode, + dtype=weight.dtype, + block_sizes=block_sizes, + signed=True, # Always used signed dtype range + ) + + if ret is None: + _logger.warning(f"Unable to quantize layer {submod_name} - skipping it.") + return + + quant_weight, scale, zp = ret + + dequant_weight = constexpr_blockwise_shift_scale.decompress( + quant_weight, + scale, + zp, + ) + + # Convert back to torch tensors + dequant_weight = _torch.from_numpy(dequant_weight) + scale = _torch.from_numpy(scale) + if zp is not None: + zp = _torch.from_numpy(zp) + + # Replace the parameter's value + submodule.get_parameter(param_name).data.copy_(dequant_weight) + + # Register compression metadata + metadata = self._get_compression_metadata(param_name, submod_config, scale, zp) + metadata.register(submodule) + + def _get_compression_metadata(self, param_name, submod_config, scale, zero_point): + metadata = _CompressionMetadata(param_name) + + metadata.compression_type = ["quantization"] + metadata.quantization_n_bits = submod_config.weight_n_bits + metadata.quantization_scale = scale + if submod_config.quantization_scheme == _QuantizationScheme.affine: + assert zero_point is not None + metadata.zero_point = zero_point + + return metadata + + def compress(self, inplace: bool = False) -> _torch.nn.Module: + """ + Compress the supported layers in the module by quantizing each weight value of the layer. + + Args: + inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and + the original module is mutated, otherwise a copy of the model is mutated and returned. + Defaults to ``False``. + """ + self._model = super().compress(inplace=inplace) + for submod_name, submodule in _get_atomic_layers( + self._model, layer_types=list(self._supported_modules) + ).items(): + submod_config = self._config.get_module_config(submod_name, submodule) + if submod_config is None: + continue + + # TODO: Replace this with supported modules abstraction + # --- Conv2D & Linear layers --- + if isinstance(submodule, (_nn.Conv2d, _nn.Linear)): + assert hasattr( + submodule, "weight" + ), f"No parameter named weight in submodule {submod_name}" + self._quantize_weight(submod_name, submodule, submod_config, "weight") + + # --- MultiheadAttention layer --- + elif isinstance(submodule, _nn.MultiheadAttention): + param_names = [ + "in_proj_weight", + "q_proj_weight", + "k_proj_weight", + "v_proj_weight", + ] + for param_name in param_names: + if not hasattr(submodule, param_name): + continue + if getattr(submodule, param_name) is None: + continue + self._quantize_weight(submod_name, submodule, submod_config, param_name) + + if hasattr(submodule, "out_proj") and submodule.out_proj.weight is not None: + self._quantize_weight( + f"{submod_name}.out_proj", + submodule.out_proj, + submod_config, + "weight", + ) + return self._model + + def report(self) -> _Report: + return _compute_post_training_report( + self._uncompressed_model, + self._model, + supported_modules=self._supported_modules, + ) diff --git a/coremltools/optimize/torch/quantization/quantization_config.py b/coremltools/optimize/torch/quantization/quantization_config.py index fa82531d1..d6c3404ad 100644 --- a/coremltools/optimize/torch/quantization/quantization_config.py +++ b/coremltools/optimize/torch/quantization/quantization_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -22,6 +22,9 @@ from attr import field as _field from attrs import validators as _validators +from coremltools.optimize.torch._utils.torch_utils import ( + get_n_bits_from_dtype as _get_n_bits_from_dtype, +) from coremltools.optimize.torch._utils.torch_utils import ( maybe_convert_str_to_dtype as _maybe_convert_str_to_dtype, ) @@ -94,6 +97,10 @@ def get_qscheme( } +# Backends only support 4 and 8 bit quantization +_SUPPORTED_N_BITS = [4, 8, 32] + + @_define class ModuleLinearQuantizerConfig(_ModuleOptimizationConfig): """ @@ -159,9 +166,10 @@ class ModuleLinearQuantizerConfig(_ModuleOptimizationConfig): # mode, thus more closely simulating the inference numerics during training time. Args: - weight_dtype (:py:class:`torch.dtype`): The dtype to use for quantizing the weights. When dtype - is set to :py:class:`torch.float32`, the weights corresponding to that layer are not quantized. - Defaults to :py:class:`torch.qint8`. + weight_dtype (:py:class:`torch.dtype`): The dtype to use for quantizing the weights. The number of bits used + for quantization is inferred from the dtype. When dtype is set to :py:class:`torch.float32`, the weights + corresponding to that layer are not quantized. Defaults to :py:class:`torch.int8` which corresponds to + 8-bit quantization. weight_observer (:py:class:`ObserverType`): Type of observer to use for quantizing weights. Defaults to ``moving_average_min_max``. weight_per_channel (:obj:`bool`): When ``True``, weights are quantized per channel; otherwise, per tensor. @@ -181,16 +189,11 @@ class ModuleLinearQuantizerConfig(_ModuleOptimizationConfig): quantization simulation, the third to disabling observers, and the last to freezing batch norm statistics. Defaults to ``None``, which means the ``step`` method of :py:class:`LinearQuantizer` will be a no-op and all observers and quantization simulation will be turned on from the first step, batch norm layers always - operate in training mode, and mean and variance statistics collection is not frozen. + operate in training mode, and mean and varaince statistics collection is not frozen. """ - weight_dtype: _torch.dtype = _field( + weight_dtype: _Union[str, _torch.dtype] = _field( default=_default_quantization_options["weight_dtype"], - converter=_maybe_convert_str_to_dtype, - validator=[ - _validators.instance_of(_torch.dtype), - _validators.in_([_torch.qint8, _torch.quint8, _torch.int8, _torch.uint8, _torch.float32]), - ], ) weight_observer: ObserverType = _field( default=_default_quantization_options["observer"], @@ -206,7 +209,7 @@ class ModuleLinearQuantizerConfig(_ModuleOptimizationConfig): converter=_maybe_convert_str_to_dtype, validator=[ _validators.instance_of(_torch.dtype), - _validators.in_([_torch.quint8, _torch.uint8, _torch.float32]), + _validators.in_([_torch.quint8, _torch.float32]), ], ) activation_observer: ObserverType = _field( @@ -230,6 +233,13 @@ class ModuleLinearQuantizerConfig(_ModuleOptimizationConfig): ) def __attrs_post_init__(self): + self.weight_n_bits = _get_n_bits_from_dtype(self.weight_dtype) + self.weight_dtype = _maybe_convert_str_to_dtype(self.weight_dtype) + if self.weight_dtype not in [_torch.qint8, _torch.quint8, _torch.float32]: + raise ValueError( + f"weight_dtype must be one of (_torch.qint8, _torch.quint8, _torch.float32) not {self.weight_dtype}" + ) + if self.weight_dtype == _torch.float32 and self.activation_dtype != _torch.float32: raise ValueError( f"Unsupported configuration: weight_dtype = {self.weight_dtype}, " @@ -246,6 +256,15 @@ def _check_milestones(self, attribute, value): f"Refer to docs for more information." ) + @classmethod + def from_dict(cls, config_dict): + converter = _cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook( + _Union[str, _torch.dtype], + lambda obj, type: obj, + ) + return converter.structure_attrs_fromdict(config_dict, cls) + _ModuleTypeConfigType = _NewType( "ModuleTypeConfigType", @@ -298,6 +317,17 @@ class LinearQuantizerConfig(_OptimizationConfig): } ) + # If model has some methods and attributes which are not used in the forward + # pass, but are needed to be preserved after quantization is added, they can + # be preserved on the quantized model by passing them in preserved_attributes + # parameter + + model = MyModel() + model.key_1 = value_1 + model.key_2 = value_2 + + config = LinearQuantizerConfig.from_dict({"preserved_attributes": ["key_1", "key_2"]}) + Args: global_config (:py:class:`ModuleLinearQuantizerConfig`): Config to be applied globally to all supported modules. Missing values are chosen from the default config. @@ -311,6 +341,9 @@ class LinearQuantizerConfig(_OptimizationConfig): from the top level module using the ``module.get_submodule(target)`` method. non_traceable_module_names (:obj:`list` of :obj:`str`): Names of modules which cannot be traced using ``torch.fx``. + preserved_attributes (:obj:`list` of :obj:`str`): Names of attributes of the model + which should be preserved on the prepared and finalized models, even if they are not + used in the model's forward pass. .. note:: The ``quantization_scheme`` parameter must be the same across all configs. @@ -347,6 +380,12 @@ class LinearQuantizerConfig(_OptimizationConfig): member_validator=_validators.instance_of(str), ), ) + preserved_attributes: _List[str] = _field( + factory=list, + validator=_validators.deep_iterable( + member_validator=_validators.instance_of(str), + ), + ) def __attrs_post_init__(self): if ( @@ -365,6 +404,10 @@ def __attrs_post_init__(self): def from_dict(cls, config_dict: _Dict[str, _Any]) -> "LinearQuantizerConfig": super().from_dict(config_dict) converter = _cattrs.Converter(forbid_extra_keys=True) + converter.register_structure_hook( + _Union[str, _torch.dtype], + lambda obj, type: obj, + ) converter.register_structure_hook( _ModuleTypeConfigType, _structure_from_dict_hook_factory(ModuleLinearQuantizerConfig), diff --git a/coremltools/optimize/torch/quantization/quantizer.py b/coremltools/optimize/torch/quantization/quantizer.py index 151ca4947..44ab8053f 100644 --- a/coremltools/optimize/torch/quantization/quantizer.py +++ b/coremltools/optimize/torch/quantization/quantizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -12,13 +12,17 @@ import torch as _torch import torch.ao.quantization as _aoquant import torch.nn.intrinsic.qat as _nnintrinsicqat +from torch.ao.quantization.fx.custom_config import ConvertCustomConfig as _ConvertCustomConfig from torch.ao.quantization.fx.custom_config import PrepareCustomConfig as _PrepareCustomConfig from torch.ao.quantization.quantize_fx import convert_to_reference_fx as _convert_to_reference_fx from coremltools.optimize.torch._utils.math_utils import rmse_error as _rmse_error +from coremltools.optimize.torch._utils.metadata_utils import ( + register_metadata_version as _register_metadata_version, +) from coremltools.optimize.torch._utils.torch_utils import get_eval_model as _get_eval_model from coremltools.optimize.torch.base_model_optimizer import ( - BaseModelOptimizer as _BaseModelOptimizer, + BaseTrainingTimeModelOptimizer as _BaseTrainingTimeModelOptimizer, ) from coremltools.optimize.torch.base_model_optimizer import _Report from coremltools.optimize.torch.quantization._backend_config import ( @@ -31,6 +35,9 @@ QATConfigurationHandler as _QATConfigurationHandler, ) from coremltools.optimize.torch.quantization._qconfig_mapping import _QConfigMappingBuilder +from coremltools.optimize.torch.quantization._utils import ( + register_compression_metadata as _register_compression_metadata, +) from coremltools.optimize.torch.quantization.quantization_config import ( LinearQuantizerConfig as _LinearQuantizerConfig, ) @@ -41,7 +48,7 @@ _logger = _logging.getLogger(__name__) -class Quantizer(_BaseModelOptimizer): +class Quantizer(_BaseTrainingTimeModelOptimizer): pass @@ -140,12 +147,14 @@ def _construct_global_config(self) -> _ModuleLinearQuantizerConfig: return config return _ModuleLinearQuantizerConfig() - def prepare(self, example_inputs: _Any, inplace: bool = False) -> _torch.nn.Module: + def prepare(self, example_inputs: _Tuple[_Any, ...], inplace: bool = False) -> _torch.nn.Module: """ Prepares the model for quantization aware training by inserting :py:class:`torch.ao.quantization.FakeQuantize` layers in the model in appropriate places. Args: + example_inputs (:obj:`Tuple[Any, ...]`): Example inputs for forward function of the model, + tuple of positional args (keyword args can be passed as positional args as well) inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned. @@ -163,13 +172,14 @@ def prepare(self, example_inputs: _Any, inplace: bool = False) -> _torch.nn.Modu "will be a no-op." ) return self._model - model = self._model - if not inplace: - model = _copy.deepcopy(self._model) + model = self._get_model_for_compression(inplace=inplace) model.train() prepare_custom_config = _PrepareCustomConfig().set_non_traceable_module_names( self._config.non_traceable_module_names ) + prepare_custom_config = prepare_custom_config.set_preserved_attributes( + self._config.preserved_attributes + ) qat_handler = _QATConfigurationHandler( prepare_custom_config=prepare_custom_config, qconfig_mapping=self._qconfig_mapping, @@ -245,9 +255,21 @@ def finalize( if not inplace: model = _copy.deepcopy(model) model.eval() + convert_custom_config = _ConvertCustomConfig().set_preserved_attributes( + self._config.preserved_attributes + ) finalized_model = _convert_to_reference_fx( - model, qconfig_mapping=self._qconfig_mapping, backend_config=_get_backend_config() + model, + convert_custom_config=convert_custom_config, + qconfig_mapping=self._qconfig_mapping, + backend_config=_get_backend_config(), ) + _register_metadata_version(finalized_model) + for name, submodule in finalized_model.named_modules(remove_duplicate=True): + if hasattr(submodule, "weight_scale"): + submod_config = self._config.get_module_config(name, submodule) + _register_compression_metadata(submodule, submod_config) + if model is None: self._model = finalized_model return finalized_model diff --git a/coremltools/proto/FeatureTypes_pb2.py b/coremltools/proto/FeatureTypes_pb2.py index ef54f1120..021f5bfd6 100644 --- a/coremltools/proto/FeatureTypes_pb2.py +++ b/coremltools/proto/FeatureTypes_pb2.py @@ -16,10 +16,12 @@ DESCRIPTOR = _descriptor.FileDescriptor( - name='FeatureTypes.proto', - package='CoreML.Specification', - syntax='proto3', - serialized_pb=_b('\n\x12\x46\x65\x61tureTypes.proto\x12\x14\x43oreML.Specification\"\x12\n\x10Int64FeatureType\"\x13\n\x11\x44oubleFeatureType\"\x13\n\x11StringFeatureType\"3\n\tSizeRange\x12\x12\n\nlowerBound\x18\x01 \x01(\x04\x12\x12\n\nupperBound\x18\x02 \x01(\x03\"\x95\x05\n\x10ImageFeatureType\x12\r\n\x05width\x18\x01 \x01(\x03\x12\x0e\n\x06height\x18\x02 \x01(\x03\x12V\n\x0f\x65numeratedSizes\x18\x15 \x01(\x0b\x32;.CoreML.Specification.ImageFeatureType.EnumeratedImageSizesH\x00\x12O\n\x0eimageSizeRange\x18\x1f \x01(\x0b\x32\x35.CoreML.Specification.ImageFeatureType.ImageSizeRangeH\x00\x12\x45\n\ncolorSpace\x18\x03 \x01(\x0e\x32\x31.CoreML.Specification.ImageFeatureType.ColorSpace\x1a*\n\tImageSize\x12\r\n\x05width\x18\x01 \x01(\x04\x12\x0e\n\x06height\x18\x02 \x01(\x04\x1aW\n\x14\x45numeratedImageSizes\x12?\n\x05sizes\x18\x01 \x03(\x0b\x32\x30.CoreML.Specification.ImageFeatureType.ImageSize\x1a{\n\x0eImageSizeRange\x12\x33\n\nwidthRange\x18\x01 \x01(\x0b\x32\x1f.CoreML.Specification.SizeRange\x12\x34\n\x0bheightRange\x18\x02 \x01(\x0b\x32\x1f.CoreML.Specification.SizeRange\"]\n\nColorSpace\x12\x17\n\x13INVALID_COLOR_SPACE\x10\x00\x12\r\n\tGRAYSCALE\x10\n\x12\x07\n\x03RGB\x10\x14\x12\x07\n\x03\x42GR\x10\x1e\x12\x15\n\x11GRAYSCALE_FLOAT16\x10(B\x11\n\x0fSizeFlexibility\"\x9d\x05\n\x10\x41rrayFeatureType\x12\r\n\x05shape\x18\x01 \x03(\x03\x12\x46\n\x08\x64\x61taType\x18\x02 \x01(\x0e\x32\x34.CoreML.Specification.ArrayFeatureType.ArrayDataType\x12S\n\x10\x65numeratedShapes\x18\x15 \x01(\x0b\x32\x37.CoreML.Specification.ArrayFeatureType.EnumeratedShapesH\x00\x12G\n\nshapeRange\x18\x1f \x01(\x0b\x32\x31.CoreML.Specification.ArrayFeatureType.ShapeRangeH\x00\x12\x19\n\x0fintDefaultValue\x18) \x01(\x05H\x01\x12\x1b\n\x11\x66loatDefaultValue\x18\x33 \x01(\x02H\x01\x12\x1c\n\x12\x64oubleDefaultValue\x18= \x01(\x01H\x01\x1a\x16\n\x05Shape\x12\r\n\x05shape\x18\x01 \x03(\x03\x1aP\n\x10\x45numeratedShapes\x12<\n\x06shapes\x18\x01 \x03(\x0b\x32,.CoreML.Specification.ArrayFeatureType.Shape\x1a\x41\n\nShapeRange\x12\x33\n\nsizeRanges\x18\x01 \x03(\x0b\x32\x1f.CoreML.Specification.SizeRange\"e\n\rArrayDataType\x12\x1b\n\x17INVALID_ARRAY_DATA_TYPE\x10\x00\x12\r\n\x07\x46LOAT32\x10\xa0\x80\x04\x12\x0c\n\x06\x44OUBLE\x10\xc0\x80\x04\x12\x0b\n\x05INT32\x10\xa0\x80\x08\x12\r\n\x07\x46LOAT16\x10\x90\x80\x04\x42\x12\n\x10ShapeFlexibilityB\x16\n\x14\x64\x65\x66\x61ultOptionalValue\"\xa4\x01\n\x15\x44ictionaryFeatureType\x12>\n\x0cint64KeyType\x18\x01 \x01(\x0b\x32&.CoreML.Specification.Int64FeatureTypeH\x00\x12@\n\rstringKeyType\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.StringFeatureTypeH\x00\x42\t\n\x07KeyType\"\xcd\x01\n\x13SequenceFeatureType\x12;\n\tint64Type\x18\x01 \x01(\x0b\x32&.CoreML.Specification.Int64FeatureTypeH\x00\x12=\n\nstringType\x18\x03 \x01(\x0b\x32\'.CoreML.Specification.StringFeatureTypeH\x00\x12\x32\n\tsizeRange\x18\x65 \x01(\x0b\x32\x1f.CoreML.Specification.SizeRangeB\x06\n\x04Type\"\xee\x03\n\x0b\x46\x65\x61tureType\x12;\n\tint64Type\x18\x01 \x01(\x0b\x32&.CoreML.Specification.Int64FeatureTypeH\x00\x12=\n\ndoubleType\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.DoubleFeatureTypeH\x00\x12=\n\nstringType\x18\x03 \x01(\x0b\x32\'.CoreML.Specification.StringFeatureTypeH\x00\x12;\n\timageType\x18\x04 \x01(\x0b\x32&.CoreML.Specification.ImageFeatureTypeH\x00\x12@\n\x0emultiArrayType\x18\x05 \x01(\x0b\x32&.CoreML.Specification.ArrayFeatureTypeH\x00\x12\x45\n\x0e\x64ictionaryType\x18\x06 \x01(\x0b\x32+.CoreML.Specification.DictionaryFeatureTypeH\x00\x12\x41\n\x0csequenceType\x18\x07 \x01(\x0b\x32).CoreML.Specification.SequenceFeatureTypeH\x00\x12\x13\n\nisOptional\x18\xe8\x07 \x01(\x08\x42\x06\n\x04TypeB\x02H\x03\x62\x06proto3') + name="FeatureTypes.proto", + package="CoreML.Specification", + syntax="proto3", + serialized_pb=_b( + '\n\x12\x46\x65\x61tureTypes.proto\x12\x14\x43oreML.Specification"\x12\n\x10Int64FeatureType"\x13\n\x11\x44oubleFeatureType"\x13\n\x11StringFeatureType"3\n\tSizeRange\x12\x12\n\nlowerBound\x18\x01 \x01(\x04\x12\x12\n\nupperBound\x18\x02 \x01(\x03"\x95\x05\n\x10ImageFeatureType\x12\r\n\x05width\x18\x01 \x01(\x03\x12\x0e\n\x06height\x18\x02 \x01(\x03\x12V\n\x0f\x65numeratedSizes\x18\x15 \x01(\x0b\x32;.CoreML.Specification.ImageFeatureType.EnumeratedImageSizesH\x00\x12O\n\x0eimageSizeRange\x18\x1f \x01(\x0b\x32\x35.CoreML.Specification.ImageFeatureType.ImageSizeRangeH\x00\x12\x45\n\ncolorSpace\x18\x03 \x01(\x0e\x32\x31.CoreML.Specification.ImageFeatureType.ColorSpace\x1a*\n\tImageSize\x12\r\n\x05width\x18\x01 \x01(\x04\x12\x0e\n\x06height\x18\x02 \x01(\x04\x1aW\n\x14\x45numeratedImageSizes\x12?\n\x05sizes\x18\x01 \x03(\x0b\x32\x30.CoreML.Specification.ImageFeatureType.ImageSize\x1a{\n\x0eImageSizeRange\x12\x33\n\nwidthRange\x18\x01 \x01(\x0b\x32\x1f.CoreML.Specification.SizeRange\x12\x34\n\x0bheightRange\x18\x02 \x01(\x0b\x32\x1f.CoreML.Specification.SizeRange"]\n\nColorSpace\x12\x17\n\x13INVALID_COLOR_SPACE\x10\x00\x12\r\n\tGRAYSCALE\x10\n\x12\x07\n\x03RGB\x10\x14\x12\x07\n\x03\x42GR\x10\x1e\x12\x15\n\x11GRAYSCALE_FLOAT16\x10(B\x11\n\x0fSizeFlexibility"\x9d\x05\n\x10\x41rrayFeatureType\x12\r\n\x05shape\x18\x01 \x03(\x03\x12\x46\n\x08\x64\x61taType\x18\x02 \x01(\x0e\x32\x34.CoreML.Specification.ArrayFeatureType.ArrayDataType\x12S\n\x10\x65numeratedShapes\x18\x15 \x01(\x0b\x32\x37.CoreML.Specification.ArrayFeatureType.EnumeratedShapesH\x00\x12G\n\nshapeRange\x18\x1f \x01(\x0b\x32\x31.CoreML.Specification.ArrayFeatureType.ShapeRangeH\x00\x12\x19\n\x0fintDefaultValue\x18) \x01(\x05H\x01\x12\x1b\n\x11\x66loatDefaultValue\x18\x33 \x01(\x02H\x01\x12\x1c\n\x12\x64oubleDefaultValue\x18= \x01(\x01H\x01\x1a\x16\n\x05Shape\x12\r\n\x05shape\x18\x01 \x03(\x03\x1aP\n\x10\x45numeratedShapes\x12<\n\x06shapes\x18\x01 \x03(\x0b\x32,.CoreML.Specification.ArrayFeatureType.Shape\x1a\x41\n\nShapeRange\x12\x33\n\nsizeRanges\x18\x01 \x03(\x0b\x32\x1f.CoreML.Specification.SizeRange"e\n\rArrayDataType\x12\x1b\n\x17INVALID_ARRAY_DATA_TYPE\x10\x00\x12\r\n\x07\x46LOAT32\x10\xa0\x80\x04\x12\x0c\n\x06\x44OUBLE\x10\xc0\x80\x04\x12\x0b\n\x05INT32\x10\xa0\x80\x08\x12\r\n\x07\x46LOAT16\x10\x90\x80\x04\x42\x12\n\x10ShapeFlexibilityB\x16\n\x14\x64\x65\x66\x61ultOptionalValue"\xa4\x01\n\x15\x44ictionaryFeatureType\x12>\n\x0cint64KeyType\x18\x01 \x01(\x0b\x32&.CoreML.Specification.Int64FeatureTypeH\x00\x12@\n\rstringKeyType\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.StringFeatureTypeH\x00\x42\t\n\x07KeyType"\xcd\x01\n\x13SequenceFeatureType\x12;\n\tint64Type\x18\x01 \x01(\x0b\x32&.CoreML.Specification.Int64FeatureTypeH\x00\x12=\n\nstringType\x18\x03 \x01(\x0b\x32\'.CoreML.Specification.StringFeatureTypeH\x00\x12\x32\n\tsizeRange\x18\x65 \x01(\x0b\x32\x1f.CoreML.Specification.SizeRangeB\x06\n\x04Type"W\n\x10StateFeatureType\x12;\n\tarrayType\x18\x01 \x01(\x0b\x32&.CoreML.Specification.ArrayFeatureTypeH\x00\x42\x06\n\x04Type"\xab\x04\n\x0b\x46\x65\x61tureType\x12;\n\tint64Type\x18\x01 \x01(\x0b\x32&.CoreML.Specification.Int64FeatureTypeH\x00\x12=\n\ndoubleType\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.DoubleFeatureTypeH\x00\x12=\n\nstringType\x18\x03 \x01(\x0b\x32\'.CoreML.Specification.StringFeatureTypeH\x00\x12;\n\timageType\x18\x04 \x01(\x0b\x32&.CoreML.Specification.ImageFeatureTypeH\x00\x12@\n\x0emultiArrayType\x18\x05 \x01(\x0b\x32&.CoreML.Specification.ArrayFeatureTypeH\x00\x12\x45\n\x0e\x64ictionaryType\x18\x06 \x01(\x0b\x32+.CoreML.Specification.DictionaryFeatureTypeH\x00\x12\x41\n\x0csequenceType\x18\x07 \x01(\x0b\x32).CoreML.Specification.SequenceFeatureTypeH\x00\x12;\n\tstateType\x18\x08 \x01(\x0b\x32&.CoreML.Specification.StateFeatureTypeH\x00\x12\x13\n\nisOptional\x18\xe8\x07 \x01(\x08\x42\x06\n\x04TypeB\x02H\x03\x62\x06proto3' + ), ) @@ -629,86 +631,231 @@ ) +_STATEFEATURETYPE = _descriptor.Descriptor( + name="StateFeatureType", + full_name="CoreML.Specification.StateFeatureType", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="arrayType", + full_name="CoreML.Specification.StateFeatureType.arrayType", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="Type", + full_name="CoreML.Specification.StateFeatureType.Type", + index=0, + containing_type=None, + fields=[], + ), + ], + serialized_start=1870, + serialized_end=1957, +) + + _FEATURETYPE = _descriptor.Descriptor( - name='FeatureType', - full_name='CoreML.Specification.FeatureType', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='int64Type', full_name='CoreML.Specification.FeatureType.int64Type', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='doubleType', full_name='CoreML.Specification.FeatureType.doubleType', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='stringType', full_name='CoreML.Specification.FeatureType.stringType', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='imageType', full_name='CoreML.Specification.FeatureType.imageType', index=3, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='multiArrayType', full_name='CoreML.Specification.FeatureType.multiArrayType', index=4, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='dictionaryType', full_name='CoreML.Specification.FeatureType.dictionaryType', index=5, - number=6, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='sequenceType', full_name='CoreML.Specification.FeatureType.sequenceType', index=6, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='isOptional', full_name='CoreML.Specification.FeatureType.isOptional', index=7, - number=1000, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='Type', full_name='CoreML.Specification.FeatureType.Type', - index=0, containing_type=None, fields=[]), - ], - serialized_start=1871, - serialized_end=2365, + name="FeatureType", + full_name="CoreML.Specification.FeatureType", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="int64Type", + full_name="CoreML.Specification.FeatureType.int64Type", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="doubleType", + full_name="CoreML.Specification.FeatureType.doubleType", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="stringType", + full_name="CoreML.Specification.FeatureType.stringType", + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="imageType", + full_name="CoreML.Specification.FeatureType.imageType", + index=3, + number=4, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="multiArrayType", + full_name="CoreML.Specification.FeatureType.multiArrayType", + index=4, + number=5, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="dictionaryType", + full_name="CoreML.Specification.FeatureType.dictionaryType", + index=5, + number=6, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="sequenceType", + full_name="CoreML.Specification.FeatureType.sequenceType", + index=6, + number=7, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="stateType", + full_name="CoreML.Specification.FeatureType.stateType", + index=7, + number=8, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="isOptional", + full_name="CoreML.Specification.FeatureType.isOptional", + index=8, + number=1000, + type=8, + cpp_type=7, + label=1, + has_default_value=False, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="Type", + full_name="CoreML.Specification.FeatureType.Type", + index=0, + containing_type=None, + fields=[], + ), + ], + serialized_start=1960, + serialized_end=2515, ) _IMAGEFEATURETYPE_IMAGESIZE.containing_type = _IMAGEFEATURETYPE @@ -736,75 +883,106 @@ _ARRAYFEATURETYPE.fields_by_name['enumeratedShapes'].message_type = _ARRAYFEATURETYPE_ENUMERATEDSHAPES _ARRAYFEATURETYPE.fields_by_name['shapeRange'].message_type = _ARRAYFEATURETYPE_SHAPERANGE _ARRAYFEATURETYPE_ARRAYDATATYPE.containing_type = _ARRAYFEATURETYPE -_ARRAYFEATURETYPE.oneofs_by_name['ShapeFlexibility'].fields.append( - _ARRAYFEATURETYPE.fields_by_name['enumeratedShapes']) -_ARRAYFEATURETYPE.fields_by_name['enumeratedShapes'].containing_oneof = _ARRAYFEATURETYPE.oneofs_by_name['ShapeFlexibility'] -_ARRAYFEATURETYPE.oneofs_by_name['ShapeFlexibility'].fields.append( - _ARRAYFEATURETYPE.fields_by_name['shapeRange']) -_ARRAYFEATURETYPE.fields_by_name['shapeRange'].containing_oneof = _ARRAYFEATURETYPE.oneofs_by_name['ShapeFlexibility'] -_ARRAYFEATURETYPE.oneofs_by_name['defaultOptionalValue'].fields.append( - _ARRAYFEATURETYPE.fields_by_name['intDefaultValue']) -_ARRAYFEATURETYPE.fields_by_name['intDefaultValue'].containing_oneof = _ARRAYFEATURETYPE.oneofs_by_name['defaultOptionalValue'] -_ARRAYFEATURETYPE.oneofs_by_name['defaultOptionalValue'].fields.append( - _ARRAYFEATURETYPE.fields_by_name['floatDefaultValue']) -_ARRAYFEATURETYPE.fields_by_name['floatDefaultValue'].containing_oneof = _ARRAYFEATURETYPE.oneofs_by_name['defaultOptionalValue'] -_ARRAYFEATURETYPE.oneofs_by_name['defaultOptionalValue'].fields.append( - _ARRAYFEATURETYPE.fields_by_name['doubleDefaultValue']) -_ARRAYFEATURETYPE.fields_by_name['doubleDefaultValue'].containing_oneof = _ARRAYFEATURETYPE.oneofs_by_name['defaultOptionalValue'] -_DICTIONARYFEATURETYPE.fields_by_name['int64KeyType'].message_type = _INT64FEATURETYPE -_DICTIONARYFEATURETYPE.fields_by_name['stringKeyType'].message_type = _STRINGFEATURETYPE -_DICTIONARYFEATURETYPE.oneofs_by_name['KeyType'].fields.append( - _DICTIONARYFEATURETYPE.fields_by_name['int64KeyType']) -_DICTIONARYFEATURETYPE.fields_by_name['int64KeyType'].containing_oneof = _DICTIONARYFEATURETYPE.oneofs_by_name['KeyType'] -_DICTIONARYFEATURETYPE.oneofs_by_name['KeyType'].fields.append( - _DICTIONARYFEATURETYPE.fields_by_name['stringKeyType']) -_DICTIONARYFEATURETYPE.fields_by_name['stringKeyType'].containing_oneof = _DICTIONARYFEATURETYPE.oneofs_by_name['KeyType'] -_SEQUENCEFEATURETYPE.fields_by_name['int64Type'].message_type = _INT64FEATURETYPE -_SEQUENCEFEATURETYPE.fields_by_name['stringType'].message_type = _STRINGFEATURETYPE -_SEQUENCEFEATURETYPE.fields_by_name['sizeRange'].message_type = _SIZERANGE -_SEQUENCEFEATURETYPE.oneofs_by_name['Type'].fields.append( - _SEQUENCEFEATURETYPE.fields_by_name['int64Type']) -_SEQUENCEFEATURETYPE.fields_by_name['int64Type'].containing_oneof = _SEQUENCEFEATURETYPE.oneofs_by_name['Type'] -_SEQUENCEFEATURETYPE.oneofs_by_name['Type'].fields.append( - _SEQUENCEFEATURETYPE.fields_by_name['stringType']) -_SEQUENCEFEATURETYPE.fields_by_name['stringType'].containing_oneof = _SEQUENCEFEATURETYPE.oneofs_by_name['Type'] -_FEATURETYPE.fields_by_name['int64Type'].message_type = _INT64FEATURETYPE -_FEATURETYPE.fields_by_name['doubleType'].message_type = _DOUBLEFEATURETYPE -_FEATURETYPE.fields_by_name['stringType'].message_type = _STRINGFEATURETYPE -_FEATURETYPE.fields_by_name['imageType'].message_type = _IMAGEFEATURETYPE -_FEATURETYPE.fields_by_name['multiArrayType'].message_type = _ARRAYFEATURETYPE -_FEATURETYPE.fields_by_name['dictionaryType'].message_type = _DICTIONARYFEATURETYPE -_FEATURETYPE.fields_by_name['sequenceType'].message_type = _SEQUENCEFEATURETYPE -_FEATURETYPE.oneofs_by_name['Type'].fields.append( - _FEATURETYPE.fields_by_name['int64Type']) -_FEATURETYPE.fields_by_name['int64Type'].containing_oneof = _FEATURETYPE.oneofs_by_name['Type'] -_FEATURETYPE.oneofs_by_name['Type'].fields.append( - _FEATURETYPE.fields_by_name['doubleType']) -_FEATURETYPE.fields_by_name['doubleType'].containing_oneof = _FEATURETYPE.oneofs_by_name['Type'] -_FEATURETYPE.oneofs_by_name['Type'].fields.append( - _FEATURETYPE.fields_by_name['stringType']) -_FEATURETYPE.fields_by_name['stringType'].containing_oneof = _FEATURETYPE.oneofs_by_name['Type'] -_FEATURETYPE.oneofs_by_name['Type'].fields.append( - _FEATURETYPE.fields_by_name['imageType']) -_FEATURETYPE.fields_by_name['imageType'].containing_oneof = _FEATURETYPE.oneofs_by_name['Type'] -_FEATURETYPE.oneofs_by_name['Type'].fields.append( - _FEATURETYPE.fields_by_name['multiArrayType']) -_FEATURETYPE.fields_by_name['multiArrayType'].containing_oneof = _FEATURETYPE.oneofs_by_name['Type'] -_FEATURETYPE.oneofs_by_name['Type'].fields.append( - _FEATURETYPE.fields_by_name['dictionaryType']) -_FEATURETYPE.fields_by_name['dictionaryType'].containing_oneof = _FEATURETYPE.oneofs_by_name['Type'] -_FEATURETYPE.oneofs_by_name['Type'].fields.append( - _FEATURETYPE.fields_by_name['sequenceType']) -_FEATURETYPE.fields_by_name['sequenceType'].containing_oneof = _FEATURETYPE.oneofs_by_name['Type'] -DESCRIPTOR.message_types_by_name['Int64FeatureType'] = _INT64FEATURETYPE -DESCRIPTOR.message_types_by_name['DoubleFeatureType'] = _DOUBLEFEATURETYPE -DESCRIPTOR.message_types_by_name['StringFeatureType'] = _STRINGFEATURETYPE -DESCRIPTOR.message_types_by_name['SizeRange'] = _SIZERANGE -DESCRIPTOR.message_types_by_name['ImageFeatureType'] = _IMAGEFEATURETYPE -DESCRIPTOR.message_types_by_name['ArrayFeatureType'] = _ARRAYFEATURETYPE -DESCRIPTOR.message_types_by_name['DictionaryFeatureType'] = _DICTIONARYFEATURETYPE -DESCRIPTOR.message_types_by_name['SequenceFeatureType'] = _SEQUENCEFEATURETYPE -DESCRIPTOR.message_types_by_name['FeatureType'] = _FEATURETYPE +_ARRAYFEATURETYPE.oneofs_by_name["ShapeFlexibility"].fields.append( + _ARRAYFEATURETYPE.fields_by_name["enumeratedShapes"] +) +_ARRAYFEATURETYPE.fields_by_name[ + "enumeratedShapes" +].containing_oneof = _ARRAYFEATURETYPE.oneofs_by_name["ShapeFlexibility"] +_ARRAYFEATURETYPE.oneofs_by_name["ShapeFlexibility"].fields.append( + _ARRAYFEATURETYPE.fields_by_name["shapeRange"] +) +_ARRAYFEATURETYPE.fields_by_name["shapeRange"].containing_oneof = _ARRAYFEATURETYPE.oneofs_by_name[ + "ShapeFlexibility" +] +_ARRAYFEATURETYPE.oneofs_by_name["defaultOptionalValue"].fields.append( + _ARRAYFEATURETYPE.fields_by_name["intDefaultValue"] +) +_ARRAYFEATURETYPE.fields_by_name[ + "intDefaultValue" +].containing_oneof = _ARRAYFEATURETYPE.oneofs_by_name["defaultOptionalValue"] +_ARRAYFEATURETYPE.oneofs_by_name["defaultOptionalValue"].fields.append( + _ARRAYFEATURETYPE.fields_by_name["floatDefaultValue"] +) +_ARRAYFEATURETYPE.fields_by_name[ + "floatDefaultValue" +].containing_oneof = _ARRAYFEATURETYPE.oneofs_by_name["defaultOptionalValue"] +_ARRAYFEATURETYPE.oneofs_by_name["defaultOptionalValue"].fields.append( + _ARRAYFEATURETYPE.fields_by_name["doubleDefaultValue"] +) +_ARRAYFEATURETYPE.fields_by_name[ + "doubleDefaultValue" +].containing_oneof = _ARRAYFEATURETYPE.oneofs_by_name["defaultOptionalValue"] +_DICTIONARYFEATURETYPE.fields_by_name["int64KeyType"].message_type = _INT64FEATURETYPE +_DICTIONARYFEATURETYPE.fields_by_name["stringKeyType"].message_type = _STRINGFEATURETYPE +_DICTIONARYFEATURETYPE.oneofs_by_name["KeyType"].fields.append( + _DICTIONARYFEATURETYPE.fields_by_name["int64KeyType"] +) +_DICTIONARYFEATURETYPE.fields_by_name[ + "int64KeyType" +].containing_oneof = _DICTIONARYFEATURETYPE.oneofs_by_name["KeyType"] +_DICTIONARYFEATURETYPE.oneofs_by_name["KeyType"].fields.append( + _DICTIONARYFEATURETYPE.fields_by_name["stringKeyType"] +) +_DICTIONARYFEATURETYPE.fields_by_name[ + "stringKeyType" +].containing_oneof = _DICTIONARYFEATURETYPE.oneofs_by_name["KeyType"] +_SEQUENCEFEATURETYPE.fields_by_name["int64Type"].message_type = _INT64FEATURETYPE +_SEQUENCEFEATURETYPE.fields_by_name["stringType"].message_type = _STRINGFEATURETYPE +_SEQUENCEFEATURETYPE.fields_by_name["sizeRange"].message_type = _SIZERANGE +_SEQUENCEFEATURETYPE.oneofs_by_name["Type"].fields.append( + _SEQUENCEFEATURETYPE.fields_by_name["int64Type"] +) +_SEQUENCEFEATURETYPE.fields_by_name[ + "int64Type" +].containing_oneof = _SEQUENCEFEATURETYPE.oneofs_by_name["Type"] +_SEQUENCEFEATURETYPE.oneofs_by_name["Type"].fields.append( + _SEQUENCEFEATURETYPE.fields_by_name["stringType"] +) +_SEQUENCEFEATURETYPE.fields_by_name[ + "stringType" +].containing_oneof = _SEQUENCEFEATURETYPE.oneofs_by_name["Type"] +_STATEFEATURETYPE.fields_by_name["arrayType"].message_type = _ARRAYFEATURETYPE +_STATEFEATURETYPE.oneofs_by_name["Type"].fields.append( + _STATEFEATURETYPE.fields_by_name["arrayType"] +) +_STATEFEATURETYPE.fields_by_name["arrayType"].containing_oneof = _STATEFEATURETYPE.oneofs_by_name[ + "Type" +] +_FEATURETYPE.fields_by_name["int64Type"].message_type = _INT64FEATURETYPE +_FEATURETYPE.fields_by_name["doubleType"].message_type = _DOUBLEFEATURETYPE +_FEATURETYPE.fields_by_name["stringType"].message_type = _STRINGFEATURETYPE +_FEATURETYPE.fields_by_name["imageType"].message_type = _IMAGEFEATURETYPE +_FEATURETYPE.fields_by_name["multiArrayType"].message_type = _ARRAYFEATURETYPE +_FEATURETYPE.fields_by_name["dictionaryType"].message_type = _DICTIONARYFEATURETYPE +_FEATURETYPE.fields_by_name["sequenceType"].message_type = _SEQUENCEFEATURETYPE +_FEATURETYPE.fields_by_name["stateType"].message_type = _STATEFEATURETYPE +_FEATURETYPE.oneofs_by_name["Type"].fields.append(_FEATURETYPE.fields_by_name["int64Type"]) +_FEATURETYPE.fields_by_name["int64Type"].containing_oneof = _FEATURETYPE.oneofs_by_name["Type"] +_FEATURETYPE.oneofs_by_name["Type"].fields.append(_FEATURETYPE.fields_by_name["doubleType"]) +_FEATURETYPE.fields_by_name["doubleType"].containing_oneof = _FEATURETYPE.oneofs_by_name["Type"] +_FEATURETYPE.oneofs_by_name["Type"].fields.append(_FEATURETYPE.fields_by_name["stringType"]) +_FEATURETYPE.fields_by_name["stringType"].containing_oneof = _FEATURETYPE.oneofs_by_name["Type"] +_FEATURETYPE.oneofs_by_name["Type"].fields.append(_FEATURETYPE.fields_by_name["imageType"]) +_FEATURETYPE.fields_by_name["imageType"].containing_oneof = _FEATURETYPE.oneofs_by_name["Type"] +_FEATURETYPE.oneofs_by_name["Type"].fields.append(_FEATURETYPE.fields_by_name["multiArrayType"]) +_FEATURETYPE.fields_by_name["multiArrayType"].containing_oneof = _FEATURETYPE.oneofs_by_name["Type"] +_FEATURETYPE.oneofs_by_name["Type"].fields.append(_FEATURETYPE.fields_by_name["dictionaryType"]) +_FEATURETYPE.fields_by_name["dictionaryType"].containing_oneof = _FEATURETYPE.oneofs_by_name["Type"] +_FEATURETYPE.oneofs_by_name["Type"].fields.append(_FEATURETYPE.fields_by_name["sequenceType"]) +_FEATURETYPE.fields_by_name["sequenceType"].containing_oneof = _FEATURETYPE.oneofs_by_name["Type"] +_FEATURETYPE.oneofs_by_name["Type"].fields.append(_FEATURETYPE.fields_by_name["stateType"]) +_FEATURETYPE.fields_by_name["stateType"].containing_oneof = _FEATURETYPE.oneofs_by_name["Type"] +DESCRIPTOR.message_types_by_name["Int64FeatureType"] = _INT64FEATURETYPE +DESCRIPTOR.message_types_by_name["DoubleFeatureType"] = _DOUBLEFEATURETYPE +DESCRIPTOR.message_types_by_name["StringFeatureType"] = _STRINGFEATURETYPE +DESCRIPTOR.message_types_by_name["SizeRange"] = _SIZERANGE +DESCRIPTOR.message_types_by_name["ImageFeatureType"] = _IMAGEFEATURETYPE +DESCRIPTOR.message_types_by_name["ArrayFeatureType"] = _ARRAYFEATURETYPE +DESCRIPTOR.message_types_by_name["DictionaryFeatureType"] = _DICTIONARYFEATURETYPE +DESCRIPTOR.message_types_by_name["SequenceFeatureType"] = _SEQUENCEFEATURETYPE +DESCRIPTOR.message_types_by_name["StateFeatureType"] = _STATEFEATURETYPE +DESCRIPTOR.message_types_by_name["FeatureType"] = _FEATURETYPE _sym_db.RegisterFileDescriptor(DESCRIPTOR) Int64FeatureType = _reflection.GeneratedProtocolMessageType('Int64FeatureType', (_message.Message,), dict( @@ -911,6 +1089,17 @@ )) _sym_db.RegisterMessage(SequenceFeatureType) +StateFeatureType = _reflection.GeneratedProtocolMessageType( + "StateFeatureType", + (_message.Message,), + dict( + DESCRIPTOR=_STATEFEATURETYPE, + __module__="FeatureTypes_pb2" + # @@protoc_insertion_point(class_scope:CoreML.Specification.StateFeatureType) + ), +) +_sym_db.RegisterMessage(StateFeatureType) + FeatureType = _reflection.GeneratedProtocolMessageType('FeatureType', (_message.Message,), dict( DESCRIPTOR = _FEATURETYPE, __module__ = 'FeatureTypes_pb2' diff --git a/coremltools/proto/MIL_pb2.py b/coremltools/proto/MIL_pb2.py index b1be30e92..c6abbbb2e 100644 --- a/coremltools/proto/MIL_pb2.py +++ b/coremltools/proto/MIL_pb2.py @@ -17,83 +17,68 @@ DESCRIPTOR = _descriptor.FileDescriptor( - name='MIL.proto', - package='CoreML.Specification.MILSpec', - syntax='proto3', - serialized_pb=_b('\n\tMIL.proto\x12\x1c\x43oreML.Specification.MILSpec\"\xf3\x02\n\x07Program\x12\x0f\n\x07version\x18\x01 \x01(\x03\x12G\n\tfunctions\x18\x02 \x03(\x0b\x32\x34.CoreML.Specification.MILSpec.Program.FunctionsEntry\x12\x11\n\tdocString\x18\x03 \x01(\t\x12I\n\nattributes\x18\x04 \x03(\x0b\x32\x35.CoreML.Specification.MILSpec.Program.AttributesEntry\x1aX\n\x0e\x46unctionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x35\n\x05value\x18\x02 \x01(\x0b\x32&.CoreML.Specification.MILSpec.Function:\x02\x38\x01\x1aV\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value:\x02\x38\x01\"\xbe\x03\n\x08\x46unction\x12<\n\x06inputs\x18\x01 \x03(\x0b\x32,.CoreML.Specification.MILSpec.NamedValueType\x12\r\n\x05opset\x18\x02 \x01(\t\x12_\n\x15\x62lock_specializations\x18\x03 \x03(\x0b\x32@.CoreML.Specification.MILSpec.Function.BlockSpecializationsEntry\x12J\n\nattributes\x18\x04 \x03(\x0b\x32\x36.CoreML.Specification.MILSpec.Function.AttributesEntry\x1a`\n\x19\x42lockSpecializationsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Block:\x02\x38\x01\x1aV\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value:\x02\x38\x01\"\xb4\x02\n\x05\x42lock\x12<\n\x06inputs\x18\x01 \x03(\x0b\x32,.CoreML.Specification.MILSpec.NamedValueType\x12\x0f\n\x07outputs\x18\x02 \x03(\t\x12;\n\noperations\x18\x03 \x03(\x0b\x32\'.CoreML.Specification.MILSpec.Operation\x12G\n\nattributes\x18\x04 \x03(\x0b\x32\x33.CoreML.Specification.MILSpec.Block.AttributesEntry\x1aV\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value:\x02\x38\x01\"\xa9\x01\n\x08\x41rgument\x12\x41\n\targuments\x18\x01 \x03(\x0b\x32..CoreML.Specification.MILSpec.Argument.Binding\x1aZ\n\x07\x42inding\x12\x0e\n\x04name\x18\x01 \x01(\tH\x00\x12\x34\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.ValueH\x00\x42\t\n\x07\x62inding\"\xce\x03\n\tOperation\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x43\n\x06inputs\x18\x02 \x03(\x0b\x32\x33.CoreML.Specification.MILSpec.Operation.InputsEntry\x12=\n\x07outputs\x18\x03 \x03(\x0b\x32,.CoreML.Specification.MILSpec.NamedValueType\x12\x33\n\x06\x62locks\x18\x04 \x03(\x0b\x32#.CoreML.Specification.MILSpec.Block\x12K\n\nattributes\x18\x05 \x03(\x0b\x32\x37.CoreML.Specification.MILSpec.Operation.AttributesEntry\x1aU\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x35\n\x05value\x18\x02 \x01(\x0b\x32&.CoreML.Specification.MILSpec.Argument:\x02\x38\x01\x1aV\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value:\x02\x38\x01\"U\n\x0eNamedValueType\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x35\n\x04type\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType\"\x95\x02\n\tValueType\x12>\n\ntensorType\x18\x01 \x01(\x0b\x32(.CoreML.Specification.MILSpec.TensorTypeH\x00\x12:\n\x08listType\x18\x02 \x01(\x0b\x32&.CoreML.Specification.MILSpec.ListTypeH\x00\x12<\n\ttupleType\x18\x03 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.TupleTypeH\x00\x12\x46\n\x0e\x64ictionaryType\x18\x04 \x01(\x0b\x32,.CoreML.Specification.MILSpec.DictionaryTypeH\x00\x42\x06\n\x04type\"\xb7\x02\n\nTensorType\x12\x38\n\x08\x64\x61taType\x18\x01 \x01(\x0e\x32&.CoreML.Specification.MILSpec.DataType\x12\x0c\n\x04rank\x18\x02 \x01(\x03\x12;\n\ndimensions\x18\x03 \x03(\x0b\x32\'.CoreML.Specification.MILSpec.Dimension\x12L\n\nattributes\x18\x04 \x03(\x0b\x32\x38.CoreML.Specification.MILSpec.TensorType.AttributesEntry\x1aV\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value:\x02\x38\x01\"C\n\tTupleType\x12\x36\n\x05types\x18\x01 \x03(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType\"z\n\x08ListType\x12\x35\n\x04type\x18\x01 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType\x12\x37\n\x06length\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.Dimension\"\x86\x01\n\x0e\x44ictionaryType\x12\x38\n\x07keyType\x18\x01 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType\x12:\n\tvalueType\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType\"\xfd\x01\n\tDimension\x12M\n\x08\x63onstant\x18\x01 \x01(\x0b\x32\x39.CoreML.Specification.MILSpec.Dimension.ConstantDimensionH\x00\x12K\n\x07unknown\x18\x02 \x01(\x0b\x32\x38.CoreML.Specification.MILSpec.Dimension.UnknownDimensionH\x00\x1a!\n\x11\x43onstantDimension\x12\x0c\n\x04size\x18\x01 \x01(\x04\x1a$\n\x10UnknownDimension\x12\x10\n\x08variadic\x18\x01 \x01(\x08\x42\x0b\n\tdimension\"\xb9\x04\n\x05Value\x12\x11\n\tdocString\x18\x01 \x01(\t\x12\x35\n\x04type\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType\x12L\n\x0eimmediateValue\x18\x03 \x01(\x0b\x32\x32.CoreML.Specification.MILSpec.Value.ImmediateValueH\x00\x12J\n\rblobFileValue\x18\x05 \x01(\x0b\x32\x31.CoreML.Specification.MILSpec.Value.BlobFileValueH\x00\x1a\x8f\x02\n\x0eImmediateValue\x12;\n\x06tensor\x18\x01 \x01(\x0b\x32).CoreML.Specification.MILSpec.TensorValueH\x00\x12\x39\n\x05tuple\x18\x02 \x01(\x0b\x32(.CoreML.Specification.MILSpec.TupleValueH\x00\x12\x37\n\x04list\x18\x03 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ListValueH\x00\x12\x43\n\ndictionary\x18\x04 \x01(\x0b\x32-.CoreML.Specification.MILSpec.DictionaryValueH\x00\x42\x07\n\x05value\x1a\x31\n\rBlobFileValue\x12\x10\n\x08\x66ileName\x18\x01 \x01(\t\x12\x0e\n\x06offset\x18\x02 \x01(\x04\x42\x07\n\x05value\"\xac\x06\n\x0bTensorValue\x12J\n\x06\x66loats\x18\x01 \x01(\x0b\x32\x38.CoreML.Specification.MILSpec.TensorValue.RepeatedFloatsH\x00\x12\x46\n\x04ints\x18\x02 \x01(\x0b\x32\x36.CoreML.Specification.MILSpec.TensorValue.RepeatedIntsH\x00\x12H\n\x05\x62ools\x18\x03 \x01(\x0b\x32\x37.CoreML.Specification.MILSpec.TensorValue.RepeatedBoolsH\x00\x12L\n\x07strings\x18\x04 \x01(\x0b\x32\x39.CoreML.Specification.MILSpec.TensorValue.RepeatedStringsH\x00\x12N\n\x08longInts\x18\x05 \x01(\x0b\x32:.CoreML.Specification.MILSpec.TensorValue.RepeatedLongIntsH\x00\x12L\n\x07\x64oubles\x18\x06 \x01(\x0b\x32\x39.CoreML.Specification.MILSpec.TensorValue.RepeatedDoublesH\x00\x12H\n\x05\x62ytes\x18\x07 \x01(\x0b\x32\x37.CoreML.Specification.MILSpec.TensorValue.RepeatedBytesH\x00\x1a$\n\x0eRepeatedFloats\x12\x12\n\x06values\x18\x01 \x03(\x02\x42\x02\x10\x01\x1a%\n\x0fRepeatedDoubles\x12\x12\n\x06values\x18\x01 \x03(\x01\x42\x02\x10\x01\x1a\"\n\x0cRepeatedInts\x12\x12\n\x06values\x18\x01 \x03(\x05\x42\x02\x10\x01\x1a&\n\x10RepeatedLongInts\x12\x12\n\x06values\x18\x01 \x03(\x03\x42\x02\x10\x01\x1a#\n\rRepeatedBools\x12\x12\n\x06values\x18\x01 \x03(\x08\x42\x02\x10\x01\x1a!\n\x0fRepeatedStrings\x12\x0e\n\x06values\x18\x01 \x03(\t\x1a\x1f\n\rRepeatedBytes\x12\x0e\n\x06values\x18\x01 \x01(\x0c\x42\x07\n\x05value\"A\n\nTupleValue\x12\x33\n\x06values\x18\x01 \x03(\x0b\x32#.CoreML.Specification.MILSpec.Value\"@\n\tListValue\x12\x33\n\x06values\x18\x01 \x03(\x0b\x32#.CoreML.Specification.MILSpec.Value\"\xd3\x01\n\x0f\x44ictionaryValue\x12J\n\x06values\x18\x01 \x03(\x0b\x32:.CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair\x1at\n\x0cKeyValuePair\x12\x30\n\x03key\x18\x01 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value*\xc0\x01\n\x08\x44\x61taType\x12\x0f\n\x0bUNUSED_TYPE\x10\x00\x12\x08\n\x04\x42OOL\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x0b\n\x07\x46LOAT16\x10\n\x12\x0b\n\x07\x46LOAT32\x10\x0b\x12\x0b\n\x07\x46LOAT64\x10\x0c\x12\x0c\n\x08\x42\x46LOAT16\x10\r\x12\x08\n\x04INT8\x10\x15\x12\t\n\x05INT16\x10\x16\x12\t\n\x05INT32\x10\x17\x12\t\n\x05INT64\x10\x18\x12\t\n\x05UINT8\x10\x1f\x12\n\n\x06UINT16\x10 \x12\n\n\x06UINT32\x10!\x12\n\n\x06UINT64\x10\"B\x02H\x03\x62\x06proto3') + name="MIL.proto", + package="CoreML.Specification.MILSpec", + syntax="proto3", + serialized_pb=_b( + '\n\tMIL.proto\x12\x1c\x43oreML.Specification.MILSpec"\xf3\x02\n\x07Program\x12\x0f\n\x07version\x18\x01 \x01(\x03\x12G\n\tfunctions\x18\x02 \x03(\x0b\x32\x34.CoreML.Specification.MILSpec.Program.FunctionsEntry\x12\x11\n\tdocString\x18\x03 \x01(\t\x12I\n\nattributes\x18\x04 \x03(\x0b\x32\x35.CoreML.Specification.MILSpec.Program.AttributesEntry\x1aX\n\x0e\x46unctionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x35\n\x05value\x18\x02 \x01(\x0b\x32&.CoreML.Specification.MILSpec.Function:\x02\x38\x01\x1aV\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value:\x02\x38\x01"\xbe\x03\n\x08\x46unction\x12<\n\x06inputs\x18\x01 \x03(\x0b\x32,.CoreML.Specification.MILSpec.NamedValueType\x12\r\n\x05opset\x18\x02 \x01(\t\x12_\n\x15\x62lock_specializations\x18\x03 \x03(\x0b\x32@.CoreML.Specification.MILSpec.Function.BlockSpecializationsEntry\x12J\n\nattributes\x18\x04 \x03(\x0b\x32\x36.CoreML.Specification.MILSpec.Function.AttributesEntry\x1a`\n\x19\x42lockSpecializationsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Block:\x02\x38\x01\x1aV\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value:\x02\x38\x01"\xb4\x02\n\x05\x42lock\x12<\n\x06inputs\x18\x01 \x03(\x0b\x32,.CoreML.Specification.MILSpec.NamedValueType\x12\x0f\n\x07outputs\x18\x02 \x03(\t\x12;\n\noperations\x18\x03 \x03(\x0b\x32\'.CoreML.Specification.MILSpec.Operation\x12G\n\nattributes\x18\x04 \x03(\x0b\x32\x33.CoreML.Specification.MILSpec.Block.AttributesEntry\x1aV\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value:\x02\x38\x01"\xa9\x01\n\x08\x41rgument\x12\x41\n\targuments\x18\x01 \x03(\x0b\x32..CoreML.Specification.MILSpec.Argument.Binding\x1aZ\n\x07\x42inding\x12\x0e\n\x04name\x18\x01 \x01(\tH\x00\x12\x34\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.ValueH\x00\x42\t\n\x07\x62inding"\xce\x03\n\tOperation\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x43\n\x06inputs\x18\x02 \x03(\x0b\x32\x33.CoreML.Specification.MILSpec.Operation.InputsEntry\x12=\n\x07outputs\x18\x03 \x03(\x0b\x32,.CoreML.Specification.MILSpec.NamedValueType\x12\x33\n\x06\x62locks\x18\x04 \x03(\x0b\x32#.CoreML.Specification.MILSpec.Block\x12K\n\nattributes\x18\x05 \x03(\x0b\x32\x37.CoreML.Specification.MILSpec.Operation.AttributesEntry\x1aU\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x35\n\x05value\x18\x02 \x01(\x0b\x32&.CoreML.Specification.MILSpec.Argument:\x02\x38\x01\x1aV\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value:\x02\x38\x01"U\n\x0eNamedValueType\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x35\n\x04type\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType"\xd3\x02\n\tValueType\x12>\n\ntensorType\x18\x01 \x01(\x0b\x32(.CoreML.Specification.MILSpec.TensorTypeH\x00\x12:\n\x08listType\x18\x02 \x01(\x0b\x32&.CoreML.Specification.MILSpec.ListTypeH\x00\x12<\n\ttupleType\x18\x03 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.TupleTypeH\x00\x12\x46\n\x0e\x64ictionaryType\x18\x04 \x01(\x0b\x32,.CoreML.Specification.MILSpec.DictionaryTypeH\x00\x12<\n\tstateType\x18\x05 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.StateTypeH\x00\x42\x06\n\x04type"\xb7\x02\n\nTensorType\x12\x38\n\x08\x64\x61taType\x18\x01 \x01(\x0e\x32&.CoreML.Specification.MILSpec.DataType\x12\x0c\n\x04rank\x18\x02 \x01(\x03\x12;\n\ndimensions\x18\x03 \x03(\x0b\x32\'.CoreML.Specification.MILSpec.Dimension\x12L\n\nattributes\x18\x04 \x03(\x0b\x32\x38.CoreML.Specification.MILSpec.TensorType.AttributesEntry\x1aV\n\x0f\x41ttributesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value:\x02\x38\x01"C\n\tTupleType\x12\x36\n\x05types\x18\x01 \x03(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType"z\n\x08ListType\x12\x35\n\x04type\x18\x01 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType\x12\x37\n\x06length\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.Dimension"\x86\x01\n\x0e\x44ictionaryType\x12\x38\n\x07keyType\x18\x01 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType\x12:\n\tvalueType\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType"I\n\tStateType\x12<\n\x0bwrappedType\x18\x01 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType"\xfd\x01\n\tDimension\x12M\n\x08\x63onstant\x18\x01 \x01(\x0b\x32\x39.CoreML.Specification.MILSpec.Dimension.ConstantDimensionH\x00\x12K\n\x07unknown\x18\x02 \x01(\x0b\x32\x38.CoreML.Specification.MILSpec.Dimension.UnknownDimensionH\x00\x1a!\n\x11\x43onstantDimension\x12\x0c\n\x04size\x18\x01 \x01(\x04\x1a$\n\x10UnknownDimension\x12\x10\n\x08variadic\x18\x01 \x01(\x08\x42\x0b\n\tdimension"\xb9\x04\n\x05Value\x12\x11\n\tdocString\x18\x01 \x01(\t\x12\x35\n\x04type\x18\x02 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ValueType\x12L\n\x0eimmediateValue\x18\x03 \x01(\x0b\x32\x32.CoreML.Specification.MILSpec.Value.ImmediateValueH\x00\x12J\n\rblobFileValue\x18\x05 \x01(\x0b\x32\x31.CoreML.Specification.MILSpec.Value.BlobFileValueH\x00\x1a\x8f\x02\n\x0eImmediateValue\x12;\n\x06tensor\x18\x01 \x01(\x0b\x32).CoreML.Specification.MILSpec.TensorValueH\x00\x12\x39\n\x05tuple\x18\x02 \x01(\x0b\x32(.CoreML.Specification.MILSpec.TupleValueH\x00\x12\x37\n\x04list\x18\x03 \x01(\x0b\x32\'.CoreML.Specification.MILSpec.ListValueH\x00\x12\x43\n\ndictionary\x18\x04 \x01(\x0b\x32-.CoreML.Specification.MILSpec.DictionaryValueH\x00\x42\x07\n\x05value\x1a\x31\n\rBlobFileValue\x12\x10\n\x08\x66ileName\x18\x01 \x01(\t\x12\x0e\n\x06offset\x18\x02 \x01(\x04\x42\x07\n\x05value"\xac\x06\n\x0bTensorValue\x12J\n\x06\x66loats\x18\x01 \x01(\x0b\x32\x38.CoreML.Specification.MILSpec.TensorValue.RepeatedFloatsH\x00\x12\x46\n\x04ints\x18\x02 \x01(\x0b\x32\x36.CoreML.Specification.MILSpec.TensorValue.RepeatedIntsH\x00\x12H\n\x05\x62ools\x18\x03 \x01(\x0b\x32\x37.CoreML.Specification.MILSpec.TensorValue.RepeatedBoolsH\x00\x12L\n\x07strings\x18\x04 \x01(\x0b\x32\x39.CoreML.Specification.MILSpec.TensorValue.RepeatedStringsH\x00\x12N\n\x08longInts\x18\x05 \x01(\x0b\x32:.CoreML.Specification.MILSpec.TensorValue.RepeatedLongIntsH\x00\x12L\n\x07\x64oubles\x18\x06 \x01(\x0b\x32\x39.CoreML.Specification.MILSpec.TensorValue.RepeatedDoublesH\x00\x12H\n\x05\x62ytes\x18\x07 \x01(\x0b\x32\x37.CoreML.Specification.MILSpec.TensorValue.RepeatedBytesH\x00\x1a$\n\x0eRepeatedFloats\x12\x12\n\x06values\x18\x01 \x03(\x02\x42\x02\x10\x01\x1a%\n\x0fRepeatedDoubles\x12\x12\n\x06values\x18\x01 \x03(\x01\x42\x02\x10\x01\x1a"\n\x0cRepeatedInts\x12\x12\n\x06values\x18\x01 \x03(\x05\x42\x02\x10\x01\x1a&\n\x10RepeatedLongInts\x12\x12\n\x06values\x18\x01 \x03(\x03\x42\x02\x10\x01\x1a#\n\rRepeatedBools\x12\x12\n\x06values\x18\x01 \x03(\x08\x42\x02\x10\x01\x1a!\n\x0fRepeatedStrings\x12\x0e\n\x06values\x18\x01 \x03(\t\x1a\x1f\n\rRepeatedBytes\x12\x0e\n\x06values\x18\x01 \x01(\x0c\x42\x07\n\x05value"A\n\nTupleValue\x12\x33\n\x06values\x18\x01 \x03(\x0b\x32#.CoreML.Specification.MILSpec.Value"@\n\tListValue\x12\x33\n\x06values\x18\x01 \x03(\x0b\x32#.CoreML.Specification.MILSpec.Value"\xd3\x01\n\x0f\x44ictionaryValue\x12J\n\x06values\x18\x01 \x03(\x0b\x32:.CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair\x1at\n\x0cKeyValuePair\x12\x30\n\x03key\x18\x01 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value\x12\x32\n\x05value\x18\x02 \x01(\x0b\x32#.CoreML.Specification.MILSpec.Value*\xa3\x02\n\x08\x44\x61taType\x12\x0f\n\x0bUNUSED_TYPE\x10\x00\x12\x08\n\x04\x42OOL\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x10\n\x0c\x46LOAT8E4M3FN\x10(\x12\x0e\n\nFLOAT8E5M2\x10)\x12\x0b\n\x07\x46LOAT16\x10\n\x12\x0b\n\x07\x46LOAT32\x10\x0b\x12\x0b\n\x07\x46LOAT64\x10\x0c\x12\x0c\n\x08\x42\x46LOAT16\x10\r\x12\x08\n\x04INT8\x10\x15\x12\t\n\x05INT16\x10\x16\x12\t\n\x05INT32\x10\x17\x12\t\n\x05INT64\x10\x18\x12\x08\n\x04INT4\x10\x19\x12\t\n\x05UINT8\x10\x1f\x12\n\n\x06UINT16\x10 \x12\n\n\x06UINT32\x10!\x12\n\n\x06UINT64\x10"\x12\t\n\x05UINT4\x10#\x12\t\n\x05UINT2\x10$\x12\t\n\x05UINT1\x10%\x12\t\n\x05UINT6\x10&\x12\t\n\x05UINT3\x10\'B\x02H\x03\x62\x06proto3' + ), ) _DATATYPE = _descriptor.EnumDescriptor( - name='DataType', - full_name='CoreML.Specification.MILSpec.DataType', - filename=None, - file=DESCRIPTOR, - values=[ - _descriptor.EnumValueDescriptor( - name='UNUSED_TYPE', index=0, number=0, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='BOOL', index=1, number=1, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='STRING', index=2, number=2, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='FLOAT16', index=3, number=10, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='FLOAT32', index=4, number=11, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='FLOAT64', index=5, number=12, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='BFLOAT16', index=6, number=13, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='INT8', index=7, number=21, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='INT16', index=8, number=22, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='INT32', index=9, number=23, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='INT64', index=10, number=24, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='UINT8', index=11, number=31, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='UINT16', index=12, number=32, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='UINT32', index=13, number=33, - options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='UINT64', index=14, number=34, - options=None, - type=None), - ], - containing_type=None, - options=None, - serialized_start=4816, - serialized_end=5008, + name="DataType", + full_name="CoreML.Specification.MILSpec.DataType", + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name="UNUSED_TYPE", index=0, number=0, options=None, type=None + ), + _descriptor.EnumValueDescriptor(name="BOOL", index=1, number=1, options=None, type=None), + _descriptor.EnumValueDescriptor(name="STRING", index=2, number=2, options=None, type=None), + _descriptor.EnumValueDescriptor( + name="FLOAT8E4M3FN", index=3, number=40, options=None, type=None + ), + _descriptor.EnumValueDescriptor( + name="FLOAT8E5M2", index=4, number=41, options=None, type=None + ), + _descriptor.EnumValueDescriptor( + name="FLOAT16", index=5, number=10, options=None, type=None + ), + _descriptor.EnumValueDescriptor( + name="FLOAT32", index=6, number=11, options=None, type=None + ), + _descriptor.EnumValueDescriptor( + name="FLOAT64", index=7, number=12, options=None, type=None + ), + _descriptor.EnumValueDescriptor( + name="BFLOAT16", index=8, number=13, options=None, type=None + ), + _descriptor.EnumValueDescriptor(name="INT8", index=9, number=21, options=None, type=None), + _descriptor.EnumValueDescriptor(name="INT16", index=10, number=22, options=None, type=None), + _descriptor.EnumValueDescriptor(name="INT32", index=11, number=23, options=None, type=None), + _descriptor.EnumValueDescriptor(name="INT64", index=12, number=24, options=None, type=None), + _descriptor.EnumValueDescriptor(name="INT4", index=13, number=25, options=None, type=None), + _descriptor.EnumValueDescriptor(name="UINT8", index=14, number=31, options=None, type=None), + _descriptor.EnumValueDescriptor( + name="UINT16", index=15, number=32, options=None, type=None + ), + _descriptor.EnumValueDescriptor( + name="UINT32", index=16, number=33, options=None, type=None + ), + _descriptor.EnumValueDescriptor( + name="UINT64", index=17, number=34, options=None, type=None + ), + _descriptor.EnumValueDescriptor(name="UINT4", index=18, number=35, options=None, type=None), + _descriptor.EnumValueDescriptor(name="UINT2", index=19, number=36, options=None, type=None), + _descriptor.EnumValueDescriptor(name="UINT1", index=20, number=37, options=None, type=None), + _descriptor.EnumValueDescriptor(name="UINT6", index=21, number=38, options=None, type=None), + _descriptor.EnumValueDescriptor(name="UINT3", index=22, number=39, options=None, type=None), + ], + containing_type=None, + options=None, + serialized_start=4953, + serialized_end=5244, ) _sym_db.RegisterEnumDescriptor(_DATATYPE) @@ -101,6 +86,8 @@ UNUSED_TYPE = 0 BOOL = 1 STRING = 2 +FLOAT8E4M3FN = 40 +FLOAT8E5M2 = 41 FLOAT16 = 10 FLOAT32 = 11 FLOAT64 = 12 @@ -109,10 +96,16 @@ INT16 = 22 INT32 = 23 INT64 = 24 +INT4 = 25 UINT8 = 31 UINT16 = 32 UINT32 = 33 UINT64 = 34 +UINT4 = 35 +UINT2 = 36 +UINT1 = 37 +UINT6 = 38 +UINT3 = 39 @@ -700,57 +693,116 @@ _VALUETYPE = _descriptor.Descriptor( - name='ValueType', - full_name='CoreML.Specification.MILSpec.ValueType', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='tensorType', full_name='CoreML.Specification.MILSpec.ValueType.tensorType', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='listType', full_name='CoreML.Specification.MILSpec.ValueType.listType', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='tupleType', full_name='CoreML.Specification.MILSpec.ValueType.tupleType', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='dictionaryType', full_name='CoreML.Specification.MILSpec.ValueType.dictionaryType', index=3, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='type', full_name='CoreML.Specification.MILSpec.ValueType.type', - index=0, containing_type=None, fields=[]), - ], - serialized_start=1902, - serialized_end=2179, + name="ValueType", + full_name="CoreML.Specification.MILSpec.ValueType", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="tensorType", + full_name="CoreML.Specification.MILSpec.ValueType.tensorType", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="listType", + full_name="CoreML.Specification.MILSpec.ValueType.listType", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="tupleType", + full_name="CoreML.Specification.MILSpec.ValueType.tupleType", + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="dictionaryType", + full_name="CoreML.Specification.MILSpec.ValueType.dictionaryType", + index=3, + number=4, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="stateType", + full_name="CoreML.Specification.MILSpec.ValueType.stateType", + index=4, + number=5, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="type", + full_name="CoreML.Specification.MILSpec.ValueType.type", + index=0, + containing_type=None, + fields=[], + ), + ], + serialized_start=1902, + serialized_end=2241, ) @@ -792,824 +844,1254 @@ ) _TENSORTYPE = _descriptor.Descriptor( - name='TensorType', - full_name='CoreML.Specification.MILSpec.TensorType', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='dataType', full_name='CoreML.Specification.MILSpec.TensorType.dataType', index=0, - number=1, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='rank', full_name='CoreML.Specification.MILSpec.TensorType.rank', index=1, - number=2, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='dimensions', full_name='CoreML.Specification.MILSpec.TensorType.dimensions', index=2, - number=3, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='attributes', full_name='CoreML.Specification.MILSpec.TensorType.attributes', index=3, - number=4, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[_TENSORTYPE_ATTRIBUTESENTRY, ], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2182, - serialized_end=2493, + name="TensorType", + full_name="CoreML.Specification.MILSpec.TensorType", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="dataType", + full_name="CoreML.Specification.MILSpec.TensorType.dataType", + index=0, + number=1, + type=14, + cpp_type=8, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="rank", + full_name="CoreML.Specification.MILSpec.TensorType.rank", + index=1, + number=2, + type=3, + cpp_type=2, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="dimensions", + full_name="CoreML.Specification.MILSpec.TensorType.dimensions", + index=2, + number=3, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="attributes", + full_name="CoreML.Specification.MILSpec.TensorType.attributes", + index=3, + number=4, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[ + _TENSORTYPE_ATTRIBUTESENTRY, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=2244, + serialized_end=2555, ) _TUPLETYPE = _descriptor.Descriptor( - name='TupleType', - full_name='CoreML.Specification.MILSpec.TupleType', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='types', full_name='CoreML.Specification.MILSpec.TupleType.types', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2495, - serialized_end=2562, + name="TupleType", + full_name="CoreML.Specification.MILSpec.TupleType", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="types", + full_name="CoreML.Specification.MILSpec.TupleType.types", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=2557, + serialized_end=2624, ) _LISTTYPE = _descriptor.Descriptor( - name='ListType', - full_name='CoreML.Specification.MILSpec.ListType', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='type', full_name='CoreML.Specification.MILSpec.ListType.type', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='length', full_name='CoreML.Specification.MILSpec.ListType.length', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2564, - serialized_end=2686, + name="ListType", + full_name="CoreML.Specification.MILSpec.ListType", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="type", + full_name="CoreML.Specification.MILSpec.ListType.type", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="length", + full_name="CoreML.Specification.MILSpec.ListType.length", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=2626, + serialized_end=2748, ) _DICTIONARYTYPE = _descriptor.Descriptor( - name='DictionaryType', - full_name='CoreML.Specification.MILSpec.DictionaryType', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='keyType', full_name='CoreML.Specification.MILSpec.DictionaryType.keyType', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='valueType', full_name='CoreML.Specification.MILSpec.DictionaryType.valueType', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2689, - serialized_end=2823, + name="DictionaryType", + full_name="CoreML.Specification.MILSpec.DictionaryType", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="keyType", + full_name="CoreML.Specification.MILSpec.DictionaryType.keyType", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="valueType", + full_name="CoreML.Specification.MILSpec.DictionaryType.valueType", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=2751, + serialized_end=2885, +) + + +_STATETYPE = _descriptor.Descriptor( + name="StateType", + full_name="CoreML.Specification.MILSpec.StateType", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="wrappedType", + full_name="CoreML.Specification.MILSpec.StateType.wrappedType", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=2887, + serialized_end=2960, ) _DIMENSION_CONSTANTDIMENSION = _descriptor.Descriptor( - name='ConstantDimension', - full_name='CoreML.Specification.MILSpec.Dimension.ConstantDimension', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='size', full_name='CoreML.Specification.MILSpec.Dimension.ConstantDimension.size', index=0, - number=1, type=4, cpp_type=4, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2995, - serialized_end=3028, + name="ConstantDimension", + full_name="CoreML.Specification.MILSpec.Dimension.ConstantDimension", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="size", + full_name="CoreML.Specification.MILSpec.Dimension.ConstantDimension.size", + index=0, + number=1, + type=4, + cpp_type=4, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=3132, + serialized_end=3165, ) _DIMENSION_UNKNOWNDIMENSION = _descriptor.Descriptor( - name='UnknownDimension', - full_name='CoreML.Specification.MILSpec.Dimension.UnknownDimension', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='variadic', full_name='CoreML.Specification.MILSpec.Dimension.UnknownDimension.variadic', index=0, - number=1, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=3030, - serialized_end=3066, + name="UnknownDimension", + full_name="CoreML.Specification.MILSpec.Dimension.UnknownDimension", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="variadic", + full_name="CoreML.Specification.MILSpec.Dimension.UnknownDimension.variadic", + index=0, + number=1, + type=8, + cpp_type=7, + label=1, + has_default_value=False, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=3167, + serialized_end=3203, ) _DIMENSION = _descriptor.Descriptor( - name='Dimension', - full_name='CoreML.Specification.MILSpec.Dimension', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='constant', full_name='CoreML.Specification.MILSpec.Dimension.constant', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='unknown', full_name='CoreML.Specification.MILSpec.Dimension.unknown', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[_DIMENSION_CONSTANTDIMENSION, _DIMENSION_UNKNOWNDIMENSION, ], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='dimension', full_name='CoreML.Specification.MILSpec.Dimension.dimension', - index=0, containing_type=None, fields=[]), - ], - serialized_start=2826, - serialized_end=3079, + name="Dimension", + full_name="CoreML.Specification.MILSpec.Dimension", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="constant", + full_name="CoreML.Specification.MILSpec.Dimension.constant", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="unknown", + full_name="CoreML.Specification.MILSpec.Dimension.unknown", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[ + _DIMENSION_CONSTANTDIMENSION, + _DIMENSION_UNKNOWNDIMENSION, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="dimension", + full_name="CoreML.Specification.MILSpec.Dimension.dimension", + index=0, + containing_type=None, + fields=[], + ), + ], + serialized_start=2963, + serialized_end=3216, ) _VALUE_IMMEDIATEVALUE = _descriptor.Descriptor( - name='ImmediateValue', - full_name='CoreML.Specification.MILSpec.Value.ImmediateValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='tensor', full_name='CoreML.Specification.MILSpec.Value.ImmediateValue.tensor', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='tuple', full_name='CoreML.Specification.MILSpec.Value.ImmediateValue.tuple', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='list', full_name='CoreML.Specification.MILSpec.Value.ImmediateValue.list', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='dictionary', full_name='CoreML.Specification.MILSpec.Value.ImmediateValue.dictionary', index=3, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='value', full_name='CoreML.Specification.MILSpec.Value.ImmediateValue.value', - index=0, containing_type=None, fields=[]), - ], - serialized_start=3320, - serialized_end=3591, + name="ImmediateValue", + full_name="CoreML.Specification.MILSpec.Value.ImmediateValue", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="tensor", + full_name="CoreML.Specification.MILSpec.Value.ImmediateValue.tensor", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="tuple", + full_name="CoreML.Specification.MILSpec.Value.ImmediateValue.tuple", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="list", + full_name="CoreML.Specification.MILSpec.Value.ImmediateValue.list", + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="dictionary", + full_name="CoreML.Specification.MILSpec.Value.ImmediateValue.dictionary", + index=3, + number=4, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="value", + full_name="CoreML.Specification.MILSpec.Value.ImmediateValue.value", + index=0, + containing_type=None, + fields=[], + ), + ], + serialized_start=3457, + serialized_end=3728, ) _VALUE_BLOBFILEVALUE = _descriptor.Descriptor( - name='BlobFileValue', - full_name='CoreML.Specification.MILSpec.Value.BlobFileValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='fileName', full_name='CoreML.Specification.MILSpec.Value.BlobFileValue.fileName', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='offset', full_name='CoreML.Specification.MILSpec.Value.BlobFileValue.offset', index=1, - number=2, type=4, cpp_type=4, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=3593, - serialized_end=3642, + name="BlobFileValue", + full_name="CoreML.Specification.MILSpec.Value.BlobFileValue", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="fileName", + full_name="CoreML.Specification.MILSpec.Value.BlobFileValue.fileName", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="offset", + full_name="CoreML.Specification.MILSpec.Value.BlobFileValue.offset", + index=1, + number=2, + type=4, + cpp_type=4, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=3730, + serialized_end=3779, ) _VALUE = _descriptor.Descriptor( - name='Value', - full_name='CoreML.Specification.MILSpec.Value', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='docString', full_name='CoreML.Specification.MILSpec.Value.docString', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='type', full_name='CoreML.Specification.MILSpec.Value.type', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='immediateValue', full_name='CoreML.Specification.MILSpec.Value.immediateValue', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='blobFileValue', full_name='CoreML.Specification.MILSpec.Value.blobFileValue', index=3, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[_VALUE_IMMEDIATEVALUE, _VALUE_BLOBFILEVALUE, ], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='value', full_name='CoreML.Specification.MILSpec.Value.value', - index=0, containing_type=None, fields=[]), - ], - serialized_start=3082, - serialized_end=3651, + name="Value", + full_name="CoreML.Specification.MILSpec.Value", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="docString", + full_name="CoreML.Specification.MILSpec.Value.docString", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="type", + full_name="CoreML.Specification.MILSpec.Value.type", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="immediateValue", + full_name="CoreML.Specification.MILSpec.Value.immediateValue", + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="blobFileValue", + full_name="CoreML.Specification.MILSpec.Value.blobFileValue", + index=3, + number=5, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[ + _VALUE_IMMEDIATEVALUE, + _VALUE_BLOBFILEVALUE, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="value", + full_name="CoreML.Specification.MILSpec.Value.value", + index=0, + containing_type=None, + fields=[], + ), + ], + serialized_start=3219, + serialized_end=3788, ) _TENSORVALUE_REPEATEDFLOATS = _descriptor.Descriptor( - name='RepeatedFloats', - full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedFloats', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedFloats.values', index=0, - number=1, type=2, cpp_type=6, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4201, - serialized_end=4237, + name="RepeatedFloats", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedFloats", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedFloats.values", + index=0, + number=1, + type=2, + cpp_type=6, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4338, + serialized_end=4374, ) _TENSORVALUE_REPEATEDDOUBLES = _descriptor.Descriptor( - name='RepeatedDoubles', - full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedDoubles', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedDoubles.values', index=0, - number=1, type=1, cpp_type=5, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4239, - serialized_end=4276, + name="RepeatedDoubles", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedDoubles", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedDoubles.values", + index=0, + number=1, + type=1, + cpp_type=5, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4376, + serialized_end=4413, ) _TENSORVALUE_REPEATEDINTS = _descriptor.Descriptor( - name='RepeatedInts', - full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedInts', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedInts.values', index=0, - number=1, type=5, cpp_type=1, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4278, - serialized_end=4312, + name="RepeatedInts", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedInts", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedInts.values", + index=0, + number=1, + type=5, + cpp_type=1, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4415, + serialized_end=4449, ) _TENSORVALUE_REPEATEDLONGINTS = _descriptor.Descriptor( - name='RepeatedLongInts', - full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedLongInts', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedLongInts.values', index=0, - number=1, type=3, cpp_type=2, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4314, - serialized_end=4352, + name="RepeatedLongInts", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedLongInts", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedLongInts.values", + index=0, + number=1, + type=3, + cpp_type=2, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4451, + serialized_end=4489, ) _TENSORVALUE_REPEATEDBOOLS = _descriptor.Descriptor( - name='RepeatedBools', - full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedBools', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedBools.values', index=0, - number=1, type=8, cpp_type=7, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4354, - serialized_end=4389, + name="RepeatedBools", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedBools", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedBools.values", + index=0, + number=1, + type=8, + cpp_type=7, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4491, + serialized_end=4526, ) _TENSORVALUE_REPEATEDSTRINGS = _descriptor.Descriptor( - name='RepeatedStrings', - full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedStrings', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedStrings.values', index=0, - number=1, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4391, - serialized_end=4424, + name="RepeatedStrings", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedStrings", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedStrings.values", + index=0, + number=1, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4528, + serialized_end=4561, ) _TENSORVALUE_REPEATEDBYTES = _descriptor.Descriptor( - name='RepeatedBytes', - full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedBytes', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values', index=0, - number=1, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=_b(""), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4426, - serialized_end=4457, + name="RepeatedBytes", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedBytes", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values", + index=0, + number=1, + type=12, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b(""), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4563, + serialized_end=4594, ) _TENSORVALUE = _descriptor.Descriptor( - name='TensorValue', - full_name='CoreML.Specification.MILSpec.TensorValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='floats', full_name='CoreML.Specification.MILSpec.TensorValue.floats', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='ints', full_name='CoreML.Specification.MILSpec.TensorValue.ints', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='bools', full_name='CoreML.Specification.MILSpec.TensorValue.bools', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='strings', full_name='CoreML.Specification.MILSpec.TensorValue.strings', index=3, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='longInts', full_name='CoreML.Specification.MILSpec.TensorValue.longInts', index=4, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='doubles', full_name='CoreML.Specification.MILSpec.TensorValue.doubles', index=5, - number=6, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='bytes', full_name='CoreML.Specification.MILSpec.TensorValue.bytes', index=6, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[_TENSORVALUE_REPEATEDFLOATS, _TENSORVALUE_REPEATEDDOUBLES, _TENSORVALUE_REPEATEDINTS, _TENSORVALUE_REPEATEDLONGINTS, _TENSORVALUE_REPEATEDBOOLS, _TENSORVALUE_REPEATEDSTRINGS, _TENSORVALUE_REPEATEDBYTES, ], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='value', full_name='CoreML.Specification.MILSpec.TensorValue.value', - index=0, containing_type=None, fields=[]), - ], - serialized_start=3654, - serialized_end=4466, + name="TensorValue", + full_name="CoreML.Specification.MILSpec.TensorValue", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="floats", + full_name="CoreML.Specification.MILSpec.TensorValue.floats", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="ints", + full_name="CoreML.Specification.MILSpec.TensorValue.ints", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="bools", + full_name="CoreML.Specification.MILSpec.TensorValue.bools", + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="strings", + full_name="CoreML.Specification.MILSpec.TensorValue.strings", + index=3, + number=4, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="longInts", + full_name="CoreML.Specification.MILSpec.TensorValue.longInts", + index=4, + number=5, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="doubles", + full_name="CoreML.Specification.MILSpec.TensorValue.doubles", + index=5, + number=6, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="bytes", + full_name="CoreML.Specification.MILSpec.TensorValue.bytes", + index=6, + number=7, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[ + _TENSORVALUE_REPEATEDFLOATS, + _TENSORVALUE_REPEATEDDOUBLES, + _TENSORVALUE_REPEATEDINTS, + _TENSORVALUE_REPEATEDLONGINTS, + _TENSORVALUE_REPEATEDBOOLS, + _TENSORVALUE_REPEATEDSTRINGS, + _TENSORVALUE_REPEATEDBYTES, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="value", + full_name="CoreML.Specification.MILSpec.TensorValue.value", + index=0, + containing_type=None, + fields=[], + ), + ], + serialized_start=3791, + serialized_end=4603, ) _TUPLEVALUE = _descriptor.Descriptor( - name='TupleValue', - full_name='CoreML.Specification.MILSpec.TupleValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='CoreML.Specification.MILSpec.TupleValue.values', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4468, - serialized_end=4533, + name="TupleValue", + full_name="CoreML.Specification.MILSpec.TupleValue", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="CoreML.Specification.MILSpec.TupleValue.values", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4605, + serialized_end=4670, ) _LISTVALUE = _descriptor.Descriptor( - name='ListValue', - full_name='CoreML.Specification.MILSpec.ListValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='CoreML.Specification.MILSpec.ListValue.values', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4535, - serialized_end=4599, + name="ListValue", + full_name="CoreML.Specification.MILSpec.ListValue", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="CoreML.Specification.MILSpec.ListValue.values", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4672, + serialized_end=4736, ) _DICTIONARYVALUE_KEYVALUEPAIR = _descriptor.Descriptor( - name='KeyValuePair', - full_name='CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.key', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='value', full_name='CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4697, - serialized_end=4813, + name="KeyValuePair", + full_name="CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.key", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4834, + serialized_end=4950, ) _DICTIONARYVALUE = _descriptor.Descriptor( - name='DictionaryValue', - full_name='CoreML.Specification.MILSpec.DictionaryValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='CoreML.Specification.MILSpec.DictionaryValue.values', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[_DICTIONARYVALUE_KEYVALUEPAIR, ], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4602, - serialized_end=4813, + name="DictionaryValue", + full_name="CoreML.Specification.MILSpec.DictionaryValue", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="CoreML.Specification.MILSpec.DictionaryValue.values", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[ + _DICTIONARYVALUE_KEYVALUEPAIR, + ], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=4739, + serialized_end=4950, ) _PROGRAM_FUNCTIONSENTRY.fields_by_name['value'].message_type = _FUNCTION @@ -1643,37 +2125,37 @@ _OPERATION_INPUTSENTRY.containing_type = _OPERATION _OPERATION_ATTRIBUTESENTRY.fields_by_name['value'].message_type = _VALUE _OPERATION_ATTRIBUTESENTRY.containing_type = _OPERATION -_OPERATION.fields_by_name['inputs'].message_type = _OPERATION_INPUTSENTRY -_OPERATION.fields_by_name['outputs'].message_type = _NAMEDVALUETYPE -_OPERATION.fields_by_name['blocks'].message_type = _BLOCK -_OPERATION.fields_by_name['attributes'].message_type = _OPERATION_ATTRIBUTESENTRY -_NAMEDVALUETYPE.fields_by_name['type'].message_type = _VALUETYPE -_VALUETYPE.fields_by_name['tensorType'].message_type = _TENSORTYPE -_VALUETYPE.fields_by_name['listType'].message_type = _LISTTYPE -_VALUETYPE.fields_by_name['tupleType'].message_type = _TUPLETYPE -_VALUETYPE.fields_by_name['dictionaryType'].message_type = _DICTIONARYTYPE -_VALUETYPE.oneofs_by_name['type'].fields.append( - _VALUETYPE.fields_by_name['tensorType']) -_VALUETYPE.fields_by_name['tensorType'].containing_oneof = _VALUETYPE.oneofs_by_name['type'] -_VALUETYPE.oneofs_by_name['type'].fields.append( - _VALUETYPE.fields_by_name['listType']) -_VALUETYPE.fields_by_name['listType'].containing_oneof = _VALUETYPE.oneofs_by_name['type'] -_VALUETYPE.oneofs_by_name['type'].fields.append( - _VALUETYPE.fields_by_name['tupleType']) -_VALUETYPE.fields_by_name['tupleType'].containing_oneof = _VALUETYPE.oneofs_by_name['type'] -_VALUETYPE.oneofs_by_name['type'].fields.append( - _VALUETYPE.fields_by_name['dictionaryType']) -_VALUETYPE.fields_by_name['dictionaryType'].containing_oneof = _VALUETYPE.oneofs_by_name['type'] -_TENSORTYPE_ATTRIBUTESENTRY.fields_by_name['value'].message_type = _VALUE +_OPERATION.fields_by_name["inputs"].message_type = _OPERATION_INPUTSENTRY +_OPERATION.fields_by_name["outputs"].message_type = _NAMEDVALUETYPE +_OPERATION.fields_by_name["blocks"].message_type = _BLOCK +_OPERATION.fields_by_name["attributes"].message_type = _OPERATION_ATTRIBUTESENTRY +_NAMEDVALUETYPE.fields_by_name["type"].message_type = _VALUETYPE +_VALUETYPE.fields_by_name["tensorType"].message_type = _TENSORTYPE +_VALUETYPE.fields_by_name["listType"].message_type = _LISTTYPE +_VALUETYPE.fields_by_name["tupleType"].message_type = _TUPLETYPE +_VALUETYPE.fields_by_name["dictionaryType"].message_type = _DICTIONARYTYPE +_VALUETYPE.fields_by_name["stateType"].message_type = _STATETYPE +_VALUETYPE.oneofs_by_name["type"].fields.append(_VALUETYPE.fields_by_name["tensorType"]) +_VALUETYPE.fields_by_name["tensorType"].containing_oneof = _VALUETYPE.oneofs_by_name["type"] +_VALUETYPE.oneofs_by_name["type"].fields.append(_VALUETYPE.fields_by_name["listType"]) +_VALUETYPE.fields_by_name["listType"].containing_oneof = _VALUETYPE.oneofs_by_name["type"] +_VALUETYPE.oneofs_by_name["type"].fields.append(_VALUETYPE.fields_by_name["tupleType"]) +_VALUETYPE.fields_by_name["tupleType"].containing_oneof = _VALUETYPE.oneofs_by_name["type"] +_VALUETYPE.oneofs_by_name["type"].fields.append(_VALUETYPE.fields_by_name["dictionaryType"]) +_VALUETYPE.fields_by_name["dictionaryType"].containing_oneof = _VALUETYPE.oneofs_by_name["type"] +_VALUETYPE.oneofs_by_name["type"].fields.append(_VALUETYPE.fields_by_name["stateType"]) +_VALUETYPE.fields_by_name["stateType"].containing_oneof = _VALUETYPE.oneofs_by_name["type"] +_TENSORTYPE_ATTRIBUTESENTRY.fields_by_name["value"].message_type = _VALUE _TENSORTYPE_ATTRIBUTESENTRY.containing_type = _TENSORTYPE -_TENSORTYPE.fields_by_name['dataType'].enum_type = _DATATYPE -_TENSORTYPE.fields_by_name['dimensions'].message_type = _DIMENSION -_TENSORTYPE.fields_by_name['attributes'].message_type = _TENSORTYPE_ATTRIBUTESENTRY -_TUPLETYPE.fields_by_name['types'].message_type = _VALUETYPE -_LISTTYPE.fields_by_name['type'].message_type = _VALUETYPE -_LISTTYPE.fields_by_name['length'].message_type = _DIMENSION -_DICTIONARYTYPE.fields_by_name['keyType'].message_type = _VALUETYPE -_DICTIONARYTYPE.fields_by_name['valueType'].message_type = _VALUETYPE +_TENSORTYPE.fields_by_name["dataType"].enum_type = _DATATYPE +_TENSORTYPE.fields_by_name["dimensions"].message_type = _DIMENSION +_TENSORTYPE.fields_by_name["attributes"].message_type = _TENSORTYPE_ATTRIBUTESENTRY +_TUPLETYPE.fields_by_name["types"].message_type = _VALUETYPE +_LISTTYPE.fields_by_name["type"].message_type = _VALUETYPE +_LISTTYPE.fields_by_name["length"].message_type = _DIMENSION +_DICTIONARYTYPE.fields_by_name["keyType"].message_type = _VALUETYPE +_DICTIONARYTYPE.fields_by_name["valueType"].message_type = _VALUETYPE +_STATETYPE.fields_by_name["wrappedType"].message_type = _VALUETYPE _DIMENSION_CONSTANTDIMENSION.containing_type = _DIMENSION _DIMENSION_UNKNOWNDIMENSION.containing_type = _DIMENSION _DIMENSION.fields_by_name['constant'].message_type = _DIMENSION_CONSTANTDIMENSION @@ -1751,25 +2233,26 @@ _DICTIONARYVALUE_KEYVALUEPAIR.fields_by_name['key'].message_type = _VALUE _DICTIONARYVALUE_KEYVALUEPAIR.fields_by_name['value'].message_type = _VALUE _DICTIONARYVALUE_KEYVALUEPAIR.containing_type = _DICTIONARYVALUE -_DICTIONARYVALUE.fields_by_name['values'].message_type = _DICTIONARYVALUE_KEYVALUEPAIR -DESCRIPTOR.message_types_by_name['Program'] = _PROGRAM -DESCRIPTOR.message_types_by_name['Function'] = _FUNCTION -DESCRIPTOR.message_types_by_name['Block'] = _BLOCK -DESCRIPTOR.message_types_by_name['Argument'] = _ARGUMENT -DESCRIPTOR.message_types_by_name['Operation'] = _OPERATION -DESCRIPTOR.message_types_by_name['NamedValueType'] = _NAMEDVALUETYPE -DESCRIPTOR.message_types_by_name['ValueType'] = _VALUETYPE -DESCRIPTOR.message_types_by_name['TensorType'] = _TENSORTYPE -DESCRIPTOR.message_types_by_name['TupleType'] = _TUPLETYPE -DESCRIPTOR.message_types_by_name['ListType'] = _LISTTYPE -DESCRIPTOR.message_types_by_name['DictionaryType'] = _DICTIONARYTYPE -DESCRIPTOR.message_types_by_name['Dimension'] = _DIMENSION -DESCRIPTOR.message_types_by_name['Value'] = _VALUE -DESCRIPTOR.message_types_by_name['TensorValue'] = _TENSORVALUE -DESCRIPTOR.message_types_by_name['TupleValue'] = _TUPLEVALUE -DESCRIPTOR.message_types_by_name['ListValue'] = _LISTVALUE -DESCRIPTOR.message_types_by_name['DictionaryValue'] = _DICTIONARYVALUE -DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE +_DICTIONARYVALUE.fields_by_name["values"].message_type = _DICTIONARYVALUE_KEYVALUEPAIR +DESCRIPTOR.message_types_by_name["Program"] = _PROGRAM +DESCRIPTOR.message_types_by_name["Function"] = _FUNCTION +DESCRIPTOR.message_types_by_name["Block"] = _BLOCK +DESCRIPTOR.message_types_by_name["Argument"] = _ARGUMENT +DESCRIPTOR.message_types_by_name["Operation"] = _OPERATION +DESCRIPTOR.message_types_by_name["NamedValueType"] = _NAMEDVALUETYPE +DESCRIPTOR.message_types_by_name["ValueType"] = _VALUETYPE +DESCRIPTOR.message_types_by_name["TensorType"] = _TENSORTYPE +DESCRIPTOR.message_types_by_name["TupleType"] = _TUPLETYPE +DESCRIPTOR.message_types_by_name["ListType"] = _LISTTYPE +DESCRIPTOR.message_types_by_name["DictionaryType"] = _DICTIONARYTYPE +DESCRIPTOR.message_types_by_name["StateType"] = _STATETYPE +DESCRIPTOR.message_types_by_name["Dimension"] = _DIMENSION +DESCRIPTOR.message_types_by_name["Value"] = _VALUE +DESCRIPTOR.message_types_by_name["TensorValue"] = _TENSORVALUE +DESCRIPTOR.message_types_by_name["TupleValue"] = _TUPLEVALUE +DESCRIPTOR.message_types_by_name["ListValue"] = _LISTVALUE +DESCRIPTOR.message_types_by_name["DictionaryValue"] = _DICTIONARYVALUE +DESCRIPTOR.enum_types_by_name["DataType"] = _DATATYPE _sym_db.RegisterFileDescriptor(DESCRIPTOR) Program = _reflection.GeneratedProtocolMessageType('Program', (_message.Message,), dict( @@ -1921,6 +2404,17 @@ )) _sym_db.RegisterMessage(DictionaryType) +StateType = _reflection.GeneratedProtocolMessageType( + "StateType", + (_message.Message,), + dict( + DESCRIPTOR=_STATETYPE, + __module__="MIL_pb2" + # @@protoc_insertion_point(class_scope:CoreML.Specification.MILSpec.StateType) + ), +) +_sym_db.RegisterMessage(StateType) + Dimension = _reflection.GeneratedProtocolMessageType('Dimension', (_message.Message,), dict( ConstantDimension = _reflection.GeneratedProtocolMessageType('ConstantDimension', (_message.Message,), dict( diff --git a/coremltools/proto/Model_pb2.py b/coremltools/proto/Model_pb2.py index 86743064a..b2d2b740a 100644 --- a/coremltools/proto/Model_pb2.py +++ b/coremltools/proto/Model_pb2.py @@ -250,15 +250,79 @@ from .ClassConfidenceThresholding_pb2 import * DESCRIPTOR = _descriptor.FileDescriptor( - name='Model.proto', - package='CoreML.Specification', - syntax='proto3', - serialized_pb=_b('\n\x0bModel.proto\x12\x14\x43oreML.Specification\x1a\x18VisionFeaturePrint.proto\x1a\x17\x41udioFeaturePrint.proto\x1a\x14TextClassifier.proto\x1a\x10WordTagger.proto\x1a\x0fGazetteer.proto\x1a\x13WordEmbedding.proto\x1a\x1b\x41rrayFeatureExtractor.proto\x1a\x1d\x42\x61yesianProbitRegressor.proto\x1a\x18\x43\x61tegoricalMapping.proto\x1a\x11\x43ustomModel.proto\x1a\x14\x44ictVectorizer.proto\x1a\x12\x46\x65\x61tureTypes.proto\x1a\x17\x46\x65\x61tureVectorizer.proto\x1a\x12GLMRegressor.proto\x1a\x13GLMClassifier.proto\x1a\x16NearestNeighbors.proto\x1a\x0eIdentity.proto\x1a\rImputer.proto\x1a\tMIL.proto\x1a\x13NeuralNetwork.proto\x1a\x10Normalizer.proto\x1a\x13OneHotEncoder.proto\x1a\x0cScaler.proto\x1a\x1bNonMaximumSuppression.proto\x1a\tSVM.proto\x1a\x12TreeEnsemble.proto\x1a\x10Parameters.proto\x1a\x1fItemSimilarityRecommender.proto\x1a SoundAnalysisPreprocessing.proto\x1a\x11LinkedModel.proto\x1a!ClassConfidenceThresholding.proto\"F\n\x08Pipeline\x12+\n\x06models\x18\x01 \x03(\x0b\x32\x1b.CoreML.Specification.Model\x12\r\n\x05names\x18\x02 \x03(\t\"F\n\x12PipelineClassifier\x12\x30\n\x08pipeline\x18\x01 \x01(\x0b\x32\x1e.CoreML.Specification.Pipeline\"E\n\x11PipelineRegressor\x12\x30\n\x08pipeline\x18\x01 \x01(\x0b\x32\x1e.CoreML.Specification.Pipeline\"m\n\x12\x46\x65\x61tureDescription\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x18\n\x10shortDescription\x18\x02 \x01(\t\x12/\n\x04type\x18\x03 \x01(\x0b\x32!.CoreML.Specification.FeatureType\"\xd6\x01\n\x08Metadata\x12\x18\n\x10shortDescription\x18\x01 \x01(\t\x12\x15\n\rversionString\x18\x02 \x01(\t\x12\x0e\n\x06\x61uthor\x18\x03 \x01(\t\x12\x0f\n\x07license\x18\x04 \x01(\t\x12\x44\n\x0buserDefined\x18\x64 \x03(\x0b\x32/.CoreML.Specification.Metadata.UserDefinedEntry\x1a\x32\n\x10UserDefinedEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xba\x02\n\x10ModelDescription\x12\x37\n\x05input\x18\x01 \x03(\x0b\x32(.CoreML.Specification.FeatureDescription\x12\x38\n\x06output\x18\n \x03(\x0b\x32(.CoreML.Specification.FeatureDescription\x12\x1c\n\x14predictedFeatureName\x18\x0b \x01(\t\x12\"\n\x1apredictedProbabilitiesName\x18\x0c \x01(\t\x12?\n\rtrainingInput\x18\x32 \x03(\x0b\x32(.CoreML.Specification.FeatureDescription\x12\x30\n\x08metadata\x18\x64 \x01(\x0b\x32\x1e.CoreML.Specification.Metadata\"4\n\x0fSerializedModel\x12\x12\n\nidentifier\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\x0c\"\xf1\x15\n\x05Model\x12\x1c\n\x14specificationVersion\x18\x01 \x01(\x05\x12;\n\x0b\x64\x65scription\x18\x02 \x01(\x0b\x32&.CoreML.Specification.ModelDescription\x12\x13\n\x0bisUpdatable\x18\n \x01(\x08\x12G\n\x12pipelineClassifier\x18\xc8\x01 \x01(\x0b\x32(.CoreML.Specification.PipelineClassifierH\x00\x12\x45\n\x11pipelineRegressor\x18\xc9\x01 \x01(\x0b\x32\'.CoreML.Specification.PipelineRegressorH\x00\x12\x33\n\x08pipeline\x18\xca\x01 \x01(\x0b\x32\x1e.CoreML.Specification.PipelineH\x00\x12;\n\x0cglmRegressor\x18\xac\x02 \x01(\x0b\x32\".CoreML.Specification.GLMRegressorH\x00\x12O\n\x16supportVectorRegressor\x18\xad\x02 \x01(\x0b\x32,.CoreML.Specification.SupportVectorRegressorH\x00\x12M\n\x15treeEnsembleRegressor\x18\xae\x02 \x01(\x0b\x32+.CoreML.Specification.TreeEnsembleRegressorH\x00\x12O\n\x16neuralNetworkRegressor\x18\xaf\x02 \x01(\x0b\x32,.CoreML.Specification.NeuralNetworkRegressorH\x00\x12Q\n\x17\x62\x61yesianProbitRegressor\x18\xb0\x02 \x01(\x0b\x32-.CoreML.Specification.BayesianProbitRegressorH\x00\x12=\n\rglmClassifier\x18\x90\x03 \x01(\x0b\x32#.CoreML.Specification.GLMClassifierH\x00\x12Q\n\x17supportVectorClassifier\x18\x91\x03 \x01(\x0b\x32-.CoreML.Specification.SupportVectorClassifierH\x00\x12O\n\x16treeEnsembleClassifier\x18\x92\x03 \x01(\x0b\x32,.CoreML.Specification.TreeEnsembleClassifierH\x00\x12Q\n\x17neuralNetworkClassifier\x18\x93\x03 \x01(\x0b\x32-.CoreML.Specification.NeuralNetworkClassifierH\x00\x12Y\n\x1bkNearestNeighborsClassifier\x18\x94\x03 \x01(\x0b\x32\x31.CoreML.Specification.KNearestNeighborsClassifierH\x00\x12=\n\rneuralNetwork\x18\xf4\x03 \x01(\x0b\x32#.CoreML.Specification.NeuralNetworkH\x00\x12U\n\x19itemSimilarityRecommender\x18\xf5\x03 \x01(\x0b\x32/.CoreML.Specification.ItemSimilarityRecommenderH\x00\x12;\n\tmlProgram\x18\xf6\x03 \x01(\x0b\x32%.CoreML.Specification.MILSpec.ProgramH\x00\x12\x39\n\x0b\x63ustomModel\x18\xab\x04 \x01(\x0b\x32!.CoreML.Specification.CustomModelH\x00\x12\x39\n\x0blinkedModel\x18\xac\x04 \x01(\x0b\x32!.CoreML.Specification.LinkedModelH\x00\x12Y\n\x1b\x63lassConfidenceThresholding\x18\xb0\x04 \x01(\x0b\x32\x31.CoreML.Specification.ClassConfidenceThresholdingH\x00\x12=\n\roneHotEncoder\x18\xd8\x04 \x01(\x0b\x32#.CoreML.Specification.OneHotEncoderH\x00\x12\x31\n\x07imputer\x18\xd9\x04 \x01(\x0b\x32\x1d.CoreML.Specification.ImputerH\x00\x12\x45\n\x11\x66\x65\x61tureVectorizer\x18\xda\x04 \x01(\x0b\x32\'.CoreML.Specification.FeatureVectorizerH\x00\x12?\n\x0e\x64ictVectorizer\x18\xdb\x04 \x01(\x0b\x32$.CoreML.Specification.DictVectorizerH\x00\x12/\n\x06scaler\x18\xdc\x04 \x01(\x0b\x32\x1c.CoreML.Specification.ScalerH\x00\x12G\n\x12\x63\x61tegoricalMapping\x18\xde\x04 \x01(\x0b\x32(.CoreML.Specification.CategoricalMappingH\x00\x12\x37\n\nnormalizer\x18\xdf\x04 \x01(\x0b\x32 .CoreML.Specification.NormalizerH\x00\x12M\n\x15\x61rrayFeatureExtractor\x18\xe1\x04 \x01(\x0b\x32+.CoreML.Specification.ArrayFeatureExtractorH\x00\x12M\n\x15nonMaximumSuppression\x18\xe2\x04 \x01(\x0b\x32+.CoreML.Specification.NonMaximumSuppressionH\x00\x12\x33\n\x08identity\x18\x84\x07 \x01(\x0b\x32\x1e.CoreML.Specification.IdentityH\x00\x12L\n\x0etextClassifier\x18\xd0\x0f \x01(\x0b\x32\x31.CoreML.Specification.CoreMLModels.TextClassifierH\x00\x12\x44\n\nwordTagger\x18\xd1\x0f \x01(\x0b\x32-.CoreML.Specification.CoreMLModels.WordTaggerH\x00\x12T\n\x12visionFeaturePrint\x18\xd2\x0f \x01(\x0b\x32\x35.CoreML.Specification.CoreMLModels.VisionFeaturePrintH\x00\x12\x64\n\x1asoundAnalysisPreprocessing\x18\xd3\x0f \x01(\x0b\x32=.CoreML.Specification.CoreMLModels.SoundAnalysisPreprocessingH\x00\x12\x42\n\tgazetteer\x18\xd4\x0f \x01(\x0b\x32,.CoreML.Specification.CoreMLModels.GazetteerH\x00\x12J\n\rwordEmbedding\x18\xd5\x0f \x01(\x0b\x32\x30.CoreML.Specification.CoreMLModels.WordEmbeddingH\x00\x12R\n\x11\x61udioFeaturePrint\x18\xd6\x0f \x01(\x0b\x32\x34.CoreML.Specification.CoreMLModels.AudioFeaturePrintH\x00\x12\x41\n\x0fserializedModel\x18\xb8\x17 \x01(\x0b\x32%.CoreML.Specification.SerializedModelH\x00\x42\x06\n\x04TypeB\x02H\x03P\x00P\x01P\x02P\x03P\x04P\x05P\x06P\x07P\x08P\tP\nP\x0bP\x0cP\rP\x0eP\x0fP\x10P\x11P\x12P\x13P\x14P\x15P\x16P\x17P\x18P\x19P\x1aP\x1bP\x1cP\x1dP\x1e\x62\x06proto3') - , - dependencies=[VisionFeaturePrint__pb2.DESCRIPTOR,AudioFeaturePrint__pb2.DESCRIPTOR,TextClassifier__pb2.DESCRIPTOR,WordTagger__pb2.DESCRIPTOR,Gazetteer__pb2.DESCRIPTOR,WordEmbedding__pb2.DESCRIPTOR,ArrayFeatureExtractor__pb2.DESCRIPTOR,BayesianProbitRegressor__pb2.DESCRIPTOR,CategoricalMapping__pb2.DESCRIPTOR,CustomModel__pb2.DESCRIPTOR,DictVectorizer__pb2.DESCRIPTOR,FeatureTypes__pb2.DESCRIPTOR,FeatureVectorizer__pb2.DESCRIPTOR,GLMRegressor__pb2.DESCRIPTOR,GLMClassifier__pb2.DESCRIPTOR,NearestNeighbors__pb2.DESCRIPTOR,Identity__pb2.DESCRIPTOR,Imputer__pb2.DESCRIPTOR,MIL__pb2.DESCRIPTOR,NeuralNetwork__pb2.DESCRIPTOR,Normalizer__pb2.DESCRIPTOR,OneHotEncoder__pb2.DESCRIPTOR,Scaler__pb2.DESCRIPTOR,NonMaximumSuppression__pb2.DESCRIPTOR,SVM__pb2.DESCRIPTOR,TreeEnsemble__pb2.DESCRIPTOR,Parameters__pb2.DESCRIPTOR,ItemSimilarityRecommender__pb2.DESCRIPTOR,SoundAnalysisPreprocessing__pb2.DESCRIPTOR,LinkedModel__pb2.DESCRIPTOR,ClassConfidenceThresholding__pb2.DESCRIPTOR,], - public_dependencies=[VisionFeaturePrint__pb2.DESCRIPTOR,AudioFeaturePrint__pb2.DESCRIPTOR,TextClassifier__pb2.DESCRIPTOR,WordTagger__pb2.DESCRIPTOR,Gazetteer__pb2.DESCRIPTOR,WordEmbedding__pb2.DESCRIPTOR,ArrayFeatureExtractor__pb2.DESCRIPTOR,BayesianProbitRegressor__pb2.DESCRIPTOR,CategoricalMapping__pb2.DESCRIPTOR,CustomModel__pb2.DESCRIPTOR,DictVectorizer__pb2.DESCRIPTOR,FeatureTypes__pb2.DESCRIPTOR,FeatureVectorizer__pb2.DESCRIPTOR,GLMRegressor__pb2.DESCRIPTOR,GLMClassifier__pb2.DESCRIPTOR,NearestNeighbors__pb2.DESCRIPTOR,Identity__pb2.DESCRIPTOR,Imputer__pb2.DESCRIPTOR,MIL__pb2.DESCRIPTOR,NeuralNetwork__pb2.DESCRIPTOR,Normalizer__pb2.DESCRIPTOR,OneHotEncoder__pb2.DESCRIPTOR,Scaler__pb2.DESCRIPTOR,NonMaximumSuppression__pb2.DESCRIPTOR,SVM__pb2.DESCRIPTOR,TreeEnsemble__pb2.DESCRIPTOR,Parameters__pb2.DESCRIPTOR,ItemSimilarityRecommender__pb2.DESCRIPTOR,SoundAnalysisPreprocessing__pb2.DESCRIPTOR,LinkedModel__pb2.DESCRIPTOR,ClassConfidenceThresholding__pb2.DESCRIPTOR,]) - - + name="Model.proto", + package="CoreML.Specification", + syntax="proto3", + serialized_pb=_b( + '\n\x0bModel.proto\x12\x14\x43oreML.Specification\x1a\x18VisionFeaturePrint.proto\x1a\x17\x41udioFeaturePrint.proto\x1a\x14TextClassifier.proto\x1a\x10WordTagger.proto\x1a\x0fGazetteer.proto\x1a\x13WordEmbedding.proto\x1a\x1b\x41rrayFeatureExtractor.proto\x1a\x1d\x42\x61yesianProbitRegressor.proto\x1a\x18\x43\x61tegoricalMapping.proto\x1a\x11\x43ustomModel.proto\x1a\x14\x44ictVectorizer.proto\x1a\x12\x46\x65\x61tureTypes.proto\x1a\x17\x46\x65\x61tureVectorizer.proto\x1a\x12GLMRegressor.proto\x1a\x13GLMClassifier.proto\x1a\x16NearestNeighbors.proto\x1a\x0eIdentity.proto\x1a\rImputer.proto\x1a\tMIL.proto\x1a\x13NeuralNetwork.proto\x1a\x10Normalizer.proto\x1a\x13OneHotEncoder.proto\x1a\x0cScaler.proto\x1a\x1bNonMaximumSuppression.proto\x1a\tSVM.proto\x1a\x12TreeEnsemble.proto\x1a\x10Parameters.proto\x1a\x1fItemSimilarityRecommender.proto\x1a SoundAnalysisPreprocessing.proto\x1a\x11LinkedModel.proto\x1a!ClassConfidenceThresholding.proto"F\n\x08Pipeline\x12+\n\x06models\x18\x01 \x03(\x0b\x32\x1b.CoreML.Specification.Model\x12\r\n\x05names\x18\x02 \x03(\t"F\n\x12PipelineClassifier\x12\x30\n\x08pipeline\x18\x01 \x01(\x0b\x32\x1e.CoreML.Specification.Pipeline"E\n\x11PipelineRegressor\x12\x30\n\x08pipeline\x18\x01 \x01(\x0b\x32\x1e.CoreML.Specification.Pipeline"m\n\x12\x46\x65\x61tureDescription\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x18\n\x10shortDescription\x18\x02 \x01(\t\x12/\n\x04type\x18\x03 \x01(\x0b\x32!.CoreML.Specification.FeatureType"\xd6\x01\n\x08Metadata\x12\x18\n\x10shortDescription\x18\x01 \x01(\t\x12\x15\n\rversionString\x18\x02 \x01(\t\x12\x0e\n\x06\x61uthor\x18\x03 \x01(\t\x12\x0f\n\x07license\x18\x04 \x01(\t\x12\x44\n\x0buserDefined\x18\x64 \x03(\x0b\x32/.CoreML.Specification.Metadata.UserDefinedEntry\x1a\x32\n\x10UserDefinedEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"\x91\x02\n\x13\x46unctionDescription\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x37\n\x05input\x18\x02 \x03(\x0b\x32(.CoreML.Specification.FeatureDescription\x12\x38\n\x06output\x18\x03 \x03(\x0b\x32(.CoreML.Specification.FeatureDescription\x12\x37\n\x05state\x18\x06 \x03(\x0b\x32(.CoreML.Specification.FeatureDescription\x12\x1c\n\x14predictedFeatureName\x18\x04 \x01(\t\x12"\n\x1apredictedProbabilitiesName\x18\x05 \x01(\t"\xce\x03\n\x10ModelDescription\x12<\n\tfunctions\x18\x14 \x03(\x0b\x32).CoreML.Specification.FunctionDescription\x12\x1b\n\x13\x64\x65\x66\x61ultFunctionName\x18\x15 \x01(\t\x12\x30\n\x08metadata\x18\x64 \x01(\x0b\x32\x1e.CoreML.Specification.Metadata\x12\x37\n\x05input\x18\x01 \x03(\x0b\x32(.CoreML.Specification.FeatureDescription\x12\x38\n\x06output\x18\n \x03(\x0b\x32(.CoreML.Specification.FeatureDescription\x12\x37\n\x05state\x18\r \x03(\x0b\x32(.CoreML.Specification.FeatureDescription\x12\x1c\n\x14predictedFeatureName\x18\x0b \x01(\t\x12"\n\x1apredictedProbabilitiesName\x18\x0c \x01(\t\x12?\n\rtrainingInput\x18\x32 \x03(\x0b\x32(.CoreML.Specification.FeatureDescription"4\n\x0fSerializedModel\x12\x12\n\nidentifier\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\x0c"\xf1\x15\n\x05Model\x12\x1c\n\x14specificationVersion\x18\x01 \x01(\x05\x12;\n\x0b\x64\x65scription\x18\x02 \x01(\x0b\x32&.CoreML.Specification.ModelDescription\x12\x13\n\x0bisUpdatable\x18\n \x01(\x08\x12G\n\x12pipelineClassifier\x18\xc8\x01 \x01(\x0b\x32(.CoreML.Specification.PipelineClassifierH\x00\x12\x45\n\x11pipelineRegressor\x18\xc9\x01 \x01(\x0b\x32\'.CoreML.Specification.PipelineRegressorH\x00\x12\x33\n\x08pipeline\x18\xca\x01 \x01(\x0b\x32\x1e.CoreML.Specification.PipelineH\x00\x12;\n\x0cglmRegressor\x18\xac\x02 \x01(\x0b\x32".CoreML.Specification.GLMRegressorH\x00\x12O\n\x16supportVectorRegressor\x18\xad\x02 \x01(\x0b\x32,.CoreML.Specification.SupportVectorRegressorH\x00\x12M\n\x15treeEnsembleRegressor\x18\xae\x02 \x01(\x0b\x32+.CoreML.Specification.TreeEnsembleRegressorH\x00\x12O\n\x16neuralNetworkRegressor\x18\xaf\x02 \x01(\x0b\x32,.CoreML.Specification.NeuralNetworkRegressorH\x00\x12Q\n\x17\x62\x61yesianProbitRegressor\x18\xb0\x02 \x01(\x0b\x32-.CoreML.Specification.BayesianProbitRegressorH\x00\x12=\n\rglmClassifier\x18\x90\x03 \x01(\x0b\x32#.CoreML.Specification.GLMClassifierH\x00\x12Q\n\x17supportVectorClassifier\x18\x91\x03 \x01(\x0b\x32-.CoreML.Specification.SupportVectorClassifierH\x00\x12O\n\x16treeEnsembleClassifier\x18\x92\x03 \x01(\x0b\x32,.CoreML.Specification.TreeEnsembleClassifierH\x00\x12Q\n\x17neuralNetworkClassifier\x18\x93\x03 \x01(\x0b\x32-.CoreML.Specification.NeuralNetworkClassifierH\x00\x12Y\n\x1bkNearestNeighborsClassifier\x18\x94\x03 \x01(\x0b\x32\x31.CoreML.Specification.KNearestNeighborsClassifierH\x00\x12=\n\rneuralNetwork\x18\xf4\x03 \x01(\x0b\x32#.CoreML.Specification.NeuralNetworkH\x00\x12U\n\x19itemSimilarityRecommender\x18\xf5\x03 \x01(\x0b\x32/.CoreML.Specification.ItemSimilarityRecommenderH\x00\x12;\n\tmlProgram\x18\xf6\x03 \x01(\x0b\x32%.CoreML.Specification.MILSpec.ProgramH\x00\x12\x39\n\x0b\x63ustomModel\x18\xab\x04 \x01(\x0b\x32!.CoreML.Specification.CustomModelH\x00\x12\x39\n\x0blinkedModel\x18\xac\x04 \x01(\x0b\x32!.CoreML.Specification.LinkedModelH\x00\x12Y\n\x1b\x63lassConfidenceThresholding\x18\xb0\x04 \x01(\x0b\x32\x31.CoreML.Specification.ClassConfidenceThresholdingH\x00\x12=\n\roneHotEncoder\x18\xd8\x04 \x01(\x0b\x32#.CoreML.Specification.OneHotEncoderH\x00\x12\x31\n\x07imputer\x18\xd9\x04 \x01(\x0b\x32\x1d.CoreML.Specification.ImputerH\x00\x12\x45\n\x11\x66\x65\x61tureVectorizer\x18\xda\x04 \x01(\x0b\x32\'.CoreML.Specification.FeatureVectorizerH\x00\x12?\n\x0e\x64ictVectorizer\x18\xdb\x04 \x01(\x0b\x32$.CoreML.Specification.DictVectorizerH\x00\x12/\n\x06scaler\x18\xdc\x04 \x01(\x0b\x32\x1c.CoreML.Specification.ScalerH\x00\x12G\n\x12\x63\x61tegoricalMapping\x18\xde\x04 \x01(\x0b\x32(.CoreML.Specification.CategoricalMappingH\x00\x12\x37\n\nnormalizer\x18\xdf\x04 \x01(\x0b\x32 .CoreML.Specification.NormalizerH\x00\x12M\n\x15\x61rrayFeatureExtractor\x18\xe1\x04 \x01(\x0b\x32+.CoreML.Specification.ArrayFeatureExtractorH\x00\x12M\n\x15nonMaximumSuppression\x18\xe2\x04 \x01(\x0b\x32+.CoreML.Specification.NonMaximumSuppressionH\x00\x12\x33\n\x08identity\x18\x84\x07 \x01(\x0b\x32\x1e.CoreML.Specification.IdentityH\x00\x12L\n\x0etextClassifier\x18\xd0\x0f \x01(\x0b\x32\x31.CoreML.Specification.CoreMLModels.TextClassifierH\x00\x12\x44\n\nwordTagger\x18\xd1\x0f \x01(\x0b\x32-.CoreML.Specification.CoreMLModels.WordTaggerH\x00\x12T\n\x12visionFeaturePrint\x18\xd2\x0f \x01(\x0b\x32\x35.CoreML.Specification.CoreMLModels.VisionFeaturePrintH\x00\x12\x64\n\x1asoundAnalysisPreprocessing\x18\xd3\x0f \x01(\x0b\x32=.CoreML.Specification.CoreMLModels.SoundAnalysisPreprocessingH\x00\x12\x42\n\tgazetteer\x18\xd4\x0f \x01(\x0b\x32,.CoreML.Specification.CoreMLModels.GazetteerH\x00\x12J\n\rwordEmbedding\x18\xd5\x0f \x01(\x0b\x32\x30.CoreML.Specification.CoreMLModels.WordEmbeddingH\x00\x12R\n\x11\x61udioFeaturePrint\x18\xd6\x0f \x01(\x0b\x32\x34.CoreML.Specification.CoreMLModels.AudioFeaturePrintH\x00\x12\x41\n\x0fserializedModel\x18\xb8\x17 \x01(\x0b\x32%.CoreML.Specification.SerializedModelH\x00\x42\x06\n\x04TypeB\x02H\x03P\x00P\x01P\x02P\x03P\x04P\x05P\x06P\x07P\x08P\tP\nP\x0bP\x0cP\rP\x0eP\x0fP\x10P\x11P\x12P\x13P\x14P\x15P\x16P\x17P\x18P\x19P\x1aP\x1bP\x1cP\x1dP\x1e\x62\x06proto3' + ), + dependencies=[ + VisionFeaturePrint__pb2.DESCRIPTOR, + AudioFeaturePrint__pb2.DESCRIPTOR, + TextClassifier__pb2.DESCRIPTOR, + WordTagger__pb2.DESCRIPTOR, + Gazetteer__pb2.DESCRIPTOR, + WordEmbedding__pb2.DESCRIPTOR, + ArrayFeatureExtractor__pb2.DESCRIPTOR, + BayesianProbitRegressor__pb2.DESCRIPTOR, + CategoricalMapping__pb2.DESCRIPTOR, + CustomModel__pb2.DESCRIPTOR, + DictVectorizer__pb2.DESCRIPTOR, + FeatureTypes__pb2.DESCRIPTOR, + FeatureVectorizer__pb2.DESCRIPTOR, + GLMRegressor__pb2.DESCRIPTOR, + GLMClassifier__pb2.DESCRIPTOR, + NearestNeighbors__pb2.DESCRIPTOR, + Identity__pb2.DESCRIPTOR, + Imputer__pb2.DESCRIPTOR, + MIL__pb2.DESCRIPTOR, + NeuralNetwork__pb2.DESCRIPTOR, + Normalizer__pb2.DESCRIPTOR, + OneHotEncoder__pb2.DESCRIPTOR, + Scaler__pb2.DESCRIPTOR, + NonMaximumSuppression__pb2.DESCRIPTOR, + SVM__pb2.DESCRIPTOR, + TreeEnsemble__pb2.DESCRIPTOR, + Parameters__pb2.DESCRIPTOR, + ItemSimilarityRecommender__pb2.DESCRIPTOR, + SoundAnalysisPreprocessing__pb2.DESCRIPTOR, + LinkedModel__pb2.DESCRIPTOR, + ClassConfidenceThresholding__pb2.DESCRIPTOR, + ], + public_dependencies=[ + VisionFeaturePrint__pb2.DESCRIPTOR, + AudioFeaturePrint__pb2.DESCRIPTOR, + TextClassifier__pb2.DESCRIPTOR, + WordTagger__pb2.DESCRIPTOR, + Gazetteer__pb2.DESCRIPTOR, + WordEmbedding__pb2.DESCRIPTOR, + ArrayFeatureExtractor__pb2.DESCRIPTOR, + BayesianProbitRegressor__pb2.DESCRIPTOR, + CategoricalMapping__pb2.DESCRIPTOR, + CustomModel__pb2.DESCRIPTOR, + DictVectorizer__pb2.DESCRIPTOR, + FeatureTypes__pb2.DESCRIPTOR, + FeatureVectorizer__pb2.DESCRIPTOR, + GLMRegressor__pb2.DESCRIPTOR, + GLMClassifier__pb2.DESCRIPTOR, + NearestNeighbors__pb2.DESCRIPTOR, + Identity__pb2.DESCRIPTOR, + Imputer__pb2.DESCRIPTOR, + MIL__pb2.DESCRIPTOR, + NeuralNetwork__pb2.DESCRIPTOR, + Normalizer__pb2.DESCRIPTOR, + OneHotEncoder__pb2.DESCRIPTOR, + Scaler__pb2.DESCRIPTOR, + NonMaximumSuppression__pb2.DESCRIPTOR, + SVM__pb2.DESCRIPTOR, + TreeEnsemble__pb2.DESCRIPTOR, + Parameters__pb2.DESCRIPTOR, + ItemSimilarityRecommender__pb2.DESCRIPTOR, + SoundAnalysisPreprocessing__pb2.DESCRIPTOR, + LinkedModel__pb2.DESCRIPTOR, + ClassConfidenceThresholding__pb2.DESCRIPTOR, + ], +) _PIPELINE = _descriptor.Descriptor( @@ -502,414 +566,1064 @@ ) +_FUNCTIONDESCRIPTION = _descriptor.Descriptor( + name="FunctionDescription", + full_name="CoreML.Specification.FunctionDescription", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="name", + full_name="CoreML.Specification.FunctionDescription.name", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="input", + full_name="CoreML.Specification.FunctionDescription.input", + index=1, + number=2, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="output", + full_name="CoreML.Specification.FunctionDescription.output", + index=2, + number=3, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="state", + full_name="CoreML.Specification.FunctionDescription.state", + index=3, + number=6, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="predictedFeatureName", + full_name="CoreML.Specification.FunctionDescription.predictedFeatureName", + index=4, + number=4, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="predictedProbabilitiesName", + full_name="CoreML.Specification.FunctionDescription.predictedProbabilitiesName", + index=5, + number=5, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=1262, + serialized_end=1535, +) + + _MODELDESCRIPTION = _descriptor.Descriptor( - name='ModelDescription', - full_name='CoreML.Specification.ModelDescription', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='input', full_name='CoreML.Specification.ModelDescription.input', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='output', full_name='CoreML.Specification.ModelDescription.output', index=1, - number=10, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='predictedFeatureName', full_name='CoreML.Specification.ModelDescription.predictedFeatureName', index=2, - number=11, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='predictedProbabilitiesName', full_name='CoreML.Specification.ModelDescription.predictedProbabilitiesName', index=3, - number=12, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='trainingInput', full_name='CoreML.Specification.ModelDescription.trainingInput', index=4, - number=50, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='metadata', full_name='CoreML.Specification.ModelDescription.metadata', index=5, - number=100, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1262, - serialized_end=1576, + name="ModelDescription", + full_name="CoreML.Specification.ModelDescription", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="functions", + full_name="CoreML.Specification.ModelDescription.functions", + index=0, + number=20, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="defaultFunctionName", + full_name="CoreML.Specification.ModelDescription.defaultFunctionName", + index=1, + number=21, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="metadata", + full_name="CoreML.Specification.ModelDescription.metadata", + index=2, + number=100, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="input", + full_name="CoreML.Specification.ModelDescription.input", + index=3, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="output", + full_name="CoreML.Specification.ModelDescription.output", + index=4, + number=10, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="state", + full_name="CoreML.Specification.ModelDescription.state", + index=5, + number=13, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="predictedFeatureName", + full_name="CoreML.Specification.ModelDescription.predictedFeatureName", + index=6, + number=11, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="predictedProbabilitiesName", + full_name="CoreML.Specification.ModelDescription.predictedProbabilitiesName", + index=7, + number=12, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="trainingInput", + full_name="CoreML.Specification.ModelDescription.trainingInput", + index=8, + number=50, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=1538, + serialized_end=2000, ) _SERIALIZEDMODEL = _descriptor.Descriptor( - name='SerializedModel', - full_name='CoreML.Specification.SerializedModel', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='identifier', full_name='CoreML.Specification.SerializedModel.identifier', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='model', full_name='CoreML.Specification.SerializedModel.model', index=1, - number=2, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=_b(""), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=1578, - serialized_end=1630, + name="SerializedModel", + full_name="CoreML.Specification.SerializedModel", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="identifier", + full_name="CoreML.Specification.SerializedModel.identifier", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="model", + full_name="CoreML.Specification.SerializedModel.model", + index=1, + number=2, + type=12, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b(""), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=2002, + serialized_end=2054, ) _MODEL = _descriptor.Descriptor( - name='Model', - full_name='CoreML.Specification.Model', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='specificationVersion', full_name='CoreML.Specification.Model.specificationVersion', index=0, - number=1, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='description', full_name='CoreML.Specification.Model.description', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='isUpdatable', full_name='CoreML.Specification.Model.isUpdatable', index=2, - number=10, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='pipelineClassifier', full_name='CoreML.Specification.Model.pipelineClassifier', index=3, - number=200, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='pipelineRegressor', full_name='CoreML.Specification.Model.pipelineRegressor', index=4, - number=201, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='pipeline', full_name='CoreML.Specification.Model.pipeline', index=5, - number=202, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='glmRegressor', full_name='CoreML.Specification.Model.glmRegressor', index=6, - number=300, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='supportVectorRegressor', full_name='CoreML.Specification.Model.supportVectorRegressor', index=7, - number=301, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='treeEnsembleRegressor', full_name='CoreML.Specification.Model.treeEnsembleRegressor', index=8, - number=302, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='neuralNetworkRegressor', full_name='CoreML.Specification.Model.neuralNetworkRegressor', index=9, - number=303, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='bayesianProbitRegressor', full_name='CoreML.Specification.Model.bayesianProbitRegressor', index=10, - number=304, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='glmClassifier', full_name='CoreML.Specification.Model.glmClassifier', index=11, - number=400, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='supportVectorClassifier', full_name='CoreML.Specification.Model.supportVectorClassifier', index=12, - number=401, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='treeEnsembleClassifier', full_name='CoreML.Specification.Model.treeEnsembleClassifier', index=13, - number=402, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='neuralNetworkClassifier', full_name='CoreML.Specification.Model.neuralNetworkClassifier', index=14, - number=403, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='kNearestNeighborsClassifier', full_name='CoreML.Specification.Model.kNearestNeighborsClassifier', index=15, - number=404, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='neuralNetwork', full_name='CoreML.Specification.Model.neuralNetwork', index=16, - number=500, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='itemSimilarityRecommender', full_name='CoreML.Specification.Model.itemSimilarityRecommender', index=17, - number=501, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='mlProgram', full_name='CoreML.Specification.Model.mlProgram', index=18, - number=502, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='customModel', full_name='CoreML.Specification.Model.customModel', index=19, - number=555, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='linkedModel', full_name='CoreML.Specification.Model.linkedModel', index=20, - number=556, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='classConfidenceThresholding', full_name='CoreML.Specification.Model.classConfidenceThresholding', index=21, - number=560, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='oneHotEncoder', full_name='CoreML.Specification.Model.oneHotEncoder', index=22, - number=600, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='imputer', full_name='CoreML.Specification.Model.imputer', index=23, - number=601, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='featureVectorizer', full_name='CoreML.Specification.Model.featureVectorizer', index=24, - number=602, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='dictVectorizer', full_name='CoreML.Specification.Model.dictVectorizer', index=25, - number=603, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='scaler', full_name='CoreML.Specification.Model.scaler', index=26, - number=604, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='categoricalMapping', full_name='CoreML.Specification.Model.categoricalMapping', index=27, - number=606, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='normalizer', full_name='CoreML.Specification.Model.normalizer', index=28, - number=607, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='arrayFeatureExtractor', full_name='CoreML.Specification.Model.arrayFeatureExtractor', index=29, - number=609, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='nonMaximumSuppression', full_name='CoreML.Specification.Model.nonMaximumSuppression', index=30, - number=610, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='identity', full_name='CoreML.Specification.Model.identity', index=31, - number=900, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='textClassifier', full_name='CoreML.Specification.Model.textClassifier', index=32, - number=2000, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='wordTagger', full_name='CoreML.Specification.Model.wordTagger', index=33, - number=2001, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='visionFeaturePrint', full_name='CoreML.Specification.Model.visionFeaturePrint', index=34, - number=2002, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='soundAnalysisPreprocessing', full_name='CoreML.Specification.Model.soundAnalysisPreprocessing', index=35, - number=2003, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='gazetteer', full_name='CoreML.Specification.Model.gazetteer', index=36, - number=2004, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='wordEmbedding', full_name='CoreML.Specification.Model.wordEmbedding', index=37, - number=2005, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='audioFeaturePrint', full_name='CoreML.Specification.Model.audioFeaturePrint', index=38, - number=2006, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='serializedModel', full_name='CoreML.Specification.Model.serializedModel', index=39, - number=3000, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='Type', full_name='CoreML.Specification.Model.Type', - index=0, containing_type=None, fields=[]), - ], - serialized_start=1633, - serialized_end=4434, + name="Model", + full_name="CoreML.Specification.Model", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="specificationVersion", + full_name="CoreML.Specification.Model.specificationVersion", + index=0, + number=1, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="description", + full_name="CoreML.Specification.Model.description", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="isUpdatable", + full_name="CoreML.Specification.Model.isUpdatable", + index=2, + number=10, + type=8, + cpp_type=7, + label=1, + has_default_value=False, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="pipelineClassifier", + full_name="CoreML.Specification.Model.pipelineClassifier", + index=3, + number=200, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="pipelineRegressor", + full_name="CoreML.Specification.Model.pipelineRegressor", + index=4, + number=201, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="pipeline", + full_name="CoreML.Specification.Model.pipeline", + index=5, + number=202, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="glmRegressor", + full_name="CoreML.Specification.Model.glmRegressor", + index=6, + number=300, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="supportVectorRegressor", + full_name="CoreML.Specification.Model.supportVectorRegressor", + index=7, + number=301, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="treeEnsembleRegressor", + full_name="CoreML.Specification.Model.treeEnsembleRegressor", + index=8, + number=302, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="neuralNetworkRegressor", + full_name="CoreML.Specification.Model.neuralNetworkRegressor", + index=9, + number=303, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="bayesianProbitRegressor", + full_name="CoreML.Specification.Model.bayesianProbitRegressor", + index=10, + number=304, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="glmClassifier", + full_name="CoreML.Specification.Model.glmClassifier", + index=11, + number=400, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="supportVectorClassifier", + full_name="CoreML.Specification.Model.supportVectorClassifier", + index=12, + number=401, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="treeEnsembleClassifier", + full_name="CoreML.Specification.Model.treeEnsembleClassifier", + index=13, + number=402, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="neuralNetworkClassifier", + full_name="CoreML.Specification.Model.neuralNetworkClassifier", + index=14, + number=403, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="kNearestNeighborsClassifier", + full_name="CoreML.Specification.Model.kNearestNeighborsClassifier", + index=15, + number=404, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="neuralNetwork", + full_name="CoreML.Specification.Model.neuralNetwork", + index=16, + number=500, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="itemSimilarityRecommender", + full_name="CoreML.Specification.Model.itemSimilarityRecommender", + index=17, + number=501, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="mlProgram", + full_name="CoreML.Specification.Model.mlProgram", + index=18, + number=502, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="customModel", + full_name="CoreML.Specification.Model.customModel", + index=19, + number=555, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="linkedModel", + full_name="CoreML.Specification.Model.linkedModel", + index=20, + number=556, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="classConfidenceThresholding", + full_name="CoreML.Specification.Model.classConfidenceThresholding", + index=21, + number=560, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="oneHotEncoder", + full_name="CoreML.Specification.Model.oneHotEncoder", + index=22, + number=600, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="imputer", + full_name="CoreML.Specification.Model.imputer", + index=23, + number=601, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="featureVectorizer", + full_name="CoreML.Specification.Model.featureVectorizer", + index=24, + number=602, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="dictVectorizer", + full_name="CoreML.Specification.Model.dictVectorizer", + index=25, + number=603, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="scaler", + full_name="CoreML.Specification.Model.scaler", + index=26, + number=604, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="categoricalMapping", + full_name="CoreML.Specification.Model.categoricalMapping", + index=27, + number=606, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="normalizer", + full_name="CoreML.Specification.Model.normalizer", + index=28, + number=607, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="arrayFeatureExtractor", + full_name="CoreML.Specification.Model.arrayFeatureExtractor", + index=29, + number=609, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="nonMaximumSuppression", + full_name="CoreML.Specification.Model.nonMaximumSuppression", + index=30, + number=610, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="identity", + full_name="CoreML.Specification.Model.identity", + index=31, + number=900, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="textClassifier", + full_name="CoreML.Specification.Model.textClassifier", + index=32, + number=2000, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="wordTagger", + full_name="CoreML.Specification.Model.wordTagger", + index=33, + number=2001, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="visionFeaturePrint", + full_name="CoreML.Specification.Model.visionFeaturePrint", + index=34, + number=2002, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="soundAnalysisPreprocessing", + full_name="CoreML.Specification.Model.soundAnalysisPreprocessing", + index=35, + number=2003, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="gazetteer", + full_name="CoreML.Specification.Model.gazetteer", + index=36, + number=2004, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="wordEmbedding", + full_name="CoreML.Specification.Model.wordEmbedding", + index=37, + number=2005, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="audioFeaturePrint", + full_name="CoreML.Specification.Model.audioFeaturePrint", + index=38, + number=2006, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + _descriptor.FieldDescriptor( + name="serializedModel", + full_name="CoreML.Specification.Model.serializedModel", + index=39, + number=3000, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="Type", + full_name="CoreML.Specification.Model.Type", + index=0, + containing_type=None, + fields=[], + ), + ], + serialized_start=2057, + serialized_end=4858, ) _PIPELINE.fields_by_name['models'].message_type = _MODEL @@ -917,168 +1631,167 @@ _PIPELINEREGRESSOR.fields_by_name['pipeline'].message_type = _PIPELINE _FEATUREDESCRIPTION.fields_by_name['type'].message_type = FeatureTypes__pb2._FEATURETYPE _METADATA_USERDEFINEDENTRY.containing_type = _METADATA -_METADATA.fields_by_name['userDefined'].message_type = _METADATA_USERDEFINEDENTRY -_MODELDESCRIPTION.fields_by_name['input'].message_type = _FEATUREDESCRIPTION -_MODELDESCRIPTION.fields_by_name['output'].message_type = _FEATUREDESCRIPTION -_MODELDESCRIPTION.fields_by_name['trainingInput'].message_type = _FEATUREDESCRIPTION -_MODELDESCRIPTION.fields_by_name['metadata'].message_type = _METADATA -_MODEL.fields_by_name['description'].message_type = _MODELDESCRIPTION -_MODEL.fields_by_name['pipelineClassifier'].message_type = _PIPELINECLASSIFIER -_MODEL.fields_by_name['pipelineRegressor'].message_type = _PIPELINEREGRESSOR -_MODEL.fields_by_name['pipeline'].message_type = _PIPELINE -_MODEL.fields_by_name['glmRegressor'].message_type = GLMRegressor__pb2._GLMREGRESSOR -_MODEL.fields_by_name['supportVectorRegressor'].message_type = SVM__pb2._SUPPORTVECTORREGRESSOR -_MODEL.fields_by_name['treeEnsembleRegressor'].message_type = TreeEnsemble__pb2._TREEENSEMBLEREGRESSOR -_MODEL.fields_by_name['neuralNetworkRegressor'].message_type = NeuralNetwork__pb2._NEURALNETWORKREGRESSOR -_MODEL.fields_by_name['bayesianProbitRegressor'].message_type = BayesianProbitRegressor__pb2._BAYESIANPROBITREGRESSOR -_MODEL.fields_by_name['glmClassifier'].message_type = GLMClassifier__pb2._GLMCLASSIFIER -_MODEL.fields_by_name['supportVectorClassifier'].message_type = SVM__pb2._SUPPORTVECTORCLASSIFIER -_MODEL.fields_by_name['treeEnsembleClassifier'].message_type = TreeEnsemble__pb2._TREEENSEMBLECLASSIFIER -_MODEL.fields_by_name['neuralNetworkClassifier'].message_type = NeuralNetwork__pb2._NEURALNETWORKCLASSIFIER -_MODEL.fields_by_name['kNearestNeighborsClassifier'].message_type = NearestNeighbors__pb2._KNEARESTNEIGHBORSCLASSIFIER -_MODEL.fields_by_name['neuralNetwork'].message_type = NeuralNetwork__pb2._NEURALNETWORK -_MODEL.fields_by_name['itemSimilarityRecommender'].message_type = ItemSimilarityRecommender__pb2._ITEMSIMILARITYRECOMMENDER -_MODEL.fields_by_name['mlProgram'].message_type = MIL__pb2._PROGRAM -_MODEL.fields_by_name['customModel'].message_type = CustomModel__pb2._CUSTOMMODEL -_MODEL.fields_by_name['linkedModel'].message_type = LinkedModel__pb2._LINKEDMODEL -_MODEL.fields_by_name['classConfidenceThresholding'].message_type = ClassConfidenceThresholding__pb2._CLASSCONFIDENCETHRESHOLDING -_MODEL.fields_by_name['oneHotEncoder'].message_type = OneHotEncoder__pb2._ONEHOTENCODER -_MODEL.fields_by_name['imputer'].message_type = Imputer__pb2._IMPUTER -_MODEL.fields_by_name['featureVectorizer'].message_type = FeatureVectorizer__pb2._FEATUREVECTORIZER -_MODEL.fields_by_name['dictVectorizer'].message_type = DictVectorizer__pb2._DICTVECTORIZER -_MODEL.fields_by_name['scaler'].message_type = Scaler__pb2._SCALER -_MODEL.fields_by_name['categoricalMapping'].message_type = CategoricalMapping__pb2._CATEGORICALMAPPING -_MODEL.fields_by_name['normalizer'].message_type = Normalizer__pb2._NORMALIZER -_MODEL.fields_by_name['arrayFeatureExtractor'].message_type = ArrayFeatureExtractor__pb2._ARRAYFEATUREEXTRACTOR -_MODEL.fields_by_name['nonMaximumSuppression'].message_type = NonMaximumSuppression__pb2._NONMAXIMUMSUPPRESSION -_MODEL.fields_by_name['identity'].message_type = Identity__pb2._IDENTITY -_MODEL.fields_by_name['textClassifier'].message_type = TextClassifier__pb2._TEXTCLASSIFIER -_MODEL.fields_by_name['wordTagger'].message_type = WordTagger__pb2._WORDTAGGER -_MODEL.fields_by_name['visionFeaturePrint'].message_type = VisionFeaturePrint__pb2._VISIONFEATUREPRINT -_MODEL.fields_by_name['soundAnalysisPreprocessing'].message_type = SoundAnalysisPreprocessing__pb2._SOUNDANALYSISPREPROCESSING -_MODEL.fields_by_name['gazetteer'].message_type = Gazetteer__pb2._GAZETTEER -_MODEL.fields_by_name['wordEmbedding'].message_type = WordEmbedding__pb2._WORDEMBEDDING -_MODEL.fields_by_name['audioFeaturePrint'].message_type = AudioFeaturePrint__pb2._AUDIOFEATUREPRINT -_MODEL.fields_by_name['serializedModel'].message_type = _SERIALIZEDMODEL -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['pipelineClassifier']) -_MODEL.fields_by_name['pipelineClassifier'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['pipelineRegressor']) -_MODEL.fields_by_name['pipelineRegressor'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['pipeline']) -_MODEL.fields_by_name['pipeline'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['glmRegressor']) -_MODEL.fields_by_name['glmRegressor'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['supportVectorRegressor']) -_MODEL.fields_by_name['supportVectorRegressor'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['treeEnsembleRegressor']) -_MODEL.fields_by_name['treeEnsembleRegressor'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['neuralNetworkRegressor']) -_MODEL.fields_by_name['neuralNetworkRegressor'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['bayesianProbitRegressor']) -_MODEL.fields_by_name['bayesianProbitRegressor'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['glmClassifier']) -_MODEL.fields_by_name['glmClassifier'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['supportVectorClassifier']) -_MODEL.fields_by_name['supportVectorClassifier'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['treeEnsembleClassifier']) -_MODEL.fields_by_name['treeEnsembleClassifier'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['neuralNetworkClassifier']) -_MODEL.fields_by_name['neuralNetworkClassifier'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['kNearestNeighborsClassifier']) -_MODEL.fields_by_name['kNearestNeighborsClassifier'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['neuralNetwork']) -_MODEL.fields_by_name['neuralNetwork'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['itemSimilarityRecommender']) -_MODEL.fields_by_name['itemSimilarityRecommender'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['mlProgram']) -_MODEL.fields_by_name['mlProgram'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['customModel']) -_MODEL.fields_by_name['customModel'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['linkedModel']) -_MODEL.fields_by_name['linkedModel'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['classConfidenceThresholding']) -_MODEL.fields_by_name['classConfidenceThresholding'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['oneHotEncoder']) -_MODEL.fields_by_name['oneHotEncoder'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['imputer']) -_MODEL.fields_by_name['imputer'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['featureVectorizer']) -_MODEL.fields_by_name['featureVectorizer'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['dictVectorizer']) -_MODEL.fields_by_name['dictVectorizer'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['scaler']) -_MODEL.fields_by_name['scaler'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['categoricalMapping']) -_MODEL.fields_by_name['categoricalMapping'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['normalizer']) -_MODEL.fields_by_name['normalizer'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['arrayFeatureExtractor']) -_MODEL.fields_by_name['arrayFeatureExtractor'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['nonMaximumSuppression']) -_MODEL.fields_by_name['nonMaximumSuppression'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['identity']) -_MODEL.fields_by_name['identity'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['textClassifier']) -_MODEL.fields_by_name['textClassifier'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['wordTagger']) -_MODEL.fields_by_name['wordTagger'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['visionFeaturePrint']) -_MODEL.fields_by_name['visionFeaturePrint'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['soundAnalysisPreprocessing']) -_MODEL.fields_by_name['soundAnalysisPreprocessing'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['gazetteer']) -_MODEL.fields_by_name['gazetteer'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['wordEmbedding']) -_MODEL.fields_by_name['wordEmbedding'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['audioFeaturePrint']) -_MODEL.fields_by_name['audioFeaturePrint'].containing_oneof = _MODEL.oneofs_by_name['Type'] -_MODEL.oneofs_by_name['Type'].fields.append( - _MODEL.fields_by_name['serializedModel']) -_MODEL.fields_by_name['serializedModel'].containing_oneof = _MODEL.oneofs_by_name['Type'] -DESCRIPTOR.message_types_by_name['Pipeline'] = _PIPELINE -DESCRIPTOR.message_types_by_name['PipelineClassifier'] = _PIPELINECLASSIFIER -DESCRIPTOR.message_types_by_name['PipelineRegressor'] = _PIPELINEREGRESSOR -DESCRIPTOR.message_types_by_name['FeatureDescription'] = _FEATUREDESCRIPTION -DESCRIPTOR.message_types_by_name['Metadata'] = _METADATA -DESCRIPTOR.message_types_by_name['ModelDescription'] = _MODELDESCRIPTION -DESCRIPTOR.message_types_by_name['SerializedModel'] = _SERIALIZEDMODEL -DESCRIPTOR.message_types_by_name['Model'] = _MODEL +_METADATA.fields_by_name["userDefined"].message_type = _METADATA_USERDEFINEDENTRY +_FUNCTIONDESCRIPTION.fields_by_name["input"].message_type = _FEATUREDESCRIPTION +_FUNCTIONDESCRIPTION.fields_by_name["output"].message_type = _FEATUREDESCRIPTION +_FUNCTIONDESCRIPTION.fields_by_name["state"].message_type = _FEATUREDESCRIPTION +_MODELDESCRIPTION.fields_by_name["functions"].message_type = _FUNCTIONDESCRIPTION +_MODELDESCRIPTION.fields_by_name["metadata"].message_type = _METADATA +_MODELDESCRIPTION.fields_by_name["input"].message_type = _FEATUREDESCRIPTION +_MODELDESCRIPTION.fields_by_name["output"].message_type = _FEATUREDESCRIPTION +_MODELDESCRIPTION.fields_by_name["state"].message_type = _FEATUREDESCRIPTION +_MODELDESCRIPTION.fields_by_name["trainingInput"].message_type = _FEATUREDESCRIPTION +_MODEL.fields_by_name["description"].message_type = _MODELDESCRIPTION +_MODEL.fields_by_name["pipelineClassifier"].message_type = _PIPELINECLASSIFIER +_MODEL.fields_by_name["pipelineRegressor"].message_type = _PIPELINEREGRESSOR +_MODEL.fields_by_name["pipeline"].message_type = _PIPELINE +_MODEL.fields_by_name["glmRegressor"].message_type = GLMRegressor__pb2._GLMREGRESSOR +_MODEL.fields_by_name["supportVectorRegressor"].message_type = SVM__pb2._SUPPORTVECTORREGRESSOR +_MODEL.fields_by_name[ + "treeEnsembleRegressor" +].message_type = TreeEnsemble__pb2._TREEENSEMBLEREGRESSOR +_MODEL.fields_by_name[ + "neuralNetworkRegressor" +].message_type = NeuralNetwork__pb2._NEURALNETWORKREGRESSOR +_MODEL.fields_by_name[ + "bayesianProbitRegressor" +].message_type = BayesianProbitRegressor__pb2._BAYESIANPROBITREGRESSOR +_MODEL.fields_by_name["glmClassifier"].message_type = GLMClassifier__pb2._GLMCLASSIFIER +_MODEL.fields_by_name["supportVectorClassifier"].message_type = SVM__pb2._SUPPORTVECTORCLASSIFIER +_MODEL.fields_by_name[ + "treeEnsembleClassifier" +].message_type = TreeEnsemble__pb2._TREEENSEMBLECLASSIFIER +_MODEL.fields_by_name[ + "neuralNetworkClassifier" +].message_type = NeuralNetwork__pb2._NEURALNETWORKCLASSIFIER +_MODEL.fields_by_name[ + "kNearestNeighborsClassifier" +].message_type = NearestNeighbors__pb2._KNEARESTNEIGHBORSCLASSIFIER +_MODEL.fields_by_name["neuralNetwork"].message_type = NeuralNetwork__pb2._NEURALNETWORK +_MODEL.fields_by_name[ + "itemSimilarityRecommender" +].message_type = ItemSimilarityRecommender__pb2._ITEMSIMILARITYRECOMMENDER +_MODEL.fields_by_name["mlProgram"].message_type = MIL__pb2._PROGRAM +_MODEL.fields_by_name["customModel"].message_type = CustomModel__pb2._CUSTOMMODEL +_MODEL.fields_by_name["linkedModel"].message_type = LinkedModel__pb2._LINKEDMODEL +_MODEL.fields_by_name[ + "classConfidenceThresholding" +].message_type = ClassConfidenceThresholding__pb2._CLASSCONFIDENCETHRESHOLDING +_MODEL.fields_by_name["oneHotEncoder"].message_type = OneHotEncoder__pb2._ONEHOTENCODER +_MODEL.fields_by_name["imputer"].message_type = Imputer__pb2._IMPUTER +_MODEL.fields_by_name["featureVectorizer"].message_type = FeatureVectorizer__pb2._FEATUREVECTORIZER +_MODEL.fields_by_name["dictVectorizer"].message_type = DictVectorizer__pb2._DICTVECTORIZER +_MODEL.fields_by_name["scaler"].message_type = Scaler__pb2._SCALER +_MODEL.fields_by_name[ + "categoricalMapping" +].message_type = CategoricalMapping__pb2._CATEGORICALMAPPING +_MODEL.fields_by_name["normalizer"].message_type = Normalizer__pb2._NORMALIZER +_MODEL.fields_by_name[ + "arrayFeatureExtractor" +].message_type = ArrayFeatureExtractor__pb2._ARRAYFEATUREEXTRACTOR +_MODEL.fields_by_name[ + "nonMaximumSuppression" +].message_type = NonMaximumSuppression__pb2._NONMAXIMUMSUPPRESSION +_MODEL.fields_by_name["identity"].message_type = Identity__pb2._IDENTITY +_MODEL.fields_by_name["textClassifier"].message_type = TextClassifier__pb2._TEXTCLASSIFIER +_MODEL.fields_by_name["wordTagger"].message_type = WordTagger__pb2._WORDTAGGER +_MODEL.fields_by_name[ + "visionFeaturePrint" +].message_type = VisionFeaturePrint__pb2._VISIONFEATUREPRINT +_MODEL.fields_by_name[ + "soundAnalysisPreprocessing" +].message_type = SoundAnalysisPreprocessing__pb2._SOUNDANALYSISPREPROCESSING +_MODEL.fields_by_name["gazetteer"].message_type = Gazetteer__pb2._GAZETTEER +_MODEL.fields_by_name["wordEmbedding"].message_type = WordEmbedding__pb2._WORDEMBEDDING +_MODEL.fields_by_name["audioFeaturePrint"].message_type = AudioFeaturePrint__pb2._AUDIOFEATUREPRINT +_MODEL.fields_by_name["serializedModel"].message_type = _SERIALIZEDMODEL +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["pipelineClassifier"]) +_MODEL.fields_by_name["pipelineClassifier"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["pipelineRegressor"]) +_MODEL.fields_by_name["pipelineRegressor"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["pipeline"]) +_MODEL.fields_by_name["pipeline"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["glmRegressor"]) +_MODEL.fields_by_name["glmRegressor"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["supportVectorRegressor"]) +_MODEL.fields_by_name["supportVectorRegressor"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["treeEnsembleRegressor"]) +_MODEL.fields_by_name["treeEnsembleRegressor"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["neuralNetworkRegressor"]) +_MODEL.fields_by_name["neuralNetworkRegressor"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["bayesianProbitRegressor"]) +_MODEL.fields_by_name["bayesianProbitRegressor"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["glmClassifier"]) +_MODEL.fields_by_name["glmClassifier"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["supportVectorClassifier"]) +_MODEL.fields_by_name["supportVectorClassifier"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["treeEnsembleClassifier"]) +_MODEL.fields_by_name["treeEnsembleClassifier"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["neuralNetworkClassifier"]) +_MODEL.fields_by_name["neuralNetworkClassifier"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["kNearestNeighborsClassifier"]) +_MODEL.fields_by_name["kNearestNeighborsClassifier"].containing_oneof = _MODEL.oneofs_by_name[ + "Type" +] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["neuralNetwork"]) +_MODEL.fields_by_name["neuralNetwork"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["itemSimilarityRecommender"]) +_MODEL.fields_by_name["itemSimilarityRecommender"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["mlProgram"]) +_MODEL.fields_by_name["mlProgram"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["customModel"]) +_MODEL.fields_by_name["customModel"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["linkedModel"]) +_MODEL.fields_by_name["linkedModel"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["classConfidenceThresholding"]) +_MODEL.fields_by_name["classConfidenceThresholding"].containing_oneof = _MODEL.oneofs_by_name[ + "Type" +] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["oneHotEncoder"]) +_MODEL.fields_by_name["oneHotEncoder"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["imputer"]) +_MODEL.fields_by_name["imputer"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["featureVectorizer"]) +_MODEL.fields_by_name["featureVectorizer"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["dictVectorizer"]) +_MODEL.fields_by_name["dictVectorizer"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["scaler"]) +_MODEL.fields_by_name["scaler"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["categoricalMapping"]) +_MODEL.fields_by_name["categoricalMapping"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["normalizer"]) +_MODEL.fields_by_name["normalizer"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["arrayFeatureExtractor"]) +_MODEL.fields_by_name["arrayFeatureExtractor"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["nonMaximumSuppression"]) +_MODEL.fields_by_name["nonMaximumSuppression"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["identity"]) +_MODEL.fields_by_name["identity"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["textClassifier"]) +_MODEL.fields_by_name["textClassifier"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["wordTagger"]) +_MODEL.fields_by_name["wordTagger"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["visionFeaturePrint"]) +_MODEL.fields_by_name["visionFeaturePrint"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["soundAnalysisPreprocessing"]) +_MODEL.fields_by_name["soundAnalysisPreprocessing"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["gazetteer"]) +_MODEL.fields_by_name["gazetteer"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["wordEmbedding"]) +_MODEL.fields_by_name["wordEmbedding"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["audioFeaturePrint"]) +_MODEL.fields_by_name["audioFeaturePrint"].containing_oneof = _MODEL.oneofs_by_name["Type"] +_MODEL.oneofs_by_name["Type"].fields.append(_MODEL.fields_by_name["serializedModel"]) +_MODEL.fields_by_name["serializedModel"].containing_oneof = _MODEL.oneofs_by_name["Type"] +DESCRIPTOR.message_types_by_name["Pipeline"] = _PIPELINE +DESCRIPTOR.message_types_by_name["PipelineClassifier"] = _PIPELINECLASSIFIER +DESCRIPTOR.message_types_by_name["PipelineRegressor"] = _PIPELINEREGRESSOR +DESCRIPTOR.message_types_by_name["FeatureDescription"] = _FEATUREDESCRIPTION +DESCRIPTOR.message_types_by_name["Metadata"] = _METADATA +DESCRIPTOR.message_types_by_name["FunctionDescription"] = _FUNCTIONDESCRIPTION +DESCRIPTOR.message_types_by_name["ModelDescription"] = _MODELDESCRIPTION +DESCRIPTOR.message_types_by_name["SerializedModel"] = _SERIALIZEDMODEL +DESCRIPTOR.message_types_by_name["Model"] = _MODEL _sym_db.RegisterFileDescriptor(DESCRIPTOR) Pipeline = _reflection.GeneratedProtocolMessageType('Pipeline', (_message.Message,), dict( @@ -1124,6 +1837,17 @@ _sym_db.RegisterMessage(Metadata) _sym_db.RegisterMessage(Metadata.UserDefinedEntry) +FunctionDescription = _reflection.GeneratedProtocolMessageType( + "FunctionDescription", + (_message.Message,), + dict( + DESCRIPTOR=_FUNCTIONDESCRIPTION, + __module__="Model_pb2" + # @@protoc_insertion_point(class_scope:CoreML.Specification.FunctionDescription) + ), +) +_sym_db.RegisterMessage(FunctionDescription) + ModelDescription = _reflection.GeneratedProtocolMessageType('ModelDescription', (_message.Message,), dict( DESCRIPTOR = _MODELDESCRIPTION, __module__ = 'Model_pb2' diff --git a/coremltools/test/api/test_api_visibilities.py b/coremltools/test/api/test_api_visibilities.py index b783828bc..49fc0e401 100644 --- a/coremltools/test/api/test_api_visibilities.py +++ b/coremltools/test/api/test_api_visibilities.py @@ -44,6 +44,7 @@ def _check_visible_modules(actual, expected): "libmodelpackage", "libmilstoragepython", "optimize", + "StateType", ] @@ -67,6 +68,9 @@ def test_utils(self): "load_spec", "rename_feature", "save_spec", + "save_multifunction", + "MultiFunctionDescriptor", + "randomize_weights", ] _check_visible_modules(_get_visible_items(ct.utils), expected) @@ -100,6 +104,7 @@ def test_models_mlmodel(self): "user_defined_metadata", "version", "weights_dir", + "make_state", ] _check_visible_modules(_get_visible_items(ct.models.MLModel), expected) @@ -161,6 +166,7 @@ def test_converters(self): "mil", "sklearn", "xgboost", + "StateType", ] _check_visible_modules(_get_visible_items(ct.converters), expected) @@ -178,6 +184,7 @@ def test_optimize_coreml(self): "OpPalettizerConfig", "OptimizationConfig", "OpThresholdPrunerConfig", + "experimental", "linear_quantize_weights", "palettize_weights", "prune_weights", diff --git a/coremltools/test/blob/test_weights.py b/coremltools/test/blob/test_weights.py index 72bb061d1..0b68317c2 100644 --- a/coremltools/test/blob/test_weights.py +++ b/coremltools/test/blob/test_weights.py @@ -9,18 +9,77 @@ import unittest import numpy as np +import pytest +import coremltools as ct +from coremltools import _SPECIFICATION_VERSION_IOS_18 +from coremltools.converters.mil import mil +from coremltools.converters.mil.converter import mil_convert as _mil_convert +from coremltools.converters.mil.mil.builder import Builder as mb from coremltools.libmilstoragepython import _BlobStorageReader as BlobReader from coremltools.libmilstoragepython import _BlobStorageWriter as BlobWriter -class WeightTest(unittest.TestCase): - def setUp(self): - self.working_dir = tempfile.mkdtemp() +class TestWeightBlob: + @classmethod + def setup_class(cls): + cls.working_dir = tempfile.mkdtemp() - def tearDown(self): - if os.path.exists(self.working_dir): - shutil.rmtree(self.working_dir) + @classmethod + def teardown_class(cls): + if os.path.exists(cls.working_dir): + shutil.rmtree(cls.working_dir) + + def test_weight_blob_int4(self): + writer = BlobWriter(self.working_dir + "/net.wt") + # All values in input_arr should be within range of int4, although they are stored in int8. + input_arr1 = np.array([-8, -2, 0, 2, 7], dtype=np.int8) + offset1 = writer.write_int4_data(input_arr1) + input_arr2 = np.array([3, -8, 5, 7, -6], dtype=np.int8) + offset2 = writer.write_int4_data(input_arr2) + writer = None + + reader = BlobReader(self.working_dir + "/net.wt") + output_arr1 = reader.read_int4_data(offset1) + output_arr2 = reader.read_int4_data(offset2) + np.testing.assert_equal(input_arr1, output_arr1) + np.testing.assert_equal(input_arr2, output_arr2) + + def test_weight_blob_int4_invalid(self): + writer = BlobWriter(self.working_dir + "/net.wt") + input_arr = np.array([-80, -2, 0, 2, 7], dtype=np.float32) + with pytest.raises( + ValueError, match="Value -80 is outside allowed subbyte datatype range \[-8, 7\]." + ): + writer.write_int4_data(input_arr) + + @pytest.mark.parametrize("nbits", (1, 2, 3, 4, 6)) + def test_weight_blob_unsigned_sub_byte(self, nbits): + writer = BlobWriter(self.working_dir + "/net.wt") + # All values in input_arr are within range of uint{nbits}, but stored in uint8. + input_arr1 = np.random.randint(0, 2**nbits, (5,), dtype=np.uint8) + write_method = getattr(writer, f"write_uint{nbits}_data") + offset1 = write_method(input_arr1) + input_arr2 = np.random.randint(0, 2**nbits, (5,), dtype=np.uint8) + offset2 = write_method(input_arr2) + writer = None + + reader = BlobReader(self.working_dir + "/net.wt") + read_method = getattr(reader, f"read_uint{nbits}_data") + output_arr1 = read_method(offset1) + output_arr2 = read_method(offset2) + np.testing.assert_equal(input_arr1, output_arr1) + np.testing.assert_equal(input_arr2, output_arr2) + + @pytest.mark.parametrize("nbits", (1, 2, 3, 4, 6)) + def test_weight_blob_unsigned_sub_byte_invalid(self, nbits): + writer = BlobWriter(self.working_dir + "/net.wt") + input_arr = np.array([1, 80, 2, 0, 2]) + with pytest.raises( + ValueError, + match=f"Value 80 is outside allowed subbyte datatype range \[0, {2 ** nbits - 1}\].", + ): + getattr(writer, f"write_uint{nbits}_data")(input_arr) def test_weight_blob_int8(self): writer = BlobWriter(self.working_dir + "/net.wt") @@ -52,6 +111,16 @@ def test_weight_blob_int16(self): output_arr = reader.read_int16_data(offset) np.testing.assert_equal(input_arr, output_arr) + def test_weight_blob_int32(self): + writer = BlobWriter(self.working_dir + "/net.wt") + input_arr = np.array([-5, -2, 0, 2, 5], dtype=np.int32) + offset = writer.write_int32_data(input_arr) + writer = None + + reader = BlobReader(self.working_dir + "/net.wt") + output_arr = reader.read_int32_data(offset) + np.testing.assert_equal(input_arr, output_arr) + def test_weight_blob_uint16(self): writer = BlobWriter(self.working_dir + "/net.wt") input_arr = np.array([1, 2, 3, 4, 5], dtype=np.uint16) @@ -62,6 +131,16 @@ def test_weight_blob_uint16(self): output_arr = reader.read_uint16_data(offset) np.testing.assert_almost_equal(input_arr, output_arr) + def test_weight_blob_uint32(self): + writer = BlobWriter(self.working_dir + "/net.wt") + input_arr = np.array([1, 2, 3, 4, 5], dtype=np.uint32) + offset = writer.write_uint32_data(input_arr) + writer = None + + reader = BlobReader(self.working_dir + "/net.wt") + output_arr = reader.read_uint32_data(offset) + np.testing.assert_almost_equal(input_arr, output_arr) + def test_weight_blob_fp16(self): writer = BlobWriter(self.working_dir + "/net.wt") input_arr = np.array([2.3, 4.6, 7.9], dtype=np.float16) @@ -84,5 +163,150 @@ def test_weight_blob_fp32(self): output_arr = reader.read_float_data(offset) np.testing.assert_almost_equal(input_arr, output_arr) + +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), + reason="Multi-function only supported on macOS 15+") +class TestWeightIDSharing: + @staticmethod + def test_single_function(): + @mb.program( + input_specs=[mb.TensorSpec((500,))], + opset_version=ct.target.iOS16, + ) + def prog(x): + val = np.random.rand( + 500, + ) + const_1 = mb.const(val=val, name="const_1") + const_2 = mb.const(val=val, name="const_2") + const_3 = mb.const(val=val, name="const_3") + + # const 1 and 2 share the same weight id, so they should be serialized + # as the same blob value + const_1.op.weight_id = 0 + const_2.op.weight_id = 0 + + x = mb.add(x=x, y=const_1) + x = mb.add(x=x, y=const_2) + x = mb.add(x=x, y=const_3) + + return x + + # skip all passes to avoid running the const_deduplicate pass + prog.skip_all_passes = True + mlmodel = ct.convert( + prog, + convert_to="mlprogram", + compute_precision=ct.precision.FLOAT32, + minimum_deployment_target=ct.target.iOS16, + ) + + # In the above model, const_1 and const_2 are going to share the same blob file value. + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + + with tempfile.TemporaryDirectory() as serialize_dir: + os.system(f"coremlcompiler compile {package_path} {serialize_dir}") + model_name_with_extension = os.path.basename(package_path) + model_name_wo_extension, _ = os.path.splitext(model_name_with_extension) + mil_file = open( + os.path.join(serialize_dir, f"{model_name_wo_extension}.mlmodelc", "model.mil") + ) + mil_txt = mil_file.read() + + assert ( + 'tensor const_1 = const()[name = string("const_1"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64)))];' + in mil_txt + ) + assert ( + 'tensor const_2 = const()[name = string("const_2"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64)))];' + in mil_txt + ) + assert ( + 'tensor const_3 = const()[name = string("const_3"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(2176)))];' + in mil_txt + ) + assert "add(x = x, y = const_1)" in mil_txt + assert "add(x = add_0, y = const_2)" in mil_txt + + shutil.rmtree(package_path) + + @staticmethod + def test_multi_functions(): + + val = np.random.rand( + 500, + ) + + @mb.function( + input_specs=[mb.TensorSpec((500,))], + opset_version=ct.target.iOS16, + ) + def func(x): + const_1 = mb.const(val=val, name="const_1") + const_1.op.weight_id = 0 + return mb.add(x=x, y=const_1) + + @mb.function( + input_specs=[mb.TensorSpec((500,))], + opset_version=ct.target.iOS16, + ) + def func_1(x): + const_2 = mb.const(val=val, name="const_2") + const_3 = mb.const(val=val, name="const_3") + # const_3 shared the same blob file value with const_1 in another function + const_3.op.weight_id = 0 + + x = mb.add(x=x, y=const_2) + return mb.add(x=x, y=const_3) + + prog = mil.Program() + prog.add_function("main", func) + prog.add_function("func_1", func_1) + + # skip all passes to avoid running the const_deduplicate pass + prog.skip_all_passes = True + mlmodel = _mil_convert( + prog, + convert_to="mlprogram", + convert_from="milinternal", + specification_version=_SPECIFICATION_VERSION_IOS_18, + compute_units=ct.ComputeUnit.CPU_ONLY, + export_multi_functions=True, + skip_model_load=True, + ) + + # In the above model, const_1 and const_3 are going to share the same blob file value. + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + + with tempfile.TemporaryDirectory() as serialize_dir: + os.system(f"coremlcompiler compile {package_path} {serialize_dir}") + model_name_with_extension = os.path.basename(package_path) + model_name_wo_extension, _ = os.path.splitext(model_name_with_extension) + mil_file = open( + os.path.join(serialize_dir, f"{model_name_wo_extension}.mlmodelc", "model.mil") + ) + mil_txt = mil_file.read() + + assert ( + 'tensor const_3 = const()[name = string("const_3"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64)))];' + in mil_txt + ) + assert ( + 'tensor const_2 = const()[name = string("const_2"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(2176)))];' + in mil_txt + ) + assert ( + 'tensor const_1 = const()[name = string("const_1"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64)))];' + in mil_txt + ) + assert "add(x = x, y = const_2)" in mil_txt + assert "add(x = add_1, y = const_3)" in mil_txt + assert "add(x = x, y = const_1)" in mil_txt + + shutil.rmtree(package_path) + + if __name__ == "__main__": unittest.main() diff --git a/coremltools/test/ml_program/test_utils.py b/coremltools/test/ml_program/test_utils.py new file mode 100644 index 000000000..b982e70a1 --- /dev/null +++ b/coremltools/test/ml_program/test_utils.py @@ -0,0 +1,894 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import os +import platform +import shutil +import tempfile + +import numpy as np +import pytest +import torch + +import coremltools as ct +from coremltools import _SPECIFICATION_VERSION_IOS_18, proto +from coremltools.converters.mil import mil +from coremltools.converters.mil.converter import mil_convert as _mil_convert +from coremltools.converters.mil.mil.builder import Builder as mb +from coremltools.models.utils import MultiFunctionDescriptor, load_spec, save_multifunction + + +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), + reason="Multi-function only supported on macOS 15+") +class TestMultiFunctionDescriptor: + @staticmethod + def _convert_multifunction_prog(prog): + mlmodel = _mil_convert( + prog, + convert_to="mlprogram", + convert_from="milinternal", + specification_version=_SPECIFICATION_VERSION_IOS_18, + compute_units=ct.ComputeUnit.CPU_ONLY, + export_multi_functions=True, + skip_model_load=True, + ) + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + return package_path + + @staticmethod + def _get_singlefunction_mlpackage(opset_version=ct.target.iOS16): + @mb.program( + input_specs=[mb.TensorSpec((3,))], + opset_version=opset_version, + ) + def prog(x): + return mb.relu(x=x) + + mlmodel = ct.convert( + prog, + minimum_deployment_target=opset_version, + ) + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + + return package_path + + def _get_multifunction_mlpackage_1(self): + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func(x): + return mb.relu(x=x) + + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func_1(x): + return mb.sin(x=x) + + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func_2(x): + return mb.cos(x=x) + + prog = mil.Program() + prog.add_function("relu", func) + prog.add_function("sin", func_1) + prog.add_function("cos", func_2) + prog.default_function_name = "relu" + + return self._convert_multifunction_prog(prog) + + def _get_multifunction_mlpackage_2(self): + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func(x): + return mb.relu(x=x) + + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func_1(x): + return mb.sin(x=x) + + prog = mil.Program() + prog.add_function("relu", func) + prog.add_function("sin", func_1) + prog.default_function_name = "sin" + + return self._convert_multifunction_prog(prog) + + def _get_multifunction_mlpackage_3(self): + @mb.function( + input_specs=[mb.TensorSpec((3,))], + opset_version=ct.target.iOS18, + ) + def func(x): + return mb.relu(x=x) + + prog = mil.Program() + prog.add_function("relu", func) + prog.default_function_name = "relu" + + return self._convert_multifunction_prog(prog) + + def test_initialization(self): + # Test empty initialization + desc = MultiFunctionDescriptor() + assert desc._functions() == {} + + # Initialize with a single function model + model = self._get_singlefunction_mlpackage() + desc = MultiFunctionDescriptor(model) + assert desc._functions() == {"main": (model, "main")} + shutil.rmtree(model) + + # Initialize with a multifunction model with only a single function + model = self._get_multifunction_mlpackage_3() + desc = MultiFunctionDescriptor(model) + assert desc._functions() == {"relu": (model, "relu")} + shutil.rmtree(model) + + # Initialize with a multifunction model with several functions + model = self._get_multifunction_mlpackage_1() + desc = MultiFunctionDescriptor(model) + assert desc._functions() == { + "relu": (model, "relu"), + "sin": (model, "sin"), + "cos": (model, "cos"), + } + shutil.rmtree(model) + + # Initialize with invalid path + with pytest.raises(ValueError, match="invalid model_path invalid_path with error"): + desc = MultiFunctionDescriptor("invalid_path") + + def test_add_function(self): + # Add function from a single function model + desc = MultiFunctionDescriptor() + model = self._get_singlefunction_mlpackage() + desc.add_function(model, "main", "main_1") + assert desc._functions() == {"main_1": (model, "main")} + desc.add_function(model, "main", "main_2") + assert desc._functions() == {"main_1": (model, "main"), "main_2": (model, "main")} + with pytest.raises(ValueError, match="src_function_name invalid not found in"): + desc.add_function(model, "invalid", "main_3") + with pytest.raises(ValueError, match="function main_1 already exist"): + desc.add_function(model, "main", "main_1") + shutil.rmtree(model) + + # Add function from multifunction model + desc = MultiFunctionDescriptor() + model = self._get_multifunction_mlpackage_1() + desc.add_function(model, "relu", "main_1") + assert desc._functions() == {"main_1": (model, "relu")} + desc.add_function(model, "sin", "main_2") + assert desc._functions() == {"main_1": (model, "relu"), "main_2": (model, "sin")} + shutil.rmtree(model) + + # Initialize a desc with a model and add functions to it + model = self._get_multifunction_mlpackage_1() + desc = MultiFunctionDescriptor(model) + assert desc._functions() == { + "relu": (model, "relu"), + "sin": (model, "sin"), + "cos": (model, "cos"), + } + model_2 = self._get_multifunction_mlpackage_2() + desc.add_function(model_2, "sin", "new_sin") + assert desc._functions() == { + "relu": (model, "relu"), + "sin": (model, "sin"), + "cos": (model, "cos"), + "new_sin": (model_2, "sin"), + } + with pytest.raises(ValueError, match="function relu already exist"): + desc.add_function(model, "relu", "relu") + shutil.rmtree(model) + shutil.rmtree(model_2) + + def test_add_model(self): + # Add model from a single function model + desc = MultiFunctionDescriptor() + model = self._get_singlefunction_mlpackage() + desc.add_model(model) + assert desc._functions() == {"main": (model, "main")} + shutil.rmtree(model) + + # Add a multifunction model with only a single function + desc = MultiFunctionDescriptor() + model = self._get_multifunction_mlpackage_3() + desc.add_model(model) + assert desc._functions() == {"relu": (model, "relu")} + shutil.rmtree(model) + + # Add a multifunction model with several functions + desc = MultiFunctionDescriptor() + model = self._get_multifunction_mlpackage_1() + desc.add_model(model) + assert desc._functions() == { + "relu": (model, "relu"), + "sin": (model, "sin"), + "cos": (model, "cos"), + } + shutil.rmtree(model) + + # Add a model to a desc with functions + model = self._get_singlefunction_mlpackage() + desc = MultiFunctionDescriptor(model) + assert desc._functions() == {"main": (model, "main")} + model_2 = self._get_multifunction_mlpackage_1() + desc.add_model(model_2) + assert desc._functions() == { + "relu": (model_2, "relu"), + "sin": (model_2, "sin"), + "cos": (model_2, "cos"), + "main": (model, "main"), + } + shutil.rmtree(model) + shutil.rmtree(model_2) + + # Error handling when adding model with duplicated function name + model = self._get_multifunction_mlpackage_2() + with pytest.raises(ValueError, match="function relu already exist"): + desc.add_model(model) + shutil.rmtree(model) + + def test_remove_function(self): + model = self._get_multifunction_mlpackage_1() + desc = MultiFunctionDescriptor(model) + assert desc._functions() == { + "relu": (model, "relu"), + "sin": (model, "sin"), + "cos": (model, "cos"), + } + desc.remove_function("relu") + assert desc._functions() == { + "sin": (model, "sin"), + "cos": (model, "cos"), + } + with pytest.raises(ValueError, match="function_name relu not found"): + desc.remove_function("relu") + + desc.remove_function("sin") + assert desc._functions() == { + "cos": (model, "cos"), + } + desc.remove_function("cos") + assert desc._functions() == {} + with pytest.raises(ValueError, match="function_name relu not found"): + desc.remove_function("relu") + shutil.rmtree(model) + + def test_convert_single_function_into_multifunction_model(self): + """ + Convert a single function model into a multifunction model format, + but only consists of one function. + """ + model = self._get_singlefunction_mlpackage() + desc = MultiFunctionDescriptor() + desc.add_function(model, "main", "main_1") + desc.default_function_name = "main_1" + package_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, package_path) + shutil.rmtree(model) + + # verify the model spec + spec = load_spec(package_path) + model_desc = spec.description + assert len(model_desc.functions) == 1 + assert model_desc.functions[0].name == "main_1" + assert model_desc.defaultFunctionName == "main_1" + + # verify the model can be load / run + new_model = ct.models.MLModel(package_path, function_name="main_1") + new_model.predict( + { + "x": np.random.rand( + 3, + ) + } + ) + shutil.rmtree(package_path) + + def test_merge_two_models_into_multifunction_model(self): + """ + Merge two single function models into one multifunction model. + """ + model_1 = self._get_singlefunction_mlpackage() + model_2 = self._get_singlefunction_mlpackage() + desc = MultiFunctionDescriptor() + desc.add_function(model_1, "main", "main_1") + desc.add_function(model_2, "main", "main_2") + desc.default_function_name = "main_2" + package_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, package_path) + shutil.rmtree(model_1) + shutil.rmtree(model_2) + + # verify the model spec + spec = load_spec(package_path) + model_desc = spec.description + assert len(model_desc.functions) == 2 + assert model_desc.functions[0].name == "main_1" + assert model_desc.functions[1].name == "main_2" + assert model_desc.defaultFunctionName == "main_2" + + # verify the model can be load / run + new_model = ct.models.MLModel(package_path, function_name="main_1") + new_model.predict( + { + "x": np.random.rand( + 3, + ) + } + ) + new_model = ct.models.MLModel(package_path, function_name="main_2") + new_model.predict( + { + "x": np.random.rand( + 3, + ) + } + ) + shutil.rmtree(package_path) + + def test_copy_a_single_model_twice_into_multifunction_model(self): + """ + Copy the function in a single function model twice to make a multifunction model. + """ + model = self._get_singlefunction_mlpackage() + desc = MultiFunctionDescriptor() + desc.add_function(model, "main", "main_1") + desc.add_function(model, "main", "main_2") + desc.default_function_name = "main_2" + package_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, package_path) + shutil.rmtree(model) + + # verify the model spec + spec = load_spec(package_path) + model_desc = spec.description + assert len(model_desc.functions) == 2 + assert model_desc.functions[0].name == "main_1" + assert model_desc.functions[1].name == "main_2" + assert model_desc.defaultFunctionName == "main_2" + + # verify the model can be load / run + new_model = ct.models.MLModel(package_path, function_name="main_1") + new_model.predict( + { + "x": np.random.rand( + 3, + ) + } + ) + new_model = ct.models.MLModel(package_path, function_name="main_2") + new_model.predict( + { + "x": np.random.rand( + 3, + ) + } + ) + shutil.rmtree(package_path) + + def test_combine_multifunctin_models(self): + """ + Combine two multifunction models into one multifunction model. + """ + model_1 = self._get_multifunction_mlpackage_1() + desc = MultiFunctionDescriptor(model_1) + model_2 = self._get_multifunction_mlpackage_2() + desc.add_function(model_2, "relu", "main_1") + desc.add_function(model_2, "sin", "main_2") + desc.default_function_name = "main_2" + package_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, package_path) + shutil.rmtree(model_1) + shutil.rmtree(model_2) + + # verify the model spec + spec = load_spec(package_path) + model_desc = spec.description + assert len(model_desc.functions) == 5 + assert model_desc.functions[0].name == "relu" + assert model_desc.functions[1].name == "sin" + assert model_desc.functions[2].name == "cos" + assert model_desc.functions[3].name == "main_1" + assert model_desc.functions[4].name == "main_2" + assert model_desc.defaultFunctionName == "main_2" + + # verify the model can be load / run + new_model = ct.models.MLModel(package_path, function_name="relu") + new_model.predict( + { + "x": np.random.rand( + 3, + ) + } + ) + new_model = ct.models.MLModel(package_path, function_name="sin") + new_model.predict( + { + "x": np.random.rand( + 3, + ) + } + ) + new_model = ct.models.MLModel(package_path, function_name="cos") + new_model.predict( + { + "x": np.random.rand( + 3, + ) + } + ) + new_model = ct.models.MLModel(package_path, function_name="main_1") + new_model.predict( + { + "x": np.random.rand( + 3, + ) + } + ) + new_model = ct.models.MLModel(package_path, function_name="main_2") + new_model.predict( + { + "x": np.random.rand( + 3, + ) + } + ) + shutil.rmtree(package_path) + + def test_invalid_default_function_name(self): + # invalid type + model = self._get_multifunction_mlpackage_1() + desc = MultiFunctionDescriptor(model) + with pytest.raises(ValueError, match="default_function_name must be type of str. Got 1."): + desc.default_function_name = 1 + + # default function name not found in the program + desc.default_function_name = "invalid" + package_path = tempfile.mkdtemp(suffix=".mlpackage") + with pytest.raises( + ValueError, match="default_function_name invalid not found in the program." + ): + save_multifunction(desc, package_path) + + # default function name not set + desc = MultiFunctionDescriptor(model) + with pytest.raises( + ValueError, + match="default_function_name must be set for the MultiFunctionDescriptor instance before calling save_multifunction.", + ): + save_multifunction(desc, package_path) + + # cleanup + + def test_spec_version_save_multifunction(self): + """ + When save models to the multifunction format, the spec version are promoted to iOS18. + """ + model_1 = self._get_singlefunction_mlpackage(opset_version=ct.target.iOS15) + model_2 = self._get_singlefunction_mlpackage(opset_version=ct.target.iOS16) + desc = MultiFunctionDescriptor(model_1) + desc.add_function(model_2, "main", "main_2") + desc.default_function_name = "main_2" + package_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, package_path) + shutil.rmtree(model_1) + shutil.rmtree(model_2) + + # verify the spec version of the multifunctino model is iOS18 + spec = load_spec(package_path) + assert spec.specificationVersion == _SPECIFICATION_VERSION_IOS_18 + shutil.rmtree(package_path) + + @staticmethod + def _multifunction_model_from_single_function(model_path: str) -> str: + desc = MultiFunctionDescriptor() + desc.add_function(model_path, "main", "main_1") + desc.add_function(model_path, "main", "main_2") + desc.default_function_name = "main_1" + multifunction_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, multifunction_path) + return multifunction_path + + @staticmethod + def _multifunction_model_from_multifunction_model(model_path: str) -> str: + desc = MultiFunctionDescriptor() + desc.add_function(model_path, "main_1", "main_3") + desc.add_function(model_path, "main_2", "main_4") + desc.default_function_name = "main_3" + multifunction_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, multifunction_path) + return multifunction_path + + def test_classifier_description(self): + """ + If the source model is a classifier, the resulting multifunction model should + inherit the classifier description as well. + """ + + def check_classifier_spec(model_path: str) -> None: + spec = load_spec(model_path) + model_desc = spec.description + + assert len(model_desc.functions) == 2 + + for idx in [0, 1]: + assert model_desc.functions[idx].predictedFeatureName == "class_label" + assert model_desc.functions[idx].predictedProbabilitiesName == "class_label_probs" + assert model_desc.functions[idx].output[0].name == "class_label" + assert model_desc.functions[idx].output[1].name == "class_label_probs" + + # source model with classifier config + torch_model = torch.nn.ReLU().eval() + traced_model = torch.jit.trace( + torch_model, + torch.rand( + 3, + ), + ) + variable_name = "var_2" + class_label_name = "class_label" + classifier_config = ct.ClassifierConfig( + class_labels=["a", "b", "c"], + predicted_feature_name=class_label_name, + predicted_probabilities_output=variable_name, + ) + + mlmodel = ct.convert( + traced_model, + inputs=[ct.TensorType(shape=(3,))], + classifier_config=classifier_config, + minimum_deployment_target=ct.target.iOS16, + ) + + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + + # multifunction model should have the same classifier description + model_path = self._multifunction_model_from_single_function(package_path) + check_classifier_spec(model_path) + + # construct another multifunction model with an existing multifunction model, + # the classifier description should still be the same. + model_path_2 = self._multifunction_model_from_multifunction_model(model_path) + check_classifier_spec(model_path_2) + + # cleanup + shutil.rmtree(package_path) + shutil.rmtree(model_path) + shutil.rmtree(model_path_2) + + def test_input_output_description(self): + """ + When using save_multifunction to produce a model, we should respect + the original model description in the original model. + """ + + def check_i_o_spec(model_path: str) -> None: + spec = load_spec(model_path) + model_desc = spec.description + + assert len(model_desc.functions) == 2 + + for idx in [0, 1]: + assert ( + model_desc.functions[idx].input[0].type.imageType.colorSpace + == proto.FeatureTypes_pb2.ImageFeatureType.BGR + ) + assert ( + model_desc.functions[idx].output[0].type.imageType.colorSpace + == proto.FeatureTypes_pb2.ImageFeatureType.RGB + ) + + # source model with i/o with ImageType + class Model(torch.nn.Module): + def forward(self, x): + return x + 5.0 + + example_input = torch.randint(0, 100, (1, 3, 10, 20), dtype=torch.float32) + model = torch.jit.trace(Model().eval(), example_input) + mlmodel = ct.convert( + model, + inputs=[ct.ImageType(shape=(1, 3, 10, 20), color_layout=ct.colorlayout.BGR)], + outputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)], + minimum_deployment_target=ct.target.iOS16, + ) + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + + # multifunction model should have the same i/o description + model_path = self._multifunction_model_from_single_function(package_path) + check_i_o_spec(model_path) + + # construct another multifunction model with an existing multifunction model, + # the i/o description should still be the same + model_path_2 = self._multifunction_model_from_multifunction_model(model_path) + check_i_o_spec(model_path_2) + + # cleanup + shutil.rmtree(package_path) + shutil.rmtree(model_path) + shutil.rmtree(model_path_2) + + +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), + reason="Multi-function only supported on macOS 15+") +class TestMultiFunctionModelEnd2End: + @staticmethod + def _get_test_model(): + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 8, 5, padding="same", bias=False) + self.bn1 = torch.nn.BatchNorm2d(8) + self.linear1 = torch.nn.Linear(28 * 28 * 8, 5, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.linear1(torch.flatten(x)) + return x + + model = TestModel().eval() + example_input = torch.rand(1, 1, 28, 28) + return torch.jit.trace(model, example_input) + + @staticmethod + def _get_test_model_2(): + """ + Base model have the same weights, while the weights in submodule are different. + """ + + class SubModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(28 * 28 * 8, 5, bias=False) + + def forward(self, x): + return self.linear1(x) + + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 8, 5, padding="same", bias=False) + self.bn1 = torch.nn.BatchNorm2d(8) + self.linear1 = None + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.linear1(torch.flatten(x)) + return x + + example_input = torch.rand(1, 1, 28, 28) + model = TestModel().eval() + + submodule_1 = SubModel().eval() + model.linear1 = submodule_1 + trace_1 = torch.jit.trace(model, example_input) + + submodule_2 = SubModel().eval() + model.linear1 = submodule_2 + trace_2 = torch.jit.trace(model, example_input) + + return trace_1, trace_2 + + def test_two_models(self): + """ + model_1: base + function_1 + model_2: base + function_2 + + After merging model_1 with model_2, the base weights should be shared. + """ + traced_model_1, traced_model_2 = self._get_test_model_2() + input = np.random.rand(1, 1, 28, 28) + + mlmodel_1 = ct.convert( + traced_model_1, + inputs=[ct.TensorType(name="x", shape=(1, 1, 28, 28))], + outputs=[ct.TensorType(name="out")], + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS17, + ) + mlmodel_2 = ct.convert( + traced_model_2, + inputs=[ct.TensorType(name="x", shape=(1, 1, 28, 28))], + outputs=[ct.TensorType(name="out")], + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS17, + ) + + gt_output_1 = mlmodel_1.predict({"x": input})["out"] + gt_output_2 = mlmodel_2.predict({"x": input})["out"] + + package_path_1 = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel_1.save(package_path_1) + package_path_2 = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel_2.save(package_path_2) + + # save multifuntion model + desc = MultiFunctionDescriptor() + desc.add_function(package_path_1, "main", "main_1") + desc.add_function(package_path_2, "main", "main_2") + desc.default_function_name = "main_1" + saved_package_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, saved_package_path) + shutil.rmtree(package_path_1) + shutil.rmtree(package_path_2) + + # verify the model spec + spec = load_spec(saved_package_path) + model_desc = spec.description + assert len(model_desc.functions) == 2 + assert model_desc.functions[0].name == "main_1" + assert model_desc.functions[1].name == "main_2" + assert model_desc.defaultFunctionName == "main_1" + + # verify the model can be load / run + # rdar://126898335 ([multifunction][bug] CoreML "maybe" is not handling the fallback for the compute units) + if platform.machine() == "arm64": + multifunction_mlmodel_1 = ct.models.MLModel(saved_package_path, function_name="main_1") + output = multifunction_mlmodel_1.predict({"x": input})["out"] + np.testing.assert_allclose(gt_output_1, output) + + multifunction_mlmodel_2 = ct.models.MLModel(saved_package_path, function_name="main_2") + output = multifunction_mlmodel_2.predict({"x": input})["out"] + np.testing.assert_allclose(gt_output_2, output) + + # make sure the weights are deduplicated + with tempfile.TemporaryDirectory() as serialize_dir: + os.system(f"coremlcompiler compile {saved_package_path} {serialize_dir}") + model_name_with_extension = os.path.basename(saved_package_path) + model_name_wo_extension, _ = os.path.splitext(model_name_with_extension) + mil_file = open( + os.path.join(serialize_dir, f"{model_name_wo_extension}.mlmodelc", "model.mil") + ) + mil_txt = mil_file.read() + assert ( + mil_txt.count( + 'const()[name = string("x_weight_0_to_fp16"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64)))];' + ) + == 2 + ) + assert ( + mil_txt.count( + 'tensor linear1_linear1_weight_to_fp16 = const()[name = string("linear1_linear1_weight_to_fp16"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(576)))];' + ) + == 1 + ) + assert ( + mil_txt.count( + 'tensor linear1_linear1_weight_to_fp16 = const()[name = string("linear1_linear1_weight_to_fp16"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(63360)))];' + ) + == 1 + ) + shutil.rmtree(saved_package_path) + + def test_single_model(self): + """ + Convert a single model into a multi-functions model with only one function. + """ + traced_model = self._get_test_model() + input = np.random.rand(1, 1, 28, 28) + mlmodel = ct.convert( + traced_model, + inputs=[ct.TensorType(name="x", shape=(1, 1, 28, 28))], + outputs=[ct.TensorType(name="out")], + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS16, + ) + gt_output = mlmodel.predict({"x": input})["out"] + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + + # save multifuntion model + desc = MultiFunctionDescriptor() + desc.add_function(package_path, "main", "main_1") + desc.default_function_name = "main_1" + saved_package_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, saved_package_path) + shutil.rmtree(package_path) + + # verify the model spec + spec = load_spec(saved_package_path) + model_desc = spec.description + assert len(model_desc.functions) == 1 + assert model_desc.functions[0].name == "main_1" + assert model_desc.defaultFunctionName == "main_1" + + # verify the model can be load / run + # rdar://126898335 ([multifunction][bug] CoreML "maybe" is not handling the fallback for the compute units) + if platform.machine() == "arm64": + multifunction_mlmodel = ct.models.MLModel(saved_package_path, function_name="main_1") + output = multifunction_mlmodel.predict({"x": input})["out"] + np.testing.assert_allclose(gt_output, output) + shutil.rmtree(saved_package_path) + + def test_10_duplicated_model(self): + """ + Copy a single model 10 times and create a multi-functions model with 10 functions. + """ + traced_model = self._get_test_model() + input = np.random.rand(1, 1, 28, 28) + NUM_MODEL = 10 + saved_paths = [] + + for i in range(NUM_MODEL): + mlmodel = ct.convert( + traced_model, + inputs=[ct.TensorType(name="x", shape=(1, 1, 28, 28))], + outputs=[ct.TensorType(name="out")], + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS17, + ) + gt_output = mlmodel.predict({"x": input})["out"] + saved_paths.append(tempfile.mkdtemp(suffix=".mlpackage")) + mlmodel.save(saved_paths[-1]) + + # save the multifunction model + desc = MultiFunctionDescriptor() + for i in range(NUM_MODEL): + desc.add_function(saved_paths[i], "main", f"main_{i}") + desc.default_function_name = "main_5" + saved_package_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, saved_package_path) + + for val in saved_paths: + shutil.rmtree(val) + + # verify the model spec + spec = load_spec(saved_package_path) + model_desc = spec.description + assert len(model_desc.functions) == NUM_MODEL + for i in range(NUM_MODEL): + assert model_desc.functions[i].name == f"main_{i}" + assert model_desc.defaultFunctionName == "main_5" + + # verify the model can be load / run + # rdar://126898335 ([multifunction][bug] CoreML "maybe" is not handling the fallback for the compute units) + if platform.machine() == "arm64": + for i in range(NUM_MODEL): + multifunction_mlmodel = ct.models.MLModel( + saved_package_path, function_name=f"main_{i}" + ) + output = multifunction_mlmodel.predict({"x": input})["out"] + np.testing.assert_allclose(gt_output, output) + + # make sure the weights are deduplicated + with tempfile.TemporaryDirectory() as serialize_dir: + os.system(f"coremlcompiler compile {saved_package_path} {serialize_dir}") + model_name_with_extension = os.path.basename(saved_package_path) + model_name_wo_extension, _ = os.path.splitext(model_name_with_extension) + mil_file = open( + os.path.join(serialize_dir, f"{model_name_wo_extension}.mlmodelc", "model.mil") + ) + mil_txt = mil_file.read() + assert ( + mil_txt.count( + 'const()[name = string("x_weight_0_to_fp16"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(64)))];' + ) + == 10 + ) + assert ( + mil_txt.count( + 'tensor linear1_weight_to_fp16 = const()[name = string("linear1_weight_to_fp16"), val = tensor(BLOBFILE(path = string("@model_path/weights/weight.bin"), offset = uint64(576)))];' + ) + == 10 + ) + shutil.rmtree(saved_package_path) diff --git a/coremltools/test/modelpackage/test_modelpackage.py b/coremltools/test/modelpackage/test_modelpackage.py index d23591ea9..0f581fc42 100644 --- a/coremltools/test/modelpackage/test_modelpackage.py +++ b/coremltools/test/modelpackage/test_modelpackage.py @@ -5,6 +5,7 @@ import json import os +import platform import shutil import tempfile @@ -12,11 +13,14 @@ import pytest import coremltools +import coremltools as ct from coremltools import ComputeUnit, utils from coremltools._deps import _HAS_EXECUTORCH, _HAS_TORCH from coremltools.converters.mil import Builder as mb +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.builder import Builder as mb from coremltools.libmodelpackage import ModelPackage -from coremltools.models import _METADATA_VERSION, MLModel +from coremltools.models import _METADATA_VERSION, CompiledMLModel, MLModel from coremltools.models.utils import _MLPACKAGE_AUTHOR_NAME, _WEIGHTS_DIR_NAME from coremltools.proto import Model_pb2 @@ -295,7 +299,7 @@ def _compare_loaded_debug_handle_mapping_with_original(package): assert loaded_debug_handle_mapping == debug_handle_mapping def _compare_prediction_with_torch(coreml_model, torch_model): - x = torch.rand(2, 10) + x = torch.rand(INPUT_SHAPE) coreml_x = {list(coreml_model.input_description)[0]: x.numpy()} coreml_preds = coreml_model.predict(coreml_x) @@ -303,18 +307,16 @@ def _compare_prediction_with_torch(coreml_model, torch_model): coreml_y = list(coreml_preds.values())[0] torch_y = torch_model(x).detach().numpy() - np.testing.assert_allclose(coreml_y, torch_y, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(coreml_y, torch_y, rtol=1e-3, atol=1e-3) torch_model = TestModule() torch_model.eval() - example_input = (torch.rand(*INPUT_SHAPE),) + example_input = (torch.rand(*INPUT_SHAPE, dtype=torch.float16).to(torch.float32),) exir_program_aten = torch.export.export(torch_model, example_input) exir_program_edge = executorch.exir.to_edge(exir_program_aten).exported_program() - coreml_model = coremltools.convert( - exir_program_edge, compute_precision=coremltools.precision.FLOAT32 - ) + coreml_model = coremltools.convert(exir_program_edge) debug_handle_mapping = { "version" : coreml_model.user_defined_metadata[_METADATA_VERSION], "mapping" : { @@ -323,7 +325,6 @@ def _compare_prediction_with_torch(coreml_model, torch_model): }, } - with tempfile.TemporaryDirectory(suffix=".mlpackage") as package0: coreml_model.save(package0) loaded_model0 = MLModel(package0) @@ -457,6 +458,63 @@ def forward(self, x): shutil.rmtree(package_path) + +class TestCompiledMLModel: + @pytest.mark.skipif(ct.utils._macos_version() < (15, 0), reason="State only supported on macOS 15+") + def test_state(self): + """ + Test prediction from a stateful model + """ + + @mb.program( + input_specs=[ + mb.StateTensorSpec((1,), dtype=types.fp16), + ], + opset_version=ct.target.iOS18, + ) + def increment(x): + # Read + y = mb.read_state(input=x) + # Update + y = mb.add(x=y, y=np.array([1.0]).astype("float16")) + # Write + y = mb.coreml_update_state(state=x, value=y) + # Return + return y + + mlmodel = ct.convert( + increment, + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS18, + ) + + def extract_value(y): + return list(y.values())[0][0] + + compiled_model = CompiledMLModel(mlmodel.get_compiled_model_path()) + + # Using first state + state1 = compiled_model.make_state() + for i in range(1, 5): + y = compiled_model.predict({}, state=state1) + assert extract_value(y) == i + + # rdar://126957030 ([State][Bug][Intel] Stateful model prediction is wrong on Intel laptop) + if platform.machine() != "arm64": + return + + # Use a new state + state2 = compiled_model.make_state() + for i in range(1, 5): + y = compiled_model.predict({}, state=state2) + assert extract_value(y) == i + + # Go back to using the first state + for i in range(5, 10): + y = compiled_model.predict({}, state=state1) + assert extract_value(y) == i + + class TestSpecAndMLModelAPIs: def setup_class(self): diff --git a/coremltools/test/optimize/coreml/test_passes.py b/coremltools/test/optimize/coreml/test_passes.py index 3ca77687c..0ffc7d054 100644 --- a/coremltools/test/optimize/coreml/test_passes.py +++ b/coremltools/test/optimize/coreml/test_passes.py @@ -5,6 +5,7 @@ import itertools import os +import re import tempfile import cattrs @@ -18,7 +19,7 @@ from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import types from coremltools.converters.mil.mil.passes.tests.test_passes import CONSTEXPR_FUNCS, CONSTEXPR_OPS -from coremltools.converters.mil.testing_utils import get_op_types_in_program +from coremltools.converters.mil.testing_utils import compute_snr_and_psnr, get_op_types_in_program class TestCompressionNumerical: @@ -52,6 +53,166 @@ def test_linear_quantizer_compression(self, axis, mode, source_dtype, target_dty decompressed_val = quantization.linear_quantize_weights.decompress(params) np.testing.assert_allclose(val, decompressed_val, rtol=1e-02, atol=1e-02) + @pytest.mark.parametrize( + "nbits, signed, block_size, mode, source_dtype, data_range", + itertools.product( + [4, 8], + [True, False], + [0, 1, 2, 8, 32], + ["LINEAR", "LINEAR_SYMMETRIC"], + [np.float16, np.float32], + [ + [-1.0, 1.0], + [-3.0, -1.0], + [1.0, 3.0], + [1.0, 1.0], # Test corner case of same values. + ], + ), + ) + def test_linear_quantizer_compression_blockwise( + self, + nbits, + signed, + block_size, + mode, + source_dtype, + data_range, + ): + """ + This test mainly follows the weights pattern in real life's ML models. However, when compressing + weights to a small number of bits (such as 4-bit), the information loss is critical, which + makes the numerical test hard. That's why we adjust the atol and rtol based on nbits and + block_size values. + For more comprehensive numerical tests, see `test_linear_quantizer_compression_blockwise_integer`. + """ + original_data = np.random.uniform(data_range[0], data_range[1], (32, 64)).astype( + source_dtype + ) + + compressed_params = quantization.linear_quantize_weights.blockwise_compress( + original_data, nbits, mode, signed, block_sizes=[1, block_size] + ) + decompressed_val = quantization.linear_quantize_weights.decompress(compressed_params) + + if nbits > 4 and block_size < 3: + # When block size is small and nbits is large, the information loss is limited. + atol, rtol = 1e-02, 1e-02 + elif nbits <= 2 and block_size >= 2: + atol, rtol = 0.5, 0.5 + else: + atol, rtol = 0.2, 0.2 + np.testing.assert_allclose(original_data, decompressed_val, rtol=rtol, atol=atol) + + @pytest.mark.parametrize( + "nbits, signed, block_size, mode", + itertools.product( + [4, 8], + [True, False], + [1, 2, 8, 32], + ["LINEAR", "LINEAR_SYMMETRIC"], + ), + ) + def test_linear_quantizer_compression_blockwise_integer(self, nbits, signed, block_size, mode): + """ + We use int input because after rounding the dequantized data the numerical loss is less + critical when comparing it to the original data. + """ + input_shape = (32, 64) + nbits_range_max = 2 ** (nbits - 1) - 1 + nbits_range_min = -nbits_range_max + original_data = np.random.randint(nbits_range_min, nbits_range_max, input_shape).astype( + np.float32 + ) + compressed_params = quantization.linear_quantize_weights.blockwise_compress( + original_data, nbits, mode, signed, block_sizes=[1, block_size] + ) + decompressed_val = quantization.linear_quantize_weights.decompress(compressed_params) + decompressed_val = np.round(decompressed_val).astype(original_data.dtype) + + assert np.sum(original_data != decompressed_val) / original_data.size < 0.03 + assert np.all(np.abs(original_data - decompressed_val) <= 1) + + def test_linear_quantizer_compression_blockwise_corner_case(self): + """ + When the input data is [-2, -10, 6, -3], the + np.round(quantized_data / scale) + np.round(zero_point) + AND + np.round(quantized_data / scale + zero_point) + is different ([-1, -8, 7, -2] vs [0, -8, 7, -1]), while we follow PyTorch to use the former. + """ + original_data = np.array([-2, -10, 6, -3]).astype(np.float32) + params = quantization.linear_quantize_weights.blockwise_compress( + original_data, + nbits=4, + block_sizes=[4], + mode="LINEAR", + signed=True, + ) + expected_quantized_data = np.array([-1, -8, 7, -2], dtype=np.int8) + np.testing.assert_equal(params.data, expected_quantized_data) + + def test_linear_quantizer_compression_blockwise_invalid_original_data(self): + original_data_not_np_array = [1.0, 2.0] + with pytest.raises(ValueError, match="Only numpy arrays are supported"): + quantization.linear_quantize_weights.blockwise_compress( + original_data_not_np_array, + nbits=8, + block_sizes=[2], + mode="LINEAR", + signed=True, + ) + + original_data_integer = np.random.randint(0, 10, size=(3, 2)) + with pytest.raises(ValueError, match="Only floating numpy arrays are supported."): + quantization.linear_quantize_weights.blockwise_compress( + original_data_integer, + nbits=8, + block_sizes=[0, 2], + mode="LINEAR", + signed=True, + ) + + def test_linear_quantizer_compression_blockwise_invalid_block_size(self, caplog): + original_data = np.random.uniform(-1.0, 1.0, (4, 6)) + + params = quantization.linear_quantize_weights.blockwise_compress( + original_data, + nbits=8, + block_sizes=[1, 2], + mode="LINEAR", + signed=True, + ) + assert params.scale.shape == (4, 3) + + params = quantization.linear_quantize_weights.blockwise_compress( + original_data, + nbits=8, + block_sizes=[1, 6], + mode="LINEAR", + signed=True, + ) + assert params.scale.shape == (4, 1) + + params = quantization.linear_quantize_weights.blockwise_compress( + original_data, + nbits=8, + block_sizes=[2, 6], + mode="LINEAR", + signed=True, + ) + assert params.scale.shape == (2, 1) + + result = quantization.linear_quantize_weights.blockwise_compress( + original_data, + nbits=8, + block_sizes=[1, 8], + mode="LINEAR", + signed=True, + ) + assert result is None + expected_warning_msg = "Invalid block_sizes" + assert any([expected_warning_msg in rec.message for rec in caplog.records]) + @pytest.mark.parametrize( "mode, nbits, shape", itertools.product( @@ -82,6 +243,197 @@ def test_palettizer_compression(self, mode, nbits, shape): if (mode in ["UNIQUE", "KMEANS"]) or (mode == "UNIFORM" and max_val <= val_size): np.testing.assert_allclose(val, decompressed_val, rtol=1e-02, atol=1e-02) + def test_palettizer_compression_channelwise_basic(self): + original_data = np.arange(16, dtype=np.float32).reshape((4, 4)) + + # Group on axis=0. + result = quantization.palettize_weights.blockwise_compress( + original_data, "UNIQUE", nbits=3, block_sizes=[2, 0] + ) + expected_lut = np.array( + [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15]], dtype=np.float32 + ).reshape((2, 1, 8, 1)) + np.testing.assert_array_equal(result.lut, expected_lut) + expected_indices = np.array( + [[0, 1, 2, 3], [4, 5, 6, 7], [0, 1, 2, 3], [4, 5, 6, 7]] + ).astype(np.int8) + np.testing.assert_array_equal(result.indices, expected_indices) + + # Group on axis=1. + result = quantization.palettize_weights.blockwise_compress( + original_data, "UNIQUE", nbits=3, block_sizes=[0, 2] + ) + expected_lut = np.array( + [[0, 1, 4, 5, 8, 9, 12, 13], [2, 3, 6, 7, 10, 11, 14, 15]], dtype=np.float32 + ).reshape((1, 2, 8, 1)) + np.testing.assert_array_equal(result.lut, expected_lut) + expected_indices = np.array( + [[0, 1, 0, 1], [2, 3, 2, 3], [4, 5, 4, 5], [6, 7, 6, 7]] + ).astype(np.int8) + np.testing.assert_array_equal(result.indices, expected_indices) + + @pytest.mark.parametrize( + "nbits, channel_axis, mode, source_dtype, data_range, channel_group_size", + itertools.product( + [1, 2, 3, 4, 6, 8], + [0, 1, 2, -1], + ["KMEANS", "UNIFORM"], + [np.float16, np.float32], + [ + [-1.0, 1.0], + [-3.0, -1.0], + [1.0, 3.0], + [1.0, 1.0], + ], + [0, 1, 2], + ), + ) + def test_palettizer_compression_channelwise_stress( + self, nbits, channel_axis, mode, source_dtype, data_range, channel_group_size + ): + if nbits < 8: + # As sub-byte numerical accuracy loss is significant, we construct palettize-friendly data. + upper_bound = 2**nbits + original_data = np.stack( + [np.arange(upper_bound).reshape((1, upper_bound)) for _ in range(4)], + axis=channel_axis, + ) + else: + original_data = np.random.uniform(data_range[0], data_range[1], (2, 4, 16)).astype( + source_dtype + ) + block_sizes = [0] * len(original_data.shape) + block_sizes[channel_axis] = channel_group_size + params = quantization.palettize_weights.blockwise_compress( + original_data, + mode, + nbits, + block_sizes, + ) + decompressed_val = quantization.palettize_weights.decompress(params) + if nbits < 8 or mode == "KMEANS": + np.testing.assert_allclose(original_data, decompressed_val, rtol=3e-4, atol=3e-4) + else: + np.testing.assert_array_almost_equal(original_data, decompressed_val, decimal=2) + + @pytest.mark.parametrize( + "nbits, channel_axis, channel_group_size", + itertools.product( + [2, 3, 4, 6], + [0, 1, -1], + [0, 1, 2], + ), + ) + def test_grouped_channelwise_equivalent_to_blockwise( + self, nbits, channel_axis, channel_group_size + ): + """The grouped channelwise palettization could be expressed as general blockwise.""" + original_data = np.random.randint(low=-256, high=256, size=(16, 16, 2, 2)).astype( + np.float32 + ) + + params_grouped_channelwise = quantization.palettize_weights.grouped_channelwise_compress( + original_data, "UNIFORM", nbits, channel_axis, channel_group_size + ) + decompressed_grouped_channelwise = quantization.palettize_weights.decompress( + params_grouped_channelwise + ) + + block_sizes = [0] * len(original_data.shape) + block_sizes[channel_axis] = channel_group_size + params_blockwise = quantization.palettize_weights.blockwise_compress( + original_data, "UNIFORM", nbits, block_sizes=block_sizes + ) + decompressed_blockwise = quantization.palettize_weights.decompress(params_blockwise) + + np.testing.assert_allclose( + np.sort(params_grouped_channelwise.lut, axis=None), + np.sort(params_blockwise.lut, axis=None), + ) + np.testing.assert_allclose(decompressed_grouped_channelwise, decompressed_blockwise) + + @pytest.mark.parametrize( + "nbits, mode", + itertools.product( + [2, 3, 4, 6], + ["KMEANS", "UNIFORM"], + ), + ) + def test_tensorwise_equivalent_to_blockwise_zero(self, nbits, mode): + """The block_size=0 in palettization is equivalent to legacy tensorwise compression.""" + original_data = np.random.randint(low=-256, high=256, size=(16, 16, 2, 2)).astype( + np.float32 + ) + params_old = quantization.palettize_weights.compress(original_data, mode, nbits) + decompressed_old = quantization.palettize_weights.decompress(params_old) + params_new = quantization.palettize_weights.blockwise_compress( + original_data, mode, nbits, block_sizes=[0] * len(original_data.shape) + ) + decompressed_new = quantization.palettize_weights.decompress(params_new) + np.testing.assert_allclose( + np.sort(params_old.lut, axis=None), + np.sort(params_new.lut, axis=None), + atol=5e-5, + rtol=1e-6, + ) + np.testing.assert_allclose(decompressed_old, decompressed_new, atol=5e-5, rtol=1e-6) + + @pytest.mark.parametrize( + "nbits, channel_axis, channel_group_size", + itertools.product( + [2, 3, 4], + [0, 1], + [1, 2], + ), + ) + def test_grouped_channelwise_better_than_tensorwise( + self, nbits, channel_axis, channel_group_size + ): + """The noise introduced by per-tensor lut should be more than grouped-channel-wise lut.""" + original_data = np.random.randint(low=-512, high=512, size=(32, 32, 2, 2)).astype( + np.float32 + ) + block_sizes_channelwise = [0] * len(original_data.shape) + block_sizes_channelwise[channel_axis] = channel_group_size + params_grouped_channelwise = quantization.palettize_weights.blockwise_compress( + original_data, + "UNIFORM", + nbits, + block_sizes_channelwise, + ) + + block_sizes_per_tensor = [0] * len(original_data.shape) + params_per_tensor = quantization.palettize_weights.blockwise_compress( + original_data, + "UNIFORM", + nbits, + block_sizes_per_tensor, + ) + decompressed_grouped_channelwise = quantization.palettize_weights.decompress( + params_grouped_channelwise + ) + decompressed_per_tensor = quantization.palettize_weights.decompress(params_per_tensor) + snr_grouped_channelwise = compute_snr_and_psnr( + original_data, decompressed_grouped_channelwise + )[0] + snr_per_tensor = compute_snr_and_psnr(original_data, decompressed_per_tensor)[0] + assert snr_grouped_channelwise > snr_per_tensor + + def test_palettizer_compression_blockwise_invalid(self): + with pytest.raises(ValueError, match="Only numpy arrays are supported"): + quantization.palettize_weights.blockwise_compress(10, "KMEANS", 6, [0]) + with pytest.raises(ValueError, match="Invalid nbits."): + quantization.palettize_weights.blockwise_compress( + np.random.uniform(-1.0, 1.0, (2, 3, 4)), "KMEANS", nbits=5, block_sizes=[0, 0, 1] + ) + + assert ( + quantization.palettize_weights.blockwise_compress( + np.random.uniform(-1.0, 1.0, (2, 3, 4)), "KMEANS", nbits=3, block_sizes=[3, 0, 0] + ) + is None + ) + def test_block_sparsity_pruning_smoke(self): # dim = 0 val = np.array( @@ -434,6 +786,33 @@ def prog(x): return x return prog + @staticmethod + def _get_test_program_3(): + """An iOS18 program with conv, linear, matmul, and conv_transpose.""" + + @mb.program( + input_specs=[mb.TensorSpec(shape=(1, 30, 10, 10))], + opset_version=ct.target.iOS18, + ) + def prog(x): + # weight + conv_weight = np.random.rand(90, 30, 2, 2).astype(np.float32) + linear_weight = np.random.rand(70, 81).astype(np.float32) + matmul_weight = np.random.rand(2, 1, 70, 35).astype(np.float32) + conv_transpose_weight = np.random.rand(30, 4, 21, 10).astype(np.float32) + + # graph + x = mb.conv(x=x, weight=conv_weight, name="conv") + x = mb.reshape(x=x, shape=(1, 90, 81), name="reshape_1") + x = mb.linear(x=x, weight=linear_weight, name="linear") + x = mb.matmul(x=x, y=matmul_weight, transpose_y=False, name="matmul") + x = mb.reshape(x=x, shape=(1, 30, 21, 10), name="reshape_2") + x = mb.conv_transpose(x=x, weight=conv_transpose_weight, name="conv_transpose") + return x + + return prog + + class TestOptimizationConfig(TestCompressionPasses): """ Test some basic functionality of the OptimizationConfig. @@ -933,6 +1312,66 @@ def test_global_config_affine_quantizer(self, mode, dtype, weight_threshold, fak ] assert get_op_types_in_program(prog) == expected_ops + @pytest.mark.parametrize( + "mode, dtype, block_size, weight_threshold, fake_compression", + itertools.product( + ["LINEAR", "LINEAR_SYMMETRIC"], + ["int4", "uint4", "int8", "uint8", np.int8, np.uint8], + [1], + [1000, 7000], + [True, False], + ), + ) + def test_global_config_affine_quantizer_blockwise( + self, mode, dtype, block_size, weight_threshold, fake_compression + ): + """ + Global config would compress all operations with the same config for blockwise. + """ + op_config = cto.coreml.OpLinearQuantizerConfig( + mode=mode, + dtype=dtype, + granularity="per_block", + block_size=block_size, + weight_threshold=weight_threshold, + ) + config = cto.coreml.OptimizationConfig(global_config=op_config) + compressor = quantization.linear_quantize_weights( + config=config, fake_compression=fake_compression + ) + prog = self._get_test_program_3() + compressor.apply(prog) + + if fake_compression: + expected_ops = ["conv", "reshape", "linear", "matmul", "reshape", "conv_transpose"] + elif weight_threshold == 1000: + expected_ops = [ + "constexpr_blockwise_shift_scale", + "conv", + "reshape", + "constexpr_blockwise_shift_scale", + "linear", + "constexpr_blockwise_shift_scale", + "matmul", + "reshape", + "constexpr_blockwise_shift_scale", + "conv_transpose", + ] + else: + assert weight_threshold == 7000 + # linear and matmul weight size < 7000 + expected_ops = [ + "constexpr_blockwise_shift_scale", + "conv", + "reshape", + "linear", + "matmul", + "reshape", + "constexpr_blockwise_shift_scale", + "conv_transpose", + ] + assert get_op_types_in_program(prog) == expected_ops + def test_op_type_config_linear_quantizer(self): """ set_op_type allow the user to set different config for each op type. @@ -993,6 +1432,66 @@ def test_op_type_config_linear_quantizer(self): == np.uint8 ) + def test_op_type_config_linear_quantizer_blockwise(self): + """ + set_op_type allow the user to set different config for each op type for blockwise. + Also checking that the config can be overwritten. + """ + conv_config_1 = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR_SYMMETRIC", + dtype="int8", + granularity="per_block", + block_size=10, + weight_threshold=5000, + ) + # conv_config_2 overwrite conv_config_1 + conv_config_2 = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR_SYMMETRIC", + dtype="int4", + granularity="per_block", + block_size=3, + weight_threshold=2000, + ) + # The weight_threshold is super large so linear is not going to be compressed + linear_config = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR_SYMMETRIC", + dtype="int4", + granularity="per_block", + weight_threshold=1000000, + ) + conv_transpose_config = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR", + dtype="int8", + granularity="per_block", + block_size=10, + weight_threshold=2000, + ) + + config = cto.coreml.OptimizationConfig() + config.set_op_type("conv", conv_config_1) + config.set_op_type("conv", conv_config_2) + config.set_op_type("linear", linear_config) + config.set_op_type("conv_transpose", conv_transpose_config) + + compressor = quantization.linear_quantize_weights(config=config) + + prog = self._get_test_program_3() + compressor.apply(prog) + + expected_ops = [ + "constexpr_blockwise_shift_scale", + "conv", + "reshape", + "linear", + "matmul", + "reshape", + "constexpr_blockwise_shift_scale", + "conv_transpose", + ] + assert get_op_types_in_program(prog) == expected_ops + assert prog.find_ops(op_type="constexpr_blockwise_shift_scale")[0].offset is None + assert prog.find_ops(op_type="constexpr_blockwise_shift_scale")[1].offset is not None + def test_op_name_config_linear_quantizer(self): """ set_op_name allow the user to set different config for each op specified by name. @@ -1053,6 +1552,149 @@ def test_op_name_config_linear_quantizer(self): == np.uint8 ) + def test_op_name_config_linear_quantizer_blockwise(self): + """ + set_op_name allow the user to set different config for each op specified by name. + Also checking that the config can be overwritten + """ + conv_config_1 = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR_SYMMETRIC", + dtype="int8", + granularity="per_block", + block_size=4, + weight_threshold=2000, + ) + # conv_config_2 overwrite conv_config_1 + conv_config_2 = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR_SYMMETRIC", + dtype="int8", + granularity="per_block", + block_size=2, + weight_threshold=2000, + ) + # The weight_threshold is super large so linear is not going to be compressed + linear_config = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR_SYMMETRIC", + dtype="int4", + weight_threshold=1000000, + ) + conv_transpose_config = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR", + dtype="int8", + granularity="per_block", + block_size=6, + weight_threshold=2000, + ) + + config = cto.coreml.OptimizationConfig() + config.set_op_name("conv", conv_config_1) + config.set_op_name("conv", conv_config_2) + config.set_op_name("linear", linear_config) + config.set_op_name("conv_transpose", conv_transpose_config) + + compressor = quantization.linear_quantize_weights(config=config) + + prog = self._get_test_program_3() + compressor.apply(prog) + + expected_ops = [ + "constexpr_blockwise_shift_scale", + "conv", + "reshape", + "linear", + "matmul", + "reshape", + "constexpr_blockwise_shift_scale", + "conv_transpose", + ] + assert get_op_types_in_program(prog) == expected_ops + blockwise_ops = prog.find_ops(op_type="constexpr_blockwise_shift_scale") + assert blockwise_ops[0].offset is None + assert blockwise_ops[1].offset is not None + # Conv transpose original weight shape is (30, 4, 21, 10). The output channel axis is 1 and + # input channel axis is 0, so the scale's first axis dim is 30 / 6 = 5. + assert blockwise_ops[1].scale.shape == (5, 4, 1, 1) + + def test_auto_pick_channel_axis_quantizer(self): + """ + Check the right output channel axis is picked for block-wise quantization. + """ + global_config = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR", + dtype="int4", + granularity="per_block", + block_size=2, + weight_threshold=2000, + ) + linear_config = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR_SYMMETRIC", + dtype="int4", + granularity="per_block", + block_size=9, + weight_threshold=100, + ) + config = cto.coreml.OptimizationConfig() + config.set_global(global_config) + config.set_op_name("linear", linear_config) + compressor = quantization.linear_quantize_weights(config=config) + + prog = self._get_test_program_3() + compressor.apply(prog) + + blockwise_ops = prog.find_ops(op_type="constexpr_blockwise_shift_scale") + # For conv, input channel axis is 1, output channel axis is 0. + # The original weight shape is [90, 30, 2, 2], the scale's second dim is 30 / 2 = 15. + assert blockwise_ops[0].scale.shape == (90, 15, 1, 1) + # For linear, input channel axis is 1, output channel axis is 0. + # The original weight shape is [70, 81], the scale's second dim is 81 / 9 = 9. + assert blockwise_ops[1].scale.shape == (70, 9) + # For matmul (transpose_y=False), input channel axis is -2, output channel axis is -1. + # The original weight shape is [2, 1, 70, 35], the scale's third dim is 70 / 2 = 35. + assert blockwise_ops[2].scale.shape == (1, 1, 35, 35) + # For conv_transpose, input channel axis is 0, output channel axis is 1. + # The original weight shape is [30, 4, 21, 10], the scale's first dim is 30 / 2 = 15. + assert blockwise_ops[3].scale.shape == (15, 4, 1, 1) + + def test_invalid_config(self): + with pytest.raises( + ValueError, + match="Invalid dtype int2. Only support int8/uint8/int4/uint4", + ): + cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR_SYMMETRIC", + dtype="int2", + block_size=2, + weight_threshold=2000, + ) + + with pytest.raises( + ValueError, + match="Only mode \('LINEAR_SYMMETRIC', 'LINEAR'\) supported for weight affine quantization. Got mode: \"DUMMY\".", + ): + cto.coreml.OpLinearQuantizerConfig( + mode="DUMMY", + dtype="int4", + block_size=32, + weight_threshold=5000, + ) + + def test_not_divisible_block_size(self, caplog): + global_config = cto.coreml.OpLinearQuantizerConfig( + mode="LINEAR_SYMMETRIC", + granularity="per_block", + dtype="int4", + block_size=13, + weight_threshold=100, + ) + config = cto.coreml.OptimizationConfig() + config.set_global(global_config) + compressor = quantization.linear_quantize_weights(config=config) + + prog = self._get_test_program_3() + compressor.apply(prog) + warning_msg = "Invalid block_sizes; On 1th axis, the dim size 30 is not divisible by block size 13. Unable to perform structured quantization." + assert any([re.match(warning_msg, rec.message) for rec in caplog.records]) + class TestPruner(TestCompressionPasses): @pytest.mark.parametrize( @@ -1572,6 +2214,228 @@ def test_op_name_config_palettizer(self): assert prog.find_ops(op_type="constexpr_lut_to_dense")[0].lut.val.shape == (4,) assert prog.find_ops(op_type="constexpr_lut_to_dense")[1].lut.val.shape == (16,) + def test_op_name_config_palettizer_blockwise(self): + """ + set_op_name allow the user to set different config for each op specified by name. + Also checking that the config can be overwritten. + """ + conv_config_1 = cto.coreml.OpPalettizerConfig( + mode="uniform", + nbits=4, + granularity="per_tensor", + weight_threshold=500000, + ) + # The conv_config_2 overwrites conv_config_1. + conv_config_2 = cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=8, + granularity="per_grouped_channel", + group_size=1, + channel_axis=1, + weight_threshold=2000, + ) + # The weight_threshold is super large so linear is not going to be compressed. + linear_config = cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=4, + weight_threshold=1000000, + ) + conv_transpose_config = cto.coreml.OpPalettizerConfig( + mode="uniform", + nbits=4, + granularity="per_grouped_channel", + group_size=1, + weight_threshold=2000, + ) + + config = cto.coreml.OptimizationConfig() + config.set_op_name("conv", conv_config_1) + config.set_op_name("conv", conv_config_2) + config.set_op_name("linear", linear_config) + config.set_op_name("conv_transpose", conv_transpose_config) + + prog = self._get_test_program_3() + compressor = quantization.palettize_weights(config=config) + compressor.apply(prog) + + expected_ops = [ + "constexpr_lut_to_dense", + "conv", + "reshape", + "linear", + "matmul", + "reshape", + "constexpr_lut_to_dense", + "conv_transpose", + ] + assert get_op_types_in_program(prog) == expected_ops + assert prog.find_ops(op_type="constexpr_lut_to_dense")[0].vector_axis is None + # Makes sure the channel_axis in conv_config_2 is effective. + conv_lut = prog.find_ops(op_type="constexpr_lut_to_dense")[0].lut + assert conv_lut.shape[0] == 1 + assert conv_lut.shape[1] == 30 + + def test_invalid_granularity(self): + with pytest.raises( + ValueError, + match='"granularity" must be one of .*, but got CompressionGranularity.PER_CHANNEL', + ): + cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=4, + granularity="per_channel", + weight_threshold=2000, + ) + + with pytest.raises(TypeError, match="got an unexpected keyword argument 'block_size'"): + cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=4, + granularity="per_tensor", + block_size=2, + weight_threshold=2000, + ) + + def test_auto_pick_channel_axis_palettizer(self): + """ + Check the right output channel axis is picked for granularity='per_grouped_channel'. + """ + global_config = cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=4, + granularity="per_grouped_channel", + group_size=1, + weight_threshold=2000, + ) + config = cto.coreml.OptimizationConfig() + config.set_global(global_config) + compressor = quantization.palettize_weights(config=config) + + prog = self._get_test_program_3() + compressor.apply(prog) + + # For conv, the output channel-axis is 0. + conv_lut = prog.find_ops(op_type="constexpr_lut_to_dense")[0].lut + assert conv_lut.shape[0] == 90 + assert conv_lut.shape[1] == 1 + # For linear, the output channel-axis is 0. + linear_lut = prog.find_ops(op_type="constexpr_lut_to_dense")[1].lut + assert linear_lut.shape[0] == 70 + assert linear_lut.shape[1] == 1 + # For matmul with transpose_y=False, the output channel-axis is -1. + matmul_lut = prog.find_ops(op_type="constexpr_lut_to_dense")[2].lut + assert matmul_lut.shape == (1, 1, 1, 35, 16, 1) + # For conv_transpose, the output channel-axis is -2. + conv_transpose_lut = prog.find_ops(op_type="constexpr_lut_to_dense")[3].lut + assert conv_transpose_lut.shape[0] == 1 + assert conv_transpose_lut.shape[1] == 4 + + def test_group_channel_wise(self): + global_config = cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=3, + granularity="per_grouped_channel", + group_size=2, + weight_threshold=2000, + ) + config = cto.coreml.OptimizationConfig() + config.set_global(global_config) + compressor = quantization.palettize_weights(config=config) + + prog = self._get_test_program_3() + compressor.apply(prog) + lut_ops = prog.find_ops(op_type="constexpr_lut_to_dense") + # The conv weight dense shape is (90, 30, 2, 2). Auto-picked axis=0. + assert lut_ops[0].lut.shape == (45, 1, 1, 1, 8, 1) + # The linear weight dense shape is (70, 81). Auto-picked axis=0. + assert lut_ops[1].lut.shape == (35, 1, 8, 1) + # The matmul y dense shape is (2, 1, 70, 35). Auto-picked axis=-1. + # However, the 35 is not divisible by 2, so it will get skipped. + assert prog.find_ops(op_type="matmul")[0].y.op.op_type == "const" + # The conv_transpose weight dense shape is (30, 4, 21, 10). Auto-picked axis=-2. + assert lut_ops[2].lut.shape == (1, 2, 1, 1, 8, 1) + + def test_tensor_wise(self): + """Test granularity='per_block' with block_size=0 equivalent to granularity='per_tensor'.""" + global_config_1 = cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=3, + granularity="per_tensor", + weight_threshold=2000, + ) + global_config_2 = cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=3, + granularity="per_grouped_channel", + group_size=0, + weight_threshold=2000, + ) + + for global_config in (global_config_1, global_config_2): + config = cto.coreml.OptimizationConfig(global_config=global_config) + compressor = quantization.palettize_weights(config=config) + + prog = self._get_test_program_3() + compressor.apply(prog) + lut_ops = prog.find_ops(op_type="constexpr_lut_to_dense") + # The conv weight dense shape is (90, 30, 2, 2). + assert lut_ops[0].lut.shape == (1, 1, 1, 1, 8, 1) + # The linear weight dense shape is (70, 81). + assert lut_ops[1].lut.shape == (1, 1, 8, 1) + # The matmul y dense shape is (2, 1, 70, 35). + assert lut_ops[2].lut.shape == (1, 1, 1, 1, 8, 1) + # The conv_transpose weight dense shape is (30, 4, 21, 10). + assert lut_ops[3].lut.shape == (1, 1, 1, 1, 8, 1) + + def test_not_divisible_channel_group_size(self, caplog): + global_config = cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=4, + granularity="per_grouped_channel", + group_size=3, + weight_threshold=2000, + ) + config = cto.coreml.OptimizationConfig() + config.set_global(global_config) + compressor = quantization.palettize_weights(config=config) + + prog = self._get_test_program_3() + compressor.apply(prog) + + # The axis-0 in linear (70), axis-3 in matmul (35), and axis-1 in conv_transpose (4) are not divisible by 3. + for axis in (0, 3, 1): + warning_msg = ( + f"Can't perform palettization: The number of channels at {axis}th axis .* is not " + "divisible by channel_group_size" + ) + assert any([re.match(warning_msg, rec.message) for rec in caplog.records]) + # Only the conv get compressed. + lut_ops = prog.find_ops(op_type="constexpr_lut_to_dense") + assert len(lut_ops) == 1 + assert lut_ops[0].outputs[0].child_ops[0].op_type == "conv" + + def test_ios16_program_not_support_channel_wise_lut(self): + global_config = cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=4, + granularity="per_grouped_channel", + group_size=3, + weight_threshold=2000, + ) + config = cto.coreml.OptimizationConfig() + config.set_global(global_config) + compressor = quantization.palettize_weights(config=config) + + prog = self._get_test_program() + with pytest.raises( + AssertionError, + match=re.escape( + "The iOS16 only supports per-tensor lut, but got more than one lut " + "on 0th axis. LUT shape: (30, 1, 1, 1, 16, 1)" + ), + ): + compressor.apply(prog) + class TestCompressionOperations(TestCompressionPasses): """ @@ -2082,34 +2946,37 @@ def test_config_load_invalid_key(config_cls): config_cls._from_dict(config_dict) @pytest.mark.parametrize( - "mode, dtype, weight_threshold, use_yaml", + "mode, dtype, granularity, block_size, weight_threshold, use_yaml", itertools.product( ["linear", "linear_symmetric"], - ["int8", "uint8", np.int8, np.uint8, types.int8, types.uint8], + ["int4", "uint4", "int8", "uint8", np.int8, np.uint8, types.int8, types.uint8], + ["per_tensor", "per_channel", "per_block"], + [0, 1, 2, [0, 1]], [1024, None], [True, False], ), ) - def test_linear_quantizer_config_load_stress(self, mode, dtype, weight_threshold, use_yaml): + def test_linear_quantizer_config_load_stress( + self, mode, dtype, granularity, block_size, weight_threshold, use_yaml + ): config_dict = { "mode": mode, "dtype": dtype, + "granularity": granularity, + "block_size": block_size, "weight_threshold": weight_threshold, } - if use_yaml and dtype in ("int8", "uint8"): + if use_yaml and isinstance(dtype, str): config_dict = self.load_to_yaml(config_dict) config = quantization.OpLinearQuantizerConfig._from_dict(config_dict) - if dtype in ["int8", np.int8, types.int8]: - expected_dtype = np.int8 - elif dtype in ["uint8", np.uint8, types.uint8]: - expected_dtype = np.uint8 - expected_config = quantization.OpLinearQuantizerConfig( mode=mode, - dtype=expected_dtype, + dtype=dtype, + granularity=granularity, + block_size=block_size, weight_threshold=weight_threshold, ) assert config == expected_config @@ -2211,24 +3078,37 @@ def test_magnitude_block_sparsity_pruner_config_load_stress( assert config == expected_config @pytest.mark.parametrize( - "mode_nbits, weight_threshold, use_yaml", + "mode, nbits, granularity, group_size, channel_axis, weight_threshold, num_kmeans_workers, use_yaml", itertools.product( - [ - ("kmeans", 2), - ("uniform", 1), - ("unique", None), - ], - [None, 1024], + ["kmeans", "uniform"], + [1, 2, 3, 4, 6, 8], + ["per_tensor", "per_grouped_channel"], + [0, 1, 32], + [None, 0, 1], + [1024, None], + [1, 4], [True, False], ), ) - def test_palettizer_config_load_stress(self, mode_nbits, weight_threshold, use_yaml): - mode, nbits = mode_nbits - + def test_palettizer_config_load_stress( + self, + mode, + nbits, + granularity, + group_size, + channel_axis, + weight_threshold, + num_kmeans_workers, + use_yaml, + ): config_dict = { "mode": mode, "nbits": nbits, + "granularity": granularity, + "group_size": group_size, + "channel_axis": channel_axis, "weight_threshold": weight_threshold, + "num_kmeans_workers": num_kmeans_workers, } if use_yaml: @@ -2239,7 +3119,11 @@ def test_palettizer_config_load_stress(self, mode_nbits, weight_threshold, use_y expected_config = quantization.OpPalettizerConfig( mode=mode, nbits=nbits, + granularity=granularity, + group_size=group_size, + channel_axis=channel_axis, weight_threshold=weight_threshold, + num_kmeans_workers=num_kmeans_workers, ) assert config == expected_config diff --git a/coremltools/test/optimize/coreml/test_post_training_quantization.py b/coremltools/test/optimize/coreml/test_post_training_quantization.py index dacae90e2..5540f0325 100644 --- a/coremltools/test/optimize/coreml/test_post_training_quantization.py +++ b/coremltools/test/optimize/coreml/test_post_training_quantization.py @@ -4,6 +4,8 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause import itertools +import shutil +import tempfile from typing import Tuple import numpy as np @@ -13,9 +15,15 @@ import coremltools as ct import coremltools.optimize as cto from coremltools._deps import _HAS_SKLEARN +from coremltools.converters.mil.frontend.torch.test.test_torch_conversion_api import ( + TestPyTorchConverterExamples, +) from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import types -from coremltools.converters.mil.testing_utils import get_op_types_in_program +from coremltools.converters.mil.mil.ops.tests.iOS18 import backends +from coremltools.converters.mil.testing_reqs import compute_units +from coremltools.converters.mil.testing_utils import compute_snr_and_psnr, get_op_types_in_program +from coremltools.models.utils import MultiFunctionDescriptor, _macos_version, save_multifunction from coremltools.optimize.coreml import _utils as optimize_utils from coremltools.optimize.coreml._post_training_quantization import CoreMLWeightMetaData from coremltools.test.ml_program.test_compression import get_test_model_and_data @@ -172,6 +180,35 @@ def verify_model_outputs(model, compressed_model, input_values, rtol=1e-7, atol= class TestLinearQuantizeWeights: + @staticmethod + def test_linear_quantization_with_classifier(): + traced_model, example_input = TestPyTorchConverterExamples._get_classifier_model() + for class_type in ("str", "int"): + mlmodel = TestPyTorchConverterExamples._convert_classifier_model( + traced_model, example_input, class_type + ) + config = cto.coreml.OptimizationConfig() + global_config = cto.coreml.OpLinearQuantizerConfig( + mode="linear_symmetric", dtype=np.int8, weight_threshold=0 + ) + config.set_global(global_config) + mlmodel = cto.coreml.linear_quantize_weights(mlmodel, config) + expected_ops = [ + "cast", + "reshape", + "constexpr_affine_dequantize", + "linear", + "relu", + "constexpr_affine_dequantize", + "linear", + "relu", + "constexpr_affine_dequantize", + "linear", + "cast", + "classify", + ] + assert get_op_types_in_program(mlmodel._mil_program) == expected_ops + @staticmethod def test_linear_quantization(): model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() @@ -234,8 +271,230 @@ def test_linear_quanitzation_stress(mode, dtype): verify_model_outputs(mlmodel, mlmodel_quantized, coreml_input_values) + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product(compute_units, backends), + ) + def test_blockwise_quantization(self, compute_unit, backend): + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + config = cto.coreml.OptimizationConfig() + conv_config = cto.coreml.OpLinearQuantizerConfig( + mode="linear_symmetric", + dtype="int4", + granularity="per_block", + block_size=2, + weight_threshold=500, + ) + lstm_config = cto.coreml.OpLinearQuantizerConfig( + mode="linear", + dtype="int4", + granularity="per_block", + block_size=2, + weight_threshold=4800, + ) + + config.set_op_type("conv", conv_config) + config.set_op_type("lstm", lstm_config) + # Set a specific conv's config to None to prevent it being compressed. + conv_not_to_compress_name = "conv_2_1" + if backend.precision == "fp16": + conv_not_to_compress_name += "_cast_fp16" + config.set_op_name(conv_not_to_compress_name, None) + + mlmodel_quantized = cto.coreml.linear_quantize_weights(mlmodel, config) + expected_ops = [ + "constexpr_blockwise_shift_scale", + "conv", + "conv", + "reshape", + "linear", + "linear", + "constexpr_blockwise_shift_scale", + "constexpr_blockwise_shift_scale", + "constexpr_blockwise_shift_scale", + "lstm", + "expand_dims", + "expand_dims", + ] + prog = mlmodel_quantized._mil_program + assert get_op_types_in_program(prog) == expected_ops + assert prog.find_ops(op_type="conv")[1].weight.op.op_type == "const" + + quantize_ops = prog.find_ops(op_type="constexpr_blockwise_shift_scale") + for quantize_op in quantize_ops: + assert quantize_op.data.dtype == types.int4 + assert types.builtin_to_string(quantize_op.scale.dtype) == backend.precision + + if _macos_version() >= (15, 0): + verify_model_outputs( + mlmodel, mlmodel_quantized, coreml_input_values, rtol=1e-2, atol=4e-2 + ) + + @staticmethod + @pytest.mark.parametrize( + "compute_unit, backend, mode, nbits, signed, block_size", + itertools.product( + compute_units, + backends, + ("linear", "linear_symmetric"), + (4, 8), + (True, False), + (0, 1, 2, 4), + ), + ) + def test_blockwise_quanitzation_stress(compute_unit, backend, mode, nbits, signed, block_size): + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + dtype_str = types.builtin_to_string(types.get_nbits_int_builtin_type(nbits, signed)) + op_config = cto.coreml.OpLinearQuantizerConfig( + mode=mode, dtype=dtype_str, granularity="per_block", block_size=block_size + ) + config = cto.coreml.OptimizationConfig(global_config=op_config) + mlmodel_quantized = cto.coreml.linear_quantize_weights(mlmodel, config) + + # Verify ops. + if backend.precision == "fp16": + # For fp16 precision there is no extra cast op inserted. + expected_ops = ["constexpr_blockwise_shift_scale", "conv"] + else: + expected_ops = ["constexpr_blockwise_shift_scale", "cast", "conv", "cast"] + assert get_op_types_in_program(mlmodel_quantized._mil_program) == expected_ops + quantize_op = mlmodel_quantized._mil_program.functions["main"].find_ops( + op_type="constexpr_blockwise_shift_scale" + )[0] + assert types.builtin_to_string(quantize_op.data.dtype) == dtype_str + # For sub-byte dtype, we still use np.int8/uint8 to store the data. + assert quantize_op.data.val.dtype == np.int8 if signed else np.uint8 + assert model.weight.detach().numpy().size == quantize_op.data.val.size + + if _macos_version() >= (15, 0): + verify_model_outputs(mlmodel, mlmodel_quantized, coreml_input_values) + + # The verify_model_outputs only check compressed and decompressed consistency. + # Also need to compare original and compressed model. + original_output = mlmodel.predict(coreml_input_values) + quantized_output = mlmodel_quantized.predict(coreml_input_values) + + for k, v in quantized_output.items(): + + if nbits <= 4 and block_size != 1: + # Low-bit has too much info lost when block size is not 1. + continue + + # When nbits is larger and block_size is smaller, the info lost is less. + atol, rtol = 0.4, 0.4 + if block_size == 1 and nbits > 4: + atol, rtol = 1e-2, 1e-2 + + np.testing.assert_allclose(v, original_output[k], atol=atol, rtol=rtol) + + @staticmethod + @pytest.mark.parametrize( + "compute_unit, backend, mode, nbits", + itertools.product( + compute_units, + backends, + ("linear", "linear_symmetric"), + (4, 8), + ), + ) + def test_per_tensor_quantization_with_blockwise_op(compute_unit, backend, mode, nbits): + op_config = cto.coreml.OpLinearQuantizerConfig( + mode=mode, dtype=f"int{nbits}", granularity="per_tensor" + ) + config = cto.coreml.OptimizationConfig(global_config=op_config) + + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data( + quantize_config=op_config + ) + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + mlmodel_quantized = cto.coreml.linear_quantize_weights(mlmodel, config) + + # Verify ops. + if backend.precision == "fp16": + # For fp16 precision there is no extra cast op inserted. + expected_ops = ["constexpr_blockwise_shift_scale", "conv"] + else: + expected_ops = ["constexpr_blockwise_shift_scale", "cast", "conv", "cast"] + assert get_op_types_in_program(mlmodel_quantized._mil_program) == expected_ops + quantize_op = mlmodel_quantized._mil_program.functions["main"].find_ops( + op_type="constexpr_blockwise_shift_scale" + )[0] + assert types.builtin_to_string(quantize_op.data.dtype) == f"int{nbits}" + if mode == "linear": + assert types.builtin_to_string(quantize_op.offset.dtype) == f"int{nbits}" + # For int4, we still use np.int8 to store the data. + assert quantize_op.data.val.dtype == np.int8 + assert model.weight.detach().numpy().size == quantize_op.data.val.size + + if _macos_version() >= (15, 0): + verify_model_outputs(mlmodel, mlmodel_quantized, coreml_input_values) + class TestPalettizeWeights: + @staticmethod + def test_palettization_with_classifier(): + traced_model, example_input = TestPyTorchConverterExamples._get_classifier_model() + for class_type in ("str", "int"): + mlmodel = TestPyTorchConverterExamples._convert_classifier_model( + traced_model, example_input, class_type + ) + config = cto.coreml.OptimizationConfig() + global_config = cto.coreml.OpPalettizerConfig( + nbits=8, mode="kmeans", weight_threshold=2 + ) + config.set_global(global_config) + mlmodel = cto.coreml.palettize_weights(mlmodel, config) + expected_ops = [ + "cast", + "reshape", + "constexpr_lut_to_dense", + "linear", + "relu", + "constexpr_lut_to_dense", + "linear", + "relu", + "constexpr_lut_to_dense", + "linear", + "cast", + "classify", + ] + assert get_op_types_in_program(mlmodel._mil_program) == expected_ops + @staticmethod def test_palettization(): model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() @@ -504,8 +763,359 @@ def test_convert_palettized_source_model_custom(): assert linear_ops[0].weight.op.op_type == "constexpr_lut_to_dense" assert linear_ops[1].weight.op.op_type == "constexpr_lut_to_dense" + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product(compute_units, backends), + ) + def test_channelwise_palettization(self, compute_unit, backend): + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + config = cto.coreml.OptimizationConfig() + conv_config = cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=8, + granularity="per_grouped_channel", + group_size=1, + weight_threshold=500, + ) + lstm_config = cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=4, + granularity="per_grouped_channel", + group_size=1, + weight_threshold=4800, + ) + + config.set_op_type("conv", conv_config) + config.set_op_type("lstm", lstm_config) + # Set a specific conv's config to None to prevent it being compressed. + conv_not_to_compress_name = "conv_2_1" + if backend.precision == "fp16": + conv_not_to_compress_name += "_cast_fp16" + config.set_op_name(conv_not_to_compress_name, None) + + mlmodel_palettized = cto.coreml.palettize_weights(mlmodel, config) + expected_ops = [ + "constexpr_lut_to_dense", + "conv", + "conv", + "reshape", + "linear", + "linear", + "constexpr_lut_to_dense", + "constexpr_lut_to_dense", + "constexpr_lut_to_dense", + "lstm", + "expand_dims", + "expand_dims", + ] + prog = mlmodel_palettized._mil_program + assert get_op_types_in_program(prog) == expected_ops + assert prog.find_ops(op_type="conv")[1].weight.op.op_type == "const" + + palettize_ops = prog.find_ops(op_type="constexpr_lut_to_dense") + for quantize_op in palettize_ops: + assert types.builtin_to_string(quantize_op.lut.dtype) == backend.precision + assert types.builtin_to_string(palettize_ops[0].indices.dtype) == "uint8" + assert types.builtin_to_string(palettize_ops[3].indices.dtype) == "uint4" + + if _macos_version() >= (15, 0): + verify_model_outputs(mlmodel, mlmodel_palettized, coreml_input_values) + + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product(compute_units, backends), + ) + def test_channelwise_palettization_unique_skip_op(self, compute_unit, backend, caplog): + """Test where mode is unique and can't use nbits to represent the weight""" + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() + traced_model = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + traced_model, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + config = cto.coreml.OptimizationConfig() + global_config = cto.coreml.OpPalettizerConfig( + mode="unique", + granularity="per_grouped_channel", + group_size=1, + weight_threshold=100, + ) + # For conv weight in the whole tensor cannot be represented by 2**8 unique values. + conv_config = cto.coreml.OpPalettizerConfig( + mode="unique", + granularity="per_tensor", + weight_threshold=100, + ) + config.set_global(global_config) + config.set_op_type("conv", conv_config) + mlmodel_palettized = cto.coreml.palettize_weights(mlmodel, config) + assert any( + [ + "Unique values in weight cannot be represented by 8 bits palettization." + in rec.message + for rec in caplog.records + ] + ) + # There is no constexpr for the conv weight. + for conv_op in mlmodel_palettized._mil_program.find_ops(op_type="conv"): + assert conv_op.weight.op.op_type == "const" + # There are still constexpr ops for linear and lstm weights. + assert len(mlmodel_palettized._mil_program.find_ops(op_type="constexpr_lut_to_dense")) == 5 + + if _macos_version() >= (15, 0): + verify_model_outputs(mlmodel, mlmodel_palettized, coreml_input_values) + + @staticmethod + @pytest.mark.parametrize( + "compute_unit, backend, mode, nbits, channel_axis, channel_group_size", + itertools.product( + compute_units, + backends, + ("kmeans", "uniform"), + (1, 2, 3, 4, 6, 8), + (0, 1), + (0, 1, 2), + ), + ) + def test_channelwise_palettization_stress( + compute_unit, backend, mode, nbits, channel_axis, channel_group_size + ): + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + op_config = cto.coreml.OpPalettizerConfig( + mode=mode, + nbits=nbits, + granularity="per_grouped_channel", + group_size=channel_group_size, + channel_axis=channel_axis, + ) + config = cto.coreml.OptimizationConfig(global_config=op_config) + mlmodel_palettized = cto.coreml.palettize_weights(mlmodel, config) + + # Verify ops. + if backend.precision == "fp16": + # For fp16 precision there is no extra cast op inserted. + expected_ops = ["constexpr_lut_to_dense", "conv"] + else: + expected_ops = ["constexpr_lut_to_dense", "cast", "conv", "cast"] + assert get_op_types_in_program(mlmodel_palettized._mil_program) == expected_ops + palettize_op = mlmodel_palettized._mil_program.functions["main"].find_ops( + op_type="constexpr_lut_to_dense" + )[0] + assert types.builtin_to_string(palettize_op.indices.dtype) == f"uint{nbits}" + # For uint4, we still use np.uint8 to store the data. + assert palettize_op.indices.val.dtype == np.uint8 + assert model.weight.detach().numpy().shape == palettize_op.indices.val.shape + + if _macos_version() >= (15, 0): + verify_model_outputs(mlmodel, mlmodel_palettized, coreml_input_values) + + # The verify_model_outputs compares the decompressed model with compressed model. + # We further compare the compressed model with original model. + ref_output_dict = mlmodel.predict(coreml_input_values) + output_dict = mlmodel_palettized.predict(coreml_input_values) + for k, v in output_dict.items(): + assert k in ref_output_dict + if nbits == 1: + continue # nbits=1 numerical loss is too significant. + elif nbits <= 3: + large_diff_count = np.sum((v - ref_output_dict[k]) > 0.2) + threshold = 0.15 if channel_group_size != 0 else 0.5 + assert large_diff_count / v.size < threshold + elif nbits < 8: + np.testing.assert_almost_equal(v, ref_output_dict[k], decimal=1) + else: + err_tol = 1e-5 if mode == "kmeans" and channel_group_size == 1 else 1e-2 + np.testing.assert_allclose(v, ref_output_dict[k], atol=err_tol, rtol=err_tol) + + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product(compute_units, backends), + ) + def test_grouped_channelwise_palettization_better_than_per_tensor(self, compute_unit, backend): + """The grouped channelwise lut should be better than per-tensor lut.""" + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + per_tensor_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=4, + granularity="per_tensor", + ) + ) + grouped_channelwise_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=4, + granularity="per_grouped_channel", + group_size=1, + ) + ) + + if _macos_version() < (15, 0): + pytest.skip("Channelwise palettization prediction only support in iOS18+") + + mlmodel_per_tensor_palettized = cto.coreml.palettize_weights(mlmodel, per_tensor_config) + mlmodel_grouped_channelwise_palettized = cto.coreml.palettize_weights( + mlmodel, grouped_channelwise_config + ) + output_ref = mlmodel.predict(coreml_input_values) + output_per_tensor = mlmodel_per_tensor_palettized.predict(coreml_input_values) + output_grouped_channelwise = mlmodel_grouped_channelwise_palettized.predict( + coreml_input_values + ) + for k_ref, v_ref in output_ref.items(): + snr_per_tensor = compute_snr_and_psnr(v_ref, output_per_tensor[k_ref])[0] + snr_grouped_channelwise = compute_snr_and_psnr( + v_ref, output_grouped_channelwise[k_ref] + )[0] + assert snr_grouped_channelwise > snr_per_tensor + + def test_channelwise_palettization_invalid_config(self): + with pytest.raises(ValueError, match='Invalid value of "nbits" \(7\) for palettization'): + cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=7, + granularity="per_tensor", + weight_threshold=500, + ) + + @pytest.mark.parametrize( + "compute_unit, backend, group_size", + itertools.product(compute_units, backends, [1, 16]), + ) + def test_convert_palettized_model_with_pipeline(self, compute_unit, backend, group_size): + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data( + multi_layer=True + ) + with torch.no_grad(): + model.conv_1.weight = torch.nn.Parameter( + torch.Tensor(create_unique_weight(model.conv_1.weight, nbits=2)) + ) + model.conv_2.weight = torch.nn.Parameter( + torch.Tensor(create_unique_weight(model.conv_2.weight, nbits=6)) + ) + + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + pass_pipeline = ct.PassPipeline.DEFAULT_PALETTIZATION + pass_pipeline.set_options( + "compression::palettize_weights", + { + "config": cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpPalettizerConfig( + mode="unique", granularity="per_grouped_channel", group_size=group_size + ) + ) + }, + ) + mlmodel_palettized = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + pass_pipeline=pass_pipeline, + ) + + expected_ops = ["constexpr_lut_to_dense", "constexpr_lut_to_dense", "conv", "conv"] + assert get_op_types_in_program(mlmodel_palettized._mil_program) == expected_ops + palettize_ops = mlmodel_palettized._mil_program.functions["main"].find_ops( + op_type="constexpr_lut_to_dense" + ) + assert types.builtin_to_string(palettize_ops[0].indices.dtype) == "uint2" + assert palettize_ops[0].lut.shape == (32 // group_size, 1, 1, 1, 4, 1) + assert types.builtin_to_string(palettize_ops[1].indices.dtype) == "uint6" + assert palettize_ops[1].lut.shape == (64 // group_size, 1, 1, 1, 64, 1) + + if _macos_version() >= (15, 0): + verify_model_outputs(mlmodel, mlmodel_palettized, coreml_input_values) + class TestPruneWeights: + @staticmethod + def test_pruning_with_classifier(): + traced_model, example_input = TestPyTorchConverterExamples._get_classifier_model() + for class_type in ("str", "int"): + mlmodel = TestPyTorchConverterExamples._convert_classifier_model( + traced_model, example_input, class_type + ) + config = cto.coreml.OptimizationConfig() + global_config = cto.coreml.OpMagnitudePrunerConfig( + target_sparsity=0.9, weight_threshold=2 + ) + config.set_global(global_config) + mlmodel = cto.coreml.prune_weights(mlmodel, config) + expected_ops = [ + "cast", + "reshape", + "constexpr_sparse_to_dense", + "linear", + "relu", + "constexpr_sparse_to_dense", + "linear", + "relu", + "constexpr_sparse_to_dense", + "linear", + "cast", + "classify", + ] + assert get_op_types_in_program(mlmodel._mil_program) == expected_ops + @staticmethod def test_pruning(): model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() @@ -736,6 +1346,554 @@ def test_convert_sparse_source_model_custom(self): assert linear_ops[0].weight.op.op_type == "constexpr_sparse_to_dense" assert linear_ops[1].weight.op.op_type == "const" + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product(compute_units, backends), + ) + def test_default_prune_pipeline_ios18(self, compute_unit, backend): + """Make sure the new iOS18 op is used for DEFAULT_PRUNING pass pipeline.""" + # Make the weight size not divisible by 8, to make sure the internal conversion to ios18 + # sparse_to_dense op handles sub-byte masks correctly. + model = torch.nn.Linear(21, 121) + model.eval() + weight_sparse = create_sparse_weight(model.weight, 0.7) + with torch.no_grad(): + model.weight = torch.nn.Parameter(torch.Tensor(weight_sparse)) + + inputs = [ct.TensorType(name="data", shape=(4, 21))] + torch_input_values = [torch.rand(*i.shape.to_list()) for i in inputs] + coreml_input_values = { + i.name: val.detach().numpy() for i, val in zip(inputs, torch_input_values) + } + torchmodel = torch.jit.trace(model, torch_input_values) + + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + mlmodel_pruned = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + pass_pipeline=ct.PassPipeline.DEFAULT_PRUNING, + ) + sparse_ops = mlmodel_pruned._mil_program.find_ops(op_type="constexpr_sparse_to_dense") + assert len(sparse_ops) > 0 + for sparse_op in sparse_ops: + assert types.builtin_to_string(sparse_op.nonzero_data.dtype) == backend.precision + if backend.opset_version >= ct.target.iOS18: + assert types.builtin_to_string(sparse_op.mask.dtype) == "uint1" + else: + assert types.builtin_to_string(sparse_op.mask.dtype) == "uint8" + assert types.builtin_to_string(sparse_op.shape.dtype) == "uint32" + + if _macos_version() >= (15, 0): + verify_model_outputs(mlmodel, mlmodel_pruned, coreml_input_values) + + +class TestJointCompressWeights: + """Test using coremltools PTQ to do joint compression.""" + + @pytest.mark.parametrize( + "compute_unit, backend, dtype, block_size, output_channel_block_size, prune_first", + itertools.product( + compute_units, + backends, + ("int4", "int8", "uint4", "uint8"), + (0, 1, 2), + (0, 1), + (True, False), + ), + ) + def test_joint_prune_quantize_weights( + self, compute_unit, backend, dtype, block_size, output_channel_block_size, prune_first + ): + """Jointly prune and quantize the model, where non-sparse entries are quantized.""" + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + prune_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpMagnitudePrunerConfig( + target_sparsity=0.5, weight_threshold=500 + ) + ) + + quant_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpLinearQuantizerConfig( + mode="linear", + dtype=dtype, + granularity="per_block", + block_size=[0, block_size] if output_channel_block_size == 0 else block_size, + weight_threshold=500, + ), + op_type_configs={ + "conv": cto.coreml.OpLinearQuantizerConfig( + mode="linear", + dtype=dtype, + granularity="per_block", + block_size=[0, block_size, 0, 0] + if output_channel_block_size == 0 + else block_size, + weight_threshold=500, + ), + }, + ) + + if prune_first: + mlmodel_pruned = cto.coreml.prune_weights(mlmodel, prune_config) + mlmodel_joint_pruned_quantized = cto.coreml.linear_quantize_weights( + mlmodel_pruned, quant_config, joint_compression=True + ) + else: + mlmodel_quantized = cto.coreml.linear_quantize_weights(mlmodel, quant_config) + mlmodel_joint_pruned_quantized = cto.coreml.prune_weights( + mlmodel_quantized, prune_config, joint_compression=True + ) + + # If run prune first, the all-zero const for lstm won't have nonzero-data, so it won't be + # further quantized. + lstm_weight_compression_ops = ( + ["constexpr_sparse_to_dense"] + if prune_first + else ["constexpr_sparse_blockwise_shift_scale", "constexpr_sparse_to_dense"] + ) + expected_ops = ( + ["constexpr_sparse_blockwise_shift_scale", "constexpr_sparse_to_dense", "conv"] * 2 + + ["reshape"] + + ["constexpr_sparse_blockwise_shift_scale", "constexpr_sparse_to_dense", "linear"] * 2 + + lstm_weight_compression_ops + + ["constexpr_sparse_blockwise_shift_scale", "constexpr_sparse_to_dense"] * 2 + + ["lstm", "expand_dims", "expand_dims"] + ) + prog = mlmodel_joint_pruned_quantized._mil_program + assert get_op_types_in_program(prog) == expected_ops + + for linear_op in prog.find_ops(op_type="linear"): + assert linear_op.weight.op.op_type == "constexpr_sparse_to_dense" + for conv_op in prog.find_ops(op_type="conv"): + assert conv_op.weight.op.op_type == "constexpr_sparse_to_dense" + + sparse_quantize_ops = prog.find_ops(op_type="constexpr_sparse_blockwise_shift_scale") + assert len(sparse_quantize_ops) > 0 + for sparse_quantize_op in sparse_quantize_ops: + assert types.builtin_to_string(sparse_quantize_op.nonzero_data.dtype) == dtype + assert sparse_quantize_op.data_mask.dtype == types.uint1 + assert sparse_quantize_op.scale.dtype == types.fp16 + assert types.builtin_to_string(sparse_quantize_op.offset.dtype) == dtype + assert sparse_quantize_op.outputs[1].child_ops[0].op_type == "constexpr_sparse_to_dense" + # As both quantization and pruning is on the original weight, the shape of scale should + # match the original weight's shape except on the input/output channel. + weight_shape = sparse_quantize_op.outputs[1].child_ops[0].outputs[0].shape + expected_scale_shape = [1] * len(weight_shape) + if block_size > 0: + expected_scale_shape[1] = weight_shape[1] // block_size + if output_channel_block_size > 0: + expected_scale_shape[0] = weight_shape[0] // output_channel_block_size + assert sparse_quantize_op.scale.shape == tuple(expected_scale_shape) + + sparse_ops = prog.find_ops(op_type="constexpr_sparse_to_dense") + assert len(sparse_ops) > 0 + for sparse_op in sparse_ops: + assert sparse_op.mask.dtype == types.uint1 + assert sparse_op.nonzero_data.dtype == types.fp16 + + if _macos_version() >= (15, 0): + atol = 5e-4 if compute_unit == ct.ComputeUnit.CPU_AND_GPU else 1e-6 + verify_model_outputs( + mlmodel, mlmodel_joint_pruned_quantized, coreml_input_values, atol=atol + ) + + @pytest.mark.parametrize( + "compute_unit, backend, nbits, channel_group_size, prune_first", + itertools.product( + compute_units, + backends, + (3, 4, 8), + (0, 1, 2), + (True, False), + ), + ) + def test_joint_prune_palettize_weights( + self, compute_unit, backend, nbits, channel_group_size, prune_first + ): + """Jointly prune and palettize the model, where non-sparse entries are palettized.""" + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + prune_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpMagnitudePrunerConfig( + target_sparsity=0.2, + weight_threshold=500, + ) + ) + palettize_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpPalettizerConfig( + mode="uniform", + nbits=nbits, + granularity="per_grouped_channel", + group_size=channel_group_size, + weight_threshold=500, + ) + ) + + if prune_first: + mlmodel_pruned = cto.coreml.prune_weights(mlmodel, prune_config) + mlmodel_joint_pruned_palettized = cto.coreml.palettize_weights( + mlmodel_pruned, palettize_config, joint_compression=True + ) + else: + mlmodel_palettized = cto.coreml.palettize_weights(mlmodel, palettize_config) + mlmodel_joint_pruned_palettized = cto.coreml.prune_weights( + mlmodel_palettized, prune_config, joint_compression=True + ) + + # If run prune first, the all-zero const for lstm won't have nonzero-data, so it won't be + # further quantized. + lstm_weight_compression_ops = ( + ["constexpr_sparse_to_dense"] + if prune_first + else ["constexpr_lut_to_sparse", "constexpr_sparse_to_dense"] + ) + expected_ops = ( + ["constexpr_lut_to_sparse", "constexpr_sparse_to_dense", "conv"] * 2 + + ["reshape"] + + ["constexpr_lut_to_sparse", "constexpr_sparse_to_dense", "linear"] * 2 + + lstm_weight_compression_ops + + ["constexpr_lut_to_sparse", "constexpr_sparse_to_dense"] * 2 + + ["lstm", "expand_dims", "expand_dims"] + ) + prog = mlmodel_joint_pruned_palettized._mil_program + assert get_op_types_in_program(prog) == expected_ops + + for linear_op in prog.find_ops(op_type="linear"): + assert linear_op.weight.op.op_type == "constexpr_sparse_to_dense" + for conv_op in prog.find_ops(op_type="conv"): + assert conv_op.weight.op.op_type == "constexpr_sparse_to_dense" + + sparse_palettize_ops = prog.find_ops(op_type="constexpr_lut_to_sparse") + assert len(sparse_palettize_ops) > 0 + for sparse_palettize_op in sparse_palettize_ops: + assert sparse_palettize_op.indices_nonzero_data.dtype == types.string_to_builtin( + f"uint{nbits}" + ) + assert sparse_palettize_op.indices_mask.dtype == types.uint1 + assert sparse_palettize_op.lut.dtype == types.fp16 + assert ( + sparse_palettize_op.outputs[1].child_ops[0].op_type == "constexpr_sparse_to_dense" + ) + # As both palettization and pruning is on the original weight, the shape of lut should + # match the original weight's shape except on the output channel. + weight_shape = sparse_palettize_op.outputs[1].child_ops[0].outputs[0].shape + expected_lut_shape = [1] * len(weight_shape) + [2**nbits] + [1] + if channel_group_size > 0: + expected_lut_shape[0] = weight_shape[0] // channel_group_size + assert sparse_palettize_op.lut.shape == tuple(expected_lut_shape) + + sparse_ops = prog.find_ops(op_type="constexpr_sparse_to_dense") + assert len(sparse_ops) > 0 + for sparse_op in sparse_ops: + assert sparse_op.mask.dtype == types.uint1 + assert sparse_op.nonzero_data.dtype == types.fp16 + + if _macos_version() >= (15, 0): + atol = 5e-4 if compute_unit == ct.ComputeUnit.CPU_AND_GPU else 1e-6 + verify_model_outputs( + mlmodel, mlmodel_joint_pruned_palettized, coreml_input_values, atol=atol + ) + + @pytest.mark.parametrize( + "compute_unit, backend, nbits, channel_group_size", + itertools.product( + compute_units, + backends, + (3, 4, 8), + (0, 1, 2), + ), + ) + def test_joint_palettize_quantize_weights( + self, compute_unit, backend, nbits, channel_group_size + ): + """First palettize to get fp16 lut, and then quantize the lut to make int8 lut.""" + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + palettize_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpPalettizerConfig( + mode="uniform", + nbits=nbits, + granularity="per_grouped_channel", + group_size=channel_group_size, + weight_threshold=500, + ) + ) + quant_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpLinearQuantizerConfig( + # Quantize the whole lut tensor as the lut usually is not huge. + mode="linear", + dtype="int8", + granularity="per_tensor", + weight_threshold=500, + ) + ) + + mlmodel_palettized = cto.coreml.palettize_weights(mlmodel, palettize_config) + mlmodel_joint_palettized_quantized = cto.coreml.linear_quantize_weights( + mlmodel_palettized, quant_config, joint_compression=True + ) + expected_ops = ( + ["constexpr_blockwise_shift_scale", "constexpr_lut_to_dense", "conv"] * 2 + + ["reshape"] + + ["constexpr_blockwise_shift_scale", "constexpr_lut_to_dense", "linear"] * 2 + + ["constexpr_blockwise_shift_scale", "constexpr_lut_to_dense"] * 3 + + ["lstm", "expand_dims", "expand_dims"] + ) + prog = mlmodel_joint_palettized_quantized._mil_program + if channel_group_size == 0: + # When use per-tensor lut, the lut size is too small, so it's stored as ImmediateValue + # which won't be quantized. + ops_in_prog = get_op_types_in_program(prog) + if nbits >= 4: + assert ops_in_prog.count("constexpr_blockwise_shift_scale") >= 6 + else: + assert ops_in_prog.count("constexpr_blockwise_shift_scale") == 0 + else: + assert get_op_types_in_program(prog) == expected_ops + + for linear_op in prog.find_ops(op_type="linear"): + assert linear_op.weight.op.op_type == "constexpr_lut_to_dense" + for conv_op in prog.find_ops(op_type="conv"): + assert conv_op.weight.op.op_type == "constexpr_lut_to_dense" + + for quantize_op in prog.find_ops(op_type="constexpr_blockwise_shift_scale"): + assert quantize_op.data.dtype == types.int8 + assert quantize_op.scale.dtype == types.fp16 + assert quantize_op.offset.dtype == types.int8 + assert quantize_op.outputs[0].child_ops[0].op_type == "constexpr_lut_to_dense" + + for palettize_op in prog.find_ops(op_type="constexpr_lut_to_dense"): + assert palettize_op.lut.dtype == types.fp16 + assert palettize_op.indices.dtype == types.string_to_builtin(f"uint{nbits}") + + if _macos_version() >= (15, 0): + verify_model_outputs(mlmodel, mlmodel_joint_palettized_quantized, coreml_input_values) + + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product(compute_units, backends), + ) + def test_joint_palettize_quantize_weights_invalid(self, compute_unit, backend): + """Only support per-tensor quantization for this case.""" + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + palettize_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpPalettizerConfig( + mode="uniform", + nbits=4, + granularity="per_grouped_channel", + group_size=1, + weight_threshold=500, + ) + ) + quant_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpLinearQuantizerConfig( + mode="linear", + block_size=1, + weight_threshold=500, + ) + ) + + mlmodel_palettized = cto.coreml.palettize_weights(mlmodel, palettize_config) + with pytest.raises( + NotImplementedError, + match="When use joint compression for palettization-quantization, " + "please make sure to use per-tensor quantization", + ): + cto.coreml.linear_quantize_weights( + mlmodel_palettized, quant_config, joint_compression=True + ) + + @pytest.mark.parametrize( + "compute_unit, backend, nbits, channel_group_size, target_sparsity", + itertools.product( + compute_units, + backends, + (3, 4, 8), + (0, 1, 2), + (0.2, 0.8), + ), + ) + def test_joint_prune_palettize_quantize_weights( + self, compute_unit, backend, nbits, channel_group_size, target_sparsity + ): + """ + First prune to get sparse weight, and then palettize the non-sparse entries to get fp16 + lut, and then quantize the lut to make int8 lut. + """ + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data_complex() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert( + torchmodel, + inputs=inputs, + convert_to="mlprogram", + minimum_deployment_target=backend.opset_version, + compute_precision=ct.precision.FLOAT16 + if backend.precision == "fp16" + else ct.precision.FLOAT32, + compute_units=compute_unit, + ) + + prune_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpMagnitudePrunerConfig( + target_sparsity=target_sparsity, weight_threshold=500 + ) + ) + palettize_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpPalettizerConfig( + mode="kmeans", + nbits=nbits, + granularity="per_grouped_channel", + group_size=channel_group_size, + weight_threshold=500, + ) + ) + quant_config = cto.coreml.OptimizationConfig( + global_config=cto.coreml.OpLinearQuantizerConfig( + mode="linear", + dtype="int8", + granularity="per_tensor", + weight_threshold=200, # Need to be smaller than entries in lut (2**8=256). + ) + ) + + mlmodel_pruned = cto.coreml.prune_weights(mlmodel, prune_config) + mlmodel_joint_pruned_palettized = cto.coreml.palettize_weights( + mlmodel_pruned, palettize_config, joint_compression=True + ) + mlmodel_joint_pruned_palettized_quantized = cto.coreml.linear_quantize_weights( + mlmodel_joint_pruned_palettized, quant_config, joint_compression=True + ) + expected_ops = ( + [ + "constexpr_blockwise_shift_scale", + "constexpr_lut_to_sparse", + "constexpr_sparse_to_dense", + "conv", + ] + * 2 + + ["reshape"] + + [ + "constexpr_blockwise_shift_scale", + "constexpr_lut_to_sparse", + "constexpr_sparse_to_dense", + "linear", + ] + * 2 + + ["constexpr_sparse_to_dense"] + + [ + "constexpr_blockwise_shift_scale", + "constexpr_lut_to_sparse", + "constexpr_sparse_to_dense", + ] + * 2 + + ["lstm", "expand_dims", "expand_dims"] + ) + if nbits < 4 and channel_group_size == 0: + # The lut tensor is too small, which is stored as immediate values. + expected_ops = [ + expected_op + for expected_op in expected_ops + if expected_op != "constexpr_blockwise_shift_scale" + ] + prog = mlmodel_joint_pruned_palettized_quantized._mil_program + assert get_op_types_in_program(prog) == expected_ops + + for linear_op in prog.find_ops(op_type="linear"): + assert linear_op.weight.op.op_type == "constexpr_sparse_to_dense" + for conv_op in prog.find_ops(op_type="conv"): + assert conv_op.weight.op.op_type == "constexpr_sparse_to_dense" + + for quantize_op in prog.find_ops(op_type="constexpr_blockwise_shift_scale"): + assert types.builtin_to_string(quantize_op.data.dtype) == "int8" + assert types.builtin_to_string(quantize_op.scale.dtype) == backend.precision + assert types.builtin_to_string(quantize_op.offset.dtype) == "int8" + assert quantize_op.outputs[0].child_ops[0].op_type == "constexpr_lut_to_sparse" + + for sparse_palettize_op in prog.find_ops(op_type="constexpr_lut_to_sparse"): + assert ( + types.builtin_to_string(sparse_palettize_op.indices_nonzero_data.dtype) + == f"uint{nbits}" + ) + assert sparse_palettize_op.indices_mask.dtype == types.uint1 + assert ( + sparse_palettize_op.outputs[1].child_ops[0].op_type == "constexpr_sparse_to_dense" + ) + + for sparse_op in prog.find_ops(op_type="constexpr_sparse_to_dense"): + assert sparse_op.mask.dtype == types.uint1 + assert types.builtin_to_string(sparse_op.nonzero_data.dtype) == backend.precision + + if _macos_version() >= (15, 0): + atol = 5e-4 if compute_unit == ct.ComputeUnit.CPU_AND_GPU else 1e-6 + verify_model_outputs( + mlmodel, + mlmodel_joint_pruned_palettized_quantized, + coreml_input_values, + atol=atol, + ) + + class TestDecompressWeights: @staticmethod def test_weight_decopmression_coreml_optimize(): @@ -864,12 +2022,12 @@ def test_error_handling(): linear_quantize_weights(mlmodel, mode="invalid_mode") # Test invalid dtype for affine quantization - expected_err_str = "is unsupported for affine_quantize_weight" + expected_err_str = "Should be int4/8 or uint4/8, but got int32" with pytest.raises(ValueError, match=expected_err_str): linear_quantize_weights(mlmodel, dtype=np.int32) - expected_err_str = "\'dtype\' must be \ \(got \'int32\'" - with pytest.raises(TypeError, match=expected_err_str): + expected_err_str = "Should be int4/8 or uint4/8, but got int32" + with pytest.raises(ValueError, match=expected_err_str): linear_quantize_weights(mlmodel, dtype="int32") # Test invalid threshold for weight sparsification @@ -914,6 +2072,41 @@ def test_error_handling(): with pytest.raises(ValueError, match=expected_err_str): palettize_weights(mlmodel, mode="custom", lut_function=1) + @staticmethod + def test_error_out_multifunction(): + # prepare a mlmodel from a torch model + model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() + torchmodel = torch.jit.trace(model, torch_input_values) + mlmodel = ct.convert(torchmodel, inputs=inputs, convert_to="mlprogram") + + # make a multifunction model + package_path = tempfile.mkdtemp(suffix=".mlpackage") + mlmodel.save(package_path) + desc = MultiFunctionDescriptor(package_path) + desc.default_function_name = "main" + multifunction_path = tempfile.mkdtemp(suffix=".mlpackage") + save_multifunction(desc, multifunction_path) + multifunction_mlmodel = ct.models.MLModel(multifunction_path) + + # all PTQ API should error out, until the radar is fixed: + # rdar://126084385 ([Infra] Figure out the story of PTQ or other passes operate on loaded Mutli-function model) + def run_palettization(mlmodel): + return palettize_weights(mlmodel, nbits=2) + + for func in [ + linear_quantize_weights, + prune_weights, + run_palettization, + decompress_weights, + ct.optimize.coreml.get_weights_metadata, + ]: + with pytest.raises(ValueError, match="is not supported for a multifunction model"): + func(multifunction_mlmodel) + + # cleanup + shutil.rmtree(package_path) + shutil.rmtree(multifunction_path) + class TestCoreMLWeightMetaData: """ diff --git a/coremltools/test/optimize/coreml/test_utils.py b/coremltools/test/optimize/coreml/test_utils.py new file mode 100644 index 000000000..497a87921 --- /dev/null +++ b/coremltools/test/optimize/coreml/test_utils.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import itertools + +import numpy as np +import pytest + +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.ops.defs.iOS18.compression import constexpr_lut_to_dense +from coremltools.optimize.coreml import _utils as optimize_utils + + +class TestComputeQuantizationParams: + @pytest.mark.parametrize( + "quant_mode, rank, block_size", + itertools.product( + ["LINEAR", "LINEAR_SYMMETRIC"], + [1, 2, 3], + [0, 1, 2], + ), + ) + def test_compute_qparams(self, quant_mode, rank, block_size): + weight_shape = [10] * rank + weight = np.random.randn(*weight_shape) + ret = optimize_utils.compute_qparams( + weight, + nbits=8, + signed=True, + quantization_mode=quant_mode, + dtype=np.int8, + block_sizes=[block_size] * rank, + ) + if quant_mode == "LINEAR_SYMMETRIC": + assert ret[-1] is None + else: + assert ret[-1] is not None + + assert ret[0].shape == weight.shape + + @pytest.mark.parametrize( + "quant_mode, block_sizes", + itertools.product( + ["LINEAR", "LINEAR_SYMMETRIC"], + [ + [0], + [4, 5], + [3, 9], + [4, 5, 6], + ], + ), + ) + def test_compute_qparams_failure(self, block_sizes, quant_mode): + weight = np.random.randn(10, 10) + with pytest.raises(AssertionError): + ret = optimize_utils.compute_qparams( + weight, + nbits=8, + signed=True, + quantization_mode=quant_mode, + dtype=np.int8, + block_sizes=block_sizes, + ) + + assert ret is not None + + +class TestFindIndicesForLut: + def test_basic(self): + """ + data: [3.01, -7.99, -8.01, 3.02, 3.89, -1.88, -2.02, -6.98] + lut: [-8, -7, 3, 4, -2] + expected indices: [2, 0, 0, 2, 3, 4, 4, 1] + """ + data = np.array([3.01, -7.99, -8.01, 3.02, 3.89, 0.98, 1.98, -6.98], dtype=np.float16) + lut = np.array([-8, -7, 3, 4], dtype=np.int8).reshape((1, 4, 1)) + expected_indices = np.array([2, 0, 0, 2, 3, 2, 2, 1], dtype=np.uint8) + indices = optimize_utils.find_indices_for_lut(data, lut) + np.testing.assert_array_equal(indices, expected_indices) + assert types.builtin_to_string(types.numpy_type_to_builtin_type(indices.dtype)) == "uint2" + + @pytest.mark.parametrize( + "nbits, block_sizes", + itertools.product( + (2, 3, 4, 8), + ( + [0], + [1], + [2], + [2, 2], + [1, 2, 1], + [0, 2, 2], + [4, 0, 0, 1], + [8, 4, 2, 3], + ), + ), + ) + def test_stress(self, nbits, block_sizes): + """ + As finding indices is the reverse progress of generating data from lut, we first manually + construct indices and lut, and then generate data from lut and salt it, and finally check + if the restored indices are identical to the original indices. + """ + data_shape = [8, 4, 2, 3] + lut_shape = data_shape + [2**nbits, 1] + for idx, dim_size in enumerate(data_shape): + if idx < len(block_sizes): + lut_shape[idx] = 1 if block_sizes[idx] == 0 else data_shape[idx] // block_sizes[idx] + + nbits_range = types.type_mapping.builtin_to_range(types.string_to_builtin(f"uint{nbits}")) + lut = np.arange(np.prod(lut_shape)).reshape(lut_shape).astype(np.float32) + expected_indices = np.random.randint( + low=nbits_range.low, high=nbits_range.high + 1, size=data_shape, dtype=np.uint8 + ) + + data = constexpr_lut_to_dense.decompress(expected_indices, lut, vector_axis=None) + # Salting the data to manually introduce numerical instability. + data += np.random.randint(low=0, high=2, size=data.shape) * 0.01 + data -= np.random.randint(low=0, high=2, size=data.shape) * 0.01 + + indices = optimize_utils.find_indices_for_lut(data, lut) + + np.testing.assert_array_equal(indices, expected_indices) + assert ( + types.builtin_to_string(types.numpy_type_to_builtin_type(indices.dtype)) + == f"uint{nbits}" + ) + + +class TestPackUnpackBits: + def test_pack_basic(self): + """ + Original data: [-8, 7, 3, 4, -2]. + The 4-bit binary representation for those elements are: + -8: 1000; + 7: 0111; + 3: 0011 + 4: 0100 + -2: 1110 + Hence the packed quantized_data will be 3 bytes long, i.e., 24 bits long, which is: + 0111 1000 0100 0011 0000 1110 + So the packed data is represented by 3 uint8 values: [120, 67, 14]. + """ + original_data = np.array([-8, 7, 3, 4, -2], dtype=np.int8) + expected_packed_data = np.array([120, 67, 14], dtype=np.uint8) + packed_data = optimize_utils.pack_elements_into_bits(original_data, nbits=4) + np.testing.assert_array_equal(packed_data, expected_packed_data) + + def test_pack_basic_2(self): + original_data = np.array([1, 2, 3, 4, 5], dtype=np.int8) + expected_packed_data = np.array([33, 67, 5], dtype=np.uint8) + packed_data = optimize_utils.pack_elements_into_bits(original_data, nbits=4) + np.testing.assert_array_equal(packed_data, expected_packed_data) + + @pytest.mark.parametrize( + "nbits, data_dtype, element_num", + itertools.product(list(range(1, 9)), [np.int8, np.uint8], [1, 3, 20]), + ) + def test_round_trip_pack_unpack(self, nbits, data_dtype, element_num): + is_data_signed = np.issubdtype(data_dtype, np.signedinteger) + low, high = 0, 2**nbits + if is_data_signed: + low, high = -(2 ** (nbits - 1)), 2 ** (nbits - 1) + original_data = np.random.randint(low=low, high=high, size=(element_num,)).astype( + data_dtype + ) + packed_data = optimize_utils.pack_elements_into_bits(original_data, nbits) + restored_data = optimize_utils.restore_elements_from_packed_bits( + packed_data, nbits, element_num, are_packed_values_signed=is_data_signed + ) + np.testing.assert_array_equal(restored_data, original_data) diff --git a/coremltools/test/optimize/torch/__init__.py b/coremltools/test/optimize/torch/__init__.py index 461b69f06..ac8b4b20c 100644 --- a/coremltools/test/optimize/torch/__init__.py +++ b/coremltools/test/optimize/torch/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/conftest.py b/coremltools/test/optimize/torch/conftest.py index 182d47798..acf59383d 100644 --- a/coremltools/test/optimize/torch/conftest.py +++ b/coremltools/test/optimize/torch/conftest.py @@ -1,33 +1,43 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause import os import shutil + +import pytest + from coremltools.test.optimize.torch.models.mnist import ( mnist_dataset, + mnist_example_input, + mnist_example_output, mnist_model, mnist_model_large, mnist_model_quantization, + residual_mnist_model, ) from coremltools.test.optimize.torch.pruning.pruning_utils import get_model_and_pruner -import pytest - # dummy function to use the imported fixtures so that linter # does not remove them as unused imports def _dummy( mnist_dataset, + mnist_example_input, + mnist_example_output, mnist_model, + residual_mnist_model, mnist_model_large, mnist_model_quantization, get_model_and_pruner, ): return ( mnist_dataset, + mnist_example_input, + mnist_example_output, mnist_model, + residual_mnist_model, mnist_model_large, mnist_model_quantization, get_model_and_pruner, @@ -51,3 +61,52 @@ def datadir(request): Directory for storing test data for latter inspection. """ return _datadir(request) + + +@pytest.fixture +def mock_name_main(monkeypatch): + monkeypatch.setattr(__import__("__main__"), "__name__", "__main__") + + +def pytest_addoption(parser): + """ + Adds command line option --runopt to the pytest parser + By default, evaluates to False. + If command line option passed, evaluates to True + """ + + parser.addoption("--runopt", action="store_true", default=False, help="run optional tests") + + +def pytest_configure(config): + """ + Adds info about optional marker to pytest config + """ + config.addinivalue_line("markers", "optional: mark test run as optional") + + +def marker_names(item): + """ + Returns set containing the name of each marker associated with + the given test item + """ + return set(m.name for m in item.iter_markers()) + + +def pytest_collection_modifyitems(config, items): + """ + Modifies the test items so that items marked optional are skipped + when the --runopt command line option is not provided. + Otherwise, will not perform any modifications. + """ + + # No modifications required + if config.getoption("--runopt"): + return + + skip_opt = pytest.mark.skip(reason="need --runopt option to run") + + for item in items: + markers = marker_names(item) + if "optional" in markers: + item.add_marker(skip_opt) diff --git a/coremltools/test/optimize/torch/conversion/__init__.py b/coremltools/test/optimize/torch/conversion/__init__.py new file mode 100644 index 000000000..5dc5e6747 --- /dev/null +++ b/coremltools/test/optimize/torch/conversion/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/conversion/conversion_utils.py b/coremltools/test/optimize/torch/conversion/conversion_utils.py new file mode 100644 index 000000000..c4447ea1c --- /dev/null +++ b/coremltools/test/optimize/torch/conversion/conversion_utils.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import sys + +import numpy as np +import torch + +import coremltools as ct + + +def convert_and_verify( + pytorch_model, + input_data, + input_as_shape=False, + pass_pipeline=None, + minimum_deployment_target=ct.target.iOS18, + expected_ops=None, +): + """ + Utility to: + 1) Convert a PyTorch model to coreml format + 2) Compare their outputs for numerical equivalence + 3) Verify the converted model contains expected ops + + Args: + input_as_shape: If true generates random input data with shape. + expected_ops: List of MIL ops expected in the converted model + Returns: + Converted coreml model + """ + if input_as_shape: + example_input = torch.rand(input_data) + else: + example_input = input_data + + # Generate converted model + coreml_model = get_converted_model( + pytorch_model, example_input, pass_pipeline, minimum_deployment_target + ) + assert coreml_model is not None + + # Verify converted model output matches torch model + verify_model_outputs(pytorch_model, coreml_model, example_input) + + # Verify desired ops are present + verify_ops(coreml_model, expected_ops) + + return coreml_model + + +def get_converted_model( + pytorch_model, + input_data, + pass_pipeline=None, + minimum_deployment_target=ct.target.iOS17, +): + """ + Utility that takes a PyTorch model and converts it to a coreml model + """ + traced_model = torch.jit.trace(pytorch_model, example_inputs=(input_data,)) + coreml_model = None + try: + coreml_model = ct.convert( + traced_model, + inputs=[ct.TensorType(shape=input_data.shape)], + pass_pipeline=pass_pipeline, + minimum_deployment_target=minimum_deployment_target, + ) + except Exception as err: + print(f"Conversion Error: {err}") + + return coreml_model + + +def verify_model_outputs(pytorch_model, coreml_model, input_value): + """ + This utility functions does the following checks: + (1) Verify the output of the coreml model has the same shape of the PyTorch model + (2) The PyTorch and coreml model have the same numerical outputs + """ + # Validate the output shape / type + ref_output = pytorch_model(input_value) + output = coreml_model._mil_program.functions["main"].outputs[0] + + assert ref_output.shape == output.shape + + # Cannot run predict on linux + if sys.platform == "linux": + return + + # Validate that the coreml model produces correct outputs + pytorch_model.eval() + ref_output_dict = pytorch_model(input_value) + coreml_input_value = {"input_1": input_value.detach().numpy()} + output_dict = coreml_model.predict(coreml_input_value) + for k, v in output_dict.items(): + np.testing.assert_allclose(v, output_dict[k]) + + +def verify_ops(coreml_model, expected_ops): + if not expected_ops: + return + + for op in expected_ops: + compressed_ops = coreml_model._mil_program.functions["main"].find_ops(op_type=op) + assert len(compressed_ops) >= 1 diff --git a/coremltools/test/optimize/torch/conversion/joint/__init__.py b/coremltools/test/optimize/torch/conversion/joint/__init__.py new file mode 100644 index 000000000..5dc5e6747 --- /dev/null +++ b/coremltools/test/optimize/torch/conversion/joint/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/conversion/joint/test_joint_compression_conversion.py b/coremltools/test/optimize/torch/conversion/joint/test_joint_compression_conversion.py new file mode 100644 index 000000000..fc415de1c --- /dev/null +++ b/coremltools/test/optimize/torch/conversion/joint/test_joint_compression_conversion.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import pytest + +ct = pytest.importorskip("coremltools") +import coremltools.test.optimize.torch.conversion.conversion_utils as util +from coremltools.optimize.torch.layerwise_compression import ( + LayerwiseCompressor, + LayerwiseCompressorConfig, +) +from coremltools.optimize.torch.pruning import MagnitudePruner, MagnitudePrunerConfig +from coremltools.optimize.torch.quantization import LinearQuantizer, LinearQuantizerConfig + + +@pytest.mark.xfail( + reason="rdar://129302570 (Fix conversion support for jointly compressed models using training time algorithms)", + strict=True, +) +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), reason="Only supported on macOS 15+") +def test_joint_pruning_quantization(mnist_model, mnist_example_input): + example_input = mnist_example_input + quant_config = LinearQuantizerConfig.from_dict( + { + "global_config": { + "milestones": [0, 0, 10, 10], + } + } + ) + prune_config = MagnitudePrunerConfig.from_dict({"global_config": {"target_sparsity": 0.5}}) + + quantizer = LinearQuantizer(mnist_model, quant_config) + quant_model = quantizer.prepare(example_inputs=(example_input,)) + + pruner = MagnitudePruner(quant_model, prune_config) + pruned_quant_model = pruner.prepare(inplace=True) + + quantizer.step() + pruner.step() + + # Do a forward pass for pruner mask to be applied + # Alternatively can set initial sparsity to target sparsity + pruned_quant_model(example_input) + + quantizer.finalize(inplace=True) + finalized_model = pruner.finalize(inplace=True) + + util.convert_and_verify( + finalized_model, + example_input, + pass_pipeline=ct.PassPipeline.DEFAULT_PRUNING, + minimum_deployment_target=ct.target.iOS18, + expected_ops=["constexpr_sparse_to_dense"], + ) + + +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), reason="Only supported on macOS 15+") +@pytest.mark.parametrize( + "config, expected_ops", + [ + pytest.param( + {"global_config": {"algorithm": "sparse_gpt"}}, + ["constexpr_sparse_to_dense"], + id="pruning", + ), + pytest.param( + {"global_config": {"algorithm": "sparse_gpt", "weight_dtype": "uint4"}}, + ["constexpr_sparse_to_dense", "constexpr_sparse_blockwise_shift_scale"], + id="joint_pruning_quantization", + ), + pytest.param( + { + "global_config": { + "algorithm": "sparse_gpt", + "weight_dtype": "uint4", + "enable_normal_float": True, + } + }, + ["constexpr_sparse_to_dense", "constexpr_lut_to_sparse"], + id="joint_pruning_palettization", + ), + ], +) +def test_sparsegpt(config, mnist_model, mnist_example_input, expected_ops): + compressor_config = LayerwiseCompressorConfig.from_dict(config) + compressor = LayerwiseCompressor(mnist_model, compressor_config) + + def calibration_loader(): + yield mnist_example_input + + compressed_model = compressor.compress(calibration_loader(), device="cpu") + + util.convert_and_verify( + compressed_model, + mnist_example_input, + expected_ops=expected_ops, + ) diff --git a/coremltools/test/optimize/torch/conversion/palettization/__init__.py b/coremltools/test/optimize/torch/conversion/palettization/__init__.py new file mode 100644 index 000000000..5dc5e6747 --- /dev/null +++ b/coremltools/test/optimize/torch/conversion/palettization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/conversion/palettization/test_palettization_conversion.py b/coremltools/test/optimize/torch/conversion/palettization/test_palettization_conversion.py new file mode 100644 index 000000000..21ef5a10e --- /dev/null +++ b/coremltools/test/optimize/torch/conversion/palettization/test_palettization_conversion.py @@ -0,0 +1,399 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import pytest +import torch +import torch.nn as nn + +import coremltools.test.optimize.torch.conversion.conversion_utils as util +from coremltools.optimize.torch.palettization import ( + DKMPalettizer, + DKMPalettizerConfig, + PostTrainingPalettizer, + PostTrainingPalettizerConfig, + SKMPalettizer, + SKMPalettizerConfig, +) +from coremltools.test.optimize.torch.utils import count_unique_params + +ct = pytest.importorskip("coremltools") +cto = pytest.importorskip("coremltools.optimize") + + +# region DKMPalettizer +@pytest.mark.parametrize( + "config, lut_shape_map", + [ + pytest.param( + {"module_name_configs": {"conv2": {"n_bits": 4}}}, + {"conv2": (1, 1, 1, 1, 16, 1)}, + id="scalar_per_tensor", + ), + pytest.param( + {"module_name_configs": {"conv2": {"n_bits": 4, "cluster_dim": 4}}}, + {"conv2": (1, 1, 1, 1, 16, 4)}, + marks=pytest.mark.xfail( + reason="rdar://124474258 ([Compression] Support Vector Palettization in coremltools)" + ), + id="vector_per_tensor", + ), + ], +) +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), reason="Only supported on macOS 15+") +def test_dkm(mnist_model, mnist_example_input, config, lut_shape_map): + palettizer_config = DKMPalettizerConfig.from_dict(config) + palettizer = DKMPalettizer(mnist_model, palettizer_config) + palettized_model = get_palettized_model(palettizer) + + # Validate on torch model. + weight_sample = palettized_model.conv2.weight.detach() # per tensor + _n_bits = config["module_name_configs"]["conv2"]["n_bits"] + max_unique_values = 2**_n_bits + if "cluster_dim" in config["module_name_configs"]["conv2"]: + _cluster_dim = config["module_name_configs"]["conv2"]["cluster_dim"] + max_unique_values *= _cluster_dim + assert count_unique_params(torch.unique(weight_sample)) <= max_unique_values + + # Convert and validate on coreml model. + palettized_model_coreml = util.convert_and_verify( + palettized_model, + mnist_example_input, + pass_pipeline=ct.PassPipeline.DEFAULT_PALETTIZATION, + expected_ops=["constexpr_lut_to_dense"], + ) + verify_op_constexpr_lut_to_dense(palettized_model_coreml, lut_shape_map) +# endregion + + +# region SKM/PTP - per_tensor +@pytest.mark.parametrize( + "config, lut_shape_map", + [ + # Exclude testing for 8/6 bits since all ops in MNIST get skipped for 8/6-bit palettization. + pytest.param( + { + "global_config": {"n_bits": 4, "granularity": "per_tensor"}, + }, + { + "conv1": (1, 1, 1, 1, 16, 1), + "conv2": (1, 1, 1, 1, 16, 1), + "dense1": (1, 1, 16, 1), + "dense2": (1, 1, 16, 1), + }, + id="4bits", + ), + pytest.param( + { + "global_config": {"n_bits": 2, "granularity": "per_tensor"}, + }, + { + "conv1": (1, 1, 1, 1, 4, 1), + "conv2": (1, 1, 1, 1, 4, 1), + "dense1": (1, 1, 4, 1), + "dense2": (1, 1, 4, 1), + }, + id="2bits", + ), + ], +) +@pytest.mark.skip( + "rdar://128875026 ([Compression] Per-channel post training palettization conversion throwing a SIGABORT)" +) +@pytest.mark.parametrize("algorithm", ["SKM", "PTP"]) +def test_post_training_palettization_per_tensor( + mnist_model, + mnist_example_input, + mnist_example_output, + config, + lut_shape_map, + algorithm, +): + compressed_model = get_compressed_model( + algorithm, mnist_model, mnist_example_input, mnist_example_output, config + ) + + weight_sample = compressed_model.conv1.weight.detach() # per tensor + + # Validate on torch model. + _n_bits = config["global_config"]["n_bits"] + max_unique_values = 2**_n_bits + assert count_unique_params(torch.unique(weight_sample)) <= max_unique_values + + # Convert and validate on coreml model. + compressed_model_coreml = util.convert_and_verify( + compressed_model, + mnist_example_input, + expected_ops=["constexpr_lut_to_dense"], + ) + verify_op_constexpr_lut_to_dense(compressed_model_coreml, lut_shape_map) +# endregion + + +# region SKM/PTP - per_channel_scale +@pytest.mark.parametrize( + "config, lut_shape_map", + [ + # Exclude testing for 8/6 bits since all ops in MNIST get skipped for 8/6-bit palettization. + pytest.param( + { + "global_config": { + "n_bits": 4, + "granularity": "per_grouped_channel", + "group_size": 1, + "enable_per_channel_scale": True, + }, + }, + { + "conv1": (32, 1, 1, 1, 16, 1), + "conv2": (64, 1, 1, 1, 16, 1), + "dense1": (1024, 1, 16, 1), + "dense2": (10, 1, 16, 1), + }, + id="4bits", + ), + pytest.param( + { + "global_config": { + "n_bits": 2, + "granularity": "per_grouped_channel", + "group_size": 1, + "enable_per_channel_scale": True, + }, + }, + { + "conv1": (32, 1, 1, 1, 4, 1), + "conv2": (64, 1, 1, 1, 4, 1), + "dense1": (1024, 1, 4, 1), + "dense2": (10, 1, 4, 1), + }, + id="2bits", + ), + ], +) +@pytest.mark.skip( + "rdar://128875026 ([Compression] Per-channel post training palettization conversion throwing a SIGABORT)" +) +@pytest.mark.parametrize("algorithm", ["SKM", "PTP"]) +def test_post_training_palettization_per_channel_scale( + mnist_model, + mnist_example_input, + mnist_example_output, + config, + lut_shape_map, + algorithm, +): + compressed_model = get_compressed_model( + algorithm, mnist_model, mnist_example_input, mnist_example_output, config + ) + + # Validate on torch model. + for i in range(32): + weight_sample = compressed_model.conv1.weight[i].detach() # per channel + _n_bits = config["global_config"]["n_bits"] + max_unique_values = 2**_n_bits + assert count_unique_params(torch.unique(weight_sample)) <= max_unique_values + + compressed_model_coreml = util.convert_and_verify( + compressed_model, + mnist_example_input, + expected_ops=["constexpr_lut_to_dense"], + ) + verify_op_constexpr_lut_to_dense(compressed_model_coreml, lut_shape_map) +# endregion + + +# region SKM/PTP - grouped_channelwise +@pytest.mark.parametrize( + "config, lut_shape_map", + [ + pytest.param( + { + "global_config": { + "n_bits": 4, + "granularity": "per_grouped_channel", + "group_size": 16, + "channel_axis": 0, + }, + }, + { + "conv1": (2, 1, 1, 1, 16, 1), + "conv2": (4, 1, 1, 1, 16, 1), + "dense1": (64, 1, 16, 1), + }, + id="4bits_group_size_16_axis_0", + ), + pytest.param( + { + "global_config": { + "n_bits": 4, + "granularity": "per_grouped_channel", + "group_size": 16, + "channel_axis": 1, + }, + }, + { + "conv2": (1, 2, 1, 1, 16, 1), + "dense1": (1, 196, 16, 1), + "dense2": (1, 64, 16, 1), + }, + id="4bits_group_size_16_axis_1", + ), + ], +) +@pytest.mark.skip( + "rdar://128875026 ([Compression] Per-channel post training palettization conversion throwing a SIGABORT)" +) +@pytest.mark.parametrize("algorithm", ["SKM", "PTP"]) +def test_post_training_palettization_grouped_channelwise( + mnist_model, + mnist_example_input, + mnist_example_output, + config, + lut_shape_map, + algorithm, +): + compressed_model = get_compressed_model( + algorithm, mnist_model, mnist_example_input, mnist_example_output, config + ) + + # Validate on torch model. + _group_size = config["global_config"]["group_size"] + _axis = config["global_config"]["channel_axis"] + + for i in range(0, _group_size, 32): + if _axis == 1: + # blocking along input channel axis + weight_sample = compressed_model.conv2.weight[:, i : i + _group_size].detach() + else: + # blocking along output channel axis + weight_sample = compressed_model.conv2.weight[i : i + _group_size].detach() + _n_bits = config["global_config"]["n_bits"] + max_unique_values = 2**_n_bits + assert count_unique_params(torch.unique(weight_sample)) <= max_unique_values + + compressed_model_coreml = util.convert_and_verify( + compressed_model, + mnist_example_input, + expected_ops=["constexpr_lut_to_dense"], + ) + verify_op_constexpr_lut_to_dense(compressed_model_coreml, lut_shape_map) + + +# endregion + + +# region PTP - vector +@pytest.mark.parametrize( + "config, lut_shape_map", + [ + pytest.param( + { + "module_name_configs": { + "conv2": { + "n_bits": 4, + "granularity": "per_tensor", + "cluster_dim": 4, + } + }, + }, + { + "conv2": (1, 1, 1, 1, 16, 4), + }, + marks=pytest.mark.xfail( + reason="rdar://124474258 ([Compression] Support Vector Palettization in coremltools)" + ), + id="4bits_vector_4", + ), + ], +) +@pytest.mark.skip( + "rdar://128875026 ([Compression] Per-channel post training palettization conversion throwing a SIGABORT)" +) +@pytest.mark.parametrize("algorithm", ["PTP"]) +def test_post_training_palettization_vector( + mnist_model, + mnist_example_input, + mnist_example_output, + config, + lut_shape_map, + algorithm, +): + compressed_model = get_compressed_model( + algorithm, mnist_model, mnist_example_input, mnist_example_output, config + ) + + # Validate on torch model. + _cluster_dim = config["module_name_configs"]["conv2"]["cluster_dim"] + weight_sample = compressed_model.conv2.weight.reshape(-1, _cluster_dim) + + _n_bits = config["module_name_configs"]["conv2"]["n_bits"] + max_unique_values = 2**_n_bits + assert len(torch.unique(weight_sample, dim=0)) <= max_unique_values + + compressed_model_coreml = util.convert_and_verify( + compressed_model, + mnist_example_input, + expected_ops=["constexpr_lut_to_dense"], + ) + verify_op_constexpr_lut_to_dense(compressed_model_coreml, lut_shape_map) + + +# endregion + + +# region HelperMethods +def get_palettized_model(palettizer): + palettizer.prepare(inplace=True) + palettizer.step() + model = palettizer.finalize() + return model + + +def get_compressed_model(algorithm, mnist_model, mnist_example_input, mnist_example_output, config): + if algorithm == "SKM": + return get_compressed_model_for_skm( + mnist_model, mnist_example_input, mnist_example_output, config + ) + else: + return get_compressed_model_for_ptp(mnist_model, config) + + +# Get a compressed MNIST model with SKMPalettizer and calibration data. +def get_compressed_model_for_skm(mnist_model, mnist_example_input, mnist_example_output, config): + palettizer_config = SKMPalettizerConfig.from_dict(config) + + def calibration_loader(): + yield mnist_example_input + + def loss_fn(mnist_model, mnist_example_input): + out = mnist_model(mnist_example_input) + return nn.functional.mse_loss(out, mnist_example_output) + + compressor = SKMPalettizer(mnist_model, palettizer_config) + compressed_model = compressor.compress(dataloader=calibration_loader(), loss_fn=loss_fn) + return compressed_model + + +# Get a compressed MNIST model with PostTrainingPalettization +def get_compressed_model_for_ptp(mnist_model, config): + palettizer_config = PostTrainingPalettizerConfig.from_dict(config) + compressor = PostTrainingPalettizer(mnist_model, palettizer_config) + compressed_model = compressor.compress() + return compressed_model + + +def verify_op_constexpr_lut_to_dense(coreml_model, per_layer_lut_shape): + compressed_ops = coreml_model._mil_program.functions["main"].find_ops( + op_type="constexpr_lut_to_dense" + ) + assert len(compressed_ops) >= 1 + + # Verify if number of bits is correct. + # For palettization, it's hidden in the shape of LUT. + for compressed_op in compressed_ops: + layer_name = compressed_op.name.split("_weight")[0] + assert compressed_op.lut.shape == per_layer_lut_shape[layer_name] + +# endregion diff --git a/coremltools/test/optimize/torch/conversion/pruning/__init__.py b/coremltools/test/optimize/torch/conversion/pruning/__init__.py new file mode 100644 index 000000000..5dc5e6747 --- /dev/null +++ b/coremltools/test/optimize/torch/conversion/pruning/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/conversion/pruning/test_pruning_conversion.py b/coremltools/test/optimize/torch/conversion/pruning/test_pruning_conversion.py new file mode 100644 index 000000000..61502c39a --- /dev/null +++ b/coremltools/test/optimize/torch/conversion/pruning/test_pruning_conversion.py @@ -0,0 +1,88 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import pytest + +ct = pytest.importorskip("coremltools") +import coremltools.test.optimize.torch.conversion.conversion_utils as util +from coremltools.optimize.torch.pruning import MagnitudePruner, MagnitudePrunerConfig + + +# region MagnitudePruner +@pytest.mark.parametrize( + "config", + [ + pytest.param( + { + "global_config": { + "initial_sparsity": 0.5, + "target_sparsity": 0.5, + } + }, + id="unstructured_sparsity", + ), + pytest.param( + { + "global_config": { + "initial_sparsity": 0.5, + "target_sparsity": 0.5, + "block_size": 2, + } + }, + id="block_structured_sparsity", + ), + pytest.param( + { + "global_config": { + "initial_sparsity": 0.5, + "target_sparsity": 0.5, + "n_m_ratio": (1, 2), + } + }, + id="n_m_structured_sparsity", + ), + pytest.param( + { + "global_config": { + "initial_sparsity": 0.5, + "target_sparsity": 0.5, + "granularity": "per_channel", + } + }, + id="general_structured_sparsity", + ), + ], +) +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), reason="Only supported on macOS 15+") +def test_magnitude_pruner(config, mnist_model, mnist_example_input): + pruner_config = MagnitudePrunerConfig.from_dict(config) + pruner = MagnitudePruner(mnist_model, pruner_config) + pruned_model = get_pruned_model(pruner) + + util.convert_and_verify( + pruned_model, + mnist_example_input, + pass_pipeline=ct.PassPipeline.DEFAULT_PRUNING, + expected_ops=["constexpr_sparse_to_dense"], + ) + +# endregion + +# region GlobalUnstructuredPruner + +# endregion + +# region STRPruner + +# endregion + + +# region HelperMethods +def get_pruned_model(pruner): + pruner.prepare(inplace=True) + pruner.step() + return pruner.finalize() + +# endregion diff --git a/coremltools/test/optimize/torch/conversion/quantization/__init__.py b/coremltools/test/optimize/torch/conversion/quantization/__init__.py new file mode 100644 index 000000000..5dc5e6747 --- /dev/null +++ b/coremltools/test/optimize/torch/conversion/quantization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/conversion/quantization/test_quantization_conversion.py b/coremltools/test/optimize/torch/conversion/quantization/test_quantization_conversion.py new file mode 100644 index 000000000..579b7642b --- /dev/null +++ b/coremltools/test/optimize/torch/conversion/quantization/test_quantization_conversion.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import pytest + +ct = pytest.importorskip("coremltools") +import torch.nn as nn + +import coremltools.test.optimize.torch.conversion.conversion_utils as util +from coremltools.optimize.torch.layerwise_compression import ( + LayerwiseCompressor, + LayerwiseCompressorConfig, +) +from coremltools.optimize.torch.quantization import LinearQuantizer, LinearQuantizerConfig + + +# region LinearQuantizer +@pytest.mark.parametrize( + "config", + [ + pytest.param( + {"global_config": {"quantization_scheme": "symmetric"}}, + id="symmetric_per_tensor", + ), + pytest.param({"global_config": {"quantization_scheme": "affine"}}, id="affine_per_tensor"), + pytest.param( + { + "global_config": { + "weight_dtype": "qint4", + "quantization_scheme": "symmetric", + } + }, + id="4bit_symmetric_per_tensor", + ), + pytest.param( + { + "global_config": { + "weight_dtype": "qint4", + "quantization_scheme": "affine", + } + }, + id="4bit_affine_per_tensor", + ), + ], +) +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), reason="Only supported on macOS 15+") +def test_linear_quantizer(config, mnist_model, mnist_example_input): + quantizer_config = LinearQuantizerConfig.from_dict(config) + quantizer = LinearQuantizer(mnist_model, quantizer_config) + quantized_model = get_quantized_model(quantizer, mnist_example_input) + + util.convert_and_verify( + quantized_model, + mnist_example_input, + expected_ops=["constexpr_blockwise_shift_scale"], + ) + + +# endregion + + +# region GPTQ +@pytest.mark.parametrize( + "config", + [ + pytest.param( + {"global_config": {"algorithm": "gptq", "weight_dtype": "uint4"}}, + id="4bit", + ), + pytest.param( + { + "module_type_configs": { + nn.Linear: { + "algorithm": "gptq", + "weight_dtype": "uint8", + "block_size": 32, + "granularity": "per_block", + } + } + }, + id="blockwise", + ), + pytest.param( + { + "module_type_configs": { + nn.Linear: { + "algorithm": "gptq", + "weight_dtype": "uint4", + "block_size": 32, + "granularity": "per_block", + } + } + }, + id="4bit_blockwise", + ), + ], +) +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), reason="Only supported on macOS 15+") +def test_gptq(config, mnist_model, mnist_example_input): + compressor_config = LayerwiseCompressorConfig.from_dict(config) + compressor = LayerwiseCompressor(mnist_model, compressor_config) + + def calibration_loader(): + yield mnist_example_input + + compressed_model = compressor.compress(calibration_loader(), device="cpu") + + util.convert_and_verify( + compressed_model, + mnist_example_input, + expected_ops=["constexpr_blockwise_shift_scale"], + ) +# endregion + + +# region PTQ +@pytest.mark.parametrize( + "config", + [ + pytest.param( + {"global_config": {"weight_dtype": "int4", "granularity": "per_tensor"}}, + id="4bit_per_tensor", + ), + pytest.param( + {"global_config": {"weight_dtype": "int4", "granularity": "per_channel"}}, + id="4bit_per_channel", + ), + pytest.param( + { + "global_config": { + "weight_dtype": "int4", + "granularity": "per_block", + "block_size": 16, + } + }, + id="4bit_per_block", + ), + ], +) +@pytest.mark.skipif(ct.utils._macos_version() < (15, 0), reason="Only supported on macOS 15+") +def test_ptq(mnist_model, mnist_example_input, config): + pytest.importorskip("coremltools.optimize.coreml._utils") + from coremltools.optimize.torch.quantization.post_training_quantization import ( + PostTrainingQuantizer, + PostTrainingQuantizerConfig, + ) + + model = mnist_model + ptq_config = PostTrainingQuantizerConfig.from_dict(config) + ptq = PostTrainingQuantizer(model, ptq_config) + compressed_model = ptq.compress() + + util.convert_and_verify( + compressed_model, + mnist_example_input, + expected_ops=["constexpr_blockwise_shift_scale"], + ) + + +# endregion + +# region HelperMethods + +def get_quantized_model(quantizer, example_input): + quantizer.prepare(example_inputs=(example_input,), inplace=True) + quantizer.step() + model = quantizer.finalize() + return model + + +# endregion diff --git a/coremltools/test/optimize/torch/layerwise_compression/__init__.py b/coremltools/test/optimize/torch/layerwise_compression/__init__.py new file mode 100644 index 000000000..5dc5e6747 --- /dev/null +++ b/coremltools/test/optimize/torch/layerwise_compression/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/layerwise_compression/test_algorithms.py b/coremltools/test/optimize/torch/layerwise_compression/test_algorithms.py new file mode 100644 index 000000000..8c1c48548 --- /dev/null +++ b/coremltools/test/optimize/torch/layerwise_compression/test_algorithms.py @@ -0,0 +1,285 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from contextlib import nullcontext as does_not_raise + +import pytest +import torch +import torch.nn as nn +from attr import define, field, validators + +from coremltools.optimize.torch._utils.metadata_utils import ( + METADATA_VERSION, + METADATA_VERSION_BUFFER, + CompressionMetadata, + CompressionType, +) +from coremltools.optimize.torch.layerwise_compression import ( + LayerwiseCompressor, + LayerwiseCompressorConfig, +) +from coremltools.optimize.torch.layerwise_compression.algorithms import ( + GPTQ, + LayerwiseCompressionAlgorithmConfig, + ModuleGPTQConfig, + ModuleSparseGPTConfig, +) + + +@pytest.mark.parametrize( + "global_config_and_class", + [ + ({"algorithm": "gptq", "weight_dtype": "uint4"}, ModuleGPTQConfig), + ( + { + "algorithm": "sparse_gpt", + "weight_dtype": "uint8", + "target_sparsity": 0.25, + }, + ModuleSparseGPTConfig, + ), + ], +) +def test_obs_compression_algorithm_config(global_config_and_class): + """ + Test the registry-based configuration of the :py:class:`LayerwiseCompressionAlgorithmConfig` + using :py:class:`LayerwiseCompressorConfig` + """ + + global_config, class_type = global_config_and_class + # compress + config = LayerwiseCompressorConfig.from_dict( + { + "global_config": global_config, + "input_cacher": "default", + "calibration_nsamples": 128, + } + ) + algo = global_config.get("algorithm") + algo_class = LayerwiseCompressionAlgorithmConfig.get_class(algo) + assert algo_class == class_type + assert isinstance(config.global_config, class_type) + + +def test_custom_obs_compression_algorithm_config(): + @LayerwiseCompressionAlgorithmConfig.register("foo") + @define + class FooConfig(LayerwiseCompressionAlgorithmConfig): + bar: str = field(default=None, validator=validators.instance_of(str)) + algorithm: str = field(default="foo", validator=validators.instance_of(str)) + + config = LayerwiseCompressorConfig.from_dict( + {"global_config": {"algorithm": "foo", "bar": "baz"}} + ) + + assert isinstance(config.global_config, FooConfig) + assert config.global_config.bar == "baz" + + +@pytest.mark.parametrize( + "input_size, expectation", + [ + (512, does_not_raise()), + (1024, does_not_raise()), + (480, pytest.raises(ValueError)), + (960, pytest.raises(ValueError)), + ], +) +def test_block_size_validation_gptq(input_size, expectation): + """ + Test handling of block_size configuration for GPTQ algorithm + """ + config = ModuleGPTQConfig.from_dict( + { + "algorithm": "gptq", + "weight_dtype": "uint8", + "block_size": 128, + "granularity": "per_block", + } + ) + + _model = nn.Transformer(d_model=input_size, nhead=8) + layer = _model.encoder.layers.get_submodule("0.linear1") + + with expectation: + gptq = GPTQ(layer, config) + assert gptq is not None + + +@pytest.mark.parametrize( + "config", + [ + {"global_config": {"algorithm": "gptq", "weight_dtype": "uint4"}}, + { + "global_config": { + "algorithm": "gptq", + "weight_dtype": "uint8", + "block_size": 16, + "granularity": "per_block", + } + }, + { + "global_config": { + "algorithm": "gptq", + "weight_dtype": "uint4", + "enable_normal_float": True, + } + }, + { + "global_config": { + "algorithm": "gptq", + "weight_dtype": "uint3", + "enable_normal_float": True, + } + }, + ], +) +def test_gptq_metadata(config): + """ + Test registration of metadata buffers for GPTQ algorithm + """ + # Setup to get compressed model + model = nn.Sequential(nn.Linear(4096, 1024)) + compressor_config = LayerwiseCompressorConfig.from_dict(config) + compressor = LayerwiseCompressor(model, compressor_config) + + def calibration_loader(): + yield torch.rand(1, 4096) + + compressed_model = compressor.compress(calibration_loader(), device="cpu") + + # Extract registered metadata from state_dict + state_dict = compressed_model[0].state_dict() + metadata_dict = CompressionMetadata.from_state_dict(state_dict) + assert len(metadata_dict) == 1 + assert "weight" in metadata_dict + + # Verification + metadata = metadata_dict["weight"] + if compressor_config.global_config.enable_normal_float: + assert metadata.compression_type == [CompressionType.palettization.value] + assert metadata.lut.shape == ( + 1, + 1, + 2**compressor_config.global_config.weight_n_bits, + 1, + ) + assert metadata.palettization_scale.shape == (state_dict["weight"].shape[0], 1) + else: + assert metadata.compression_type == [CompressionType.quantization.value] + assert metadata.quantization_n_bits == compressor_config.global_config.weight_n_bits + assert metadata.zero_point.shape == metadata.quantization_scale.shape + assert metadata.quantization_scale.shape[0] == state_dict["weight"].shape[0] + block_size = compressor_config.global_config.block_size + if block_size is None: + assert metadata.quantization_scale.shape[1] == 1 + else: + assert ( + metadata.quantization_scale.shape[1] == state_dict["weight"].shape[1] / block_size + ) + + assert METADATA_VERSION_BUFFER in compressed_model.state_dict() + assert torch.equal(compressed_model.state_dict()[METADATA_VERSION_BUFFER], METADATA_VERSION) + + +@pytest.mark.parametrize( + "config", + [ + pytest.param({"global_config": {"algorithm": "sparse_gpt"}}, id="pruning"), + pytest.param( + {"global_config": {"algorithm": "sparse_gpt", "weight_dtype": "uint8"}}, + id="pruning_quantization", + ), + pytest.param( + { + "global_config": { + "algorithm": "sparse_gpt", + "weight_dtype": "uint4", + "enable_normal_float": True, + } + }, + id="pruning_palettization", + ), + ], +) +def test_sparse_gpt_metadata(config): + """ + Test registration of metadata buffers for GPTQ algorithm + """ + # Setup to get compressed model + model = nn.Sequential(nn.Linear(4096, 1024)) + compressor_config = LayerwiseCompressorConfig.from_dict(config) + compressor = LayerwiseCompressor(model, compressor_config) + + def calibration_loader(): + yield torch.rand(1, 4096) + + compressed_model = compressor.compress(calibration_loader(), device="cpu") + + # Extract registered metadata from state_dict + state_dict = compressed_model[0].state_dict() + metadata_dict = CompressionMetadata.from_state_dict(state_dict) + assert len(metadata_dict) == 1 + assert "weight" in metadata_dict + + # Verification + metadata = metadata_dict["weight"] + if compressor_config.global_config.enable_normal_float: + assert metadata.compression_type == [ + CompressionType.pruning.value, + CompressionType.palettization.value, + ] + assert metadata.lut.shape == ( + 1, + 1, + 2**compressor_config.global_config.weight_n_bits, + 1, + ) + assert metadata.palettization_scale.shape == (state_dict["weight"].shape[0], 1) + elif ( + compressor_config.global_config.weight_n_bits is not None + and compressor_config.global_config.weight_n_bits < 16 + ): + assert metadata.compression_type == [ + CompressionType.pruning.value, + CompressionType.quantization.value, + ] + assert metadata.quantization_n_bits == compressor_config.global_config.weight_n_bits + assert metadata.zero_point.shape == metadata.quantization_scale.shape + + assert METADATA_VERSION_BUFFER in compressed_model.state_dict() + assert torch.equal(compressed_model.state_dict()[METADATA_VERSION_BUFFER], METADATA_VERSION) + + +@pytest.mark.parametrize( + "config", + [ + { + "global_config": { + "algorithm": "gptq", + "weight_dtype": "uint8", + "block_size": 16, + "granularity": "per_block", + } + }, + { + "global_config": { + "algorithm": "gptq", + "weight_dtype": "uint8", + "block_size": None, + "granularity": "per_block", + } + }, + ], +) +def test_gptq_block_size_configs(config): + model = nn.Sequential(nn.Linear(4096, 1024)) + compressor_config = LayerwiseCompressorConfig.from_dict(config) + compressor = LayerwiseCompressor(model, compressor_config) + + def calibration_loader(): + yield torch.rand(1, 4096) + + compressed_model = compressor.compress(calibration_loader(), device="cpu") diff --git a/coremltools/test/optimize/torch/models/__init__.py b/coremltools/test/optimize/torch/models/__init__.py index 25c7d28c5..5dc5e6747 100644 --- a/coremltools/test/optimize/torch/models/__init__.py +++ b/coremltools/test/optimize/torch/models/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/models/mnist.py b/coremltools/test/optimize/torch/models/mnist.py index c7679c49e..bba511e9c 100644 --- a/coremltools/test/optimize/torch/models/mnist.py +++ b/coremltools/test/optimize/torch/models/mnist.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -6,13 +6,15 @@ # type: ignore import os from collections import OrderedDict -from coremltools.test.optimize.torch.utils import test_data_path import pytest +import torch import torch.nn as nn from filelock import FileLock from torchvision import datasets, transforms +from coremltools.test.optimize.torch.utils import test_data_path + # IMPORTANT: DO NOT import these fixtures in your tests. # That leads pytest to run the fixtures (even session-scoped) multiple times. # These have been imported into conftest.py, which makes them available for all @@ -22,22 +24,37 @@ num_classes = 10 +@pytest.fixture() +def mnist_example_input(): + return torch.rand(1, 1, 28, 28) + + +@pytest.fixture() +def mnist_example_output(): + return torch.rand(1, num_classes) + + @pytest.fixture def mnist_model(): - return nn.Sequential(OrderedDict([ - ('conv1', nn.Conv2d(1, 32, (5, 5), padding='same')), - ('relu1', nn.ReLU()), - ('pool1', nn.MaxPool2d(2, stride=2, padding=0)), - ('bn1', nn.BatchNorm2d(32, eps=0.001, momentum=0.01)), - ('conv2', nn.Conv2d(32, 64, (5, 5), padding='same')), - ('relu2', nn.ReLU()), - ('pool2', nn.MaxPool2d(2, stride=2, padding=0)), - ('flatten', nn.Flatten()), - ('dense1', nn.Linear(3136, 1024)), - ('relu3', nn.ReLU()), - ('dropout', nn.Dropout(p=0.4)), - ('dense2', nn.Linear(1024, num_classes)), - ('softmax', nn.LogSoftmax())])) + return nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(1, 32, (5, 5), padding=2)), + ("relu1", nn.ReLU()), + ("pool1", nn.MaxPool2d(2, stride=2, padding=0)), + ("bn1", nn.BatchNorm2d(32, eps=0.001, momentum=0.01)), + ("conv2", nn.Conv2d(32, 64, (5, 5), padding=2)), + ("relu2", nn.ReLU()), + ("pool2", nn.MaxPool2d(2, stride=2, padding=0)), + ("flatten", nn.Flatten()), + ("dense1", nn.Linear(3136, 1024)), + ("relu3", nn.ReLU()), + ("dropout", nn.Dropout(p=0.4)), + ("dense2", nn.Linear(1024, num_classes)), + ("softmax", nn.LogSoftmax()), + ] + ) + ) @pytest.fixture @@ -60,6 +77,61 @@ def mnist_model_quantization(): ('softmax', nn.LogSoftmax())])) +class Residual(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inputs): + return self.module(inputs) + inputs + + +@pytest.fixture +def residual_mnist_model(): + return nn.Sequential( + OrderedDict( + [ + ( + "conv1", + nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False), + ), + ("bn1", nn.BatchNorm2d(64)), + ("relu1", nn.ReLU()), + ("pool1", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ( + "add1", + Residual( + nn.Sequential( + OrderedDict( + [ + ( + "conv2", + nn.Conv2d(64, 64, kernel_size=1, stride=1, bias=False), + ), + ("bn2", nn.BatchNorm2d(64)), + ("relu2", nn.ReLU()), + ( + "conv3", + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + ), + ("bn3", nn.BatchNorm2d(64)), + ] + ) + ) + ), + ), + ("relu3", nn.ReLU()), + ("flatten", nn.Flatten()), + ("dense1", nn.Linear(3136, 1024)), + ("relu4", nn.ReLU()), + ("dropout", nn.Dropout(p=0.4)), + ("dense2", nn.Linear(1024, num_classes)), + ("softmax", nn.LogSoftmax()), + ] + ) + ) + + @pytest.fixture def mnist_model_large(): """ @@ -85,6 +157,31 @@ def mnist_model_large(): ('softmax', nn.LogSoftmax())])) +def LeNet5(): + """ + Original LeNet5 model for MNIST with sigmoid activations. + """ + return nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(1, 6, 5, 1, 2)), + ("sigmoid1", nn.Sigmoid()), + ("pool1", nn.AvgPool2d(2, 2)), + ("conv2", nn.Conv2d(6, 16, 5, 1, 0)), + ("sigmoid2", nn.Sigmoid()), + ("pool2", nn.AvgPool2d(2, 2)), + ("flatten", nn.Flatten()), + ("dense1", nn.Linear(5 * 5 * 16, 120)), + ("sigmoid3", nn.Sigmoid()), + ("dense2", nn.Linear(120, 84)), + ("sigmoid4", nn.Sigmoid()), + ("dense3", nn.Linear(84, num_classes)), + ("softmax", nn.LogSoftmax(dim=1)), + ] + ) + ) + + @pytest.fixture(scope="session") def mnist_dataset(): transform = transforms.Compose([ diff --git a/coremltools/test/optimize/torch/palettization/__init__.py b/coremltools/test/optimize/torch/palettization/__init__.py index 25c7d28c5..5dc5e6747 100644 --- a/coremltools/test/optimize/torch/palettization/__init__.py +++ b/coremltools/test/optimize/torch/palettization/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/palettization/palettization_utils.py b/coremltools/test/optimize/torch/palettization/palettization_utils.py index 889e93dcb..13db92cf4 100644 --- a/coremltools/test/optimize/torch/palettization/palettization_utils.py +++ b/coremltools/test/optimize/torch/palettization/palettization_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -12,9 +12,14 @@ def _assert_changes_post_attach(module, n_bits, cluster_dim): assert module.qconfig.weight.p.keywords["cluster_dim"] == cluster_dim -def _assert_changes_post_prepare(original_module, palettized_module, n_bits, cluster_dim, kmeans_max_iter): - assert type(palettized_module) == quantization_mappings.DEFAULT_QAT_MODULE_MAPPINGS[type(original_module)] - assert palettized_module.weight_fake_quant.n_clusters[0] == 2 ** n_bits +def _assert_changes_post_prepare( + original_module, palettized_module, n_bits, cluster_dim, kmeans_max_iter +): + assert ( + type(palettized_module) + == quantization_mappings.DEFAULT_QAT_MODULE_MAPPINGS[type(original_module)] + ) + assert palettized_module.weight_fake_quant.n_clusters == 2**n_bits assert palettized_module.weight_fake_quant.cluster_dim == cluster_dim assert palettized_module.weight_fake_quant.kmeans_max_iter == kmeans_max_iter diff --git a/coremltools/test/optimize/torch/palettization/test_palettization_api.py b/coremltools/test/optimize/torch/palettization/test_palettization_api.py index 5f377fa15..689d3b921 100644 --- a/coremltools/test/optimize/torch/palettization/test_palettization_api.py +++ b/coremltools/test/optimize/torch/palettization/test_palettization_api.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -7,10 +7,14 @@ import pytest import torch -import torch.functional as F import torch.nn as nn +import torch.nn.functional as F -from coremltools.optimize.torch.palettization import DKMPalettizer, DKMPalettizerConfig +from coremltools.optimize.torch.palettization import ( + DKMPalettizer, + DKMPalettizerConfig, + ModuleDKMPalettizerConfig, +) from coremltools.optimize.torch.palettization.palettization_config import ( DEFAULT_PALETTIZATION_SCHEME, ) @@ -19,6 +23,16 @@ _assert_changes_post_prepare, ) +REGEX_YAML = """ +module_name_configs: + conv\d+: + - n_bits: 4 + weight_threshold: 400 + palett_tau: 0.000004 + - n_bits: 2 + weight_threshold: 1000 + palett_tau: 0.000004 +""" def _create_simple_model(): class Net(nn.Module): @@ -75,6 +89,82 @@ def test_inplace_false_attach_config(simple_model): ) +def test_empty_dict_for_config(simple_model): + ## This test should behave the same as that when a None config is passed to DKMPalettizer + config = DKMPalettizerConfig.from_dict({}) + palettizer = DKMPalettizer(simple_model, config) + prepared_model = palettizer.prepare() + + assert not hasattr(simple_model.conv1, "qconfig") + assert not hasattr(simple_model.conv2, "qconfig") + assert not hasattr(simple_model.fc1, "qconfig") + assert not hasattr(simple_model.fc2, "qconfig") + assert not hasattr(simple_model.fc3, "qconfig") + + _assert_changes_post_attach( + prepared_model.conv2, + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.conv2)]["n_bits"], + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.conv2)]["cluster_dim"], + ) + _assert_changes_post_attach( + prepared_model.fc1, + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.fc1)]["n_bits"], + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.fc1)]["cluster_dim"], + ) + _assert_changes_post_attach( + prepared_model.fc2, + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.fc2)]["n_bits"], + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.fc2)]["cluster_dim"], + ) + + +@pytest.fixture(scope="session") +def test_empty_yaml_for_config(simple_model, tmp_path_factory): + ## This test should behave the same as that when a None config is passed to DKMPalettizer + fname = tmp_path_factory.mktemp("test_configs") / "empty_config.yaml" + with open(fname, "w") as file: + file.write("\n") + config = DKMPalettizerConfig.from_yaml(fname) + palettizer = DKMPalettizer(simple_model, config) + prepared_model = palettizer.prepare() + + assert not hasattr(simple_model.conv1, "qconfig") + assert not hasattr(simple_model.conv2, "qconfig") + assert not hasattr(simple_model.fc1, "qconfig") + assert not hasattr(simple_model.fc2, "qconfig") + assert not hasattr(simple_model.fc3, "qconfig") + + _assert_changes_post_attach( + prepared_model.conv2, + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.conv2)]["n_bits"], + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.conv2)]["cluster_dim"], + ) + _assert_changes_post_attach( + prepared_model.fc1, + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.fc1)]["n_bits"], + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.fc1)]["cluster_dim"], + ) + _assert_changes_post_attach( + prepared_model.fc2, + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.fc2)]["n_bits"], + DEFAULT_PALETTIZATION_SCHEME[type(simple_model.fc2)]["cluster_dim"], + ) + + +@pytest.fixture(scope="session") +def test_regex_module_name_configs(simple_model, tmp_path_factory): + fname = tmp_path_factory.mktemp("test_configs") / "regex_config.yaml" + with open(fname, "w") as file: + file.write(REGEX_YAML) + config = DKMPalettizerConfig.from_yaml(fname) + palettizer = DKMPalettizer(simple_model, config) + palettizer.prepare(inplace=True) + + assert hasattr(simple_model.fc1, "qconfig") and simple_model.fc1.qconfig is None + _assert_changes_post_attach(simple_model.conv1, 4, 1) + _assert_changes_post_attach(simple_model.conv2, 2, 1) + + def test_attach_config_simple_model_uniform_palettization_config(simple_model): config = DKMPalettizerConfig.from_dict({"global_config": {"n_bits": 4}}) palettizer = DKMPalettizer(simple_model, config) @@ -275,7 +365,6 @@ def test_inplace_true_prepare_palettizer(simple_model): ) - def test_prepare_palettizer_simple_model_custom_palettization_config_milestone_1(simple_model): custom_config = {nn.Conv2d: {"n_bits": 2, "cluster_dim": 2, "kmeans_max_iter": 4, "milestone": 1}, nn.Linear: {"n_bits": 4, "cluster_dim": 1, "kmeans_max_iter": 5, "milestone": 1}} @@ -561,3 +650,26 @@ def test_inplace_true_prepare_palettizer(simple_model): custom_config[nn.Linear]["cluster_dim"], custom_config[nn.Linear]["kmeans_max_iter"], ) + + +def test_quantize_activations_flag(simple_model): + config = DKMPalettizerConfig.from_dict( + {"global_config": {"n_bits": 2, "cluster_dim": 1, "quantize_activations": True}} + ) + + palettizer = DKMPalettizer(simple_model, config) + + palettizer.prepare() + for _ in range(3): + palettizer.step() + + assert not isinstance(palettizer._model.conv2.activation_post_process, torch.nn.Identity) + + +def test_deprecated_api(): + with pytest.raises(DeprecationWarning): + config = DKMPalettizerConfig.from_dict({"global_config": {"partition_size": 100}}) + + config = DKMPalettizerConfig(global_config=ModuleDKMPalettizerConfig()) + with pytest.raises(DeprecationWarning): + config.global_config.partition_size = 100 diff --git a/coremltools/test/optimize/torch/palettization/test_palettizer.py b/coremltools/test/optimize/torch/palettization/test_palettizer.py new file mode 100644 index 000000000..cc762aa2c --- /dev/null +++ b/coremltools/test/optimize/torch/palettization/test_palettizer.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import pytest +import torch + +from coremltools.optimize.torch.palettization import ( + DKMPalettizer, + DKMPalettizerConfig, + FakePalettize, + ModuleDKMPalettizerConfig, +) + + +@pytest.fixture +def palettizer_config(): + return DKMPalettizerConfig( + global_config=ModuleDKMPalettizerConfig(n_bits=4, cluster_dim=1, weight_threshold=0) + ) + + +@pytest.mark.parametrize( + "module", + [ + torch.nn.Conv1d(2, 10, (1,)), + torch.nn.Conv2d(2, 10, (2, 2)), + torch.nn.Conv3d(2, 10, (2, 2, 2)), + torch.nn.Linear(10, 20), + torch.nn.LayerNorm(10), + torch.nn.Embedding(10, 20), + ], +) +def test_fake_palettize_insertion_weighted_modules(module, palettizer_config): + wrapped_module = torch.nn.Sequential(module) + + palettizer = DKMPalettizer(wrapped_module, palettizer_config) + palettized_module = palettizer.prepare() + assert isinstance(palettized_module[0].weight_fake_quant, FakePalettize) + + +@pytest.mark.parametrize("kdim,vdim", [(None, None), (1, 1)]) +@pytest.mark.parametrize("batch_first", [True, False]) +def test_fake_palettize_insertion_multihead_attention(kdim, vdim, batch_first, palettizer_config): + attention_module = torch.nn.MultiheadAttention( + bias=True, + embed_dim=6, + num_heads=3, + add_bias_kv=True, + kdim=kdim, + vdim=vdim, + batch_first=batch_first, + ) + + class WrappedModule(torch.nn.Sequential): + def __init__(self, module): + super().__init__(module) + + def forward(self, query, key, value): + return self[0](query, key, value) + + wrapped_module = WrappedModule(attention_module) + + palettizer = DKMPalettizer(wrapped_module, palettizer_config) + palettized_module = palettizer.prepare(inplace=False) + palettizer.enable_fake_palett(True) + + query_shape = (2, 3, 6) + assert isinstance(palettized_module[0].out_proj.weight_fake_quant, FakePalettize) + assert palettized_module[0].out_proj.weight_fake_quant.fake_palett_enabled + if kdim is None and vdim is None: + assert isinstance(palettized_module[0].in_proj_weight_fake_quant, FakePalettize) + assert palettized_module[0].in_proj_weight_fake_quant.fake_palett_enabled + data_q = data_k = data_v = torch.randn(query_shape) + else: + assert isinstance(palettized_module[0].q_proj_weight_fake_quant, FakePalettize) + assert palettized_module[0].q_proj_weight_fake_quant.fake_palett_enabled + assert isinstance(palettized_module[0].k_proj_weight_fake_quant, FakePalettize) + assert palettized_module[0].k_proj_weight_fake_quant.fake_palett_enabled + assert isinstance(palettized_module[0].v_proj_weight_fake_quant, FakePalettize) + assert palettized_module[0].v_proj_weight_fake_quant.fake_palett_enabled + data_q = torch.randn(query_shape) + data_k = data_v = torch.randn(2, 3, 1) + + palettizer.enable_fake_palett(False) + output, _ = palettized_module(data_q, data_k, data_v) + if batch_first: + assert output.shape[0] == query_shape[0] + else: + assert output.shape[1] == query_shape[1] + palettizer.finalize() + assert torch.all(palettized_module[0].out_proj.bias == attention_module.out_proj.bias) + assert torch.all(palettized_module[0].in_proj_bias == attention_module.in_proj_bias) + assert torch.all(palettized_module[0].bias_k == attention_module.bias_k) + assert torch.all(palettized_module[0].bias_v == attention_module.bias_v) + # assert hasattr() + + +@pytest.mark.parametrize("module", [torch.nn.Conv1d(2, 10, (1,))]) +def test_fake_palettize_train_no_grad_fwd(module, palettizer_config): + wrapped_module = torch.nn.Sequential(module) + + palettizer = DKMPalettizer(wrapped_module, palettizer_config) + palettized_module = palettizer.prepare() + palettized_module.train() + palettizer.step() + with torch.no_grad(): + palettized_module(torch.randn(3, 2, 10)) diff --git a/coremltools/test/optimize/torch/palettization/test_post_training_palettization.py b/coremltools/test/optimize/torch/palettization/test_post_training_palettization.py new file mode 100644 index 000000000..151074757 --- /dev/null +++ b/coremltools/test/optimize/torch/palettization/test_post_training_palettization.py @@ -0,0 +1,229 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import copy + +import pytest +import torch +import torch.functional as F +import torch.nn as nn + +from coremltools.optimize.torch._utils.metadata_utils import CompressionMetadata +from coremltools.optimize.torch.palettization import ( + PostTrainingPalettizer, + PostTrainingPalettizerConfig, + SKMPalettizer, + SKMPalettizerConfig, +) + + +@pytest.fixture +def simple_model(): + class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + return Net() + + +def test_no_config(simple_model): + # Would do a 4-bit kmeans for all supported modules after giving a warning + ptpalettizer = PostTrainingPalettizer(simple_model) + palettized_model = ptpalettizer.compress() + assert palettized_model.conv1.weight.unique().size()[0] == 16 + assert palettized_model.conv2.weight.unique().size()[0] == 16 + assert palettized_model.fc1.weight.unique().size()[0] == 16 + assert palettized_model.fc2.weight.unique().size()[0] == 16 + assert palettized_model.fc3.weight.unique().size()[0] == 16 + + +@pytest.mark.parametrize( + "config_dict,expected_output", + [ + ( + {"global_config": {"n_bits": 4}}, + ["==16", "==16", "==16", "==16", "==16"], + ), + ( + { + "module_name_configs": { + "conv1": {"n_bits": 4}, + "fc1": {"n_bits": 2}, + }, + }, + ["==16", ">16", "==4", ">4", ">4"], + ), + ( + { + "module_type_configs": { + nn.Conv2d: {"n_bits": 4}, + nn.Linear: {"n_bits": 2}, + }, + }, + ["==16", "==16", "==4", "==4", "==4"], + ), + ( + { + "module_type_configs": { + # Invalid cluster_dim gets ignored. + # Conv2d should be skipped + nn.Conv2d: {"n_bits": 4, "cluster_dim": 5}, + nn.Linear: {"n_bits": 2}, + }, + }, + [">16", ">16", "==4", "==4", "==4"], + ), + ], +) +def test_post_training_palettization_dict_config(simple_model, config_dict, expected_output): + dict_config = PostTrainingPalettizerConfig.from_dict(config_dict) + ptpalettizer = PostTrainingPalettizer(simple_model, dict_config) + palettized_model = ptpalettizer.compress() + i = 0 + for name, mod in palettized_model.named_modules(): + if hasattr(mod, "weight"): + assert eval(f"mod.weight.unique().size()[0] {expected_output[i]}") + i += 1 + + +@pytest.mark.parametrize( + "config_dict,expected_output", + [ + ( + { + "module_name_configs": { + "conv1": { + "n_bits": 4, + "granularity": "per_tensor", + "cluster_dim": 3, + }, + "conv2": { + "n_bits": 4, + "granularity": "per_tensor", + "cluster_dim": 4, + }, + "fc3": { + "n_bits": 2, + "granularity": "per_tensor", + "cluster_dim": 2, + }, + }, + }, + ["==16", ">16", "==4"], + ), + ], +) +def test_post_training_vector_palettization_dict_config(simple_model, config_dict, expected_output): + dict_config = PostTrainingPalettizerConfig.from_dict(config_dict) + ptpalettizer = PostTrainingPalettizer(simple_model, dict_config) + palettized_model = ptpalettizer.compress() + i = 0 + for name, mod in palettized_model.named_modules(): + # Only validate the layers that get palettized. + if name in config_dict["module_name_configs"] and hasattr(mod, "weight"): + _cluster_dim = config_dict["module_name_configs"][name]["cluster_dim"] + weight_reshaped = mod.weight.flatten(1).reshape(-1, _cluster_dim) + unique_vector = torch.unique(weight_reshaped, dim=0) + assert eval(f"len(unique_vector) {expected_output[i]}") + i += 1 + + +@pytest.mark.parametrize( + "config_dict", + [ + { + "n_bits": 4, + "granularity": "per_tensor", + }, + { + "n_bits": 4, + "granularity": "per_grouped_channel", + "group_size": 4, + }, + { + "n_bits": 4, + "cluster_dim": 3, + }, + { + "n_bits": 4, + "granularity": "per_grouped_channel", + "group_size": 4, + "enable_per_channel_scale": True, + }, + ], +) +@pytest.mark.parametrize( + "lut_dtype", + [torch.int8, torch.uint8], +) +@pytest.mark.parametrize( + "layer", + ["conv2", "fc2"], +) +def test_ptp_int_lut(simple_model, config_dict, lut_dtype, layer): + config_dict["lut_dtype"] = lut_dtype + module_config = {"module_name_configs": {layer: config_dict}} + config = PostTrainingPalettizerConfig.from_dict(module_config) + ptpalettizer = PostTrainingPalettizer(simple_model, config) + palettized_model = ptpalettizer.compress() + + submodule = palettized_model.get_submodule(layer) + metadata_dict = CompressionMetadata.from_state_dict(submodule.state_dict()) + metadata = metadata_dict["weight"] + assert metadata.quantization_n_bits == 8 + scale = metadata.quantization_scale + zp = metadata.zero_point + lut = metadata.lut + + if lut_dtype == torch.int8: + assert zp is None + lut_quant = lut / scale + assert torch.min(lut_quant).int() >= -127 + assert torch.max(lut_quant).int() <= 128 + else: + assert zp is not None + lut_quant = lut / scale + zp + assert torch.min(lut_quant).int() >= 0 + assert torch.max(lut_quant).int() <= 254 + + +def loss_fn(model, input): + out = model(input) + return nn.functional.mse_loss(out, torch.rand(1, 10)) + + +def test_compute_sensitivity_single_worker_mutability(mnist_model, mnist_example_input): + config = {"global_config": {"n_bits": 4}} + skm_config = SKMPalettizerConfig.from_dict(config) + palettizer = SKMPalettizer(mnist_model, skm_config) + + state_dict_before = copy.deepcopy(palettizer._model.state_dict()) + + def calibration_loader(): + yield mnist_example_input + + palettizer.compute_sensitivity( + dataloader=calibration_loader(), loss_fn=loss_fn, num_sensitivity_workers=1 + ) + + state_dict_after = palettizer._model.state_dict() + assert len(state_dict_before) == len(state_dict_after) + for key in state_dict_before: + assert torch.equal(state_dict_before[key], state_dict_after[key]) diff --git a/coremltools/test/optimize/torch/palettization/test_sensitive_k_means.py b/coremltools/test/optimize/torch/palettization/test_sensitive_k_means.py new file mode 100644 index 000000000..a73480ac5 --- /dev/null +++ b/coremltools/test/optimize/torch/palettization/test_sensitive_k_means.py @@ -0,0 +1,358 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from collections import OrderedDict +from contextlib import nullcontext +from typing import Any, Dict +from unittest.mock import ANY, Mock, patch + +import pytest +import torch + +from coremltools.optimize.torch._utils.fsdp_utils import ( + FSDPAutoWrapPolicy, + ModuleWrapPolicy, + SizeBasedWrapPolicy, +) +from coremltools.optimize.torch._utils.k_means import KMeansConfig +from coremltools.optimize.torch.palettization.sensitive_k_means import ( + ModuleSKMPalettizerConfig, + SKMPalettizer, + SKMPalettizerConfig, +) + + +@pytest.mark.parametrize( + "auto_wrap_policy", + [ + ModuleWrapPolicy(module_classes=torch.nn.Linear), + SizeBasedWrapPolicy(min_num_params=1000), + None, + ], +) +@pytest.mark.parametrize("num_sensitivity_workers", [1, 8]) +@pytest.mark.parametrize("num_kmeans_workers", [1, 8]) +def test_fsdp_auto_wrap_policy_compress_call( + mocker, num_kmeans_workers, num_sensitivity_workers, auto_wrap_policy +): + """ + Test compress passes fsdp_auto_wrap_policy argument correctly to + compute_sensitivity method. + """ + mock_compute_sensitivity = Mock(return_value={"weight": None}) + + mocker.patch.object(SKMPalettizer, "compute_sensitivity", mock_compute_sensitivity) + mocker.patch("coremltools.optimize.torch.palettization.sensitive_k_means._ParallelKMeans") + mocker.patch("coremltools.optimize.torch.palettization.sensitive_k_means._SequentialKMeans") + + model = torch.nn.Linear(5, 10) + palettizer = SKMPalettizer(model) + + palettizer.compress( + num_sensitivity_workers=num_sensitivity_workers, + fsdp_auto_wrap_policy=auto_wrap_policy, + num_kmeans_workers=num_kmeans_workers, + ) + + mock_compute_sensitivity.assert_called_once_with( + None, + None, + None, + num_sensitivity_workers, + fsdp_auto_wrap_policy=auto_wrap_policy, + ) + + +@pytest.mark.parametrize( + "auto_wrap_policy", + [ + ModuleWrapPolicy(module_classes=torch.nn.Linear), + SizeBasedWrapPolicy(min_num_params=1000), + None, + ], +) +@pytest.mark.parametrize("num_sensitivity_workers", [1, 8]) +def test_fsdp_auto_wrap_policy_compute_sensitivity_call( + mocker, num_sensitivity_workers, auto_wrap_policy +): + """ + Test compute_sensitivity passes fsdp_auto_wrap_policy argument correctly to + impl methods + """ + model = torch.nn.Linear(5, 10) + + mocker.patch("coremltools.optimize.torch.palettization.sensitive_k_means._torch.save") + mocker.patch( + "coremltools.optimize.torch.palettization.sensitive_k_means._torch.load", + Mock(return_value=model.state_dict()), + ) + mocker.patch( + "coremltools.optimize.torch.palettization.sensitive_k_means._torch.cuda.is_available", + Mock(return_value=True), + ) + mock_ctx = Mock() + mocker.patch( + "coremltools.optimize.torch.palettization.sensitive_k_means._mp.get_context", + Mock(return_value=mock_ctx), + ) + mock_compute_sen_single_worker = Mock() + mocker.patch.object( + SKMPalettizer, + "_compute_sensitivity_impl_single_worker", + mock_compute_sen_single_worker, + ) + mock_dataset = Mock() + mocker.patch.object(SKMPalettizer, "_get_dataset", Mock(return_value=mock_dataset)) + mocker.patch.object(SKMPalettizer, "_process_sensitivity") + + palettizer = SKMPalettizer(model) + + dataloader = [] + loss_fn = lambda mod, dat: mod(dat) + + palettizer.compute_sensitivity( + dataloader=dataloader, + loss_fn=loss_fn, + sensitivity_path=None, + num_sensitivity_workers=num_sensitivity_workers, + fsdp_auto_wrap_policy=auto_wrap_policy, + ) + + if num_sensitivity_workers > 1: + for rank in range(num_sensitivity_workers): + mock_ctx.Process.assert_any_call( + target=palettizer._compute_sensitivity_impl_multiple_workers, + args=( + rank, + num_sensitivity_workers, + mock_dataset, + loss_fn, + None, + auto_wrap_policy, + ), + name=f"Process-{rank}", + ) + else: + mock_compute_sen_single_worker.assert_called_once_with(mock_dataset, loss_fn, None) + + +@pytest.mark.parametrize("auto_wrap_policy", [Mock(spec=FSDPAutoWrapPolicy), None]) +def test_fsdp_auto_wrap_policy_multi_worker_compute_sensitivity_call(mocker, auto_wrap_policy): + """ + Test _compute_sensitivity_impl_multiple_workers passes correct value of fsdp auto wrap policy + to FSDP call + """ + model = torch.nn.Linear(5, 10) + + mocker.patch("coremltools.optimize.torch.palettization.sensitive_k_means._torch") + mocker.patch("coremltools.optimize.torch.palettization.sensitive_k_means._ddp_setup") + mocker.patch( + "coremltools.optimize.torch.palettization.sensitive_k_means._is_leader", + Mock(return_value=True), + ) + mocker.patch.object( + SKMPalettizer, "_register_grad_square_hooks", Mock(return_value=nullcontext()) + ) + + if auto_wrap_policy is not None: + expected_auto_wrap_policy = Mock() + auto_wrap_policy.get_policy.return_value = expected_auto_wrap_policy + else: + expected_auto_wrap_policy = None + + with patch( + "coremltools.optimize.torch.palettization.sensitive_k_means._FSDP", autospec=True + ) as mock_fsdp: + mock_fsdp.state_dict_type.return_value = nullcontext() + + palettizer = SKMPalettizer(model) + + palettizer._compute_sensitivity_impl_multiple_workers( + rank=0, + num_workers=1, + dataset=[None], + loss_fn=Mock(), + sensitivity_path=None, + fsdp_auto_wrap_policy=auto_wrap_policy, + ) + + # test FSDP either gets None or output of get_policy method on the + # FSDPAutoWrapPolicy object + mock_fsdp.assert_called_with( + module=palettizer._model, + auto_wrap_policy=expected_auto_wrap_policy, + sharding_strategy=ANY, + use_orig_params=False, + device_id=ANY, + sync_module_states=True, + ) + + +@pytest.fixture() +def model_for_compression() -> torch.nn.Module: + return torch.nn.Sequential( + OrderedDict( + [ + ("modconv", torch.nn.Conv2d(3, 10, (3, 3))), + ("modlinear", torch.nn.Linear(2, 5)), + ("multihead", torch.nn.MultiheadAttention(10, 5)), + ("embedding", torch.nn.Embedding(100, 10)), + ] + ) + ) + + +@pytest.fixture() +def sensitvity_dict_for_compression() -> Dict[str, Any]: + return { + "modconv.weight": Mock(), + "modlinear.weight": Mock(), + "multihead.in_proj_weight": Mock(), + "multihead.out_proj.weight": Mock(), + "embedding.weight": Mock(), + } + + +@pytest.fixture() +def model_for_compression_custom_module() -> torch.nn.Module: + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(data=torch.randn(5, 10)) + + return torch.nn.Sequential( + OrderedDict( + [ + ("modconv", torch.nn.Conv2d(3, 10, (3, 3))), + ("modlinear", torch.nn.Linear(2, 5)), + ("multihead", torch.nn.MultiheadAttention(10, 5)), + ("custom", MyModule()), + ] + ) + ) + + +@pytest.mark.parametrize( + "model,sensitivity_dict,config,kmeans_keys", + [ + (torch.nn.Linear(5, 10), {"weight": None}, None, {"": "weight"}), + ( + "model_for_compression", + "sensitvity_dict_for_compression", + SKMPalettizerConfig( + global_config=ModuleSKMPalettizerConfig(), + module_name_configs={ + "modconv": None, + }, + ), + { + "modlinear": "weight", + "multihead": "in_proj_weight", + "multihead.out_proj": "weight", + "embedding": "weight", + }, + ), + ( + "model_for_compression", + "sensitvity_dict_for_compression", + SKMPalettizerConfig( + global_config=ModuleSKMPalettizerConfig(), + module_name_configs={ + "mod.*": None, + }, + ), + { + "multihead": "in_proj_weight", + "multihead.out_proj": "weight", + "embedding": "weight", + }, + ), + ( + "model_for_compression", + "sensitvity_dict_for_compression", + SKMPalettizerConfig( + global_config=ModuleSKMPalettizerConfig(), + module_type_configs={torch.nn.Embedding: None}, + ), + { + "modconv": "weight", + "modlinear": "weight", + "multihead": "in_proj_weight", + "multihead.out_proj": "weight", + }, + ), + ( + "model_for_compression", + "sensitvity_dict_for_compression", + SKMPalettizerConfig( + global_config=ModuleSKMPalettizerConfig(), + module_type_configs={"MultiheadAttention": None}, + module_name_configs={"multihead.out_proj": None}, + ), + {"modconv": "weight", "modlinear": "weight", "embedding": "weight"}, + ), + ( + "model_for_compression_custom_module", + "sensitvity_dict_for_compression", + None, + { + "modconv": "weight", + "modlinear": "weight", + "multihead": "in_proj_weight", + "multihead.out_proj": "weight", + }, + ), + ], +) +@pytest.mark.parametrize("num_kmeans_workers", [1, 8]) +def test_compress_cluster_weights_call( + mocker, num_kmeans_workers, model, sensitivity_dict, config, kmeans_keys, request +): + """ + Test ParallelKMeans/SequentialKMeans are called with correct arguments + """ + if isinstance(model, str): + model = request.getfixturevalue(model) + if isinstance(sensitivity_dict, str): + sensitivity_dict = request.getfixturevalue(sensitivity_dict) + + mocker.patch.object(SKMPalettizer, "compute_sensitivity", Mock(return_value=sensitivity_dict)) + mock_parallel = mocker.patch( + "coremltools.optimize.torch.palettization.sensitive_k_means._ParallelKMeans" + ) + mock_sequential = mocker.patch( + "coremltools.optimize.torch.palettization.sensitive_k_means._SequentialKMeans" + ) + + palettizer = SKMPalettizer(model, config) + + palettizer.compress( + num_sensitivity_workers=1, + fsdp_auto_wrap_policy=None, + num_kmeans_workers=num_kmeans_workers, + ) + + k_means_config_dict = {} + for key, val in kmeans_keys.items(): + sensitivity_key = f"{key}.{val}" if len(key) > 0 else val + k_means_config_dict[key] = { + val: KMeansConfig( + n_bits=ModuleSKMPalettizerConfig().n_bits, + axis=0, + block_size=None, + importance=sensitivity_dict[sensitivity_key], + enable_per_channel_scale=ModuleSKMPalettizerConfig().enable_per_channel_scale, + ) + } + + if num_kmeans_workers > 1: + mock_parallel.cluster_weights.assert_called_once_with( + palettizer._model, k_means_config_dict, num_workers=num_kmeans_workers + ) + else: + mock_sequential.cluster_weights.assert_called_once_with( + palettizer._model, + k_means_config_dict, + ) diff --git a/coremltools/test/optimize/torch/pruning/__init__.py b/coremltools/test/optimize/torch/pruning/__init__.py index 25c7d28c5..5dc5e6747 100644 --- a/coremltools/test/optimize/torch/pruning/__init__.py +++ b/coremltools/test/optimize/torch/pruning/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/pruning/pruning_utils.py b/coremltools/test/optimize/torch/pruning/pruning_utils.py index ad50cfe41..3dab93895 100644 --- a/coremltools/test/optimize/torch/pruning/pruning_utils.py +++ b/coremltools/test/optimize/torch/pruning/pruning_utils.py @@ -1,16 +1,16 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import os + import numpy as np import torch -import torch.nn.functional as F -import os -image_size = 28 +import coremltools.test.optimize.torch.utils as utils + batch_size = 128 -num_classes = 10 def verify_global_pruning_amount(supported_modules, model, expected_sparsity): @@ -28,58 +28,27 @@ def verify_global_pruning_amount(supported_modules, model, expected_sparsity): np.testing.assert_allclose(actual_global_sparsity, expected_sparsity, atol=0.02) -def train_and_eval_model(model, mnist_dataset, pruner, num_epochs): - # setup data loaders - train, test = mnist_dataset - train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True) - test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size) +def train_and_eval_model(model, dataset, pruner, num_epochs, pass_loss=False): + train_loader, test_loader = utils.setup_data_loaders(dataset, batch_size) - # train the model optimizer = torch.optim.Adam(model.parameters(), eps=1e-07, weight_decay=1e-4) - # train one epoch + # train the model for epoch in range(num_epochs): model.train() for batch_idx, (data, target) in enumerate(train_loader): - optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target) - loss.backward() - optimizer.step() - pruner.step() - if batch_idx % 100 == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) - # if not isinstance(pruner, GlobalChannelPruner): - # print(pruner.get_submodule_sparsity_summaries()) - - accuracy = eval_model(model, test_loader) - return accuracy - + loss = utils.train_step(model, optimizer, train_loader, data, target, batch_idx, epoch) + if pass_loss: + pruner.step(epoch, loss) + else: + pruner.step() -def eval_model(model, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - accuracy = 100. * correct / len(test_loader.dataset) - - print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format( - test_loss, accuracy)) + accuracy = utils.eval_model(model, test_loader) return accuracy - def get_compression_ratio(model, pruner): # export the model - import coremltools_internal as ct + import coremltools as ct model.eval() pruner.finalize(inplace=True) diff --git a/coremltools/test/optimize/torch/pruning/test_magnitude_pruner.py b/coremltools/test/optimize/torch/pruning/test_magnitude_pruner.py index 8c8ce3896..bac869681 100644 --- a/coremltools/test/optimize/torch/pruning/test_magnitude_pruner.py +++ b/coremltools/test/optimize/torch/pruning/test_magnitude_pruner.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -9,7 +9,9 @@ import numpy as np import pytest import torch +import torch.nn as nn +from coremltools.optimize.torch._utils.metadata_utils import CompressionMetadata, CompressionType from coremltools.optimize.torch.pruning import ( MagnitudePruner, MagnitudePrunerConfig, @@ -23,7 +25,7 @@ def _zero_loss(x, y): def _mock_initializer(shape, dtype): - # Each output channel is (entirely) an integer, increasing. This makes it so + # Each output channel is (entirely) an integer, increaing. This makes it so # that we know what to expect from the LnPruner. output_channel_index = 0 num_output_channels = shape[output_channel_index] @@ -77,9 +79,8 @@ def sample_data(): return X, Y - -@pytest.mark.parametrize('out_channels', [17, 127]) -@pytest.mark.parametrize('block_size', [2, 3, 4]) +@pytest.mark.parametrize("out_channels", [17, 127]) +@pytest.mark.parametrize("block_size", [2, 3, 4]) def test_magnitude_pruner_nondivisible_block_size(out_channels, block_size): """ Test block sparsity when the number of channels is not divisible by block size @@ -118,31 +119,53 @@ def test_magnitude_pruner_nondivisible_block_size(out_channels, block_size): np.testing.assert_array_almost_equal(sparsity, 0.5, decimal=2) -def test_magnitude_pruner_incompatible_block_size(simple_module): +@pytest.mark.parametrize("out_channels", [8]) +@pytest.mark.parametrize("block_size", [5, 8, 9]) +def test_magnitude_pruner_morethanhalf_block_size(out_channels, block_size): """ - Test MagnitudePruner init failure when block_size is incompatibe with the number of channels - in the block sparsity dimension + Test block sparsity when the block size is greater than half the number of channels """ - # block size greater than half the number of channels - with pytest.raises(ValueError): - config = MagnitudePrunerConfig.from_dict( - {"global_config": - { - "scheduler": {"update_steps": [0, 1]}, - "block_size": 4 - }}, - ) - pruner = MagnitudePruner(simple_module, config) - pruner.prepare(inplace=True) - # block size greater than half the number of channels - config.global_config.block_size = 4 - with pytest.raises(ValueError): - pruner = MagnitudePruner(simple_module, config) - pruner.prepare(inplace=True) + conv2d = torch.nn.Conv2d( + in_channels=3, + out_channels=out_channels, + kernel_size=(3, 3), + bias=False, + groups=1, + ) + + weight_tensor = torch.rand_like(conv2d.weight) + weight_tensor[weight_tensor == 0] = 1.0 + conv2d.weight.data = weight_tensor + + config = MagnitudePrunerConfig.from_dict( + { + "global_config": { + "scheduler": {"update_steps": [1, 2]}, + "initial_sparsity": 0.0, + "target_sparsity": 0.5, + "block_size": block_size, + } + }, + ) + pruner = MagnitudePruner(conv2d, config) + conv2d = pruner.prepare() + + for _ in range(4): + pruner.step() + + if block_size > 1: + block_sparse_channels = out_channels - out_channels % block_size + for idx in range(0, block_sparse_channels, block_size): + for jdx in range(1, block_size): + assert torch.all(conv2d.weight_mask[idx] == conv2d.weight_mask[idx + jdx]) + + sparsity = conv2d.weight_mask.eq(0).sum() / conv2d.weight_mask.numel() + assert np.isclose(sparsity, 0.5, rtol=0.05) + @pytest.mark.parametrize( "options", - [("block_size", 2), ("initial_sparsity", 0.5), ("granularity", "per_channel")], + [("block_size", 2), ("granularity", "per_channel")], ) def test_magnitude_pruner_n_m_ratio_param_usage(options): param_name, val = options @@ -568,3 +591,33 @@ def test_nm_pruner_polynomial_scheduler(): model(data) for row in range(2): assert torch.count_nonzero(model.weight_mask[row]) == (7 - idx) + + +def test_compression_metadata(): + """ + Test that calling finalize on the module leads to compression metadata being added to the model + """ + model = nn.Sequential( + OrderedDict([("conv1", nn.Conv2d(3, 32, 3)), ("fc1", nn.Linear(32, 100))]) + ) + # Disable compression for Linear layer + config = MagnitudePrunerConfig().set_module_name("fc1", None) + pruner = MagnitudePruner(model, config) + pruner.prepare(inplace=True) + pruner.step() + pruner.finalize(inplace=True) + + # Verify metadata version is added to model + assert "_COREML_/metadata_version" in model.state_dict() + + # Verify compression metadata is added for conv1 + metadata_dict = CompressionMetadata.from_state_dict(model.conv1.state_dict()) + assert len(metadata_dict) == 1 + assert "weight" in metadata_dict + + metadata = metadata_dict["weight"] + assert metadata.compression_type == [CompressionType.pruning.value] + + # Verify no compression metadata is added for fc1 + metadata_dict = CompressionMetadata.from_state_dict(model.fc1.state_dict()) + assert len(metadata_dict) == 0 diff --git a/coremltools/test/optimize/torch/pruning/test_pruning_scheduler.py b/coremltools/test/optimize/torch/pruning/test_pruning_scheduler.py index 85827651f..07f21c143 100644 --- a/coremltools/test/optimize/torch/pruning/test_pruning_scheduler.py +++ b/coremltools/test/optimize/torch/pruning/test_pruning_scheduler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/quantization/__init__.py b/coremltools/test/optimize/torch/quantization/__init__.py index 25c7d28c5..b9f43a673 100644 --- a/coremltools/test/optimize/torch/quantization/__init__.py +++ b/coremltools/test/optimize/torch/quantization/__init__.py @@ -1,4 +1,5 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + diff --git a/coremltools/test/optimize/torch/quantization/test_configure.py b/coremltools/test/optimize/torch/quantization/test_configure.py index a133733e8..6b042ac47 100644 --- a/coremltools/test/optimize/torch/quantization/test_configure.py +++ b/coremltools/test/optimize/torch/quantization/test_configure.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -20,7 +20,11 @@ from coremltools.optimize.torch.quantization import LinearQuantizer, LinearQuantizerConfig from coremltools.optimize.torch.quantization._backend_config import _mod_activations from coremltools.optimize.torch.quantization._qconfig_mapping import _QConfigMappingBuilder -from coremltools.optimize.torch.quantization._utils import find_module, is_activation_post_process +from coremltools.optimize.torch.quantization._utils import ( + find_module, + get_quant_range, + is_activation_post_process, +) from coremltools.optimize.torch.quantization.modules import fused_modules as _fused from coremltools.optimize.torch.quantization.modules import qat_modules as _qat from coremltools.optimize.torch.quantization.modules import quantized_modules as _quantized @@ -30,6 +34,7 @@ def get_configs_for_qscheme( activation_dtype=torch.quint8, weight_per_channel=True, + weight_dtype=torch.qint8, ) -> List[LinearQuantizerConfig]: return [ LinearQuantizerConfig.from_dict( @@ -37,6 +42,7 @@ def get_configs_for_qscheme( "global_config": { "quantization_scheme": QuantizationScheme.symmetric, "milestones": [0, 0, 10, 10], + "weight_dtype": weight_dtype, "activation_dtype": activation_dtype, "weight_per_channel": weight_per_channel, } @@ -47,6 +53,7 @@ def get_configs_for_qscheme( "global_config": { "quantization_scheme": QuantizationScheme.affine, "milestones": [0, 0, 10, 10], + "weight_dtype": weight_dtype, "activation_dtype": activation_dtype, "weight_per_channel": weight_per_channel, } @@ -57,15 +64,24 @@ def get_configs_for_qscheme( def quantize_model(model, data, config=None): quantizer = LinearQuantizer(model, config) - prepared_model = quantizer.prepare(example_inputs=data, inplace=False) + prepared_model = quantizer.prepare(example_inputs=(data,), inplace=False) quantizer.step() prepared_model(data) return prepared_model, quantizer +def _verify_quant_range(fake_quant, weight_n_bits, weight_dtype): + quant_min, quant_max = get_quant_range(n_bits=weight_n_bits, dtype=weight_dtype) + assert fake_quant.quant_min == quant_min + assert fake_quant.quant_max == quant_max + + @pytest.mark.parametrize( "config", - get_configs_for_qscheme() + get_configs_for_qscheme(weight_per_channel=False), + get_configs_for_qscheme() + + get_configs_for_qscheme(weight_per_channel=False) + + get_configs_for_qscheme(weight_dtype="qint8") + + get_configs_for_qscheme(weight_dtype=torch.quint8), ) def test_conv_relu_fusion(config): model = nn.Sequential( @@ -81,6 +97,11 @@ def test_conv_relu_fusion(config): prepared_model, quantizer = quantize_model(model, data, config) assert isinstance(prepared_model.conv, torch.nn.intrinsic.qat.ConvReLU2d) + _verify_quant_range( + prepared_model.conv.weight_fake_quant, + weight_n_bits=config.global_config.weight_n_bits, + weight_dtype=config.global_config.weight_dtype, + ) converted_model = quantizer.finalize(inplace=False) @@ -90,7 +111,10 @@ def test_conv_relu_fusion(config): @pytest.mark.parametrize( "config", - get_configs_for_qscheme() + get_configs_for_qscheme(weight_per_channel=False), + get_configs_for_qscheme() + + get_configs_for_qscheme(weight_per_channel=False) + + get_configs_for_qscheme(weight_dtype="qint4") + + get_configs_for_qscheme(weight_dtype="quint4"), ) @pytest.mark.parametrize("activation_fn", list(_mod_activations)) def test_conv_act_fusion(config, activation_fn): @@ -104,6 +128,11 @@ def test_conv_act_fusion(config, activation_fn): assert isinstance(prepared_model.conv, _qat.ConvAct2d) assert isinstance(prepared_model.conv.act, activation_fn) + _verify_quant_range( + prepared_model.conv.conv.weight_fake_quant, + weight_n_bits=config.global_config.weight_n_bits, + weight_dtype=config.global_config.weight_dtype, + ) converted_model = quantizer.finalize(inplace=False) @@ -113,7 +142,10 @@ def test_conv_act_fusion(config, activation_fn): @pytest.mark.parametrize( "config", - get_configs_for_qscheme() + get_configs_for_qscheme(weight_per_channel=False), + get_configs_for_qscheme() + + get_configs_for_qscheme(weight_per_channel=False) + + get_configs_for_qscheme(weight_dtype="qint4") + + get_configs_for_qscheme(weight_dtype="quint4"), ) def test_conv_bn_relu_fusion(config): model = nn.Sequential( @@ -130,6 +162,11 @@ def test_conv_bn_relu_fusion(config): prepared_model, quantizer = quantize_model(model, data, config) assert isinstance(prepared_model.conv, torch.nn.intrinsic.qat.ConvBnReLU2d) + _verify_quant_range( + prepared_model.conv.weight_fake_quant, + weight_n_bits=config.global_config.weight_n_bits, + weight_dtype=config.global_config.weight_dtype, + ) converted_model = quantizer.finalize(inplace=False) @@ -139,7 +176,10 @@ def test_conv_bn_relu_fusion(config): @pytest.mark.parametrize( "config", - get_configs_for_qscheme() + get_configs_for_qscheme(weight_per_channel=False), + get_configs_for_qscheme() + + get_configs_for_qscheme(weight_per_channel=False) + + get_configs_for_qscheme(weight_dtype="qint4") + + get_configs_for_qscheme(weight_dtype="quint4"), ) @pytest.mark.parametrize("activation_fn", list(_mod_activations)) def test_conv_bn_act_fusion(config, activation_fn): @@ -154,6 +194,11 @@ def test_conv_bn_act_fusion(config, activation_fn): assert isinstance(prepared_model.conv, _qat.ConvBnAct2d) assert isinstance(prepared_model.conv.act, activation_fn) + _verify_quant_range( + prepared_model.conv.conv.weight_fake_quant, + weight_n_bits=config.global_config.weight_n_bits, + weight_dtype=config.global_config.weight_dtype, + ) converted_model = quantizer.finalize(inplace=False) @@ -163,7 +208,10 @@ def test_conv_bn_act_fusion(config, activation_fn): @pytest.mark.parametrize( "config", - get_configs_for_qscheme() + get_configs_for_qscheme(weight_per_channel=False), + get_configs_for_qscheme() + + get_configs_for_qscheme(weight_per_channel=False) + + get_configs_for_qscheme(weight_dtype="qint4") + + get_configs_for_qscheme(weight_dtype="quint4"), ) def test_linear_relu_fusion(config): model = nn.Sequential(OrderedDict({"linear": nn.Linear(20, 100), "act": nn.ReLU()})) @@ -172,6 +220,11 @@ def test_linear_relu_fusion(config): prepared_model, quantizer = quantize_model(model, data, config) assert isinstance(prepared_model.linear, torch.nn.intrinsic.qat.LinearReLU) + _verify_quant_range( + prepared_model.linear.weight_fake_quant, + weight_n_bits=config.global_config.weight_n_bits, + weight_dtype=config.global_config.weight_dtype, + ) converted_model = quantizer.finalize(inplace=False) @@ -181,7 +234,10 @@ def test_linear_relu_fusion(config): @pytest.mark.parametrize( "config", - get_configs_for_qscheme() + get_configs_for_qscheme(weight_per_channel=False), + get_configs_for_qscheme() + + get_configs_for_qscheme(weight_per_channel=False) + + get_configs_for_qscheme(weight_dtype="qint4") + + get_configs_for_qscheme(weight_dtype="quint4"), ) @pytest.mark.parametrize("activation_fn", list(_mod_activations)) def test_linear_act_fusion(config, activation_fn): @@ -195,6 +251,11 @@ def test_linear_act_fusion(config, activation_fn): assert isinstance(prepared_model.linear, _qat.LinearAct) assert isinstance(prepared_model.linear.act, activation_fn) + _verify_quant_range( + prepared_model.linear.linear.weight_fake_quant, + weight_n_bits=config.global_config.weight_n_bits, + weight_dtype=config.global_config.weight_dtype, + ) converted_model = quantizer.finalize(inplace=False) @@ -288,7 +349,7 @@ def test_sequential_network_config_for_symmetric(mnist_model_quantization): """ Tests a sequential network with multiple modules is configured correctly. This network has layers where input and output observers are shared. We test - that for these layers, we set activation quantizer correctly for always affine layers + that for these layers, we set acitvation quantizer correctly for always affine layers """ data = torch.randn(1, 1, 28, 28) prepared_model, quantizer = quantize_model(mnist_model_quantization, data) diff --git a/coremltools/test/optimize/torch/quantization/test_post_training_quantization.py b/coremltools/test/optimize/torch/quantization/test_post_training_quantization.py new file mode 100644 index 000000000..8eff82beb --- /dev/null +++ b/coremltools/test/optimize/torch/quantization/test_post_training_quantization.py @@ -0,0 +1,235 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + + +import numpy as np +import pytest +import torch + +ct = pytest.importorskip("coremltools") +pytest.importorskip("coremltools.optimize.coreml._utils") + + +from coremltools.optimize.torch.optimization_config import QuantizationGranularity +from coremltools.optimize.torch.quantization import ( + PostTrainingQuantizer, + PostTrainingQuantizerConfig, + QuantizationScheme, +) + +np.random.seed(0) +torch.manual_seed(0) + + +def get_rmse(a, b): + return torch.norm(torch.abs(a - b)) + + +def get_atol_rtol(block_size, weight_n_bits): + if block_size is None: + block_size = 0 + if block_size == 1: + # With block_size == 1, the information loss is minimum. + atol, rtol = 1e-02, 1e-02 + elif weight_n_bits >= 4 and block_size < 3: + # When block size is small and nbits is large, the information loss is limited. + atol, rtol = 3e-02, 3e-02 + elif weight_n_bits <= 2 and block_size >= 2: + atol, rtol = 0.5, 0.5 + else: + atol, rtol = 0.4, 0.4 + return (atol, rtol) + + +def test_ptq_default_config(): + config = PostTrainingQuantizerConfig() + ptq = PostTrainingQuantizer(torch.nn.Identity(), config) + assert ptq is not None + assert config.global_config.block_size is None + assert config.global_config.weight_dtype == torch.int8 + assert config.global_config.quantization_scheme == QuantizationScheme.symmetric + assert config.global_config.weight_dtype == torch.int8 + assert config.global_config.granularity == QuantizationGranularity.per_channel + + +@pytest.mark.parametrize( + "module", + [ + torch.nn.Linear(10, 10), + torch.nn.Conv2d(10, 10, 3, 3), + torch.nn.MultiheadAttention( + bias=True, + embed_dim=6, + num_heads=3, + add_bias_kv=True, + kdim=1, + vdim=1, + ), + torch.nn.MultiheadAttention( + bias=True, + embed_dim=6, + num_heads=3, + add_bias_kv=True, + kdim=None, + vdim=None, + ), + ], +) +@pytest.mark.parametrize( + "granularity_block_size", + [ + ("per_channel", None), + ("per_tensor", None), + ("per_block", 2), + ("per_block", 5), + ("per_block", (2,)), + ("per_block", (5,)), + ("per_block", (5, 2)), + ("per_block", (2, 5)), + ], +) +@pytest.mark.parametrize("quantization_scheme", ["symmetric", "affine"]) +@pytest.mark.parametrize("weight_dtype", ["int8", "int4", "uint8", "uint4"]) +def test_ptq_compress_all_combinations( + module, + quantization_scheme, + granularity_block_size, + weight_dtype, +): + granularity, block_size = granularity_block_size + config = PostTrainingQuantizerConfig.from_dict( + { + "global_config": { + "quantization_scheme": quantization_scheme, + "granularity": granularity, + "weight_dtype": weight_dtype, + "block_size": block_size, + } + } + ) + ptq = PostTrainingQuantizer(module, config) + module = ptq.compress() + + +@pytest.mark.parametrize("quantization_scheme", ["symmetric", "affine"]) +@pytest.mark.parametrize( + "granularity_block_size", + [ + ("per_channel", None), + ("per_tensor", None), + ("per_block", 2), + ("per_block", 5), + ], +) +@pytest.mark.parametrize("weight_dtype", ["int4", "int8"]) +@pytest.mark.parametrize("module", [torch.nn.Conv2d(10, 10, 3, 3), torch.nn.Linear(10, 10)]) +def test_ptq_post_compress_conv_linear( + quantization_scheme, granularity_block_size, weight_dtype, module +): + granularity, block_size = granularity_block_size + orig_weight = module.weight.clone() + config = PostTrainingQuantizerConfig.from_dict( + { + "global_config": { + "weight_dtype": weight_dtype, + "granularity": granularity, + "block_size": block_size, + "quantization_scheme": quantization_scheme, + } + } + ) + ptq = PostTrainingQuantizer(module, config) + module = ptq.compress() + + assert hasattr(module, "_COREML_/weight/quantization_scale") + if quantization_scheme == "affine": + assert hasattr(module, "_COREML_/weight/zero_point") + assert not torch.equal(orig_weight, module.weight) + atol, rtol = get_atol_rtol(block_size, config.global_config.weight_n_bits) + np.testing.assert_allclose( + orig_weight.detach().numpy(), + module.weight.detach().numpy(), + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize("quantization_scheme", ["symmetric", "affine"]) +@pytest.mark.parametrize( + "granularity_block_size", + [ + ("per_channel", None), + ("per_tensor", None), + ("per_block", 2), + ("per_block", 5), + ], +) +@pytest.mark.parametrize("weight_dtype", ["int4", "int8"]) +def test_ptq_post_compress_multihead( + quantization_scheme, + granularity_block_size, + weight_dtype, +): + granularity, block_size = granularity_block_size + module = torch.nn.MultiheadAttention( + bias=True, + embed_dim=10, + num_heads=10, + add_bias_kv=True, + kdim=None, + vdim=None, + ) + assert hasattr(module, "in_proj_weight") + assert hasattr(module.out_proj, "weight") + orig_in_proj_weight = module.in_proj_weight.clone() + orig_out_proj_weight = module.out_proj.weight.clone() + config = PostTrainingQuantizerConfig.from_dict( + { + "global_config": { + "weight_dtype": weight_dtype, + "quantization_scheme": quantization_scheme, + "granularity": granularity, + "block_size": block_size, + } + } + ) + ptq = PostTrainingQuantizer(module, config) + module = ptq.compress() + + assert hasattr(module, "_COREML_/in_proj_weight/quantization_scale") + assert hasattr(module.out_proj, "_COREML_/weight/quantization_scale") + if quantization_scheme == "affine": + assert hasattr(module, "_COREML_/in_proj_weight/zero_point") + assert hasattr(module.out_proj, "_COREML_/weight/zero_point") + + assert not torch.equal(orig_in_proj_weight, module.in_proj_weight) + assert not torch.equal(orig_out_proj_weight, module.out_proj.weight) + atol, rtol = get_atol_rtol(block_size, config.global_config.weight_n_bits) + np.testing.assert_allclose( + orig_in_proj_weight.detach().numpy(), + module.in_proj_weight.detach().numpy(), + atol=atol, + rtol=rtol, + ) + np.testing.assert_allclose( + orig_out_proj_weight.detach().numpy(), + module.out_proj.weight.detach().numpy(), + atol=atol, + rtol=rtol, + ) + + +def test_ptq_compression_metadata(): + config = PostTrainingQuantizerConfig() + ptq = PostTrainingQuantizer(torch.nn.Linear(10, 10), config) + model = ptq.compress() + + from coremltools.optimize.torch._utils.metadata_utils import CompressionType + + assert hasattr(model, "_COREML_/weight/compression_type") + assert torch.IntTensor([CompressionType.quantization.value]) == getattr( + model, "_COREML_/weight/compression_type" + ) + assert torch.IntTensor([8]) == getattr(model, "_COREML_/weight/quantization_n_bits") diff --git a/coremltools/test/optimize/torch/quantization/test_quantizer.py b/coremltools/test/optimize/torch/quantization/test_quantizer.py index 20177a3e0..a30eb6ffc 100644 --- a/coremltools/test/optimize/torch/quantization/test_quantizer.py +++ b/coremltools/test/optimize/torch/quantization/test_quantizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -15,6 +15,7 @@ import torch.nn.quantized import torch.nn.quantized.modules.utils +from coremltools.optimize.torch._utils.metadata_utils import CompressionMetadata, CompressionType from coremltools.optimize.torch.quantization import ( LinearQuantizer, LinearQuantizerConfig, @@ -173,7 +174,7 @@ def test_activation_defaults(quantization_scheme): }} ) quantizer = LinearQuantizer(model, config) - model = quantizer.prepare(example_inputs=torch.randn(1, 1, 28, 28)) + model = quantizer.prepare(example_inputs=(torch.randn(1, 1, 28, 28),)) assert isinstance(model.conv, torch.nn.intrinsic.qat.ConvReLU2d) assert model.activation_post_process_0.dtype == torch.quint8 @@ -200,7 +201,7 @@ def test_quantizer_step_mechanism(quantization_scheme): }} ) quantizer = LinearQuantizer(model, config) - model = quantizer.prepare(example_inputs=torch.randn(1, 1, 28, 28)) + model = quantizer.prepare(example_inputs=(torch.randn(1, 1, 28, 28),)) assert not model.activation_post_process_0.observer_enabled assert not model.activation_post_process_0.fake_quant_enabled @@ -233,3 +234,118 @@ def test_quantizer_step_mechanism(quantization_scheme): assert model.activation_post_process_0.fake_quant_enabled assert not model.activation_post_process_1.observer_enabled assert model.activation_post_process_1.fake_quant_enabled + + +def test_preserved_attributes(): + """ + Test if methods and attributes on the model are preserved by passing + preserved_attributes to the config. + """ + + class MyModel(nn.Sequential): + def __init__(self): + super().__init__( + OrderedDict( + { + "conv": nn.Conv2d(1, 20, (3, 3)), + "bn": nn.BatchNorm2d(20), + "relu": nn.ReLU(), + } + ) + ) + self.conv.weight.data = torch.ones_like(self.conv.weight.data) + + def my_method(self): + return self.weight + torch.ones_like(self.weight) + + @property + def weight(self): + return ( + self.conv.weight + if hasattr(self.conv, "weight") + else self.conv.get_submodule("0").weight + ) + + preserved_attrs = ["key_1", "key_2", "my_method", "weight"] + + model = MyModel() + model.key_1 = 5 + model.key_2 = torch.tensor(5) + + config = LinearQuantizerConfig.from_dict( + { + "global_config": { + "milestones": [0, 3, 4, 5], + }, + "preserved_attributes": preserved_attrs, + } + ) + quantizer_1 = LinearQuantizer(model, LinearQuantizerConfig()) + prepared_model = quantizer_1.prepare(example_inputs=(torch.randn(1),), inplace=False) + for attr in preserved_attrs: + assert not hasattr(prepared_model, attr) + + quantizer_2 = LinearQuantizer(model, config) + prepared_model = quantizer_2.prepare(example_inputs=(torch.randn(1),), inplace=False) + for attr in preserved_attrs: + assert hasattr(prepared_model, attr) + assert torch.all( + prepared_model.my_method() == 2 * torch.ones_like(prepared_model.conv.weight.data) + ) + + quantizer_2.step() + prepared_model(torch.randn(2, 1, 28, 28)) + final_model = quantizer_2.finalize() + for attr in preserved_attrs: + assert hasattr(final_model, attr) + assert torch.all( + final_model.my_method() + == final_model.weight.data + torch.ones_like(prepared_model.weight.data) + ) + + +@pytest.mark.parametrize("dtype", ["qint4", "qint8"]) +@pytest.mark.parametrize("scheme", ["symmetric", "affine"]) +def test_compression_metadata(dtype, scheme): + """ + Test that calling finalize on the module leads to compression metadata being added to the model + """ + model = nn.Sequential( + OrderedDict([("conv1", nn.Conv2d(1, 20, 3)), ("fc1", nn.Linear(20, 100))]) + ) + config = LinearQuantizerConfig.from_dict( + { + "module_name_configs": { + "conv1": { + "weight_dtype": dtype, + "quantization_scheme": scheme, + }, + "fc1": None, + } + } + ) + quantizer = LinearQuantizer(model, config) + quantizer.prepare(inplace=True, example_inputs=(torch.randn(1, 1, 28, 28),)) + for _ in range(4): + quantizer.step() + model = quantizer.finalize() + + # Verify metadata version is added to model + assert "_COREML_/metadata_version" in model.state_dict() + + # Verify compression metadata is added for conv1 + metadata_dict = CompressionMetadata.from_state_dict(model.conv1.state_dict()) + assert len(metadata_dict) == 1 + assert "weight" in metadata_dict + + metadata = metadata_dict["weight"] + assert metadata.compression_type == [CompressionType.quantization.value] + assert metadata.quantization_n_bits == 4 if dtype == "qint4" else 8 + assert metadata.quantization_scale.shape == (20, 1) + assert metadata.zero_point.shape == (20, 1) + if scheme == "symmetric": + assert torch.all(metadata.zero_point == 0) + + # # Verify no compression metadata is added for fc1 + metadata_dict = CompressionMetadata.from_state_dict(model.fc1.state_dict()) + assert len(metadata_dict) == 0 diff --git a/coremltools/test/optimize/torch/quantization/test_utils.py b/coremltools/test/optimize/torch/quantization/test_utils.py new file mode 100644 index 000000000..3df9fe42e --- /dev/null +++ b/coremltools/test/optimize/torch/quantization/test_utils.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import pytest +import torch + +from coremltools.optimize.torch.quantization._utils import get_quant_range + + +@pytest.mark.parametrize("n_bits", list(range(2, 8))) +@pytest.mark.parametrize("dtype", [torch.quint8, torch.uint8, torch.qint8, torch.int8]) +def test_quant_range(dtype, n_bits): + quant_min, quant_max = get_quant_range(n_bits, dtype) + signed_expected_values = { + 2: [-2, 1], + 3: [-4, 3], + 4: [-8, 7], + 5: [-16, 15], + 6: [-32, 31], + 7: [-64, 63], + 8: [-128, 127], + } + unsigned_expected_values = { + 2: [0, 3], + 3: [0, 7], + 4: [0, 15], + 5: [0, 31], + 6: [0, 63], + 7: [0, 127], + 8: [0, 256], + } + if dtype in [torch.quint8, torch.uint8]: + assert quant_min == unsigned_expected_values[n_bits][0] + assert quant_max == unsigned_expected_values[n_bits][1] + else: + assert quant_min == signed_expected_values[n_bits][0] + assert quant_max == signed_expected_values[n_bits][1] diff --git a/coremltools/test/optimize/torch/smoke_test.py b/coremltools/test/optimize/torch/smoke_test.py new file mode 100644 index 000000000..25ce00878 --- /dev/null +++ b/coremltools/test/optimize/torch/smoke_test.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import torch + + +class TestSmokeTest: + def test_coremltools_optimize_torch_import(self): + import coremltools.optimize.torch + + def test_model_optimizations(self): + from coremltools.optimize.torch.palettization import DKMPalettizer, DKMPalettizerConfig + from coremltools.optimize.torch.pruning import MagnitudePruner, MagnitudePrunerConfig + from coremltools.optimize.torch.quantization import LinearQuantizer, LinearQuantizerConfig + + for OptCls, OptConfig, args in [ + (MagnitudePruner, MagnitudePrunerConfig, None), + (DKMPalettizer, DKMPalettizerConfig, None), + (LinearQuantizer, LinearQuantizerConfig, torch.randn(100)), + ]: + obj = OptCls(torch.nn.Identity(), OptConfig()) + obj.prepare(args) + obj.finalize() + + def test_model_conversion(self, mnist_model, mnist_example_input): + import coremltools.test.optimize.torch.conversion.conversion_utils as util + + converted_model = util.get_converted_model(mnist_model, mnist_example_input) + assert converted_model is not None diff --git a/coremltools/test/optimize/torch/test_api_surface.py b/coremltools/test/optimize/torch/test_api_surface.py index 8862c59ce..701afa093 100644 --- a/coremltools/test/optimize/torch/test_api_surface.py +++ b/coremltools/test/optimize/torch/test_api_surface.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -30,19 +30,30 @@ def test_top_level(self): "palettization", "pruning", "quantization", + "layerwise_compression", ] visible_modules = _get_visible_items(coremltools.optimize.torch) _check_visible_modules(visible_modules, expected) def test_base_model_optimizer_module(self): # coremltools.optimize.torch.base_model_optimizer.* - expected = ["BaseModelOptimizer"] + expected = [ + "BaseModelOptimizer", + "BaseTrainingTimeModelOptimizer", + "BasePostTrainingModelOptimizer", + "BaseDataCalibratedModelOptimizer", + ] visible_modules = _get_visible_items(coremltools.optimize.torch.base_model_optimizer) _check_visible_modules(visible_modules, expected) def test_optimization_config_module(self): # coremltools.optimize.torch.optimization_config.* - expected = ["ModuleOptimizationConfig", "OptimizationConfig"] + expected = [ + "PalettizationGranularity", + "QuantizationGranularity", + "ModuleOptimizationConfig", + "OptimizationConfig", + ] visible_modules = _get_visible_items(coremltools.optimize.torch.optimization_config) _check_visible_modules(visible_modules, expected) @@ -56,6 +67,14 @@ def test_palettization_module(self): "palettization_config", "fake_palettize", "palettizer", + "post_training_palettization", + "ModulePostTrainingPalettizerConfig", + "PostTrainingPalettizer", + "PostTrainingPalettizerConfig", + "sensitive_k_means", + "ModuleSKMPalettizerConfig", + "SKMPalettizer", + "SKMPalettizerConfig", ] visible_modules = _get_visible_items(coremltools.optimize.torch.palettization) _check_visible_modules(visible_modules, expected) @@ -92,6 +111,10 @@ def test_quantization_module(self): "quantizer", "quantization_config", "modules", + "ModulePostTrainingQuantizerConfig", + "PostTrainingQuantizer", + "PostTrainingQuantizerConfig", + "post_training_quantization", ] visible_modules = _get_visible_items(coremltools.optimize.torch.quantization) _check_visible_modules(visible_modules, expected) @@ -114,3 +137,23 @@ def test_quantization_module(self): ] visible_modules = _get_visible_items(coremltools.optimize.torch.quantization.quantizer) _check_visible_modules(visible_modules, expected) + + def test_layerwise_compression_module(self): + expected = [ + "algorithms", + "LayerwiseCompressionAlgorithm", + "LayerwiseCompressionAlgorithmConfig", + "SparseGPT", + "GPTQ", + "ModuleGPTQConfig", + "ModuleSparseGPTConfig", + "input_cacher", + "FirstLayerInputCacher", + "DefaultInputCacher", + "GPTFirstLayerInputCacher", + "layerwise_compressor", + "LayerwiseCompressor", + "LayerwiseCompressorConfig", + ] + visible_modules = _get_visible_items(coremltools.optimize.torch.layerwise_compression) + _check_visible_modules(visible_modules, expected) diff --git a/coremltools/test/optimize/torch/test_base_optimizer.py b/coremltools/test/optimize/torch/test_base_optimizer.py index 7cd5ae95b..c3a42f1fa 100644 --- a/coremltools/test/optimize/torch/test_base_optimizer.py +++ b/coremltools/test/optimize/torch/test_base_optimizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -6,6 +6,10 @@ import pytest import torch +from coremltools.optimize.torch.base_model_optimizer import ( + BaseDataCalibratedModelOptimizer, + BasePostTrainingModelOptimizer, +) from coremltools.optimize.torch.palettization import DKMPalettizer from coremltools.optimize.torch.pruning import MagnitudePruner from coremltools.optimize.torch.quantization import LinearQuantizer @@ -18,7 +22,7 @@ def test_report_model_train_state(optimizer, inplace): opt = optimizer(model) if optimizer == LinearQuantizer: - p_model = opt.prepare(inplace=inplace, example_inputs=torch.randn(1)) + p_model = opt.prepare(inplace=inplace, example_inputs=(torch.randn(1),)) else: p_model = opt.prepare(inplace=inplace) @@ -29,3 +33,36 @@ def test_report_model_train_state(optimizer, inplace): p_model.eval() opt.report() assert not p_model.training + + +@pytest.mark.parametrize( + "optimizer", [BasePostTrainingModelOptimizer, BaseDataCalibratedModelOptimizer] +) +@pytest.mark.parametrize("inplace", [True, False]) +def test_inplace_behavior_for_optimizers(optimizer, inplace): + def create_model(): + return torch.nn.Sequential(torch.nn.Conv2d(1, 31, 2, 1), torch.nn.Conv2d(31, 21, 2, 1)) + + class DummyOptimizer(optimizer): + def report(self): + return None + + @torch.no_grad() + def compress(self, *args, inplace, **kwargs): + super().compress(*args, inplace=inplace, **kwargs) + self._model[0].weight.data = torch.ones_like(self._model[0].weight.data) + return self._model + + model = create_model() + opt = DummyOptimizer(model) + opt.compress(dataloader=None, inplace=inplace) + + if inplace: + assert id(opt._model) == id(model) + assert id(opt._uncompressed_model) != id(model) + else: + assert id(opt._model) != id(model) + assert id(opt._uncompressed_model) == id(model) + + assert torch.all(opt._model[0].weight == 1) + assert not torch.all(opt._uncompressed_model[0].weight == 1) diff --git a/coremltools/test/optimize/torch/test_utils/__init__.py b/coremltools/test/optimize/torch/test_utils/__init__.py new file mode 100644 index 000000000..5dc5e6747 --- /dev/null +++ b/coremltools/test/optimize/torch/test_utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause diff --git a/coremltools/test/optimize/torch/test_utils/test_fsdp_utils.py b/coremltools/test/optimize/torch/test_utils/test_fsdp_utils.py new file mode 100644 index 000000000..bddd6d0f2 --- /dev/null +++ b/coremltools/test/optimize/torch/test_utils/test_fsdp_utils.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import torch + +from coremltools.optimize.torch._utils.fsdp_utils import ModuleWrapPolicy, SizeBasedWrapPolicy + + +def test_module_wrap_policy(): + """ + Test constructor for underlying FSDP policy is called with correct arguments + """ + module_classes = [torch.nn.Linear, torch.nn.Conv2d] + policy = ModuleWrapPolicy(module_classes=module_classes) + policy = policy.get_policy() + assert policy._module_classes == set(module_classes) + + +def test_size_based_policy(): + """ + Test constructor for underlying FSDP policy is called with correct arguments + """ + min_num_params = 100 + policy = SizeBasedWrapPolicy(min_num_params=min_num_params) + policy = policy.get_policy() + assert policy.keywords["min_num_params"] == min_num_params diff --git a/coremltools/test/optimize/torch/test_utils/test_k_means.py b/coremltools/test/optimize/torch/test_utils/test_k_means.py new file mode 100644 index 000000000..dc00af0ac --- /dev/null +++ b/coremltools/test/optimize/torch/test_utils/test_k_means.py @@ -0,0 +1,262 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import pytest +import torch + +from coremltools.optimize.torch._utils.k_means import ( + KMeansConfig, + KMeansSupportedModulesRegistry, + ParallelKMeans, + SequentialKMeans, +) +from coremltools.test.optimize.torch.utils import count_unique_params + + +@pytest.mark.parametrize( + "config", + [ + KMeansConfig(n_bits=2, enable_per_channel_scale=False), + { + "conv1": {"weight": KMeansConfig(n_bits=4, enable_per_channel_scale=False)}, + "dense1": {"weight": KMeansConfig(n_bits=2, enable_per_channel_scale=True)}, + }, + ], +) +@pytest.mark.parametrize( + "kmeans_cls", + [SequentialKMeans, ParallelKMeans], +) +def test_k_means_mnist_per_weight(mock_name_main, mnist_model, config, kmeans_cls): + model = kmeans_cls.cluster_weights(mnist_model, config=config, num_workers=4) + + layers = [ + ("conv1", model.conv1), + ("conv2", model.conv2), + ("dense1", model.dense1), + ("dense2", model.dense2), + ] + with torch.no_grad(): + for layer_name, layer in layers: + if isinstance(config, dict): + if layer_name in config: + for param_name, layer_config in config[layer_name].items(): + param = getattr(layer, param_name) + if layer_config.enable_per_channel_scale: + per_channel_scale_key = f"_COREML_/{param_name}/palettization_scale" + assert per_channel_scale_key in layer.state_dict() + per_channel_scale = layer.state_dict()[per_channel_scale_key] + param = param / per_channel_scale + assert count_unique_params(torch.unique(param)) == 2**layer_config.n_bits + else: + assert len(torch.unique(layer.weight)) > 16 + else: + assert len(torch.unique(layer.weight)) == 2**config.n_bits + + +@pytest.mark.parametrize( + "config", + [ + KMeansConfig(n_bits=4, block_size=4, axis=0, enable_per_channel_scale=False), + KMeansConfig(n_bits=4, block_size=4, axis=0, enable_per_channel_scale=True), + KMeansConfig(n_bits=4, block_size=4, axis=1, enable_per_channel_scale=False), + KMeansConfig(n_bits=4, block_size=4, axis=1, enable_per_channel_scale=True), + ], +) +@pytest.mark.parametrize( + "kmeans_cls", + [SequentialKMeans, ParallelKMeans], +) +def test_k_means_block_wise(mock_name_main, config, kmeans_cls): + model = torch.nn.Conv2d(12, 32, (2, 2)) + model = kmeans_cls.cluster_weights(model, config=config, num_workers=4) + block_size = config.block_size + + with torch.no_grad(): + weight = model.weight + + if config.enable_per_channel_scale: + per_channel_scale_key = "_COREML_/weight/palettization_scale" + assert per_channel_scale_key in model.state_dict() + per_channel_scale = model.state_dict()[per_channel_scale_key] + weight = weight / per_channel_scale + + if config.axis == 0: + weight_flat = weight.flatten(1) + else: + weight_flat = weight.transpose(0, 1).flatten(1).transpose(0, 1) + + weight_shape = weight_flat.shape[config.axis] + if config.axis == 0: + for block_idx in range(0, weight_shape, block_size): + assert ( + count_unique_params( + torch.unique(weight_flat[block_idx : block_idx + block_size, :]) + ) + == 2**config.n_bits + ) + else: + for block_idx in range(0, weight_shape, block_size): + assert ( + count_unique_params( + torch.unique(weight_flat[:, block_idx : block_idx + block_size]) + ) + == 2**config.n_bits + ) + + +@pytest.mark.parametrize( + "config", + [ + KMeansConfig(n_bits=4, cluster_dim=4, axis=0, enable_per_channel_scale=False), + KMeansConfig(n_bits=4, cluster_dim=4, axis=1, enable_per_channel_scale=False), + KMeansConfig(n_bits=2, cluster_dim=2, axis=0, enable_per_channel_scale=False), + KMeansConfig(n_bits=2, cluster_dim=2, axis=1, enable_per_channel_scale=False), + ], +) +@pytest.mark.parametrize( + "kmeans_cls", + [SequentialKMeans, ParallelKMeans], +) +def test_k_means_vector_wise(mock_name_main, config, kmeans_cls): + model = torch.nn.Conv2d(16, 8, (2, 2)) + model = kmeans_cls.cluster_weights(model, config=config, num_workers=4) + cluster_dim = config.cluster_dim + + with torch.no_grad(): + weight = model.weight + if config.axis == 0: + weight_reshaped = weight.flatten(1).reshape(-1, cluster_dim) + elif config.axis == 1: + weight_reshaped = weight.transpose(0, 1).flatten(1).reshape(-1, cluster_dim) + else: + raise ValueError("axis must be 0 or 1.") + + unique_vector = torch.unique(weight_reshaped, dim=0) + assert len(unique_vector) == 2**config.n_bits + + +@pytest.mark.parametrize("importance", [True, False]) +@pytest.mark.parametrize("config", [tuple((4, 4, 0)), tuple((4, 4, 1))]) +@pytest.mark.parametrize( + "kmeans_cls", + [ + SequentialKMeans, + ParallelKMeans, + ], +) +def test_k_means_masked(mock_name_main, importance, config, kmeans_cls): + model = torch.nn.Linear(32, 32) + block_size = config[1] + axis = config[2] + + weight_mask = torch.ones_like(model.weight.data, dtype=torch.bool) + for idx in range(32): + if axis == 0: + weight_mask[idx, torch.randperm(32)[:4]] = False + else: + weight_mask[torch.randperm(32)[:4], idx] = False + + importance = torch.abs(torch.randn(model.weight.shape)) if importance else None + config = KMeansConfig( + n_bits=config[0], + block_size=block_size, + enable_per_channel_scale=False, + axis=axis, + mask=weight_mask, + importance=importance, + ) + + weight_clone = model.weight.clone() + + model = kmeans_cls.cluster_weights(model, config=config, num_workers=4) + + with torch.no_grad(): + model_weight = model.weight + weight_shape = model_weight.shape[config.axis] + for block_idx in range(0, weight_shape, block_size): + if config.axis == 0: + mask_block = weight_mask[block_idx : block_idx + block_size, :] + weight_block_masked = model_weight[block_idx : block_idx + block_size, :][ + mask_block + ] + weight_unmasked = model_weight[block_idx : block_idx + block_size, :][~mask_block] + weight_orig_unmasked = weight_clone[block_idx : block_idx + block_size, :][ + ~mask_block + ] + else: + mask_block = weight_mask[:, block_idx : block_idx + block_size] + weight_block_masked = model_weight[:, block_idx : block_idx + block_size][ + mask_block + ] + weight_unmasked = model_weight[:, block_idx : block_idx + block_size][~mask_block] + weight_orig_unmasked = weight_clone[:, block_idx : block_idx + block_size][ + ~mask_block + ] + assert len(torch.unique(weight_block_masked)) == 2**config.n_bits + assert torch.all(weight_orig_unmasked == weight_unmasked) + + +# region KMeansModule Tests + + +@pytest.mark.parametrize( + "layer, layer_config", + [ + ( + torch.nn.Linear(10, 100), + {"weight": KMeansConfig(n_bits=4, enable_per_channel_scale=True)}, + ), + ], +) +@torch.no_grad() +def test_zero_per_channel_scale(layer, layer_config): + k_means_module_cls = KMeansSupportedModulesRegistry.get_kmeans_module(layer) + k_means_module = k_means_module_cls(layer, layer_config) + # Set one output chanel to zero so its per_channel_scale is 0 + layer.weight[0] = 0 + orig_weight = layer.weight.clone() + # Scale weights + scaled_weight = k_means_module._scale_by_per_channel_scale("weight", layer.weight) + # Verify no NaN values are introduced + assert not torch.any(torch.isnan(scaled_weight)) + # Confirm layer weights for corresponding channel remain 0 + assert torch.all(scaled_weight[0] == 0) + # Unscale weights + unscaled_weight = k_means_module._unscale_by_per_channel_scale("weight", layer.weight) + # Verify no NaN values are introduced + assert not torch.any(torch.isnan(unscaled_weight)) + # Confirm unscaled weights match original weights within tolerance + assert torch.all(torch.isclose(unscaled_weight, orig_weight)) + + +@pytest.mark.parametrize( + "layer, param_name, axis, expected_shape", + [ + (torch.nn.Conv2d(16, 32, 5), "weight", 0, (32, 16 * 5 * 5)), + (torch.nn.Conv2d(16, 32, 5), "weight", 1, (32 * 5 * 5, 16)), + (torch.nn.Linear(1024, 10), "weight", 0, (10, 1024)), + (torch.nn.Linear(1024, 10), "weight", 1, (10, 1024)), + (torch.nn.Embedding(50000, 256), "weight", 0, (50000, 256)), + (torch.nn.Embedding(50000, 256), "weight", 1, (50000, 256)), + (torch.nn.MultiheadAttention(256, 4), "in_proj_weight", 0, (3 * 256, 256)), + (torch.nn.MultiheadAttention(256, 4), "in_proj_weight", 1, (3 * 256, 256)), + ], +) +def test_parameter_reshaping(layer, param_name, axis, expected_shape): + config = {param_name: KMeansConfig(n_bits=4, block_size=8, axis=axis)} + k_means_module_cls = KMeansSupportedModulesRegistry.get_kmeans_module(layer) + k_means_module = k_means_module_cls(layer, config) + + # reshape for kmeans + param = getattr(layer, param_name) + new_param = k_means_module._reshape_for_kmeans(param_name, param) + assert new_param.shape == expected_shape + + # reshape back to original weight shape + reshaped_param = k_means_module._reshape_to_original(param_name, new_param) + assert reshaped_param.shape == param.shape + +# endregion diff --git a/coremltools/test/optimize/torch/test_utils/test_metadata_utils.py b/coremltools/test/optimize/torch/test_utils/test_metadata_utils.py new file mode 100644 index 000000000..41d88b8ca --- /dev/null +++ b/coremltools/test/optimize/torch/test_utils/test_metadata_utils.py @@ -0,0 +1,158 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from contextlib import nullcontext as does_not_raise + +import pytest +import torch + +from coremltools.optimize.torch._utils.metadata_utils import ( + METADATA_VERSION, + METADATA_VERSION_BUFFER, + CompressionMetadata, + CompressionType, + register_metadata_version, +) + + +@pytest.mark.parametrize( + "metadata_dict, expectation", + [ + ( + { + "param_name": "weight", + "quantization_scale": torch.rand(3, 1), + "quantization_n_bits": 4, + "compression_type": ["pruning", "quantization"], + }, + does_not_raise(), + ), + ( + { + "param_name": "weight", + "quantization_scale": torch.rand(3, 1), + "quantization_n_bits": 4, + "compression_type": ["pruning", "quantizatoin"], # mis-spelled + }, + pytest.raises(KeyError), + ), + ], +) +def test_metadata_from_dict(metadata_dict, expectation): + with expectation: + metadata = CompressionMetadata.from_dict(metadata_dict) + assert torch.equal(metadata.quantization_scale, metadata_dict["quantization_scale"]) + assert metadata.quantization_n_bits == metadata_dict["quantization_n_bits"] + assert metadata.compression_type == [ + CompressionType[x].value for x in metadata_dict["compression_type"] + ] + + for key, value in metadata.as_dict().items(): + if key not in metadata_dict: + assert value is None + + +@pytest.mark.parametrize( + "state_dict", + [ + { + "_COREML_/weight/quantization_scale": torch.rand(3, 1), + "_COREML_/weight/quantization_n_bits": torch.tensor(4), + "_COREML_/weight/compression_type": torch.tensor([1, 2]), + "_COREML_/bias/quantization_scale": torch.rand(3, 1), + "_COREML_/bias/quantization_n_bits": torch.tensor(8), + "_COREML_/bias/compression_type": torch.tensor([1, 3]), + } + ], +) +def test_metadata_from_state_dict(state_dict): + metadata_dict = CompressionMetadata.from_state_dict(state_dict) + print(metadata_dict) + assert len(metadata_dict) == 2 + assert "weight" in metadata_dict + assert "bias" in metadata_dict + for param in ["weight", "bias"]: + metadata = metadata_dict[param] + assert metadata.param_name == param + assert torch.equal( + metadata.quantization_scale, + state_dict[f"_COREML_/{param}/quantization_scale"], + ) + assert ( + metadata.quantization_n_bits + == state_dict[f"_COREML_/{param}/quantization_n_bits"].item() + ) + assert ( + metadata.compression_type == state_dict[f"_COREML_/{param}/compression_type"].tolist() + ) + + non_none_keys = [ + "quantization_n_bits", + "quantization_scale", + "param_name", + "compression_type", + ] + for key, value in metadata.as_dict().items(): + if key not in non_none_keys: + assert value is None + + +@pytest.mark.parametrize( + "metadata_dict", + [ + { + "param_name": "weight", + "zero_point": torch.rand(3, 1), + "compression_type": ["pruning", "quantization"], + }, + ], +) +def test_register(metadata_dict): + module = torch.nn.Conv2d(3, 32, 3) + metadata = CompressionMetadata.from_dict(metadata_dict) + + state_dict = module.state_dict() + for key in metadata_dict: + assert metadata._get_metadata_buffer_name(key) not in state_dict + + metadata.register(module) + + state_dict = module.state_dict() + for key, value in metadata_dict.items(): + if key != "param_name": + metadata_key = metadata._get_metadata_buffer_name(key) + if key == "compression_type": + metadata_value = torch.tensor([CompressionType[x].value for x in value]) + else: + metadata_value = torch.tensor(value) + assert metadata_key in state_dict + assert torch.equal(state_dict[metadata_key], metadata_value) + + +def test_chaining_compression_type(): + module = torch.nn.Conv2d(3, 32, 3) + metadata = CompressionMetadata(param_name="weight") + metadata.compression_type = ["pruning"] + + metadata.register(module) + + buffer_name = metadata._get_metadata_buffer_name("compression_type") + assert buffer_name in module.state_dict() + assert torch.equal(module.state_dict()[buffer_name], torch.tensor([1])) + + metadata2 = CompressionMetadata(param_name="weight") + metadata2.compression_type = ["palettization"] + + metadata2.register(module) + assert buffer_name in module.state_dict() + assert torch.equal(module.state_dict()[buffer_name], torch.tensor([1, 2])) + + +def test_register_metadata_version(): + model = torch.nn.Sequential(torch.nn.Conv2d(3, 32, 3), torch.nn.ReLU()) + assert METADATA_VERSION_BUFFER not in model.state_dict() + register_metadata_version(model) + assert METADATA_VERSION_BUFFER in model.state_dict() + assert torch.equal(model.state_dict()[METADATA_VERSION_BUFFER], METADATA_VERSION) diff --git a/coremltools/test/optimize/torch/test_utils/test_report_utils.py b/coremltools/test/optimize/torch/test_utils/test_report_utils.py new file mode 100644 index 000000000..e1764e2c8 --- /dev/null +++ b/coremltools/test/optimize/torch/test_utils/test_report_utils.py @@ -0,0 +1,314 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from collections import OrderedDict +from typing import Tuple + +import pytest +import torch + +from coremltools.optimize.torch.layerwise_compression import ( + LayerwiseCompressor, + LayerwiseCompressorConfig, +) +from coremltools.optimize.torch.palettization import ( + PostTrainingPalettizer, + PostTrainingPalettizerConfig, +) +from coremltools.optimize.torch.palettization.sensitive_k_means import ( + SKMPalettizer, + SKMPalettizerConfig, +) +from coremltools.optimize.torch.quantization import ( + PostTrainingQuantizer, + PostTrainingQuantizerConfig, +) + + +@pytest.fixture() +def model_for_compression(request) -> torch.nn.Module: + decomposed_multihead_forward = request.param + + class ProjectionModule(torch.nn.Module): + def __init__(self, embed_dim: int, hidden_dim: int): + super().__init__() + self.query = torch.nn.Linear(embed_dim, hidden_dim) + self.key = torch.nn.Linear(embed_dim, hidden_dim) + self.value = torch.nn.Linear(embed_dim, hidden_dim) + + def forward(self, x: torch.Tensor): + return self.query(x), self.key(x), self.value(x) + + if decomposed_multihead_forward: + + class MultiheadWrapper(torch.nn.Module): + def __init__(self, multihead_layer): + super().__init__() + self.layer = multihead_layer + + def forward(self, q, k, v): + return self.layer(q, k, v, need_weights=False)[0] + + else: + + class MultiheadWrapper(torch.nn.Module): + def __init__(self, multihead_layer): + super().__init__() + self.layer = multihead_layer + + def forward(self, x: Tuple[torch.Tensor]): + return self.layer(x[0], x[1], x[2], need_weights=False)[0] + + class LinearWrapper(torch.nn.Module): + def __init__(self, linear_layer): + super().__init__() + self.layer = linear_layer + + def forward(self, x): + out = self.layer(x) + return out.reshape(-1, 100, 10, 10) + + return torch.nn.Sequential( + OrderedDict( + [ + ("embedding", torch.nn.Embedding(100, 100)), + ("projection", ProjectionModule(100, 100)), + ( + "multihead", + MultiheadWrapper(torch.nn.MultiheadAttention(100, 5, batch_first=True)), + ), + ("linear", LinearWrapper(torch.nn.Linear(100, 100))), + ("conv", torch.nn.Conv2d(100, 100, (3, 3), padding=(1, 1))), + ] + ) + ) + + +@pytest.mark.parametrize( + "config, expected_num_columns", + [ + ( + { + "global_config": {"algorithm": "gptq", "weight_dtype": "uint4"}, + "module_name_configs": {"multihead.layer.out_proj": None}, + "input_cacher": "default", + "calibration_nsamples": 128, + }, + 3, + ), + ( + { + "global_config": { + "algorithm": "gptq", + "weight_dtype": "uint4", + "enable_normal_float": True, + }, + "module_name_configs": {"multihead.layer.out_proj": None}, + "input_cacher": "default", + "calibration_nsamples": 128, + }, + 3, + ), + ( + { + "global_config": {"algorithm": "gptq", "weight_dtype": "uint8"}, + "module_name_configs": { + "projection.*": { + "algorithm": "sparse_gpt", + "weight_dtype": "uint8", + "target_sparsity": 0.25, + }, + "multihead.layer.out_proj": None, + }, + "input_cacher": "default", + "calibration_nsamples": 128, + }, + 6, + ), + ], +) +@pytest.mark.parametrize("model_for_compression", [True], indirect=True) +def test_report_layerwise_compressor(model_for_compression, config, expected_num_columns): + config = LayerwiseCompressorConfig.from_dict(config) + compressor = LayerwiseCompressor(model_for_compression, config) + + def compression_loader(): + dataset = torch.utils.data.TensorDataset(torch.randint(0, high=100, size=(100, 100))) + loader = torch.utils.data.DataLoader(dataset, batch_size=10) + for data in loader: + yield data[0] + + compressor.compress(compression_loader(), device="cpu") + report = compressor.report() + print(report) + assert (len(report)) == 5 + expected_params = [ + "projection.query.weight", + "projection.key.weight", + "projection.value.weight", + "linear.layer.weight", + "conv.weight", + ] + for param_name in expected_params: + assert param_name in report + param_report = report[param_name] + assert len(param_report) == expected_num_columns + if not config.global_config.enable_normal_float: + assert param_report["dtype"] == f"dtype=int{config.global_config.weight_n_bits}" + else: + assert ( + param_report["palettization_mode"] + == f"num_clusters={2 ** config.global_config.weight_n_bits}, cluster_dim=1" + ) + + +@pytest.mark.parametrize("quantization_scheme", ["symmetric", "affine"]) +@pytest.mark.parametrize( + "granularity_block_size", + [ + ("per_channel", None), + ("per_tensor", None), + ("per_block", 5), + ], +) +@pytest.mark.parametrize("weight_dtype", ["int4", "int8"]) +@pytest.mark.parametrize("model_for_compression", [True], indirect=True) +def test_report_post_training_quantization( + model_for_compression, + quantization_scheme, + granularity_block_size, + weight_dtype, +): + granularity, block_size = granularity_block_size + config = PostTrainingQuantizerConfig.from_dict( + { + "global_config": { + "weight_dtype": weight_dtype, + "granularity": granularity, + "block_size": block_size, + "quantization_scheme": quantization_scheme, + } + } + ) + compressor = PostTrainingQuantizer(model_for_compression, config) + model = compressor.compress() + + report = compressor.report() + + assert (len(report)) == 7 + for param_name, param in model.named_parameters(): + if "embedding" not in param_name and "bias" not in param_name: + assert param_name in report + param_report = report[param_name] + assert len(param_report) == 3 + assert param_report["dtype"] == f"dtype=int{config.global_config.weight_n_bits}" + + +@pytest.mark.parametrize( + "config", + [ + { + "global_config": {"granularity": "per_tensor", "n_bits": 4}, + }, + { + "global_config": { + "granularity": "per_grouped_channel", + "n_bits": 4, + "group_size": 1, + }, + }, + { + "global_config": { + "granularity": "per_grouped_channel", + "n_bits": 4, + "group_size": 5, + }, + }, + { + "global_config": {"granularity": "per_tensor", "n_bits": 4}, + "module_name_configs": { + "linear.layer": { + "n_bits": 4, + "granularity": "per_tensor", + "cluster_dim": 5, + }, + "conv": { + "n_bits": 4, + "granularity": "per_tensor", + "cluster_dim": 4, + }, + }, + }, + ], +) +@pytest.mark.parametrize("model_for_compression", [True], indirect=True) +def test_report_post_training_palettization(model_for_compression, config): + config = PostTrainingPalettizerConfig.from_dict(config) + compressor = PostTrainingPalettizer(model_for_compression, config) + model = compressor.compress(num_kmeans_workers=1) + + report = compressor.report() + assert (len(report)) == 8 + for param_name, param in model.named_parameters(): + if "bias" not in param_name: + assert param_name in report + param_report = report[param_name] + assert len(param_report) == 3 + assert "num_clusters=16" in param_report["palettization_mode"] + + +@pytest.mark.parametrize( + "config", + [ + { + "global_config": {"granularity": "per_tensor", "n_bits": 6}, + }, + { + "global_config": { + "granularity": "per_grouped_channel", + "n_bits": 8, + "group_size": 1, + }, + }, + { + "global_config": { + "granularity": "per_grouped_channel", + "n_bits": 4, + "group_size": 5, + }, + }, + ], +) +@pytest.mark.parametrize("model_for_compression", [False], indirect=True) +def test_report_skm_palettizer(model_for_compression, config): + config = SKMPalettizerConfig.from_dict(config) + compressor = SKMPalettizer(model_for_compression, config) + + def compression_loader(): + dataset = torch.utils.data.TensorDataset(torch.randint(0, high=100, size=(100, 100))) + loader = torch.utils.data.DataLoader(dataset, batch_size=10) + for data in loader: + yield data[0] + + def loss_fn(model, data): + out = model(data) + return torch.sum(out) + + model = compressor.compress( + dataloader=compression_loader(), + loss_fn=loss_fn, + ) + + report = compressor.report() + assert (len(report)) == 8 + for param_name, param in model.named_parameters(): + if "bias" not in param_name: + assert param_name in report + param_report = report[param_name] + assert len(param_report) == 3 + assert ( + param_report["palettization_mode"] + == f"num_clusters={2 ** config.global_config.n_bits}, cluster_dim=1" + ) diff --git a/coremltools/test/optimize/torch/test_utils/test_validation_utils.py b/coremltools/test/optimize/torch/test_utils/test_validation_utils.py new file mode 100644 index 000000000..ddf4435bd --- /dev/null +++ b/coremltools/test/optimize/torch/test_utils/test_validation_utils.py @@ -0,0 +1,134 @@ +# Copyright (c) 2024, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import pytest +import torch + +from coremltools.optimize.torch._utils.validation_utils import ( + ConfigValidator, + validate_param_config, +) +from coremltools.optimize.torch.palettization import ( + ModulePostTrainingPalettizerConfig, + ModuleSKMPalettizerConfig, +) +from coremltools.optimize.torch.quantization import ModulePostTrainingQuantizerConfig + + +@pytest.mark.parametrize( + "config, expectation", + [ + ( + ModulePostTrainingPalettizerConfig( + n_bits=4, granularity="per_grouped_channel", group_size=4 + ), + True, + ), + ( + ModulePostTrainingPalettizerConfig(n_bits=4, granularity="per_tensor", cluster_dim=3), + False, + ), + ( + ModulePostTrainingPalettizerConfig(n_bits=4, granularity="per_tensor", cluster_dim=5), + True, + ), + ], +) +def test_validate_param_config(config, expectation): + module = torch.nn.Conv2d(16, 32, 5) + result = validate_param_config( + "weight", + module.weight, + config, + ["palettization_group_size", "palettization_cluster_dim"], + ) + if expectation: + assert result is not None + else: + assert result is None + + +def test_validate_no_check(): + module = torch.nn.Conv2d(3, 16, 5) + config = ModuleSKMPalettizerConfig() + validator = ConfigValidator("weight", module.weight, config) + with pytest.raises(AssertionError): + validator.validate(["invalid_check"]) + +@pytest.mark.parametrize( + "group_size, channel_axis, expectation", + [ + pytest.param(4, None, True, id="default_axis"), + pytest.param(4, 0, True, id="axis_0"), + pytest.param(4, 1, True, id="axis_1"), + pytest.param(5, None, False, id="default_indivisible_group_size"), + pytest.param(5, 0, False, id="axis_0_indivisible_group_size"), + pytest.param(5, 1, False, id="axis_1_indivisible_group_size"), + ], +) +def test_validate_palettization_group_size(group_size, channel_axis, expectation): + module = torch.nn.Conv2d(16, 32, 5) + if channel_axis: + config = ModuleSKMPalettizerConfig( + n_bits=4, + granularity="per_grouped_channel", + group_size=group_size, + channel_axis=channel_axis, + ) + else: + config = ModuleSKMPalettizerConfig( + n_bits=4, + granularity="per_grouped_channel", + group_size=group_size, + ) + validator = ConfigValidator("weight", module.weight, config) + assert validator.validate(["palettization_group_size"]) == expectation + + +@pytest.mark.parametrize( + "block_size, sanitized_block_size, expectation", + [ + pytest.param(4, (1, 4), True, id="default_axis_int_block_size"), + pytest.param((1, 4), (1, 4), True, id="tuple_with_per_channel"), + pytest.param((4, 16), (4, 16), True, id="tuple_block_size"), + pytest.param((4, 16, 5, 5), (4, 16), True, id="tuple_block_size_greater_than_ndim"), + pytest.param((0, 16), -1, False, id="per_block_without_per_channel"), + pytest.param((0, 0), -1, False, id="no_blocking_tuple"), + pytest.param(0, -1, False, id="no_blocking_int"), + pytest.param(5, -1, False, id="non_divisible_block_size_int"), + pytest.param((5, 5), -1, False, id="non_divisible_block_size_tuple"), + pytest.param((5, 16), -1, False, id="non_divisible_block_size_tuple_axis_0"), + pytest.param((4, 5), -1, False, id="non_divisible_block_size_tuple_axis_1"), + ], +) +def test_validate_quantization_block_size(block_size, sanitized_block_size, expectation): + module = torch.nn.Conv2d(16, 32, 5) + config = ModulePostTrainingQuantizerConfig( + weight_dtype="int4", granularity="per_block", block_size=block_size + ) + validator = ConfigValidator("weight", module.weight, config) + assert validator.validate(["quantization_block_size"]) == expectation + + if expectation is True: + assert validator.config.block_size == sanitized_block_size + + +@pytest.mark.parametrize( + "cluster_dim, expectation", + [ + pytest.param(None, True, id="cluster_dim_unspecified"), + pytest.param(1, True, id="cluster_dim_scalar"), + pytest.param(3, True, id="cluster_dim_valid_1"), + pytest.param(5, True, id="cluster_dim_valid_2"), + pytest.param(4, False, id="cluster_dim_invalid"), + ], +) +def test_validate_palettization_cluster_dim(cluster_dim, expectation): + module = torch.nn.Conv2d(3, 16, 5) + config = ModulePostTrainingPalettizerConfig( + n_bits=4, granularity="per_tensor", cluster_dim=cluster_dim + ) + validator = ConfigValidator("weight", module.weight, config) + assert validator.validate(["palettization_cluster_dim"]) == expectation diff --git a/coremltools/test/optimize/torch/utils.py b/coremltools/test/optimize/torch/utils.py index 6c50628b7..beec6b808 100644 --- a/coremltools/test/optimize/torch/utils.py +++ b/coremltools/test/optimize/torch/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Apple Inc. All rights reserved. +# Copyright (c) 2024, Apple Inc. All rights reserved. # # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause @@ -6,9 +6,13 @@ import pathlib import sys -from packaging.version import Version +import torch +import torch.nn.functional as F +import torch.utils.data +from packaging import version +# region version_utils def _python_version(): """ Return python version as a tuple of integers @@ -34,23 +38,92 @@ def _macos_version(): ) return tuple([int(v) for v in ver_str.split(".")]) except: - raise Exception("Unable to determine the macOS version") + raise Exception("Unable to detemine the macOS version") return () +def count_unique_params(tensor): + """ + Returns number of unique parameters in the same tensor. + Set a defaulted absolute tolerance, so that very close values can be treated as identical in palletization. + """ + unique_set = {tensor[0]} + for elem in tensor[1:]: + if all(not torch.isclose(elem, uelem, atol=1e-6) for uelem in unique_set): + unique_set.add(elem) + return len(unique_set) + + def version_ge(module, target_version): """ Example usage: >>> import torch # v1.5.0 >>> version_ge(torch, '1.6.0') # False """ - return Version(module.__version__) >= Version(target_version) + return version.parse(module.__version__) >= version.parse(target_version) def version_lt(module, target_version): """See version_ge""" - return Version(module.__version__) < Version(target_version) + return version.parse(module.__version__) < version.parse(target_version) + + +# endregion +# region path_utils def test_data_path(): return pathlib.Path(__file__).parent.absolute() / "_test_data" + + +# endregion + +# region train_utils + + +def setup_data_loaders(dataset, batch_size): + train, test = dataset + train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size) + + return train_loader, test_loader + + +def train_step(model, optimizer, train_loader, data, target, batch_idx, epoch): + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % 100 == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + return loss + + +def eval_model(model, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + accuracy = 100.0 * correct / len(test_loader.dataset) + + print("\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n".format(test_loss, accuracy)) + return accuracy + + +# endregion diff --git a/coremltools/version.py b/coremltools/version.py index 3e31cd3cb..af0652eca 100644 --- a/coremltools/version.py +++ b/coremltools/version.py @@ -4,4 +4,4 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -__version__ = "7.2.1" # VERSION_STRING +__version__ = "8.0b1" # VERSION_STRING diff --git a/docs/source/coremltools.converters.mil.input_types.rst b/docs/source/coremltools.converters.mil.input_types.rst index 06916c236..d8bfa7c5a 100644 --- a/docs/source/coremltools.converters.mil.input_types.rst +++ b/docs/source/coremltools.converters.mil.input_types.rst @@ -47,6 +47,12 @@ Input types supported by the Model Intermediate Language (MIL): :members: + StateType + --------- + .. autoclass:: StateType + :members: + + TensorType ---------- diff --git a/docs/source/coremltools.converters.mil.mil.ops.defs.rst b/docs/source/coremltools.converters.mil.mil.ops.defs.rst index adad2723b..fc2f586d5 100644 --- a/docs/source/coremltools.converters.mil.mil.ops.defs.rst +++ b/docs/source/coremltools.converters.mil.mil.ops.defs.rst @@ -90,6 +90,13 @@ conv (iOS 17+) .. autoclass:: conv .. autoclass:: conv_transpose +coreml_update_state +--------------------------------------------------- + +.. automodule:: coremltools.converters.mil.mil.ops.defs.coreml_dialect.ops + + .. autoclass:: coreml_update_state + elementwise\_binary ------------------------------------------------------------------ @@ -332,6 +339,13 @@ scatter\_gather (iOS 17+) .. autoclass:: scatter_along_axis .. autoclass:: scatter_nd +states (iOS 18+) +-------------------------------------------------------------- + +.. automodule:: coremltools.converters.mil.mil.ops.defs.iOS18.states + + .. autoclass:: read_state + tensor\_operation (iOS 15+) ---------------------------------------------------------------- @@ -415,3 +429,17 @@ tensor\_transformation (iOS 17+) .. autoclass:: squeeze .. autoclass:: transpose +tensor\_transformation (iOS 18+) +--------------------------------------------------------------------- + +.. automodule:: coremltools.converters.mil.mil.ops.defs.iOS18.tensor_transformation + + .. autoclass:: slice_update + +transformers (iOS 18+) +--------------------------------------------------------------------- + +.. automodule:: coremltools.converters.mil.mil.ops.defs.iOS18.transformers + + .. autoclass:: scaled_dot_product_attention + diff --git a/milstoragepython/MilStorage.cpp b/milstoragepython/MilStorage.cpp index 2c702ed74..43fc561a6 100644 --- a/milstoragepython/MilStorage.cpp +++ b/milstoragepython/MilStorage.cpp @@ -4,9 +4,11 @@ // found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause #include "MilStorage.hpp" +#include "MILBlob/SubByteTypes.hpp" #include "MILBlob/Blob/StorageReader.hpp" #include "MILBlob/Blob/StorageWriter.hpp" #include "MILBlob/Util/SpanCast.hpp" +#include "MILBlob/Util/SubByteConversionUtils.hpp" #include @@ -34,11 +36,53 @@ namespace { return m_writer.WriteData(fpSpan); } + template + u_int64_t writeUnsignedSubByteData(MILBlob::Blob::StorageWriter& m_writer, + const py::array_t& data) { + // The `data` is stored in uint8 as numpy doesn't support uint1/2/3/4/6 (denoted as uint{x}). + // First pack those uint{x} data into uint8 Span, and then cast to uint{x} Span and write it. + auto uint8SpanData = MILBlob::Util::Span(data.data(), data.size()); + std::vector packedValues = MILBlob::PackUInt8Span(uint8SpanData); + auto packedValuesSpan = MILBlob::Util::Span(packedValues.data(), packedValues.size()); + auto uintSubByteSpan = MILBlob::Util::CastToBitSpan(packedValuesSpan, data.size()); + return m_writer.WriteData(uintSubByteSpan); + } + } // These methods are needed in addition to the above template methods // because pybind does not allow us to expose template methods to // Python with gcc on Linux. +u_int64_t MilStoragePythonWriter::write_int4_data(const py::array_t& data) { + // The `data` is stored in int8 because numpy doesn't support int4. + // First pack those int4 data into uint8 Span, and then cast to int4 Span and write it. + auto int8SpanData = MILBlob::Util::Span(data.data(), data.size()); + std::vector packedValues = MILBlob::PackInt8Span(int8SpanData); + auto packedValuesSpan = MILBlob::Util::Span(packedValues.data(), packedValues.size()); + auto int4Span = MILBlob::Util::CastToBitSpan(packedValuesSpan, data.size()); + return m_writer->WriteData(int4Span); +} + +u_int64_t MilStoragePythonWriter::write_uint1_data(const py::array_t& data) { + return writeUnsignedSubByteData(*m_writer, data); +} + +u_int64_t MilStoragePythonWriter::write_uint2_data(const py::array_t& data) { + return writeUnsignedSubByteData(*m_writer, data); +} + +u_int64_t MilStoragePythonWriter::write_uint3_data(const py::array_t& data) { + return writeUnsignedSubByteData(*m_writer, data); +} + +u_int64_t MilStoragePythonWriter::write_uint4_data(const py::array_t& data) { + return writeUnsignedSubByteData(*m_writer, data); +} + +u_int64_t MilStoragePythonWriter::write_uint6_data(const py::array_t& data) { + return writeUnsignedSubByteData(*m_writer, data); +} + u_int64_t MilStoragePythonWriter::write_int8_data(const py::array_t& data) { return writeData(*m_writer, data); } @@ -55,6 +99,14 @@ u_int64_t MilStoragePythonWriter::write_uint16_data(const py::array_t& return writeData(*m_writer, data); } +u_int64_t MilStoragePythonWriter::write_int32_data(const py::array_t& data) { + return writeData(*m_writer, data); +} + +u_int64_t MilStoragePythonWriter::write_uint32_data(const py::array_t& data) { + return writeData(*m_writer, data); +} + u_int64_t MilStoragePythonWriter::write_fp16_data(const py::array_t& data){ auto intSpan = MILBlob::Util::Span(data.data(), data.size()); @@ -86,11 +138,49 @@ namespace { auto spanData = m_reader.GetDataView(offset); return py::array_t(spanData.Size(), spanData.Data()); } + + template + py::array_t readUnsignedSubByteData(MILBlob::Blob::StorageReader& m_reader, + uint64_t offset) { + // First read packed data using MILBlob reader, and restore to uint8 values which represents uint{x}. + auto uintSubByteSpanData = m_reader.GetDataView(offset); + MILBlob::Util::Span packedValuesSpan = MILBlob::Util::CastFromBitSpan(uintSubByteSpanData); + auto unpackedUIntSubByteData = MILBlob::UnPackSubByteVec({packedValuesSpan.begin(), packedValuesSpan.end()}, uintSubByteSpanData.Size()); + return py::array_t(unpackedUIntSubByteData.size(), reinterpret_cast(unpackedUIntSubByteData.data())); + } } // These methods are needed in addition to the above template methods // because pybind does not allow us to expose template methods to // Python with gcc on Linux. +py::array_t MilStoragePythonReader::read_int4_data(uint64_t offset) { + // First read packed data using MILBlob reader, and restore to int8 values which represents int4. + auto int4SpanData = m_reader->GetDataView(offset); + MILBlob::Util::Span packedValuesSpan = MILBlob::Util::CastFromBitSpan(int4SpanData); + auto unpackedInt4Data = MILBlob::UnPackSubByteVec({packedValuesSpan.begin(), packedValuesSpan.end()}, int4SpanData.Size()); + return py::array_t(unpackedInt4Data.size(), reinterpret_cast(unpackedInt4Data.data())); +} + +py::array_t MilStoragePythonReader::read_uint1_data(uint64_t offset) { + return readUnsignedSubByteData(*m_reader, offset); +} + +py::array_t MilStoragePythonReader::read_uint2_data(uint64_t offset) { + return readUnsignedSubByteData(*m_reader, offset); +} + +py::array_t MilStoragePythonReader::read_uint3_data(uint64_t offset) { + return readUnsignedSubByteData(*m_reader, offset); +} + +py::array_t MilStoragePythonReader::read_uint4_data(uint64_t offset) { + return readUnsignedSubByteData(*m_reader, offset); +} + +py::array_t MilStoragePythonReader::read_uint6_data(uint64_t offset) { + return readUnsignedSubByteData(*m_reader, offset); +} + py::array_t MilStoragePythonReader::read_int8_data(uint64_t offset) { return readData(*m_reader, offset); } @@ -107,6 +197,14 @@ py::array_t MilStoragePythonReader::read_uint16_data(uint64_t offset) return readData(*m_reader, offset); } +py::array_t MilStoragePythonReader::read_int32_data(uint64_t offset) { + return readData(*m_reader, offset); +} + +py::array_t MilStoragePythonReader::read_uint32_data(uint64_t offset) { + return readData(*m_reader, offset); +} + py::array_t MilStoragePythonReader::read_fp16_data(uint64_t offset) { auto fpView = m_reader->GetDataView(offset); @@ -117,4 +215,4 @@ py::array_t MilStoragePythonReader::read_fp16_data(uint64_t offset) { py::array_t MilStoragePythonReader::read_float_data(uint64_t offset) { return readData(*m_reader, offset); -} \ No newline at end of file +} diff --git a/milstoragepython/MilStorage.hpp b/milstoragepython/MilStorage.hpp index 1e0ac7b89..94c65c33f 100644 --- a/milstoragepython/MilStorage.hpp +++ b/milstoragepython/MilStorage.hpp @@ -34,10 +34,18 @@ namespace CoreML { MilStoragePythonWriter(const std::string& filePath, bool truncateFile); ~MilStoragePythonWriter(); + u_int64_t write_int4_data(const py::array_t& data); + u_int64_t write_uint1_data(const py::array_t& data); + u_int64_t write_uint2_data(const py::array_t& data); + u_int64_t write_uint3_data(const py::array_t& data); + u_int64_t write_uint4_data(const py::array_t& data); + u_int64_t write_uint6_data(const py::array_t& data); u_int64_t write_int8_data(const py::array_t& data); u_int64_t write_uint8_data(const py::array_t& data); u_int64_t write_int16_data(const py::array_t& data); u_int64_t write_uint16_data(const py::array_t& data); + u_int64_t write_int32_data(const py::array_t& data); + u_int64_t write_uint32_data(const py::array_t& data); u_int64_t write_fp16_data(const py::array_t& data); u_int64_t write_float_data(const py::array_t& data); @@ -55,10 +63,18 @@ namespace CoreML { MilStoragePythonReader(std::string filePath); ~MilStoragePythonReader(); + py::array_t read_int4_data(uint64_t offset); + py::array_t read_uint1_data(uint64_t offset); + py::array_t read_uint2_data(uint64_t offset); + py::array_t read_uint3_data(uint64_t offset); + py::array_t read_uint4_data(uint64_t offset); + py::array_t read_uint6_data(uint64_t offset); py::array_t read_int8_data(uint64_t offset); py::array_t read_uint8_data(uint64_t offset); py::array_t read_int16_data(uint64_t offset); py::array_t read_uint16_data(uint64_t offset); + py::array_t read_int32_data(uint64_t offset); + py::array_t read_uint32_data(uint64_t offset); py::array_t read_fp16_data(uint64_t offset); py::array_t read_float_data(uint64_t offset); diff --git a/milstoragepython/MilStoragePython.cpp b/milstoragepython/MilStoragePython.cpp index 3d70e5c4b..730c45a0e 100644 --- a/milstoragepython/MilStoragePython.cpp +++ b/milstoragepython/MilStoragePython.cpp @@ -31,21 +31,39 @@ using namespace CoreML::MilStoragePython; PYBIND11_PLUGIN(libmilstoragepython) { py::module m("libmilstoragepython", "Library to create, access and edit CoreML blob files."); + // As we have both pybind for the same MilStoragePythonWriter class, we need to set module_local + // to avoid conflicts between coremltools and coremltools-internal. py::class_ blobStorageWriter(m, "_BlobStorageWriter", py::module_local()); blobStorageWriter.def(py::init(), py::arg("file_name"), py::arg("truncate_file") = true) + .def("write_int4_data", &MilStoragePythonWriter::write_int4_data) + .def("write_uint1_data", &MilStoragePythonWriter::write_uint1_data) + .def("write_uint2_data", &MilStoragePythonWriter::write_uint2_data) + .def("write_uint3_data", &MilStoragePythonWriter::write_uint3_data) + .def("write_uint4_data", &MilStoragePythonWriter::write_uint4_data) + .def("write_uint6_data", &MilStoragePythonWriter::write_uint6_data) .def("write_int8_data", &MilStoragePythonWriter::write_int8_data) .def("write_uint8_data", &MilStoragePythonWriter::write_uint8_data) .def("write_int16_data", &MilStoragePythonWriter::write_int16_data) .def("write_uint16_data", &MilStoragePythonWriter::write_uint16_data) + .def("write_int32_data", &MilStoragePythonWriter::write_int32_data) + .def("write_uint32_data", &MilStoragePythonWriter::write_uint32_data) .def("write_fp16_data", &MilStoragePythonWriter::write_fp16_data) .def("write_float_data", &MilStoragePythonWriter::write_float_data); py::class_ blobStorageReader(m, "_BlobStorageReader", py::module_local()); blobStorageReader.def(py::init()) + .def("read_int4_data", &MilStoragePythonReader::read_int4_data) + .def("read_uint1_data", &MilStoragePythonReader::read_uint1_data) + .def("read_uint2_data", &MilStoragePythonReader::read_uint2_data) + .def("read_uint3_data", &MilStoragePythonReader::read_uint3_data) + .def("read_uint4_data", &MilStoragePythonReader::read_uint4_data) + .def("read_uint6_data", &MilStoragePythonReader::read_uint6_data) .def("read_int8_data", &MilStoragePythonReader::read_int8_data) .def("read_uint8_data", &MilStoragePythonReader::read_uint8_data) .def("read_int16_data", &MilStoragePythonReader::read_int16_data) .def("read_uint16_data", &MilStoragePythonReader::read_uint16_data) + .def("read_int32_data", &MilStoragePythonReader::read_int32_data) + .def("read_uint32_data", &MilStoragePythonReader::read_uint32_data) .def("read_fp16_data", &MilStoragePythonReader::read_fp16_data) .def("read_float_data", &MilStoragePythonReader::read_float_data); diff --git a/mlmodel/CMakeLists.txt b/mlmodel/CMakeLists.txt index 6d7b7fa66..ef52f181b 100644 --- a/mlmodel/CMakeLists.txt +++ b/mlmodel/CMakeLists.txt @@ -117,6 +117,9 @@ add_library(mlmodel src/TreeEnsembleCommon.cpp src/Utils.cpp + src/MILBlob/SubByteTypeList.hpp + src/MILBlob/SubByteTypes.cpp + src/MILBlob/Fp8.cpp src/MILBlob/Blob/FileWriter.cpp src/MILBlob/Blob/MMapFileReader.cpp src/MILBlob/Blob/MMapFileReaderFactory.cpp diff --git a/mlmodel/build/format/CategoricalMapping.pb.h b/mlmodel/build/format/CategoricalMapping.pb.h index 44b34bf39..f08e618a1 100644 --- a/mlmodel/build/format/CategoricalMapping.pb.h +++ b/mlmodel/build/format/CategoricalMapping.pb.h @@ -110,6 +110,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; diff --git a/mlmodel/build/format/ClassConfidenceThresholding.pb.h b/mlmodel/build/format/ClassConfidenceThresholding.pb.h index 445ef7bf3..d311cacc7 100644 --- a/mlmodel/build/format/ClassConfidenceThresholding.pb.h +++ b/mlmodel/build/format/ClassConfidenceThresholding.pb.h @@ -110,6 +110,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; diff --git a/mlmodel/build/format/DataStructures.pb.h b/mlmodel/build/format/DataStructures.pb.h index 64a835f61..f7acf83ab 100644 --- a/mlmodel/build/format/DataStructures.pb.h +++ b/mlmodel/build/format/DataStructures.pb.h @@ -109,6 +109,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -1542,7 +1545,7 @@ inline ::google::protobuf::int64 Int64Range::minvalue() const { return minvalue_; } inline void Int64Range::set_minvalue(::google::protobuf::int64 value) { - + minvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Int64Range.minValue) } @@ -1556,7 +1559,7 @@ inline ::google::protobuf::int64 Int64Range::maxvalue() const { return maxvalue_; } inline void Int64Range::set_maxvalue(::google::protobuf::int64 value) { - + maxvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Int64Range.maxValue) } @@ -1608,7 +1611,7 @@ inline double DoubleRange::minvalue() const { return minvalue_; } inline void DoubleRange::set_minvalue(double value) { - + minvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.DoubleRange.minValue) } @@ -1622,7 +1625,7 @@ inline double DoubleRange::maxvalue() const { return maxvalue_; } inline void DoubleRange::set_maxvalue(double value) { - + maxvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.DoubleRange.maxValue) } @@ -1645,7 +1648,7 @@ inline const ::CoreML::Specification::FloatVector& PrecisionRecallCurve::precisi : *::CoreML::Specification::FloatVector::internal_default_instance(); } inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::mutable_precisionvalues() { - + if (precisionvalues_ == NULL) { precisionvalues_ = new ::CoreML::Specification::FloatVector; } @@ -1654,7 +1657,7 @@ inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::mutable_preci } inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::release_precisionvalues() { // @@protoc_insertion_point(field_release:CoreML.Specification.PrecisionRecallCurve.precisionValues) - + ::CoreML::Specification::FloatVector* temp = precisionvalues_; precisionvalues_ = NULL; return temp; @@ -1663,9 +1666,9 @@ inline void PrecisionRecallCurve::set_allocated_precisionvalues(::CoreML::Specif delete precisionvalues_; precisionvalues_ = precisionvalues; if (precisionvalues) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.PrecisionRecallCurve.precisionValues) } @@ -1684,7 +1687,7 @@ inline const ::CoreML::Specification::FloatVector& PrecisionRecallCurve::precisi : *::CoreML::Specification::FloatVector::internal_default_instance(); } inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::mutable_precisionconfidencethresholds() { - + if (precisionconfidencethresholds_ == NULL) { precisionconfidencethresholds_ = new ::CoreML::Specification::FloatVector; } @@ -1693,7 +1696,7 @@ inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::mutable_preci } inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::release_precisionconfidencethresholds() { // @@protoc_insertion_point(field_release:CoreML.Specification.PrecisionRecallCurve.precisionConfidenceThresholds) - + ::CoreML::Specification::FloatVector* temp = precisionconfidencethresholds_; precisionconfidencethresholds_ = NULL; return temp; @@ -1702,9 +1705,9 @@ inline void PrecisionRecallCurve::set_allocated_precisionconfidencethresholds(:: delete precisionconfidencethresholds_; precisionconfidencethresholds_ = precisionconfidencethresholds; if (precisionconfidencethresholds) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.PrecisionRecallCurve.precisionConfidenceThresholds) } @@ -1723,7 +1726,7 @@ inline const ::CoreML::Specification::FloatVector& PrecisionRecallCurve::recallv : *::CoreML::Specification::FloatVector::internal_default_instance(); } inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::mutable_recallvalues() { - + if (recallvalues_ == NULL) { recallvalues_ = new ::CoreML::Specification::FloatVector; } @@ -1732,7 +1735,7 @@ inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::mutable_recal } inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::release_recallvalues() { // @@protoc_insertion_point(field_release:CoreML.Specification.PrecisionRecallCurve.recallValues) - + ::CoreML::Specification::FloatVector* temp = recallvalues_; recallvalues_ = NULL; return temp; @@ -1741,9 +1744,9 @@ inline void PrecisionRecallCurve::set_allocated_recallvalues(::CoreML::Specifica delete recallvalues_; recallvalues_ = recallvalues; if (recallvalues) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.PrecisionRecallCurve.recallValues) } @@ -1762,7 +1765,7 @@ inline const ::CoreML::Specification::FloatVector& PrecisionRecallCurve::recallc : *::CoreML::Specification::FloatVector::internal_default_instance(); } inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::mutable_recallconfidencethresholds() { - + if (recallconfidencethresholds_ == NULL) { recallconfidencethresholds_ = new ::CoreML::Specification::FloatVector; } @@ -1771,7 +1774,7 @@ inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::mutable_recal } inline ::CoreML::Specification::FloatVector* PrecisionRecallCurve::release_recallconfidencethresholds() { // @@protoc_insertion_point(field_release:CoreML.Specification.PrecisionRecallCurve.recallConfidenceThresholds) - + ::CoreML::Specification::FloatVector* temp = recallconfidencethresholds_; recallconfidencethresholds_ = NULL; return temp; @@ -1780,9 +1783,9 @@ inline void PrecisionRecallCurve::set_allocated_recallconfidencethresholds(::Cor delete recallconfidencethresholds_; recallconfidencethresholds_ = recallconfidencethresholds; if (recallconfidencethresholds) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.PrecisionRecallCurve.recallConfidenceThresholds) } diff --git a/mlmodel/build/format/DictVectorizer.pb.h b/mlmodel/build/format/DictVectorizer.pb.h index 49db4e424..5fc6635b3 100644 --- a/mlmodel/build/format/DictVectorizer.pb.h +++ b/mlmodel/build/format/DictVectorizer.pb.h @@ -110,6 +110,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; diff --git a/mlmodel/build/format/FeatureTypes.pb.cc b/mlmodel/build/format/FeatureTypes.pb.cc index b5a9ea1cb..68965166a 100644 --- a/mlmodel/build/format/FeatureTypes.pb.cc +++ b/mlmodel/build/format/FeatureTypes.pb.cc @@ -59,6 +59,10 @@ class SequenceFeatureTypeDefaultTypeInternal : public ::google::protobuf::intern const ::CoreML::Specification::Int64FeatureType* int64type_; const ::CoreML::Specification::StringFeatureType* stringtype_; } _SequenceFeatureType_default_instance_; +class StateFeatureTypeDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { + public: + const ::CoreML::Specification::ArrayFeatureType* arraytype_; +} _StateFeatureType_default_instance_; class FeatureTypeDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { public: const ::CoreML::Specification::Int64FeatureType* int64type_; @@ -68,6 +72,7 @@ class FeatureTypeDefaultTypeInternal : public ::google::protobuf::internal::Expl const ::CoreML::Specification::ArrayFeatureType* multiarraytype_; const ::CoreML::Specification::DictionaryFeatureType* dictionarytype_; const ::CoreML::Specification::SequenceFeatureType* sequencetype_; + const ::CoreML::Specification::StateFeatureType* statetype_; } _FeatureType_default_instance_; namespace protobuf_FeatureTypes_2eproto { @@ -98,6 +103,7 @@ PROTOBUF_CONSTEXPR_VAR ::google::protobuf::internal::ParseTable const { NULL, NULL, 0, -1, -1, false }, { NULL, NULL, 0, -1, -1, false }, { NULL, NULL, 0, -1, -1, false }, + { NULL, NULL, 0, -1, -1, false }, }; @@ -116,6 +122,7 @@ void TableStruct::Shutdown() { _ArrayFeatureType_default_instance_.Shutdown(); _DictionaryFeatureType_default_instance_.Shutdown(); _SequenceFeatureType_default_instance_.Shutdown(); + _StateFeatureType_default_instance_.Shutdown(); _FeatureType_default_instance_.Shutdown(); } @@ -137,6 +144,7 @@ void TableStruct::InitDefaultsImpl() { _ArrayFeatureType_default_instance_.DefaultConstruct(); _DictionaryFeatureType_default_instance_.DefaultConstruct(); _SequenceFeatureType_default_instance_.DefaultConstruct(); + _StateFeatureType_default_instance_.DefaultConstruct(); _FeatureType_default_instance_.DefaultConstruct(); _ImageFeatureType_ImageSizeRange_default_instance_.get_mutable()->widthrange_ = const_cast< ::CoreML::Specification::SizeRange*>( ::CoreML::Specification::SizeRange::internal_default_instance()); @@ -870,7 +878,7 @@ ::google::protobuf::uint64 SizeRange::lowerbound() const { return lowerbound_; } void SizeRange::set_lowerbound(::google::protobuf::uint64 value) { - + lowerbound_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SizeRange.lowerBound) } @@ -884,7 +892,7 @@ ::google::protobuf::int64 SizeRange::upperbound() const { return upperbound_; } void SizeRange::set_upperbound(::google::protobuf::int64 value) { - + upperbound_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SizeRange.upperBound) } @@ -1115,7 +1123,7 @@ ::google::protobuf::uint64 ImageFeatureType_ImageSize::width() const { return width_; } void ImageFeatureType_ImageSize::set_width(::google::protobuf::uint64 value) { - + width_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ImageFeatureType.ImageSize.width) } @@ -1129,7 +1137,7 @@ ::google::protobuf::uint64 ImageFeatureType_ImageSize::height() const { return height_; } void ImageFeatureType_ImageSize::set_height(::google::protobuf::uint64 value) { - + height_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ImageFeatureType.ImageSize.height) } @@ -1594,7 +1602,7 @@ const ::CoreML::Specification::SizeRange& ImageFeatureType_ImageSizeRange::width : *::CoreML::Specification::SizeRange::internal_default_instance(); } ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::mutable_widthrange() { - + if (widthrange_ == NULL) { widthrange_ = new ::CoreML::Specification::SizeRange; } @@ -1603,7 +1611,7 @@ ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::mutable_wid } ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::release_widthrange() { // @@protoc_insertion_point(field_release:CoreML.Specification.ImageFeatureType.ImageSizeRange.widthRange) - + ::CoreML::Specification::SizeRange* temp = widthrange_; widthrange_ = NULL; return temp; @@ -1612,9 +1620,9 @@ void ImageFeatureType_ImageSizeRange::set_allocated_widthrange(::CoreML::Specifi delete widthrange_; widthrange_ = widthrange; if (widthrange) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ImageFeatureType.ImageSizeRange.widthRange) } @@ -1633,7 +1641,7 @@ const ::CoreML::Specification::SizeRange& ImageFeatureType_ImageSizeRange::heigh : *::CoreML::Specification::SizeRange::internal_default_instance(); } ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::mutable_heightrange() { - + if (heightrange_ == NULL) { heightrange_ = new ::CoreML::Specification::SizeRange; } @@ -1642,7 +1650,7 @@ ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::mutable_hei } ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::release_heightrange() { // @@protoc_insertion_point(field_release:CoreML.Specification.ImageFeatureType.ImageSizeRange.heightRange) - + ::CoreML::Specification::SizeRange* temp = heightrange_; heightrange_ = NULL; return temp; @@ -1651,9 +1659,9 @@ void ImageFeatureType_ImageSizeRange::set_allocated_heightrange(::CoreML::Specif delete heightrange_; heightrange_ = heightrange; if (heightrange) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ImageFeatureType.ImageSizeRange.heightRange) } @@ -2026,7 +2034,7 @@ ::google::protobuf::int64 ImageFeatureType::width() const { return width_; } void ImageFeatureType::set_width(::google::protobuf::int64 value) { - + width_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ImageFeatureType.width) } @@ -2040,7 +2048,7 @@ ::google::protobuf::int64 ImageFeatureType::height() const { return height_; } void ImageFeatureType::set_height(::google::protobuf::int64 value) { - + height_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ImageFeatureType.height) } @@ -2150,7 +2158,7 @@ ::CoreML::Specification::ImageFeatureType_ColorSpace ImageFeatureType::colorspac return static_cast< ::CoreML::Specification::ImageFeatureType_ColorSpace >(colorspace_); } void ImageFeatureType::set_colorspace(::CoreML::Specification::ImageFeatureType_ColorSpace value) { - + colorspace_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ImageFeatureType.colorSpace) } @@ -3343,7 +3351,7 @@ ::CoreML::Specification::ArrayFeatureType_ArrayDataType ArrayFeatureType::dataty return static_cast< ::CoreML::Specification::ArrayFeatureType_ArrayDataType >(datatype_); } void ArrayFeatureType::set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType value) { - + datatype_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ArrayFeatureType.dataType) } @@ -4323,7 +4331,7 @@ const ::CoreML::Specification::SizeRange& SequenceFeatureType::sizerange() const : *::CoreML::Specification::SizeRange::internal_default_instance(); } ::CoreML::Specification::SizeRange* SequenceFeatureType::mutable_sizerange() { - + if (sizerange_ == NULL) { sizerange_ = new ::CoreML::Specification::SizeRange; } @@ -4332,7 +4340,7 @@ ::CoreML::Specification::SizeRange* SequenceFeatureType::mutable_sizerange() { } ::CoreML::Specification::SizeRange* SequenceFeatureType::release_sizerange() { // @@protoc_insertion_point(field_release:CoreML.Specification.SequenceFeatureType.sizeRange) - + ::CoreML::Specification::SizeRange* temp = sizerange_; sizerange_ = NULL; return temp; @@ -4341,9 +4349,9 @@ void SequenceFeatureType::set_allocated_sizerange(::CoreML::Specification::SizeR delete sizerange_; sizerange_ = sizerange; if (sizerange) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SequenceFeatureType.sizeRange) } @@ -4361,6 +4369,283 @@ SequenceFeatureType::TypeCase SequenceFeatureType::Type_case() const { // =================================================================== +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +const int StateFeatureType::kArrayTypeFieldNumber; +#endif // !defined(_MSC_VER) || _MSC_VER >= 1900 + +StateFeatureType::StateFeatureType() + : ::google::protobuf::MessageLite(), _internal_metadata_(NULL) { + if (GOOGLE_PREDICT_TRUE(this != internal_default_instance())) { + protobuf_FeatureTypes_2eproto::InitDefaults(); + } + SharedCtor(); + // @@protoc_insertion_point(constructor:CoreML.Specification.StateFeatureType) +} +StateFeatureType::StateFeatureType(const StateFeatureType& from) + : ::google::protobuf::MessageLite(), + _internal_metadata_(NULL), + _cached_size_(0) { + _internal_metadata_.MergeFrom(from._internal_metadata_); + clear_has_Type(); + switch (from.Type_case()) { + case kArrayType: { + mutable_arraytype()->::CoreML::Specification::ArrayFeatureType::MergeFrom(from.arraytype()); + break; + } + case TYPE_NOT_SET: { + break; + } + } + // @@protoc_insertion_point(copy_constructor:CoreML.Specification.StateFeatureType) +} + +void StateFeatureType::SharedCtor() { + clear_has_Type(); + _cached_size_ = 0; +} + +StateFeatureType::~StateFeatureType() { + // @@protoc_insertion_point(destructor:CoreML.Specification.StateFeatureType) + SharedDtor(); +} + +void StateFeatureType::SharedDtor() { + if (has_Type()) { + clear_Type(); + } +} + +void StateFeatureType::SetCachedSize(int size) const { + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); +} +const StateFeatureType& StateFeatureType::default_instance() { + protobuf_FeatureTypes_2eproto::InitDefaults(); + return *internal_default_instance(); +} + +StateFeatureType* StateFeatureType::New(::google::protobuf::Arena* arena) const { + StateFeatureType* n = new StateFeatureType; + if (arena != NULL) { + arena->Own(n); + } + return n; +} + +void StateFeatureType::clear_Type() { +// @@protoc_insertion_point(one_of_clear_start:CoreML.Specification.StateFeatureType) + switch (Type_case()) { + case kArrayType: { + delete Type_.arraytype_; + break; + } + case TYPE_NOT_SET: { + break; + } + } + _oneof_case_[0] = TYPE_NOT_SET; +} + + +void StateFeatureType::Clear() { +// @@protoc_insertion_point(message_clear_start:CoreML.Specification.StateFeatureType) + clear_Type(); +} + +bool StateFeatureType::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!GOOGLE_PREDICT_TRUE(EXPRESSION)) goto failure + ::google::protobuf::uint32 tag; + // @@protoc_insertion_point(parse_start:CoreML.Specification.StateFeatureType) + for (;;) { + ::std::pair< ::google::protobuf::uint32, bool> p = input->ReadTagWithCutoffNoLastTag(127u); + tag = p.first; + if (!p.second) goto handle_unusual; + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // .CoreML.Specification.ArrayFeatureType arrayType = 1; + case 1: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(10u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, mutable_arraytype())); + } else { + goto handle_unusual; + } + break; + } + + default: { + handle_unusual: + if (tag == 0 || + ::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { + goto success; + } + DO_(::google::protobuf::internal::WireFormatLite::SkipField(input, tag)); + break; + } + } + } +success: + // @@protoc_insertion_point(parse_success:CoreML.Specification.StateFeatureType) + return true; +failure: + // @@protoc_insertion_point(parse_failure:CoreML.Specification.StateFeatureType) + return false; +#undef DO_ +} + +void StateFeatureType::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // @@protoc_insertion_point(serialize_start:CoreML.Specification.StateFeatureType) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // .CoreML.Specification.ArrayFeatureType arrayType = 1; + if (has_arraytype()) { + ::google::protobuf::internal::WireFormatLite::WriteMessage( + 1, *Type_.arraytype_, output); + } + + // @@protoc_insertion_point(serialize_end:CoreML.Specification.StateFeatureType) +} + +size_t StateFeatureType::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:CoreML.Specification.StateFeatureType) + size_t total_size = 0; + + switch (Type_case()) { + // .CoreML.Specification.ArrayFeatureType arrayType = 1; + case kArrayType: { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + *Type_.arraytype_); + break; + } + case TYPE_NOT_SET: { + break; + } + } + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = cached_size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); + return total_size; +} + +void StateFeatureType::CheckTypeAndMergeFrom( + const ::google::protobuf::MessageLite& from) { + MergeFrom(*::google::protobuf::down_cast(&from)); +} + +void StateFeatureType::MergeFrom(const StateFeatureType& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:CoreML.Specification.StateFeatureType) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom(from._internal_metadata_); + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + switch (from.Type_case()) { + case kArrayType: { + mutable_arraytype()->::CoreML::Specification::ArrayFeatureType::MergeFrom(from.arraytype()); + break; + } + case TYPE_NOT_SET: { + break; + } + } +} + +void StateFeatureType::CopyFrom(const StateFeatureType& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:CoreML.Specification.StateFeatureType) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool StateFeatureType::IsInitialized() const { + return true; +} + +void StateFeatureType::Swap(StateFeatureType* other) { + if (other == this) return; + InternalSwap(other); +} +void StateFeatureType::InternalSwap(StateFeatureType* other) { + std::swap(Type_, other->Type_); + std::swap(_oneof_case_[0], other->_oneof_case_[0]); + std::swap(_cached_size_, other->_cached_size_); +} + +::std::string StateFeatureType::GetTypeName() const { + return "CoreML.Specification.StateFeatureType"; +} + +#if PROTOBUF_INLINE_NOT_IN_HEADERS +// StateFeatureType + +// .CoreML.Specification.ArrayFeatureType arrayType = 1; +bool StateFeatureType::has_arraytype() const { + return Type_case() == kArrayType; +} +void StateFeatureType::set_has_arraytype() { + _oneof_case_[0] = kArrayType; +} +void StateFeatureType::clear_arraytype() { + if (has_arraytype()) { + delete Type_.arraytype_; + clear_has_Type(); + } +} + const ::CoreML::Specification::ArrayFeatureType& StateFeatureType::arraytype() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.StateFeatureType.arrayType) + return has_arraytype() + ? *Type_.arraytype_ + : ::CoreML::Specification::ArrayFeatureType::default_instance(); +} +::CoreML::Specification::ArrayFeatureType* StateFeatureType::mutable_arraytype() { + if (!has_arraytype()) { + clear_Type(); + set_has_arraytype(); + Type_.arraytype_ = new ::CoreML::Specification::ArrayFeatureType; + } + // @@protoc_insertion_point(field_mutable:CoreML.Specification.StateFeatureType.arrayType) + return Type_.arraytype_; +} +::CoreML::Specification::ArrayFeatureType* StateFeatureType::release_arraytype() { + // @@protoc_insertion_point(field_release:CoreML.Specification.StateFeatureType.arrayType) + if (has_arraytype()) { + clear_has_Type(); + ::CoreML::Specification::ArrayFeatureType* temp = Type_.arraytype_; + Type_.arraytype_ = NULL; + return temp; + } else { + return NULL; + } +} +void StateFeatureType::set_allocated_arraytype(::CoreML::Specification::ArrayFeatureType* arraytype) { + clear_Type(); + if (arraytype) { + set_has_arraytype(); + Type_.arraytype_ = arraytype; + } + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.StateFeatureType.arrayType) +} + +bool StateFeatureType::has_Type() const { + return Type_case() != TYPE_NOT_SET; +} +void StateFeatureType::clear_has_Type() { + _oneof_case_[0] = TYPE_NOT_SET; +} +StateFeatureType::TypeCase StateFeatureType::Type_case() const { + return StateFeatureType::TypeCase(_oneof_case_[0]); +} +#endif // PROTOBUF_INLINE_NOT_IN_HEADERS + +// =================================================================== + #if !defined(_MSC_VER) || _MSC_VER >= 1900 const int FeatureType::kInt64TypeFieldNumber; const int FeatureType::kDoubleTypeFieldNumber; @@ -4369,6 +4654,7 @@ const int FeatureType::kImageTypeFieldNumber; const int FeatureType::kMultiArrayTypeFieldNumber; const int FeatureType::kDictionaryTypeFieldNumber; const int FeatureType::kSequenceTypeFieldNumber; +const int FeatureType::kStateTypeFieldNumber; const int FeatureType::kIsOptionalFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 @@ -4416,6 +4702,10 @@ FeatureType::FeatureType(const FeatureType& from) mutable_sequencetype()->::CoreML::Specification::SequenceFeatureType::MergeFrom(from.sequencetype()); break; } + case kStateType: { + mutable_statetype()->::CoreML::Specification::StateFeatureType::MergeFrom(from.statetype()); + break; + } case TYPE_NOT_SET: { break; } @@ -4489,6 +4779,10 @@ void FeatureType::clear_Type() { delete Type_.sequencetype_; break; } + case kStateType: { + delete Type_.statetype_; + break; + } case TYPE_NOT_SET: { break; } @@ -4597,6 +4891,18 @@ bool FeatureType::MergePartialFromCodedStream( break; } + // .CoreML.Specification.StateFeatureType stateType = 8; + case 8: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(66u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, mutable_statetype())); + } else { + goto handle_unusual; + } + break; + } + // bool isOptional = 1000; case 1000: { if (static_cast< ::google::protobuf::uint8>(tag) == @@ -4680,6 +4986,12 @@ void FeatureType::SerializeWithCachedSizes( 7, *Type_.sequencetype_, output); } + // .CoreML.Specification.StateFeatureType stateType = 8; + if (has_statetype()) { + ::google::protobuf::internal::WireFormatLite::WriteMessage( + 8, *Type_.statetype_, output); + } + // bool isOptional = 1000; if (this->isoptional() != 0) { ::google::protobuf::internal::WireFormatLite::WriteBool(1000, this->isoptional(), output); @@ -4747,6 +5059,13 @@ size_t FeatureType::ByteSizeLong() const { *Type_.sequencetype_); break; } + // .CoreML.Specification.StateFeatureType stateType = 8; + case kStateType: { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + *Type_.statetype_); + break; + } case TYPE_NOT_SET: { break; } @@ -4802,6 +5121,10 @@ void FeatureType::MergeFrom(const FeatureType& from) { mutable_sequencetype()->::CoreML::Specification::SequenceFeatureType::MergeFrom(from.sequencetype()); break; } + case kStateType: { + mutable_statetype()->::CoreML::Specification::StateFeatureType::MergeFrom(from.statetype()); + break; + } case TYPE_NOT_SET: { break; } @@ -5173,6 +5496,54 @@ void FeatureType::set_allocated_sequencetype(::CoreML::Specification::SequenceFe // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FeatureType.sequenceType) } +// .CoreML.Specification.StateFeatureType stateType = 8; +bool FeatureType::has_statetype() const { + return Type_case() == kStateType; +} +void FeatureType::set_has_statetype() { + _oneof_case_[0] = kStateType; +} +void FeatureType::clear_statetype() { + if (has_statetype()) { + delete Type_.statetype_; + clear_has_Type(); + } +} + const ::CoreML::Specification::StateFeatureType& FeatureType::statetype() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FeatureType.stateType) + return has_statetype() + ? *Type_.statetype_ + : ::CoreML::Specification::StateFeatureType::default_instance(); +} +::CoreML::Specification::StateFeatureType* FeatureType::mutable_statetype() { + if (!has_statetype()) { + clear_Type(); + set_has_statetype(); + Type_.statetype_ = new ::CoreML::Specification::StateFeatureType; + } + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FeatureType.stateType) + return Type_.statetype_; +} +::CoreML::Specification::StateFeatureType* FeatureType::release_statetype() { + // @@protoc_insertion_point(field_release:CoreML.Specification.FeatureType.stateType) + if (has_statetype()) { + clear_has_Type(); + ::CoreML::Specification::StateFeatureType* temp = Type_.statetype_; + Type_.statetype_ = NULL; + return temp; + } else { + return NULL; + } +} +void FeatureType::set_allocated_statetype(::CoreML::Specification::StateFeatureType* statetype) { + clear_Type(); + if (statetype) { + set_has_statetype(); + Type_.statetype_ = statetype; + } + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FeatureType.stateType) +} + // bool isOptional = 1000; void FeatureType::clear_isoptional() { isoptional_ = false; @@ -5182,7 +5553,7 @@ bool FeatureType::isoptional() const { return isoptional_; } void FeatureType::set_isoptional(bool value) { - + isoptional_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.FeatureType.isOptional) } diff --git a/mlmodel/build/format/FeatureTypes.pb.h b/mlmodel/build/format/FeatureTypes.pb.h index afb8c9943..49a217e0f 100644 --- a/mlmodel/build/format/FeatureTypes.pb.h +++ b/mlmodel/build/format/FeatureTypes.pb.h @@ -74,6 +74,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -1524,6 +1527,103 @@ class SequenceFeatureType : public ::google::protobuf::MessageLite /* @@protoc_i }; // ------------------------------------------------------------------- +class StateFeatureType : public ::google::protobuf::MessageLite /* @@protoc_insertion_point(class_definition:CoreML.Specification.StateFeatureType) */ { + public: + StateFeatureType(); + virtual ~StateFeatureType(); + + StateFeatureType(const StateFeatureType& from); + + inline StateFeatureType& operator=(const StateFeatureType& from) { + CopyFrom(from); + return *this; + } + + static const StateFeatureType& default_instance(); + + enum TypeCase { + kArrayType = 1, + TYPE_NOT_SET = 0, + }; + + static inline const StateFeatureType* internal_default_instance() { + return reinterpret_cast( + &_StateFeatureType_default_instance_); + } + static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = + 14; + + void Swap(StateFeatureType* other); + + // implements Message ---------------------------------------------- + + inline StateFeatureType* New() const PROTOBUF_FINAL { return New(NULL); } + + StateFeatureType* New(::google::protobuf::Arena* arena) const PROTOBUF_FINAL; + void CheckTypeAndMergeFrom(const ::google::protobuf::MessageLite& from) + PROTOBUF_FINAL; + void CopyFrom(const StateFeatureType& from); + void MergeFrom(const StateFeatureType& from); + void Clear() PROTOBUF_FINAL; + bool IsInitialized() const PROTOBUF_FINAL; + + size_t ByteSizeLong() const PROTOBUF_FINAL; + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) PROTOBUF_FINAL; + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const PROTOBUF_FINAL; + void DiscardUnknownFields(); + int GetCachedSize() const PROTOBUF_FINAL { return _cached_size_; } + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const; + void InternalSwap(StateFeatureType* other); + private: + inline ::google::protobuf::Arena* GetArenaNoVirtual() const { + return NULL; + } + inline void* MaybeArenaPtr() const { + return NULL; + } + public: + + ::std::string GetTypeName() const PROTOBUF_FINAL; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // .CoreML.Specification.ArrayFeatureType arrayType = 1; + bool has_arraytype() const; + void clear_arraytype(); + static const int kArrayTypeFieldNumber = 1; + const ::CoreML::Specification::ArrayFeatureType& arraytype() const; + ::CoreML::Specification::ArrayFeatureType* mutable_arraytype(); + ::CoreML::Specification::ArrayFeatureType* release_arraytype(); + void set_allocated_arraytype(::CoreML::Specification::ArrayFeatureType* arraytype); + + TypeCase Type_case() const; + // @@protoc_insertion_point(class_scope:CoreML.Specification.StateFeatureType) + private: + void set_has_arraytype(); + + inline bool has_Type() const; + void clear_Type(); + inline void clear_has_Type(); + + ::google::protobuf::internal::InternalMetadataWithArenaLite _internal_metadata_; + union TypeUnion { + TypeUnion() {} + ::CoreML::Specification::ArrayFeatureType* arraytype_; + } Type_; + mutable int _cached_size_; + ::google::protobuf::uint32 _oneof_case_[1]; + + friend struct protobuf_FeatureTypes_2eproto::TableStruct; +}; +// ------------------------------------------------------------------- + class FeatureType : public ::google::protobuf::MessageLite /* @@protoc_insertion_point(class_definition:CoreML.Specification.FeatureType) */ { public: FeatureType(); @@ -1546,6 +1646,7 @@ class FeatureType : public ::google::protobuf::MessageLite /* @@protoc_insertion kMultiArrayType = 5, kDictionaryType = 6, kSequenceType = 7, + kStateType = 8, TYPE_NOT_SET = 0, }; @@ -1554,7 +1655,7 @@ class FeatureType : public ::google::protobuf::MessageLite /* @@protoc_insertion &_FeatureType_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 14; + 15; void Swap(FeatureType* other); @@ -1666,6 +1767,15 @@ class FeatureType : public ::google::protobuf::MessageLite /* @@protoc_insertion ::CoreML::Specification::SequenceFeatureType* release_sequencetype(); void set_allocated_sequencetype(::CoreML::Specification::SequenceFeatureType* sequencetype); + // .CoreML.Specification.StateFeatureType stateType = 8; + bool has_statetype() const; + void clear_statetype(); + static const int kStateTypeFieldNumber = 8; + const ::CoreML::Specification::StateFeatureType& statetype() const; + ::CoreML::Specification::StateFeatureType* mutable_statetype(); + ::CoreML::Specification::StateFeatureType* release_statetype(); + void set_allocated_statetype(::CoreML::Specification::StateFeatureType* statetype); + TypeCase Type_case() const; // @@protoc_insertion_point(class_scope:CoreML.Specification.FeatureType) private: @@ -1676,6 +1786,7 @@ class FeatureType : public ::google::protobuf::MessageLite /* @@protoc_insertion void set_has_multiarraytype(); void set_has_dictionarytype(); void set_has_sequencetype(); + void set_has_statetype(); inline bool has_Type() const; void clear_Type(); @@ -1692,6 +1803,7 @@ class FeatureType : public ::google::protobuf::MessageLite /* @@protoc_insertion ::CoreML::Specification::ArrayFeatureType* multiarraytype_; ::CoreML::Specification::DictionaryFeatureType* dictionarytype_; ::CoreML::Specification::SequenceFeatureType* sequencetype_; + ::CoreML::Specification::StateFeatureType* statetype_; } Type_; mutable int _cached_size_; ::google::protobuf::uint32 _oneof_case_[1]; @@ -1727,7 +1839,7 @@ inline ::google::protobuf::uint64 SizeRange::lowerbound() const { return lowerbound_; } inline void SizeRange::set_lowerbound(::google::protobuf::uint64 value) { - + lowerbound_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SizeRange.lowerBound) } @@ -1741,7 +1853,7 @@ inline ::google::protobuf::int64 SizeRange::upperbound() const { return upperbound_; } inline void SizeRange::set_upperbound(::google::protobuf::int64 value) { - + upperbound_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SizeRange.upperBound) } @@ -1759,7 +1871,7 @@ inline ::google::protobuf::uint64 ImageFeatureType_ImageSize::width() const { return width_; } inline void ImageFeatureType_ImageSize::set_width(::google::protobuf::uint64 value) { - + width_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ImageFeatureType.ImageSize.width) } @@ -1773,7 +1885,7 @@ inline ::google::protobuf::uint64 ImageFeatureType_ImageSize::height() const { return height_; } inline void ImageFeatureType_ImageSize::set_height(::google::protobuf::uint64 value) { - + height_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ImageFeatureType.ImageSize.height) } @@ -1830,7 +1942,7 @@ inline const ::CoreML::Specification::SizeRange& ImageFeatureType_ImageSizeRange : *::CoreML::Specification::SizeRange::internal_default_instance(); } inline ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::mutable_widthrange() { - + if (widthrange_ == NULL) { widthrange_ = new ::CoreML::Specification::SizeRange; } @@ -1839,7 +1951,7 @@ inline ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::muta } inline ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::release_widthrange() { // @@protoc_insertion_point(field_release:CoreML.Specification.ImageFeatureType.ImageSizeRange.widthRange) - + ::CoreML::Specification::SizeRange* temp = widthrange_; widthrange_ = NULL; return temp; @@ -1848,9 +1960,9 @@ inline void ImageFeatureType_ImageSizeRange::set_allocated_widthrange(::CoreML:: delete widthrange_; widthrange_ = widthrange; if (widthrange) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ImageFeatureType.ImageSizeRange.widthRange) } @@ -1869,7 +1981,7 @@ inline const ::CoreML::Specification::SizeRange& ImageFeatureType_ImageSizeRange : *::CoreML::Specification::SizeRange::internal_default_instance(); } inline ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::mutable_heightrange() { - + if (heightrange_ == NULL) { heightrange_ = new ::CoreML::Specification::SizeRange; } @@ -1878,7 +1990,7 @@ inline ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::muta } inline ::CoreML::Specification::SizeRange* ImageFeatureType_ImageSizeRange::release_heightrange() { // @@protoc_insertion_point(field_release:CoreML.Specification.ImageFeatureType.ImageSizeRange.heightRange) - + ::CoreML::Specification::SizeRange* temp = heightrange_; heightrange_ = NULL; return temp; @@ -1887,9 +1999,9 @@ inline void ImageFeatureType_ImageSizeRange::set_allocated_heightrange(::CoreML: delete heightrange_; heightrange_ = heightrange; if (heightrange) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ImageFeatureType.ImageSizeRange.heightRange) } @@ -1907,7 +2019,7 @@ inline ::google::protobuf::int64 ImageFeatureType::width() const { return width_; } inline void ImageFeatureType::set_width(::google::protobuf::int64 value) { - + width_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ImageFeatureType.width) } @@ -1921,7 +2033,7 @@ inline ::google::protobuf::int64 ImageFeatureType::height() const { return height_; } inline void ImageFeatureType::set_height(::google::protobuf::int64 value) { - + height_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ImageFeatureType.height) } @@ -2031,7 +2143,7 @@ inline ::CoreML::Specification::ImageFeatureType_ColorSpace ImageFeatureType::co return static_cast< ::CoreML::Specification::ImageFeatureType_ColorSpace >(colorspace_); } inline void ImageFeatureType::set_colorspace(::CoreML::Specification::ImageFeatureType_ColorSpace value) { - + colorspace_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ImageFeatureType.colorSpace) } @@ -2190,7 +2302,7 @@ inline ::CoreML::Specification::ArrayFeatureType_ArrayDataType ArrayFeatureType: return static_cast< ::CoreML::Specification::ArrayFeatureType_ArrayDataType >(datatype_); } inline void ArrayFeatureType::set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType value) { - + datatype_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ArrayFeatureType.dataType) } @@ -2619,7 +2731,7 @@ inline const ::CoreML::Specification::SizeRange& SequenceFeatureType::sizerange( : *::CoreML::Specification::SizeRange::internal_default_instance(); } inline ::CoreML::Specification::SizeRange* SequenceFeatureType::mutable_sizerange() { - + if (sizerange_ == NULL) { sizerange_ = new ::CoreML::Specification::SizeRange; } @@ -2628,7 +2740,7 @@ inline ::CoreML::Specification::SizeRange* SequenceFeatureType::mutable_sizerang } inline ::CoreML::Specification::SizeRange* SequenceFeatureType::release_sizerange() { // @@protoc_insertion_point(field_release:CoreML.Specification.SequenceFeatureType.sizeRange) - + ::CoreML::Specification::SizeRange* temp = sizerange_; sizerange_ = NULL; return temp; @@ -2637,9 +2749,9 @@ inline void SequenceFeatureType::set_allocated_sizerange(::CoreML::Specification delete sizerange_; sizerange_ = sizerange; if (sizerange) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SequenceFeatureType.sizeRange) } @@ -2655,6 +2767,67 @@ inline SequenceFeatureType::TypeCase SequenceFeatureType::Type_case() const { } // ------------------------------------------------------------------- +// StateFeatureType + +// .CoreML.Specification.ArrayFeatureType arrayType = 1; +inline bool StateFeatureType::has_arraytype() const { + return Type_case() == kArrayType; +} +inline void StateFeatureType::set_has_arraytype() { + _oneof_case_[0] = kArrayType; +} +inline void StateFeatureType::clear_arraytype() { + if (has_arraytype()) { + delete Type_.arraytype_; + clear_has_Type(); + } +} +inline const ::CoreML::Specification::ArrayFeatureType& StateFeatureType::arraytype() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.StateFeatureType.arrayType) + return has_arraytype() + ? *Type_.arraytype_ + : ::CoreML::Specification::ArrayFeatureType::default_instance(); +} +inline ::CoreML::Specification::ArrayFeatureType* StateFeatureType::mutable_arraytype() { + if (!has_arraytype()) { + clear_Type(); + set_has_arraytype(); + Type_.arraytype_ = new ::CoreML::Specification::ArrayFeatureType; + } + // @@protoc_insertion_point(field_mutable:CoreML.Specification.StateFeatureType.arrayType) + return Type_.arraytype_; +} +inline ::CoreML::Specification::ArrayFeatureType* StateFeatureType::release_arraytype() { + // @@protoc_insertion_point(field_release:CoreML.Specification.StateFeatureType.arrayType) + if (has_arraytype()) { + clear_has_Type(); + ::CoreML::Specification::ArrayFeatureType* temp = Type_.arraytype_; + Type_.arraytype_ = NULL; + return temp; + } else { + return NULL; + } +} +inline void StateFeatureType::set_allocated_arraytype(::CoreML::Specification::ArrayFeatureType* arraytype) { + clear_Type(); + if (arraytype) { + set_has_arraytype(); + Type_.arraytype_ = arraytype; + } + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.StateFeatureType.arrayType) +} + +inline bool StateFeatureType::has_Type() const { + return Type_case() != TYPE_NOT_SET; +} +inline void StateFeatureType::clear_has_Type() { + _oneof_case_[0] = TYPE_NOT_SET; +} +inline StateFeatureType::TypeCase StateFeatureType::Type_case() const { + return StateFeatureType::TypeCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + // FeatureType // .CoreML.Specification.Int64FeatureType int64Type = 1; @@ -2993,6 +3166,54 @@ inline void FeatureType::set_allocated_sequencetype(::CoreML::Specification::Seq // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FeatureType.sequenceType) } +// .CoreML.Specification.StateFeatureType stateType = 8; +inline bool FeatureType::has_statetype() const { + return Type_case() == kStateType; +} +inline void FeatureType::set_has_statetype() { + _oneof_case_[0] = kStateType; +} +inline void FeatureType::clear_statetype() { + if (has_statetype()) { + delete Type_.statetype_; + clear_has_Type(); + } +} +inline const ::CoreML::Specification::StateFeatureType& FeatureType::statetype() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FeatureType.stateType) + return has_statetype() + ? *Type_.statetype_ + : ::CoreML::Specification::StateFeatureType::default_instance(); +} +inline ::CoreML::Specification::StateFeatureType* FeatureType::mutable_statetype() { + if (!has_statetype()) { + clear_Type(); + set_has_statetype(); + Type_.statetype_ = new ::CoreML::Specification::StateFeatureType; + } + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FeatureType.stateType) + return Type_.statetype_; +} +inline ::CoreML::Specification::StateFeatureType* FeatureType::release_statetype() { + // @@protoc_insertion_point(field_release:CoreML.Specification.FeatureType.stateType) + if (has_statetype()) { + clear_has_Type(); + ::CoreML::Specification::StateFeatureType* temp = Type_.statetype_; + Type_.statetype_ = NULL; + return temp; + } else { + return NULL; + } +} +inline void FeatureType::set_allocated_statetype(::CoreML::Specification::StateFeatureType* statetype) { + clear_Type(); + if (statetype) { + set_has_statetype(); + Type_.statetype_ = statetype; + } + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FeatureType.stateType) +} + // bool isOptional = 1000; inline void FeatureType::clear_isoptional() { isoptional_ = false; @@ -3002,7 +3223,7 @@ inline bool FeatureType::isoptional() const { return isoptional_; } inline void FeatureType::set_isoptional(bool value) { - + isoptional_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.FeatureType.isOptional) } @@ -3045,6 +3266,8 @@ inline FeatureType::TypeCase FeatureType::Type_case() const { // ------------------------------------------------------------------- +// ------------------------------------------------------------------- + // @@protoc_insertion_point(namespace_scope) diff --git a/mlmodel/build/format/FeatureTypes_enums.h b/mlmodel/build/format/FeatureTypes_enums.h index 3095f492d..09ba0aff6 100644 --- a/mlmodel/build/format/FeatureTypes_enums.h +++ b/mlmodel/build/format/FeatureTypes_enums.h @@ -114,6 +114,22 @@ static const char * MLSequenceFeatureTypeType_Name(MLSequenceFeatureTypeType x) return "INVALID"; } +enum MLStateFeatureTypeType: int { + MLStateFeatureTypeType_arrayType = 1, + MLStateFeatureTypeType_NOT_SET = 0, +}; + +__attribute__((__unused__)) +static const char * MLStateFeatureTypeType_Name(MLStateFeatureTypeType x) { + switch (x) { + case MLStateFeatureTypeType_arrayType: + return "MLStateFeatureTypeType_arrayType"; + case MLStateFeatureTypeType_NOT_SET: + return "INVALID"; + } + return "INVALID"; +} + enum MLFeatureTypeType: int { MLFeatureTypeType_int64Type = 1, MLFeatureTypeType_doubleType = 2, @@ -122,6 +138,7 @@ enum MLFeatureTypeType: int { MLFeatureTypeType_multiArrayType = 5, MLFeatureTypeType_dictionaryType = 6, MLFeatureTypeType_sequenceType = 7, + MLFeatureTypeType_stateType = 8, MLFeatureTypeType_NOT_SET = 0, }; @@ -142,6 +159,8 @@ static const char * MLFeatureTypeType_Name(MLFeatureTypeType x) { return "MLFeatureTypeType_dictionaryType"; case MLFeatureTypeType_sequenceType: return "MLFeatureTypeType_sequenceType"; + case MLFeatureTypeType_stateType: + return "MLFeatureTypeType_stateType"; case MLFeatureTypeType_NOT_SET: return "INVALID"; } diff --git a/mlmodel/build/format/GLMClassifier.pb.h b/mlmodel/build/format/GLMClassifier.pb.h index b5703a55f..fab03b9e3 100644 --- a/mlmodel/build/format/GLMClassifier.pb.h +++ b/mlmodel/build/format/GLMClassifier.pb.h @@ -114,6 +114,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -552,7 +555,7 @@ inline ::CoreML::Specification::GLMClassifier_PostEvaluationTransform GLMClassif return static_cast< ::CoreML::Specification::GLMClassifier_PostEvaluationTransform >(postevaluationtransform_); } inline void GLMClassifier::set_postevaluationtransform(::CoreML::Specification::GLMClassifier_PostEvaluationTransform value) { - + postevaluationtransform_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GLMClassifier.postEvaluationTransform) } @@ -566,7 +569,7 @@ inline ::CoreML::Specification::GLMClassifier_ClassEncoding GLMClassifier::class return static_cast< ::CoreML::Specification::GLMClassifier_ClassEncoding >(classencoding_); } inline void GLMClassifier::set_classencoding(::CoreML::Specification::GLMClassifier_ClassEncoding value) { - + classencoding_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GLMClassifier.classEncoding) } diff --git a/mlmodel/build/format/Gazetteer.pb.h b/mlmodel/build/format/Gazetteer.pb.h index a14305271..7b9dabdc7 100644 --- a/mlmodel/build/format/Gazetteer.pb.h +++ b/mlmodel/build/format/Gazetteer.pb.h @@ -107,6 +107,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -302,7 +305,7 @@ inline ::google::protobuf::uint32 Gazetteer::revision() const { return revision_; } inline void Gazetteer::set_revision(::google::protobuf::uint32 value) { - + revision_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.Gazetteer.revision) } @@ -316,13 +319,13 @@ inline const ::std::string& Gazetteer::language() const { return language_.GetNoArena(); } inline void Gazetteer::set_language(const ::std::string& value) { - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.Gazetteer.language) } #if LANG_CXX11 inline void Gazetteer::set_language(::std::string&& value) { - + language_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.Gazetteer.language) @@ -330,31 +333,31 @@ inline void Gazetteer::set_language(::std::string&& value) { #endif inline void Gazetteer::set_language(const char* value) { GOOGLE_DCHECK(value != NULL); - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.Gazetteer.language) } inline void Gazetteer::set_language(const char* value, size_t size) { - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.Gazetteer.language) } inline ::std::string* Gazetteer::mutable_language() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.Gazetteer.language) return language_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Gazetteer::release_language() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.Gazetteer.language) - + return language_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Gazetteer::set_allocated_language(::std::string* language) { if (language != NULL) { - + } else { - + } language_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), language); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.Gazetteer.language) @@ -369,13 +372,13 @@ inline const ::std::string& Gazetteer::modelparameterdata() const { return modelparameterdata_.GetNoArena(); } inline void Gazetteer::set_modelparameterdata(const ::std::string& value) { - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.Gazetteer.modelParameterData) } #if LANG_CXX11 inline void Gazetteer::set_modelparameterdata(::std::string&& value) { - + modelparameterdata_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.Gazetteer.modelParameterData) @@ -383,31 +386,31 @@ inline void Gazetteer::set_modelparameterdata(::std::string&& value) { #endif inline void Gazetteer::set_modelparameterdata(const char* value) { GOOGLE_DCHECK(value != NULL); - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.Gazetteer.modelParameterData) } inline void Gazetteer::set_modelparameterdata(const void* value, size_t size) { - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.Gazetteer.modelParameterData) } inline ::std::string* Gazetteer::mutable_modelparameterdata() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.Gazetteer.modelParameterData) return modelparameterdata_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Gazetteer::release_modelparameterdata() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.Gazetteer.modelParameterData) - + return modelparameterdata_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Gazetteer::set_allocated_modelparameterdata(::std::string* modelparameterdata) { if (modelparameterdata != NULL) { - + } else { - + } modelparameterdata_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), modelparameterdata); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.Gazetteer.modelParameterData) diff --git a/mlmodel/build/format/Imputer.pb.h b/mlmodel/build/format/Imputer.pb.h index 256ef8e26..fea5f51cb 100644 --- a/mlmodel/build/format/Imputer.pb.h +++ b/mlmodel/build/format/Imputer.pb.h @@ -110,6 +110,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; diff --git a/mlmodel/build/format/ItemSimilarityRecommender.pb.h b/mlmodel/build/format/ItemSimilarityRecommender.pb.h index 7a4ff1b73..2f5d678b9 100644 --- a/mlmodel/build/format/ItemSimilarityRecommender.pb.h +++ b/mlmodel/build/format/ItemSimilarityRecommender.pb.h @@ -116,6 +116,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -551,7 +554,7 @@ inline ::google::protobuf::uint64 ItemSimilarityRecommender_ConnectedItem::itemi return itemid_; } inline void ItemSimilarityRecommender_ConnectedItem::set_itemid(::google::protobuf::uint64 value) { - + itemid_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ItemSimilarityRecommender.ConnectedItem.itemId) } @@ -565,7 +568,7 @@ inline double ItemSimilarityRecommender_ConnectedItem::similarityscore() const { return similarityscore_; } inline void ItemSimilarityRecommender_ConnectedItem::set_similarityscore(double value) { - + similarityscore_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ItemSimilarityRecommender.ConnectedItem.similarityScore) } @@ -583,7 +586,7 @@ inline ::google::protobuf::uint64 ItemSimilarityRecommender_SimilarItems::itemid return itemid_; } inline void ItemSimilarityRecommender_SimilarItems::set_itemid(::google::protobuf::uint64 value) { - + itemid_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ItemSimilarityRecommender.SimilarItems.itemId) } @@ -627,7 +630,7 @@ inline double ItemSimilarityRecommender_SimilarItems::itemscoreadjustment() cons return itemscoreadjustment_; } inline void ItemSimilarityRecommender_SimilarItems::set_itemscoreadjustment(double value) { - + itemscoreadjustment_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ItemSimilarityRecommender.SimilarItems.itemScoreAdjustment) } @@ -680,7 +683,7 @@ inline const ::CoreML::Specification::StringVector& ItemSimilarityRecommender::i : *::CoreML::Specification::StringVector::internal_default_instance(); } inline ::CoreML::Specification::StringVector* ItemSimilarityRecommender::mutable_itemstringids() { - + if (itemstringids_ == NULL) { itemstringids_ = new ::CoreML::Specification::StringVector; } @@ -689,7 +692,7 @@ inline ::CoreML::Specification::StringVector* ItemSimilarityRecommender::mutable } inline ::CoreML::Specification::StringVector* ItemSimilarityRecommender::release_itemstringids() { // @@protoc_insertion_point(field_release:CoreML.Specification.ItemSimilarityRecommender.itemStringIds) - + ::CoreML::Specification::StringVector* temp = itemstringids_; itemstringids_ = NULL; return temp; @@ -698,9 +701,9 @@ inline void ItemSimilarityRecommender::set_allocated_itemstringids(::CoreML::Spe delete itemstringids_; itemstringids_ = itemstringids; if (itemstringids) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ItemSimilarityRecommender.itemStringIds) } @@ -719,7 +722,7 @@ inline const ::CoreML::Specification::Int64Vector& ItemSimilarityRecommender::it : *::CoreML::Specification::Int64Vector::internal_default_instance(); } inline ::CoreML::Specification::Int64Vector* ItemSimilarityRecommender::mutable_itemint64ids() { - + if (itemint64ids_ == NULL) { itemint64ids_ = new ::CoreML::Specification::Int64Vector; } @@ -728,7 +731,7 @@ inline ::CoreML::Specification::Int64Vector* ItemSimilarityRecommender::mutable_ } inline ::CoreML::Specification::Int64Vector* ItemSimilarityRecommender::release_itemint64ids() { // @@protoc_insertion_point(field_release:CoreML.Specification.ItemSimilarityRecommender.itemInt64Ids) - + ::CoreML::Specification::Int64Vector* temp = itemint64ids_; itemint64ids_ = NULL; return temp; @@ -737,9 +740,9 @@ inline void ItemSimilarityRecommender::set_allocated_itemint64ids(::CoreML::Spec delete itemint64ids_; itemint64ids_ = itemint64ids; if (itemint64ids) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ItemSimilarityRecommender.itemInt64Ids) } @@ -753,13 +756,13 @@ inline const ::std::string& ItemSimilarityRecommender::iteminputfeaturename() co return iteminputfeaturename_.GetNoArena(); } inline void ItemSimilarityRecommender::set_iteminputfeaturename(const ::std::string& value) { - + iteminputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.ItemSimilarityRecommender.itemInputFeatureName) } #if LANG_CXX11 inline void ItemSimilarityRecommender::set_iteminputfeaturename(::std::string&& value) { - + iteminputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ItemSimilarityRecommender.itemInputFeatureName) @@ -767,31 +770,31 @@ inline void ItemSimilarityRecommender::set_iteminputfeaturename(::std::string&& #endif inline void ItemSimilarityRecommender::set_iteminputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + iteminputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.ItemSimilarityRecommender.itemInputFeatureName) } inline void ItemSimilarityRecommender::set_iteminputfeaturename(const char* value, size_t size) { - + iteminputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ItemSimilarityRecommender.itemInputFeatureName) } inline ::std::string* ItemSimilarityRecommender::mutable_iteminputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ItemSimilarityRecommender.itemInputFeatureName) return iteminputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* ItemSimilarityRecommender::release_iteminputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.ItemSimilarityRecommender.itemInputFeatureName) - + return iteminputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void ItemSimilarityRecommender::set_allocated_iteminputfeaturename(::std::string* iteminputfeaturename) { if (iteminputfeaturename != NULL) { - + } else { - + } iteminputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), iteminputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ItemSimilarityRecommender.itemInputFeatureName) @@ -806,13 +809,13 @@ inline const ::std::string& ItemSimilarityRecommender::numrecommendationsinputfe return numrecommendationsinputfeaturename_.GetNoArena(); } inline void ItemSimilarityRecommender::set_numrecommendationsinputfeaturename(const ::std::string& value) { - + numrecommendationsinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.ItemSimilarityRecommender.numRecommendationsInputFeatureName) } #if LANG_CXX11 inline void ItemSimilarityRecommender::set_numrecommendationsinputfeaturename(::std::string&& value) { - + numrecommendationsinputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ItemSimilarityRecommender.numRecommendationsInputFeatureName) @@ -820,31 +823,31 @@ inline void ItemSimilarityRecommender::set_numrecommendationsinputfeaturename(:: #endif inline void ItemSimilarityRecommender::set_numrecommendationsinputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + numrecommendationsinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.ItemSimilarityRecommender.numRecommendationsInputFeatureName) } inline void ItemSimilarityRecommender::set_numrecommendationsinputfeaturename(const char* value, size_t size) { - + numrecommendationsinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ItemSimilarityRecommender.numRecommendationsInputFeatureName) } inline ::std::string* ItemSimilarityRecommender::mutable_numrecommendationsinputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ItemSimilarityRecommender.numRecommendationsInputFeatureName) return numrecommendationsinputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* ItemSimilarityRecommender::release_numrecommendationsinputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.ItemSimilarityRecommender.numRecommendationsInputFeatureName) - + return numrecommendationsinputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void ItemSimilarityRecommender::set_allocated_numrecommendationsinputfeaturename(::std::string* numrecommendationsinputfeaturename) { if (numrecommendationsinputfeaturename != NULL) { - + } else { - + } numrecommendationsinputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), numrecommendationsinputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ItemSimilarityRecommender.numRecommendationsInputFeatureName) @@ -859,13 +862,13 @@ inline const ::std::string& ItemSimilarityRecommender::itemrestrictioninputfeatu return itemrestrictioninputfeaturename_.GetNoArena(); } inline void ItemSimilarityRecommender::set_itemrestrictioninputfeaturename(const ::std::string& value) { - + itemrestrictioninputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.ItemSimilarityRecommender.itemRestrictionInputFeatureName) } #if LANG_CXX11 inline void ItemSimilarityRecommender::set_itemrestrictioninputfeaturename(::std::string&& value) { - + itemrestrictioninputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ItemSimilarityRecommender.itemRestrictionInputFeatureName) @@ -873,31 +876,31 @@ inline void ItemSimilarityRecommender::set_itemrestrictioninputfeaturename(::std #endif inline void ItemSimilarityRecommender::set_itemrestrictioninputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + itemrestrictioninputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.ItemSimilarityRecommender.itemRestrictionInputFeatureName) } inline void ItemSimilarityRecommender::set_itemrestrictioninputfeaturename(const char* value, size_t size) { - + itemrestrictioninputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ItemSimilarityRecommender.itemRestrictionInputFeatureName) } inline ::std::string* ItemSimilarityRecommender::mutable_itemrestrictioninputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ItemSimilarityRecommender.itemRestrictionInputFeatureName) return itemrestrictioninputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* ItemSimilarityRecommender::release_itemrestrictioninputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.ItemSimilarityRecommender.itemRestrictionInputFeatureName) - + return itemrestrictioninputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void ItemSimilarityRecommender::set_allocated_itemrestrictioninputfeaturename(::std::string* itemrestrictioninputfeaturename) { if (itemrestrictioninputfeaturename != NULL) { - + } else { - + } itemrestrictioninputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), itemrestrictioninputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ItemSimilarityRecommender.itemRestrictionInputFeatureName) @@ -912,13 +915,13 @@ inline const ::std::string& ItemSimilarityRecommender::itemexclusioninputfeature return itemexclusioninputfeaturename_.GetNoArena(); } inline void ItemSimilarityRecommender::set_itemexclusioninputfeaturename(const ::std::string& value) { - + itemexclusioninputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.ItemSimilarityRecommender.itemExclusionInputFeatureName) } #if LANG_CXX11 inline void ItemSimilarityRecommender::set_itemexclusioninputfeaturename(::std::string&& value) { - + itemexclusioninputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ItemSimilarityRecommender.itemExclusionInputFeatureName) @@ -926,31 +929,31 @@ inline void ItemSimilarityRecommender::set_itemexclusioninputfeaturename(::std:: #endif inline void ItemSimilarityRecommender::set_itemexclusioninputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + itemexclusioninputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.ItemSimilarityRecommender.itemExclusionInputFeatureName) } inline void ItemSimilarityRecommender::set_itemexclusioninputfeaturename(const char* value, size_t size) { - + itemexclusioninputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ItemSimilarityRecommender.itemExclusionInputFeatureName) } inline ::std::string* ItemSimilarityRecommender::mutable_itemexclusioninputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ItemSimilarityRecommender.itemExclusionInputFeatureName) return itemexclusioninputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* ItemSimilarityRecommender::release_itemexclusioninputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.ItemSimilarityRecommender.itemExclusionInputFeatureName) - + return itemexclusioninputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void ItemSimilarityRecommender::set_allocated_itemexclusioninputfeaturename(::std::string* itemexclusioninputfeaturename) { if (itemexclusioninputfeaturename != NULL) { - + } else { - + } itemexclusioninputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), itemexclusioninputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ItemSimilarityRecommender.itemExclusionInputFeatureName) @@ -965,13 +968,13 @@ inline const ::std::string& ItemSimilarityRecommender::recommendeditemlistoutput return recommendeditemlistoutputfeaturename_.GetNoArena(); } inline void ItemSimilarityRecommender::set_recommendeditemlistoutputfeaturename(const ::std::string& value) { - + recommendeditemlistoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.ItemSimilarityRecommender.recommendedItemListOutputFeatureName) } #if LANG_CXX11 inline void ItemSimilarityRecommender::set_recommendeditemlistoutputfeaturename(::std::string&& value) { - + recommendeditemlistoutputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ItemSimilarityRecommender.recommendedItemListOutputFeatureName) @@ -979,31 +982,31 @@ inline void ItemSimilarityRecommender::set_recommendeditemlistoutputfeaturename( #endif inline void ItemSimilarityRecommender::set_recommendeditemlistoutputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + recommendeditemlistoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.ItemSimilarityRecommender.recommendedItemListOutputFeatureName) } inline void ItemSimilarityRecommender::set_recommendeditemlistoutputfeaturename(const char* value, size_t size) { - + recommendeditemlistoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ItemSimilarityRecommender.recommendedItemListOutputFeatureName) } inline ::std::string* ItemSimilarityRecommender::mutable_recommendeditemlistoutputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ItemSimilarityRecommender.recommendedItemListOutputFeatureName) return recommendeditemlistoutputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* ItemSimilarityRecommender::release_recommendeditemlistoutputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.ItemSimilarityRecommender.recommendedItemListOutputFeatureName) - + return recommendeditemlistoutputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void ItemSimilarityRecommender::set_allocated_recommendeditemlistoutputfeaturename(::std::string* recommendeditemlistoutputfeaturename) { if (recommendeditemlistoutputfeaturename != NULL) { - + } else { - + } recommendeditemlistoutputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), recommendeditemlistoutputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ItemSimilarityRecommender.recommendedItemListOutputFeatureName) @@ -1018,13 +1021,13 @@ inline const ::std::string& ItemSimilarityRecommender::recommendeditemscoreoutpu return recommendeditemscoreoutputfeaturename_.GetNoArena(); } inline void ItemSimilarityRecommender::set_recommendeditemscoreoutputfeaturename(const ::std::string& value) { - + recommendeditemscoreoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.ItemSimilarityRecommender.recommendedItemScoreOutputFeatureName) } #if LANG_CXX11 inline void ItemSimilarityRecommender::set_recommendeditemscoreoutputfeaturename(::std::string&& value) { - + recommendeditemscoreoutputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ItemSimilarityRecommender.recommendedItemScoreOutputFeatureName) @@ -1032,31 +1035,31 @@ inline void ItemSimilarityRecommender::set_recommendeditemscoreoutputfeaturename #endif inline void ItemSimilarityRecommender::set_recommendeditemscoreoutputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + recommendeditemscoreoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.ItemSimilarityRecommender.recommendedItemScoreOutputFeatureName) } inline void ItemSimilarityRecommender::set_recommendeditemscoreoutputfeaturename(const char* value, size_t size) { - + recommendeditemscoreoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ItemSimilarityRecommender.recommendedItemScoreOutputFeatureName) } inline ::std::string* ItemSimilarityRecommender::mutable_recommendeditemscoreoutputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ItemSimilarityRecommender.recommendedItemScoreOutputFeatureName) return recommendeditemscoreoutputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* ItemSimilarityRecommender::release_recommendeditemscoreoutputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.ItemSimilarityRecommender.recommendedItemScoreOutputFeatureName) - + return recommendeditemscoreoutputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void ItemSimilarityRecommender::set_allocated_recommendeditemscoreoutputfeaturename(::std::string* recommendeditemscoreoutputfeaturename) { if (recommendeditemscoreoutputfeaturename != NULL) { - + } else { - + } recommendeditemscoreoutputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), recommendeditemscoreoutputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ItemSimilarityRecommender.recommendedItemScoreOutputFeatureName) diff --git a/mlmodel/build/format/LinkedModel.pb.h b/mlmodel/build/format/LinkedModel.pb.h index 7881b618e..6527ec533 100644 --- a/mlmodel/build/format/LinkedModel.pb.h +++ b/mlmodel/build/format/LinkedModel.pb.h @@ -122,6 +122,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -434,7 +437,7 @@ inline const ::CoreML::Specification::StringParameter& LinkedModelFile::linkedmo : *::CoreML::Specification::StringParameter::internal_default_instance(); } inline ::CoreML::Specification::StringParameter* LinkedModelFile::mutable_linkedmodelfilename() { - + if (linkedmodelfilename_ == NULL) { linkedmodelfilename_ = new ::CoreML::Specification::StringParameter; } @@ -443,7 +446,7 @@ inline ::CoreML::Specification::StringParameter* LinkedModelFile::mutable_linked } inline ::CoreML::Specification::StringParameter* LinkedModelFile::release_linkedmodelfilename() { // @@protoc_insertion_point(field_release:CoreML.Specification.LinkedModelFile.linkedModelFileName) - + ::CoreML::Specification::StringParameter* temp = linkedmodelfilename_; linkedmodelfilename_ = NULL; return temp; @@ -452,9 +455,9 @@ inline void LinkedModelFile::set_allocated_linkedmodelfilename(::CoreML::Specifi delete linkedmodelfilename_; linkedmodelfilename_ = linkedmodelfilename; if (linkedmodelfilename) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LinkedModelFile.linkedModelFileName) } @@ -473,7 +476,7 @@ inline const ::CoreML::Specification::StringParameter& LinkedModelFile::linkedmo : *::CoreML::Specification::StringParameter::internal_default_instance(); } inline ::CoreML::Specification::StringParameter* LinkedModelFile::mutable_linkedmodelsearchpath() { - + if (linkedmodelsearchpath_ == NULL) { linkedmodelsearchpath_ = new ::CoreML::Specification::StringParameter; } @@ -482,7 +485,7 @@ inline ::CoreML::Specification::StringParameter* LinkedModelFile::mutable_linked } inline ::CoreML::Specification::StringParameter* LinkedModelFile::release_linkedmodelsearchpath() { // @@protoc_insertion_point(field_release:CoreML.Specification.LinkedModelFile.linkedModelSearchPath) - + ::CoreML::Specification::StringParameter* temp = linkedmodelsearchpath_; linkedmodelsearchpath_ = NULL; return temp; @@ -491,9 +494,9 @@ inline void LinkedModelFile::set_allocated_linkedmodelsearchpath(::CoreML::Speci delete linkedmodelsearchpath_; linkedmodelsearchpath_ = linkedmodelsearchpath; if (linkedmodelsearchpath) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LinkedModelFile.linkedModelSearchPath) } diff --git a/mlmodel/build/format/MIL.pb.cc b/mlmodel/build/format/MIL.pb.cc index c265cff06..884ab24e8 100644 --- a/mlmodel/build/format/MIL.pb.cc +++ b/mlmodel/build/format/MIL.pb.cc @@ -54,6 +54,7 @@ class ValueTypeDefaultTypeInternal : public ::google::protobuf::internal::Explic const ::CoreML::Specification::MILSpec::ListType* listtype_; const ::CoreML::Specification::MILSpec::TupleType* tupletype_; const ::CoreML::Specification::MILSpec::DictionaryType* dictionarytype_; + const ::CoreML::Specification::MILSpec::StateType* statetype_; } _ValueType_default_instance_; class TensorType_AttributesEntryDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { } _TensorType_AttributesEntry_default_instance_; @@ -65,6 +66,8 @@ class ListTypeDefaultTypeInternal : public ::google::protobuf::internal::Explici } _ListType_default_instance_; class DictionaryTypeDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { } _DictionaryType_default_instance_; +class StateTypeDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { +} _StateType_default_instance_; class Dimension_ConstantDimensionDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { } _Dimension_ConstantDimension_default_instance_; class Dimension_UnknownDimensionDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { @@ -172,6 +175,7 @@ PROTOBUF_CONSTEXPR_VAR ::google::protobuf::internal::ParseTable const { NULL, NULL, 0, -1, -1, false }, { NULL, NULL, 0, -1, -1, false }, { NULL, NULL, 0, -1, -1, false }, + { NULL, NULL, 0, -1, -1, false }, }; @@ -188,6 +192,7 @@ void TableStruct::Shutdown() { _TupleType_default_instance_.Shutdown(); _ListType_default_instance_.Shutdown(); _DictionaryType_default_instance_.Shutdown(); + _StateType_default_instance_.Shutdown(); _Dimension_ConstantDimension_default_instance_.Shutdown(); _Dimension_UnknownDimension_default_instance_.Shutdown(); _Dimension_default_instance_.Shutdown(); @@ -232,6 +237,7 @@ void TableStruct::InitDefaultsImpl() { _TupleType_default_instance_.DefaultConstruct(); _ListType_default_instance_.DefaultConstruct(); _DictionaryType_default_instance_.DefaultConstruct(); + _StateType_default_instance_.DefaultConstruct(); _Dimension_ConstantDimension_default_instance_.DefaultConstruct(); _Dimension_UnknownDimension_default_instance_.DefaultConstruct(); _Dimension_default_instance_.DefaultConstruct(); @@ -276,6 +282,8 @@ void TableStruct::InitDefaultsImpl() { ::CoreML::Specification::MILSpec::ValueType::internal_default_instance()); _DictionaryType_default_instance_.get_mutable()->valuetype_ = const_cast< ::CoreML::Specification::MILSpec::ValueType*>( ::CoreML::Specification::MILSpec::ValueType::internal_default_instance()); + _StateType_default_instance_.get_mutable()->wrappedtype_ = const_cast< ::CoreML::Specification::MILSpec::ValueType*>( + ::CoreML::Specification::MILSpec::ValueType::internal_default_instance()); _Value_default_instance_.get_mutable()->type_ = const_cast< ::CoreML::Specification::MILSpec::ValueType*>( ::CoreML::Specification::MILSpec::ValueType::internal_default_instance()); _DictionaryValue_KeyValuePair_default_instance_.get_mutable()->key_ = const_cast< ::CoreML::Specification::MILSpec::Value*>( @@ -321,10 +329,18 @@ bool DataType_IsValid(int value) { case 22: case 23: case 24: + case 25: case 31: case 32: case 33: case 34: + case 35: + case 36: + case 37: + case 38: + case 39: + case 40: + case 41: return true; default: return false; @@ -759,7 +775,7 @@ ::google::protobuf::int64 Program::version() const { return version_; } void Program::set_version(::google::protobuf::int64 value) { - + version_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Program.version) } @@ -791,13 +807,13 @@ const ::std::string& Program::docstring() const { return docstring_.GetNoArena(); } void Program::set_docstring(const ::std::string& value) { - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Program.docString) } #if LANG_CXX11 void Program::set_docstring(::std::string&& value) { - + docstring_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.Program.docString) @@ -805,31 +821,31 @@ void Program::set_docstring(::std::string&& value) { #endif void Program::set_docstring(const char* value) { GOOGLE_DCHECK(value != NULL); - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.Program.docString) } void Program::set_docstring(const char* value, size_t size) { - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.Program.docString) } ::std::string* Program::mutable_docstring() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.Program.docString) return docstring_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* Program::release_docstring() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Program.docString) - + return docstring_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void Program::set_allocated_docstring(::std::string* docstring) { if (docstring != NULL) { - + } else { - + } docstring_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), docstring); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Program.docString) @@ -1312,13 +1328,13 @@ const ::std::string& Function::opset() const { return opset_.GetNoArena(); } void Function::set_opset(const ::std::string& value) { - + opset_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Function.opset) } #if LANG_CXX11 void Function::set_opset(::std::string&& value) { - + opset_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.Function.opset) @@ -1326,31 +1342,31 @@ void Function::set_opset(::std::string&& value) { #endif void Function::set_opset(const char* value) { GOOGLE_DCHECK(value != NULL); - + opset_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.Function.opset) } void Function::set_opset(const char* value, size_t size) { - + opset_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.Function.opset) } ::std::string* Function::mutable_opset() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.Function.opset) return opset_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* Function::release_opset() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Function.opset) - + return opset_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void Function::set_allocated_opset(::std::string* opset) { if (opset != NULL) { - + } else { - + } opset_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), opset); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Function.opset) @@ -2984,13 +3000,13 @@ const ::std::string& Operation::type() const { return type_.GetNoArena(); } void Operation::set_type(const ::std::string& value) { - + type_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Operation.type) } #if LANG_CXX11 void Operation::set_type(::std::string&& value) { - + type_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.Operation.type) @@ -2998,31 +3014,31 @@ void Operation::set_type(::std::string&& value) { #endif void Operation::set_type(const char* value) { GOOGLE_DCHECK(value != NULL); - + type_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.Operation.type) } void Operation::set_type(const char* value, size_t size) { - + type_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.Operation.type) } ::std::string* Operation::mutable_type() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.Operation.type) return type_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* Operation::release_type() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Operation.type) - + return type_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void Operation::set_allocated_type(::std::string* type) { if (type != NULL) { - + } else { - + } type_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), type); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Operation.type) @@ -3370,13 +3386,13 @@ const ::std::string& NamedValueType::name() const { return name_.GetNoArena(); } void NamedValueType::set_name(const ::std::string& value) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.NamedValueType.name) } #if LANG_CXX11 void NamedValueType::set_name(::std::string&& value) { - + name_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.NamedValueType.name) @@ -3384,31 +3400,31 @@ void NamedValueType::set_name(::std::string&& value) { #endif void NamedValueType::set_name(const char* value) { GOOGLE_DCHECK(value != NULL); - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.NamedValueType.name) } void NamedValueType::set_name(const char* value, size_t size) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.NamedValueType.name) } ::std::string* NamedValueType::mutable_name() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.NamedValueType.name) return name_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* NamedValueType::release_name() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.NamedValueType.name) - + return name_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void NamedValueType::set_allocated_name(::std::string* name) { if (name != NULL) { - + } else { - + } name_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), name); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.NamedValueType.name) @@ -3428,7 +3444,7 @@ const ::CoreML::Specification::MILSpec::ValueType& NamedValueType::type() const : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); } ::CoreML::Specification::MILSpec::ValueType* NamedValueType::mutable_type() { - + if (type_ == NULL) { type_ = new ::CoreML::Specification::MILSpec::ValueType; } @@ -3437,7 +3453,7 @@ ::CoreML::Specification::MILSpec::ValueType* NamedValueType::mutable_type() { } ::CoreML::Specification::MILSpec::ValueType* NamedValueType::release_type() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.NamedValueType.type) - + ::CoreML::Specification::MILSpec::ValueType* temp = type_; type_ = NULL; return temp; @@ -3446,9 +3462,9 @@ void NamedValueType::set_allocated_type(::CoreML::Specification::MILSpec::ValueT delete type_; type_ = type; if (type) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.NamedValueType.type) } @@ -3462,6 +3478,7 @@ const int ValueType::kTensorTypeFieldNumber; const int ValueType::kListTypeFieldNumber; const int ValueType::kTupleTypeFieldNumber; const int ValueType::kDictionaryTypeFieldNumber; +const int ValueType::kStateTypeFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 ValueType::ValueType() @@ -3495,6 +3512,10 @@ ValueType::ValueType(const ValueType& from) mutable_dictionarytype()->::CoreML::Specification::MILSpec::DictionaryType::MergeFrom(from.dictionarytype()); break; } + case kStateType: { + mutable_statetype()->::CoreML::Specification::MILSpec::StateType::MergeFrom(from.statetype()); + break; + } case TYPE_NOT_SET: { break; } @@ -3555,6 +3576,10 @@ void ValueType::clear_type() { delete type_.dictionarytype_; break; } + case kStateType: { + delete type_.statetype_; + break; + } case TYPE_NOT_SET: { break; } @@ -3626,6 +3651,18 @@ bool ValueType::MergePartialFromCodedStream( break; } + // .CoreML.Specification.MILSpec.StateType stateType = 5; + case 5: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(42u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, mutable_statetype())); + } else { + goto handle_unusual; + } + break; + } + default: { handle_unusual: if (tag == 0 || @@ -3677,6 +3714,12 @@ void ValueType::SerializeWithCachedSizes( 4, *type_.dictionarytype_, output); } + // .CoreML.Specification.MILSpec.StateType stateType = 5; + if (has_statetype()) { + ::google::protobuf::internal::WireFormatLite::WriteMessage( + 5, *type_.statetype_, output); + } + // @@protoc_insertion_point(serialize_end:CoreML.Specification.MILSpec.ValueType) } @@ -3713,6 +3756,13 @@ size_t ValueType::ByteSizeLong() const { *type_.dictionarytype_); break; } + // .CoreML.Specification.MILSpec.StateType stateType = 5; + case kStateType: { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + *type_.statetype_); + break; + } case TYPE_NOT_SET: { break; } @@ -3753,6 +3803,10 @@ void ValueType::MergeFrom(const ValueType& from) { mutable_dictionarytype()->::CoreML::Specification::MILSpec::DictionaryType::MergeFrom(from.dictionarytype()); break; } + case kStateType: { + mutable_statetype()->::CoreML::Specification::MILSpec::StateType::MergeFrom(from.statetype()); + break; + } case TYPE_NOT_SET: { break; } @@ -3979,6 +4033,54 @@ void ValueType::set_allocated_dictionarytype(::CoreML::Specification::MILSpec::D // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.ValueType.dictionaryType) } +// .CoreML.Specification.MILSpec.StateType stateType = 5; +bool ValueType::has_statetype() const { + return type_case() == kStateType; +} +void ValueType::set_has_statetype() { + _oneof_case_[0] = kStateType; +} +void ValueType::clear_statetype() { + if (has_statetype()) { + delete type_.statetype_; + clear_has_type(); + } +} + const ::CoreML::Specification::MILSpec::StateType& ValueType::statetype() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.MILSpec.ValueType.stateType) + return has_statetype() + ? *type_.statetype_ + : ::CoreML::Specification::MILSpec::StateType::default_instance(); +} +::CoreML::Specification::MILSpec::StateType* ValueType::mutable_statetype() { + if (!has_statetype()) { + clear_type(); + set_has_statetype(); + type_.statetype_ = new ::CoreML::Specification::MILSpec::StateType; + } + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.ValueType.stateType) + return type_.statetype_; +} +::CoreML::Specification::MILSpec::StateType* ValueType::release_statetype() { + // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.ValueType.stateType) + if (has_statetype()) { + clear_has_type(); + ::CoreML::Specification::MILSpec::StateType* temp = type_.statetype_; + type_.statetype_ = NULL; + return temp; + } else { + return NULL; + } +} +void ValueType::set_allocated_statetype(::CoreML::Specification::MILSpec::StateType* statetype) { + clear_type(); + if (statetype) { + set_has_statetype(); + type_.statetype_ = statetype; + } + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.ValueType.stateType) +} + bool ValueType::has_type() const { return type_case() != TYPE_NOT_SET; } @@ -4345,7 +4447,7 @@ ::CoreML::Specification::MILSpec::DataType TensorType::datatype() const { return static_cast< ::CoreML::Specification::MILSpec::DataType >(datatype_); } void TensorType::set_datatype(::CoreML::Specification::MILSpec::DataType value) { - + datatype_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.TensorType.dataType) } @@ -4359,7 +4461,7 @@ ::google::protobuf::int64 TensorType::rank() const { return rank_; } void TensorType::set_rank(::google::protobuf::int64 value) { - + rank_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.TensorType.rank) } @@ -4872,7 +4974,7 @@ const ::CoreML::Specification::MILSpec::ValueType& ListType::type() const { : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); } ::CoreML::Specification::MILSpec::ValueType* ListType::mutable_type() { - + if (type_ == NULL) { type_ = new ::CoreML::Specification::MILSpec::ValueType; } @@ -4881,7 +4983,7 @@ ::CoreML::Specification::MILSpec::ValueType* ListType::mutable_type() { } ::CoreML::Specification::MILSpec::ValueType* ListType::release_type() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.ListType.type) - + ::CoreML::Specification::MILSpec::ValueType* temp = type_; type_ = NULL; return temp; @@ -4890,9 +4992,9 @@ void ListType::set_allocated_type(::CoreML::Specification::MILSpec::ValueType* t delete type_; type_ = type; if (type) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.ListType.type) } @@ -4911,7 +5013,7 @@ const ::CoreML::Specification::MILSpec::Dimension& ListType::length() const { : *::CoreML::Specification::MILSpec::Dimension::internal_default_instance(); } ::CoreML::Specification::MILSpec::Dimension* ListType::mutable_length() { - + if (length_ == NULL) { length_ = new ::CoreML::Specification::MILSpec::Dimension; } @@ -4920,7 +5022,7 @@ ::CoreML::Specification::MILSpec::Dimension* ListType::mutable_length() { } ::CoreML::Specification::MILSpec::Dimension* ListType::release_length() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.ListType.length) - + ::CoreML::Specification::MILSpec::Dimension* temp = length_; length_ = NULL; return temp; @@ -4929,9 +5031,9 @@ void ListType::set_allocated_length(::CoreML::Specification::MILSpec::Dimension* delete length_; length_ = length; if (length) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.ListType.length) } @@ -5184,7 +5286,7 @@ const ::CoreML::Specification::MILSpec::ValueType& DictionaryType::keytype() con : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); } ::CoreML::Specification::MILSpec::ValueType* DictionaryType::mutable_keytype() { - + if (keytype_ == NULL) { keytype_ = new ::CoreML::Specification::MILSpec::ValueType; } @@ -5193,7 +5295,7 @@ ::CoreML::Specification::MILSpec::ValueType* DictionaryType::mutable_keytype() { } ::CoreML::Specification::MILSpec::ValueType* DictionaryType::release_keytype() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.DictionaryType.keyType) - + ::CoreML::Specification::MILSpec::ValueType* temp = keytype_; keytype_ = NULL; return temp; @@ -5202,9 +5304,9 @@ void DictionaryType::set_allocated_keytype(::CoreML::Specification::MILSpec::Val delete keytype_; keytype_ = keytype; if (keytype) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.DictionaryType.keyType) } @@ -5223,7 +5325,7 @@ const ::CoreML::Specification::MILSpec::ValueType& DictionaryType::valuetype() c : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); } ::CoreML::Specification::MILSpec::ValueType* DictionaryType::mutable_valuetype() { - + if (valuetype_ == NULL) { valuetype_ = new ::CoreML::Specification::MILSpec::ValueType; } @@ -5232,7 +5334,7 @@ ::CoreML::Specification::MILSpec::ValueType* DictionaryType::mutable_valuetype() } ::CoreML::Specification::MILSpec::ValueType* DictionaryType::release_valuetype() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.DictionaryType.valueType) - + ::CoreML::Specification::MILSpec::ValueType* temp = valuetype_; valuetype_ = NULL; return temp; @@ -5241,9 +5343,9 @@ void DictionaryType::set_allocated_valuetype(::CoreML::Specification::MILSpec::V delete valuetype_; valuetype_ = valuetype; if (valuetype) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.DictionaryType.valueType) } @@ -5252,6 +5354,236 @@ void DictionaryType::set_allocated_valuetype(::CoreML::Specification::MILSpec::V // =================================================================== +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +const int StateType::kWrappedTypeFieldNumber; +#endif // !defined(_MSC_VER) || _MSC_VER >= 1900 + +StateType::StateType() + : ::google::protobuf::MessageLite(), _internal_metadata_(NULL) { + if (GOOGLE_PREDICT_TRUE(this != internal_default_instance())) { + protobuf_MIL_2eproto::InitDefaults(); + } + SharedCtor(); + // @@protoc_insertion_point(constructor:CoreML.Specification.MILSpec.StateType) +} +StateType::StateType(const StateType& from) + : ::google::protobuf::MessageLite(), + _internal_metadata_(NULL), + _cached_size_(0) { + _internal_metadata_.MergeFrom(from._internal_metadata_); + if (from.has_wrappedtype()) { + wrappedtype_ = new ::CoreML::Specification::MILSpec::ValueType(*from.wrappedtype_); + } else { + wrappedtype_ = NULL; + } + // @@protoc_insertion_point(copy_constructor:CoreML.Specification.MILSpec.StateType) +} + +void StateType::SharedCtor() { + wrappedtype_ = NULL; + _cached_size_ = 0; +} + +StateType::~StateType() { + // @@protoc_insertion_point(destructor:CoreML.Specification.MILSpec.StateType) + SharedDtor(); +} + +void StateType::SharedDtor() { + if (this != internal_default_instance()) { + delete wrappedtype_; + } +} + +void StateType::SetCachedSize(int size) const { + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); +} +const StateType& StateType::default_instance() { + protobuf_MIL_2eproto::InitDefaults(); + return *internal_default_instance(); +} + +StateType* StateType::New(::google::protobuf::Arena* arena) const { + StateType* n = new StateType; + if (arena != NULL) { + arena->Own(n); + } + return n; +} + +void StateType::Clear() { +// @@protoc_insertion_point(message_clear_start:CoreML.Specification.MILSpec.StateType) + if (GetArenaNoVirtual() == NULL && wrappedtype_ != NULL) { + delete wrappedtype_; + } + wrappedtype_ = NULL; +} + +bool StateType::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!GOOGLE_PREDICT_TRUE(EXPRESSION)) goto failure + ::google::protobuf::uint32 tag; + // @@protoc_insertion_point(parse_start:CoreML.Specification.MILSpec.StateType) + for (;;) { + ::std::pair< ::google::protobuf::uint32, bool> p = input->ReadTagWithCutoffNoLastTag(127u); + tag = p.first; + if (!p.second) goto handle_unusual; + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // .CoreML.Specification.MILSpec.ValueType wrappedType = 1; + case 1: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(10u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, mutable_wrappedtype())); + } else { + goto handle_unusual; + } + break; + } + + default: { + handle_unusual: + if (tag == 0 || + ::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { + goto success; + } + DO_(::google::protobuf::internal::WireFormatLite::SkipField(input, tag)); + break; + } + } + } +success: + // @@protoc_insertion_point(parse_success:CoreML.Specification.MILSpec.StateType) + return true; +failure: + // @@protoc_insertion_point(parse_failure:CoreML.Specification.MILSpec.StateType) + return false; +#undef DO_ +} + +void StateType::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // @@protoc_insertion_point(serialize_start:CoreML.Specification.MILSpec.StateType) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // .CoreML.Specification.MILSpec.ValueType wrappedType = 1; + if (this->has_wrappedtype()) { + ::google::protobuf::internal::WireFormatLite::WriteMessage( + 1, *this->wrappedtype_, output); + } + + // @@protoc_insertion_point(serialize_end:CoreML.Specification.MILSpec.StateType) +} + +size_t StateType::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:CoreML.Specification.MILSpec.StateType) + size_t total_size = 0; + + // .CoreML.Specification.MILSpec.ValueType wrappedType = 1; + if (this->has_wrappedtype()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + *this->wrappedtype_); + } + + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = cached_size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); + return total_size; +} + +void StateType::CheckTypeAndMergeFrom( + const ::google::protobuf::MessageLite& from) { + MergeFrom(*::google::protobuf::down_cast(&from)); +} + +void StateType::MergeFrom(const StateType& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:CoreML.Specification.MILSpec.StateType) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom(from._internal_metadata_); + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.has_wrappedtype()) { + mutable_wrappedtype()->::CoreML::Specification::MILSpec::ValueType::MergeFrom(from.wrappedtype()); + } +} + +void StateType::CopyFrom(const StateType& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:CoreML.Specification.MILSpec.StateType) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool StateType::IsInitialized() const { + return true; +} + +void StateType::Swap(StateType* other) { + if (other == this) return; + InternalSwap(other); +} +void StateType::InternalSwap(StateType* other) { + std::swap(wrappedtype_, other->wrappedtype_); + std::swap(_cached_size_, other->_cached_size_); +} + +::std::string StateType::GetTypeName() const { + return "CoreML.Specification.MILSpec.StateType"; +} + +#if PROTOBUF_INLINE_NOT_IN_HEADERS +// StateType + +// .CoreML.Specification.MILSpec.ValueType wrappedType = 1; +bool StateType::has_wrappedtype() const { + return this != internal_default_instance() && wrappedtype_ != NULL; +} +void StateType::clear_wrappedtype() { + if (GetArenaNoVirtual() == NULL && wrappedtype_ != NULL) delete wrappedtype_; + wrappedtype_ = NULL; +} +const ::CoreML::Specification::MILSpec::ValueType& StateType::wrappedtype() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.MILSpec.StateType.wrappedType) + return wrappedtype_ != NULL ? *wrappedtype_ + : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); +} +::CoreML::Specification::MILSpec::ValueType* StateType::mutable_wrappedtype() { + + if (wrappedtype_ == NULL) { + wrappedtype_ = new ::CoreML::Specification::MILSpec::ValueType; + } + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.StateType.wrappedType) + return wrappedtype_; +} +::CoreML::Specification::MILSpec::ValueType* StateType::release_wrappedtype() { + // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.StateType.wrappedType) + + ::CoreML::Specification::MILSpec::ValueType* temp = wrappedtype_; + wrappedtype_ = NULL; + return temp; +} +void StateType::set_allocated_wrappedtype(::CoreML::Specification::MILSpec::ValueType* wrappedtype) { + delete wrappedtype_; + wrappedtype_ = wrappedtype; + if (wrappedtype) { + + } else { + + } + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.StateType.wrappedType) +} + +#endif // PROTOBUF_INLINE_NOT_IN_HEADERS + +// =================================================================== + #if !defined(_MSC_VER) || _MSC_VER >= 1900 const int Dimension_ConstantDimension::kSizeFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 @@ -5439,7 +5771,7 @@ ::google::protobuf::uint64 Dimension_ConstantDimension::size() const { return size_; } void Dimension_ConstantDimension::set_size(::google::protobuf::uint64 value) { - + size_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Dimension.ConstantDimension.size) } @@ -5633,7 +5965,7 @@ bool Dimension_UnknownDimension::variadic() const { return variadic_; } void Dimension_UnknownDimension::set_variadic(bool value) { - + variadic_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Dimension.UnknownDimension.variadic) } @@ -6773,13 +7105,13 @@ const ::std::string& Value_BlobFileValue::filename() const { return filename_.GetNoArena(); } void Value_BlobFileValue::set_filename(const ::std::string& value) { - + filename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) } #if LANG_CXX11 void Value_BlobFileValue::set_filename(::std::string&& value) { - + filename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) @@ -6787,31 +7119,31 @@ void Value_BlobFileValue::set_filename(::std::string&& value) { #endif void Value_BlobFileValue::set_filename(const char* value) { GOOGLE_DCHECK(value != NULL); - + filename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) } void Value_BlobFileValue::set_filename(const char* value, size_t size) { - + filename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) } ::std::string* Value_BlobFileValue::mutable_filename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) return filename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* Value_BlobFileValue::release_filename() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) - + return filename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void Value_BlobFileValue::set_allocated_filename(::std::string* filename) { if (filename != NULL) { - + } else { - + } filename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), filename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) @@ -6826,7 +7158,7 @@ ::google::protobuf::uint64 Value_BlobFileValue::offset() const { return offset_; } void Value_BlobFileValue::set_offset(::google::protobuf::uint64 value) { - + offset_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Value.BlobFileValue.offset) } @@ -7187,13 +7519,13 @@ const ::std::string& Value::docstring() const { return docstring_.GetNoArena(); } void Value::set_docstring(const ::std::string& value) { - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Value.docString) } #if LANG_CXX11 void Value::set_docstring(::std::string&& value) { - + docstring_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.Value.docString) @@ -7201,31 +7533,31 @@ void Value::set_docstring(::std::string&& value) { #endif void Value::set_docstring(const char* value) { GOOGLE_DCHECK(value != NULL); - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.Value.docString) } void Value::set_docstring(const char* value, size_t size) { - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.Value.docString) } ::std::string* Value::mutable_docstring() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.Value.docString) return docstring_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* Value::release_docstring() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Value.docString) - + return docstring_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void Value::set_allocated_docstring(::std::string* docstring) { if (docstring != NULL) { - + } else { - + } docstring_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), docstring); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Value.docString) @@ -7245,7 +7577,7 @@ const ::CoreML::Specification::MILSpec::ValueType& Value::type() const { : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); } ::CoreML::Specification::MILSpec::ValueType* Value::mutable_type() { - + if (type_ == NULL) { type_ = new ::CoreML::Specification::MILSpec::ValueType; } @@ -7254,7 +7586,7 @@ ::CoreML::Specification::MILSpec::ValueType* Value::mutable_type() { } ::CoreML::Specification::MILSpec::ValueType* Value::release_type() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Value.type) - + ::CoreML::Specification::MILSpec::ValueType* temp = type_; type_ = NULL; return temp; @@ -7263,9 +7595,9 @@ void Value::set_allocated_type(::CoreML::Specification::MILSpec::ValueType* type delete type_; type_ = type; if (type) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Value.type) } @@ -8951,13 +9283,13 @@ const ::std::string& TensorValue_RepeatedBytes::values() const { return values_.GetNoArena(); } void TensorValue_RepeatedBytes::set_values(const ::std::string& value) { - + values_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) } #if LANG_CXX11 void TensorValue_RepeatedBytes::set_values(::std::string&& value) { - + values_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) @@ -8965,31 +9297,31 @@ void TensorValue_RepeatedBytes::set_values(::std::string&& value) { #endif void TensorValue_RepeatedBytes::set_values(const char* value) { GOOGLE_DCHECK(value != NULL); - + values_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) } void TensorValue_RepeatedBytes::set_values(const void* value, size_t size) { - + values_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) } ::std::string* TensorValue_RepeatedBytes::mutable_values() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) return values_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* TensorValue_RepeatedBytes::release_values() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) - + return values_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void TensorValue_RepeatedBytes::set_allocated_values(::std::string* values) { if (values != NULL) { - + } else { - + } values_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), values); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) @@ -10460,7 +10792,7 @@ const ::CoreML::Specification::MILSpec::Value& DictionaryValue_KeyValuePair::key : *::CoreML::Specification::MILSpec::Value::internal_default_instance(); } ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::mutable_key() { - + if (key_ == NULL) { key_ = new ::CoreML::Specification::MILSpec::Value; } @@ -10469,7 +10801,7 @@ ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::mutable_k } ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::release_key() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.key) - + ::CoreML::Specification::MILSpec::Value* temp = key_; key_ = NULL; return temp; @@ -10478,9 +10810,9 @@ void DictionaryValue_KeyValuePair::set_allocated_key(::CoreML::Specification::MI delete key_; key_ = key; if (key) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.key) } @@ -10499,7 +10831,7 @@ const ::CoreML::Specification::MILSpec::Value& DictionaryValue_KeyValuePair::val : *::CoreML::Specification::MILSpec::Value::internal_default_instance(); } ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::mutable_value() { - + if (value_ == NULL) { value_ = new ::CoreML::Specification::MILSpec::Value; } @@ -10508,7 +10840,7 @@ ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::mutable_v } ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::release_value() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.value) - + ::CoreML::Specification::MILSpec::Value* temp = value_; value_ = NULL; return temp; @@ -10517,9 +10849,9 @@ void DictionaryValue_KeyValuePair::set_allocated_value(::CoreML::Specification:: delete value_; value_ = value; if (value) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.value) } diff --git a/mlmodel/build/format/MIL.pb.h b/mlmodel/build/format/MIL.pb.h index 6776cba76..404d42278 100644 --- a/mlmodel/build/format/MIL.pb.h +++ b/mlmodel/build/format/MIL.pb.h @@ -101,6 +101,9 @@ extern Program_AttributesEntryDefaultTypeInternal _Program_AttributesEntry_defau class Program_FunctionsEntry; class Program_FunctionsEntryDefaultTypeInternal; extern Program_FunctionsEntryDefaultTypeInternal _Program_FunctionsEntry_default_instance_; +class StateType; +class StateTypeDefaultTypeInternal; +extern StateTypeDefaultTypeInternal _StateType_default_instance_; class TensorType; class TensorTypeDefaultTypeInternal; extern TensorTypeDefaultTypeInternal _TensorType_default_instance_; @@ -175,6 +178,8 @@ enum DataType { UNUSED_TYPE = 0, BOOL = 1, STRING = 2, + FLOAT8E4M3FN = 40, + FLOAT8E5M2 = 41, FLOAT16 = 10, FLOAT32 = 11, FLOAT64 = 12, @@ -183,16 +188,22 @@ enum DataType { INT16 = 22, INT32 = 23, INT64 = 24, + INT4 = 25, UINT8 = 31, UINT16 = 32, UINT32 = 33, UINT64 = 34, + UINT4 = 35, + UINT2 = 36, + UINT1 = 37, + UINT6 = 38, + UINT3 = 39, DataType_INT_MIN_SENTINEL_DO_NOT_USE_ = ::google::protobuf::kint32min, DataType_INT_MAX_SENTINEL_DO_NOT_USE_ = ::google::protobuf::kint32max }; bool DataType_IsValid(int value); const DataType DataType_MIN = UNUSED_TYPE; -const DataType DataType_MAX = UINT64; +const DataType DataType_MAX = FLOAT8E5M2; const int DataType_ARRAYSIZE = DataType_MAX + 1; // =================================================================== @@ -1125,6 +1136,7 @@ class ValueType : public ::google::protobuf::MessageLite /* @@protoc_insertion_p kListType = 2, kTupleType = 3, kDictionaryType = 4, + kStateType = 5, TYPE_NOT_SET = 0, }; @@ -1212,6 +1224,15 @@ class ValueType : public ::google::protobuf::MessageLite /* @@protoc_insertion_p ::CoreML::Specification::MILSpec::DictionaryType* release_dictionarytype(); void set_allocated_dictionarytype(::CoreML::Specification::MILSpec::DictionaryType* dictionarytype); + // .CoreML.Specification.MILSpec.StateType stateType = 5; + bool has_statetype() const; + void clear_statetype(); + static const int kStateTypeFieldNumber = 5; + const ::CoreML::Specification::MILSpec::StateType& statetype() const; + ::CoreML::Specification::MILSpec::StateType* mutable_statetype(); + ::CoreML::Specification::MILSpec::StateType* release_statetype(); + void set_allocated_statetype(::CoreML::Specification::MILSpec::StateType* statetype); + TypeCase type_case() const; // @@protoc_insertion_point(class_scope:CoreML.Specification.MILSpec.ValueType) private: @@ -1219,6 +1240,7 @@ class ValueType : public ::google::protobuf::MessageLite /* @@protoc_insertion_p void set_has_listtype(); void set_has_tupletype(); void set_has_dictionarytype(); + void set_has_statetype(); inline bool has_type() const; void clear_type(); @@ -1231,6 +1253,7 @@ class ValueType : public ::google::protobuf::MessageLite /* @@protoc_insertion_p ::CoreML::Specification::MILSpec::ListType* listtype_; ::CoreML::Specification::MILSpec::TupleType* tupletype_; ::CoreML::Specification::MILSpec::DictionaryType* dictionarytype_; + ::CoreML::Specification::MILSpec::StateType* statetype_; } type_; mutable int _cached_size_; ::google::protobuf::uint32 _oneof_case_[1]; @@ -1630,6 +1653,87 @@ class DictionaryType : public ::google::protobuf::MessageLite /* @@protoc_insert }; // ------------------------------------------------------------------- +class StateType : public ::google::protobuf::MessageLite /* @@protoc_insertion_point(class_definition:CoreML.Specification.MILSpec.StateType) */ { + public: + StateType(); + virtual ~StateType(); + + StateType(const StateType& from); + + inline StateType& operator=(const StateType& from) { + CopyFrom(from); + return *this; + } + + static const StateType& default_instance(); + + static inline const StateType* internal_default_instance() { + return reinterpret_cast( + &_StateType_default_instance_); + } + static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = + 20; + + void Swap(StateType* other); + + // implements Message ---------------------------------------------- + + inline StateType* New() const PROTOBUF_FINAL { return New(NULL); } + + StateType* New(::google::protobuf::Arena* arena) const PROTOBUF_FINAL; + void CheckTypeAndMergeFrom(const ::google::protobuf::MessageLite& from) + PROTOBUF_FINAL; + void CopyFrom(const StateType& from); + void MergeFrom(const StateType& from); + void Clear() PROTOBUF_FINAL; + bool IsInitialized() const PROTOBUF_FINAL; + + size_t ByteSizeLong() const PROTOBUF_FINAL; + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) PROTOBUF_FINAL; + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const PROTOBUF_FINAL; + void DiscardUnknownFields(); + int GetCachedSize() const PROTOBUF_FINAL { return _cached_size_; } + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const; + void InternalSwap(StateType* other); + private: + inline ::google::protobuf::Arena* GetArenaNoVirtual() const { + return NULL; + } + inline void* MaybeArenaPtr() const { + return NULL; + } + public: + + ::std::string GetTypeName() const PROTOBUF_FINAL; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // .CoreML.Specification.MILSpec.ValueType wrappedType = 1; + bool has_wrappedtype() const; + void clear_wrappedtype(); + static const int kWrappedTypeFieldNumber = 1; + const ::CoreML::Specification::MILSpec::ValueType& wrappedtype() const; + ::CoreML::Specification::MILSpec::ValueType* mutable_wrappedtype(); + ::CoreML::Specification::MILSpec::ValueType* release_wrappedtype(); + void set_allocated_wrappedtype(::CoreML::Specification::MILSpec::ValueType* wrappedtype); + + // @@protoc_insertion_point(class_scope:CoreML.Specification.MILSpec.StateType) + private: + + ::google::protobuf::internal::InternalMetadataWithArenaLite _internal_metadata_; + ::CoreML::Specification::MILSpec::ValueType* wrappedtype_; + mutable int _cached_size_; + friend struct protobuf_MIL_2eproto::TableStruct; +}; +// ------------------------------------------------------------------- + class Dimension_ConstantDimension : public ::google::protobuf::MessageLite /* @@protoc_insertion_point(class_definition:CoreML.Specification.MILSpec.Dimension.ConstantDimension) */ { public: Dimension_ConstantDimension(); @@ -1649,7 +1753,7 @@ class Dimension_ConstantDimension : public ::google::protobuf::MessageLite /* @@ &_Dimension_ConstantDimension_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 20; + 21; void Swap(Dimension_ConstantDimension* other); @@ -1727,7 +1831,7 @@ class Dimension_UnknownDimension : public ::google::protobuf::MessageLite /* @@p &_Dimension_UnknownDimension_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 21; + 22; void Swap(Dimension_UnknownDimension* other); @@ -1811,7 +1915,7 @@ class Dimension : public ::google::protobuf::MessageLite /* @@protoc_insertion_p &_Dimension_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 22; + 23; void Swap(Dimension* other); @@ -1925,7 +2029,7 @@ class Value_ImmediateValue : public ::google::protobuf::MessageLite /* @@protoc_ &_Value_ImmediateValue_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 23; + 24; void Swap(Value_ImmediateValue* other); @@ -2050,7 +2154,7 @@ class Value_BlobFileValue : public ::google::protobuf::MessageLite /* @@protoc_i &_Value_BlobFileValue_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 24; + 25; void Swap(Value_BlobFileValue* other); @@ -2149,7 +2253,7 @@ class Value : public ::google::protobuf::MessageLite /* @@protoc_insertion_point &_Value_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 25; + 26; void Swap(Value* other); @@ -2280,7 +2384,7 @@ class TensorValue_RepeatedFloats : public ::google::protobuf::MessageLite /* @@p &_TensorValue_RepeatedFloats_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 26; + 27; void Swap(TensorValue_RepeatedFloats* other); @@ -2365,7 +2469,7 @@ class TensorValue_RepeatedDoubles : public ::google::protobuf::MessageLite /* @@ &_TensorValue_RepeatedDoubles_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 27; + 28; void Swap(TensorValue_RepeatedDoubles* other); @@ -2450,7 +2554,7 @@ class TensorValue_RepeatedInts : public ::google::protobuf::MessageLite /* @@pro &_TensorValue_RepeatedInts_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 28; + 29; void Swap(TensorValue_RepeatedInts* other); @@ -2535,7 +2639,7 @@ class TensorValue_RepeatedLongInts : public ::google::protobuf::MessageLite /* @ &_TensorValue_RepeatedLongInts_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 29; + 30; void Swap(TensorValue_RepeatedLongInts* other); @@ -2620,7 +2724,7 @@ class TensorValue_RepeatedBools : public ::google::protobuf::MessageLite /* @@pr &_TensorValue_RepeatedBools_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 30; + 31; void Swap(TensorValue_RepeatedBools* other); @@ -2705,7 +2809,7 @@ class TensorValue_RepeatedStrings : public ::google::protobuf::MessageLite /* @@ &_TensorValue_RepeatedStrings_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 31; + 32; void Swap(TensorValue_RepeatedStrings* other); @@ -2799,7 +2903,7 @@ class TensorValue_RepeatedBytes : public ::google::protobuf::MessageLite /* @@pr &_TensorValue_RepeatedBytes_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 32; + 33; void Swap(TensorValue_RepeatedBytes* other); @@ -2896,7 +3000,7 @@ class TensorValue : public ::google::protobuf::MessageLite /* @@protoc_insertion &_TensorValue_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 33; + 34; void Swap(TensorValue* other); @@ -3062,7 +3166,7 @@ class TupleValue : public ::google::protobuf::MessageLite /* @@protoc_insertion_ &_TupleValue_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 34; + 35; void Swap(TupleValue* other); @@ -3146,7 +3250,7 @@ class ListValue : public ::google::protobuf::MessageLite /* @@protoc_insertion_p &_ListValue_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 35; + 36; void Swap(ListValue* other); @@ -3230,7 +3334,7 @@ class DictionaryValue_KeyValuePair : public ::google::protobuf::MessageLite /* @ &_DictionaryValue_KeyValuePair_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 36; + 37; void Swap(DictionaryValue_KeyValuePair* other); @@ -3321,7 +3425,7 @@ class DictionaryValue : public ::google::protobuf::MessageLite /* @@protoc_inser &_DictionaryValue_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 37; + 38; void Swap(DictionaryValue* other); @@ -3407,7 +3511,7 @@ inline ::google::protobuf::int64 Program::version() const { return version_; } inline void Program::set_version(::google::protobuf::int64 value) { - + version_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Program.version) } @@ -3439,13 +3543,13 @@ inline const ::std::string& Program::docstring() const { return docstring_.GetNoArena(); } inline void Program::set_docstring(const ::std::string& value) { - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Program.docString) } #if LANG_CXX11 inline void Program::set_docstring(::std::string&& value) { - + docstring_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.Program.docString) @@ -3453,31 +3557,31 @@ inline void Program::set_docstring(::std::string&& value) { #endif inline void Program::set_docstring(const char* value) { GOOGLE_DCHECK(value != NULL); - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.Program.docString) } inline void Program::set_docstring(const char* value, size_t size) { - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.Program.docString) } inline ::std::string* Program::mutable_docstring() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.Program.docString) return docstring_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Program::release_docstring() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Program.docString) - + return docstring_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Program::set_allocated_docstring(::std::string* docstring) { if (docstring != NULL) { - + } else { - + } docstring_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), docstring); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Program.docString) @@ -3548,13 +3652,13 @@ inline const ::std::string& Function::opset() const { return opset_.GetNoArena(); } inline void Function::set_opset(const ::std::string& value) { - + opset_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Function.opset) } #if LANG_CXX11 inline void Function::set_opset(::std::string&& value) { - + opset_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.Function.opset) @@ -3562,31 +3666,31 @@ inline void Function::set_opset(::std::string&& value) { #endif inline void Function::set_opset(const char* value) { GOOGLE_DCHECK(value != NULL); - + opset_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.Function.opset) } inline void Function::set_opset(const char* value, size_t size) { - + opset_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.Function.opset) } inline ::std::string* Function::mutable_opset() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.Function.opset) return opset_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Function::release_opset() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Function.opset) - + return opset_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Function::set_allocated_opset(::std::string* opset) { if (opset != NULL) { - + } else { - + } opset_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), opset); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Function.opset) @@ -3988,13 +4092,13 @@ inline const ::std::string& Operation::type() const { return type_.GetNoArena(); } inline void Operation::set_type(const ::std::string& value) { - + type_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Operation.type) } #if LANG_CXX11 inline void Operation::set_type(::std::string&& value) { - + type_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.Operation.type) @@ -4002,31 +4106,31 @@ inline void Operation::set_type(::std::string&& value) { #endif inline void Operation::set_type(const char* value) { GOOGLE_DCHECK(value != NULL); - + type_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.Operation.type) } inline void Operation::set_type(const char* value, size_t size) { - + type_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.Operation.type) } inline ::std::string* Operation::mutable_type() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.Operation.type) return type_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Operation::release_type() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Operation.type) - + return type_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Operation::set_allocated_type(::std::string* type) { if (type != NULL) { - + } else { - + } type_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), type); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Operation.type) @@ -4141,13 +4245,13 @@ inline const ::std::string& NamedValueType::name() const { return name_.GetNoArena(); } inline void NamedValueType::set_name(const ::std::string& value) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.NamedValueType.name) } #if LANG_CXX11 inline void NamedValueType::set_name(::std::string&& value) { - + name_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.NamedValueType.name) @@ -4155,31 +4259,31 @@ inline void NamedValueType::set_name(::std::string&& value) { #endif inline void NamedValueType::set_name(const char* value) { GOOGLE_DCHECK(value != NULL); - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.NamedValueType.name) } inline void NamedValueType::set_name(const char* value, size_t size) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.NamedValueType.name) } inline ::std::string* NamedValueType::mutable_name() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.NamedValueType.name) return name_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* NamedValueType::release_name() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.NamedValueType.name) - + return name_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void NamedValueType::set_allocated_name(::std::string* name) { if (name != NULL) { - + } else { - + } name_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), name); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.NamedValueType.name) @@ -4199,7 +4303,7 @@ inline const ::CoreML::Specification::MILSpec::ValueType& NamedValueType::type() : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); } inline ::CoreML::Specification::MILSpec::ValueType* NamedValueType::mutable_type() { - + if (type_ == NULL) { type_ = new ::CoreML::Specification::MILSpec::ValueType; } @@ -4208,7 +4312,7 @@ inline ::CoreML::Specification::MILSpec::ValueType* NamedValueType::mutable_type } inline ::CoreML::Specification::MILSpec::ValueType* NamedValueType::release_type() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.NamedValueType.type) - + ::CoreML::Specification::MILSpec::ValueType* temp = type_; type_ = NULL; return temp; @@ -4217,9 +4321,9 @@ inline void NamedValueType::set_allocated_type(::CoreML::Specification::MILSpec: delete type_; type_ = type; if (type) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.NamedValueType.type) } @@ -4420,6 +4524,54 @@ inline void ValueType::set_allocated_dictionarytype(::CoreML::Specification::MIL // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.ValueType.dictionaryType) } +// .CoreML.Specification.MILSpec.StateType stateType = 5; +inline bool ValueType::has_statetype() const { + return type_case() == kStateType; +} +inline void ValueType::set_has_statetype() { + _oneof_case_[0] = kStateType; +} +inline void ValueType::clear_statetype() { + if (has_statetype()) { + delete type_.statetype_; + clear_has_type(); + } +} +inline const ::CoreML::Specification::MILSpec::StateType& ValueType::statetype() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.MILSpec.ValueType.stateType) + return has_statetype() + ? *type_.statetype_ + : ::CoreML::Specification::MILSpec::StateType::default_instance(); +} +inline ::CoreML::Specification::MILSpec::StateType* ValueType::mutable_statetype() { + if (!has_statetype()) { + clear_type(); + set_has_statetype(); + type_.statetype_ = new ::CoreML::Specification::MILSpec::StateType; + } + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.ValueType.stateType) + return type_.statetype_; +} +inline ::CoreML::Specification::MILSpec::StateType* ValueType::release_statetype() { + // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.ValueType.stateType) + if (has_statetype()) { + clear_has_type(); + ::CoreML::Specification::MILSpec::StateType* temp = type_.statetype_; + type_.statetype_ = NULL; + return temp; + } else { + return NULL; + } +} +inline void ValueType::set_allocated_statetype(::CoreML::Specification::MILSpec::StateType* statetype) { + clear_type(); + if (statetype) { + set_has_statetype(); + type_.statetype_ = statetype; + } + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.ValueType.stateType) +} + inline bool ValueType::has_type() const { return type_case() != TYPE_NOT_SET; } @@ -4444,7 +4596,7 @@ inline ::CoreML::Specification::MILSpec::DataType TensorType::datatype() const { return static_cast< ::CoreML::Specification::MILSpec::DataType >(datatype_); } inline void TensorType::set_datatype(::CoreML::Specification::MILSpec::DataType value) { - + datatype_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.TensorType.dataType) } @@ -4458,7 +4610,7 @@ inline ::google::protobuf::int64 TensorType::rank() const { return rank_; } inline void TensorType::set_rank(::google::protobuf::int64 value) { - + rank_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.TensorType.rank) } @@ -4563,7 +4715,7 @@ inline const ::CoreML::Specification::MILSpec::ValueType& ListType::type() const : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); } inline ::CoreML::Specification::MILSpec::ValueType* ListType::mutable_type() { - + if (type_ == NULL) { type_ = new ::CoreML::Specification::MILSpec::ValueType; } @@ -4572,7 +4724,7 @@ inline ::CoreML::Specification::MILSpec::ValueType* ListType::mutable_type() { } inline ::CoreML::Specification::MILSpec::ValueType* ListType::release_type() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.ListType.type) - + ::CoreML::Specification::MILSpec::ValueType* temp = type_; type_ = NULL; return temp; @@ -4581,9 +4733,9 @@ inline void ListType::set_allocated_type(::CoreML::Specification::MILSpec::Value delete type_; type_ = type; if (type) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.ListType.type) } @@ -4602,7 +4754,7 @@ inline const ::CoreML::Specification::MILSpec::Dimension& ListType::length() con : *::CoreML::Specification::MILSpec::Dimension::internal_default_instance(); } inline ::CoreML::Specification::MILSpec::Dimension* ListType::mutable_length() { - + if (length_ == NULL) { length_ = new ::CoreML::Specification::MILSpec::Dimension; } @@ -4611,7 +4763,7 @@ inline ::CoreML::Specification::MILSpec::Dimension* ListType::mutable_length() { } inline ::CoreML::Specification::MILSpec::Dimension* ListType::release_length() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.ListType.length) - + ::CoreML::Specification::MILSpec::Dimension* temp = length_; length_ = NULL; return temp; @@ -4620,9 +4772,9 @@ inline void ListType::set_allocated_length(::CoreML::Specification::MILSpec::Dim delete length_; length_ = length; if (length) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.ListType.length) } @@ -4645,7 +4797,7 @@ inline const ::CoreML::Specification::MILSpec::ValueType& DictionaryType::keytyp : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); } inline ::CoreML::Specification::MILSpec::ValueType* DictionaryType::mutable_keytype() { - + if (keytype_ == NULL) { keytype_ = new ::CoreML::Specification::MILSpec::ValueType; } @@ -4654,7 +4806,7 @@ inline ::CoreML::Specification::MILSpec::ValueType* DictionaryType::mutable_keyt } inline ::CoreML::Specification::MILSpec::ValueType* DictionaryType::release_keytype() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.DictionaryType.keyType) - + ::CoreML::Specification::MILSpec::ValueType* temp = keytype_; keytype_ = NULL; return temp; @@ -4663,9 +4815,9 @@ inline void DictionaryType::set_allocated_keytype(::CoreML::Specification::MILSp delete keytype_; keytype_ = keytype; if (keytype) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.DictionaryType.keyType) } @@ -4684,7 +4836,7 @@ inline const ::CoreML::Specification::MILSpec::ValueType& DictionaryType::valuet : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); } inline ::CoreML::Specification::MILSpec::ValueType* DictionaryType::mutable_valuetype() { - + if (valuetype_ == NULL) { valuetype_ = new ::CoreML::Specification::MILSpec::ValueType; } @@ -4693,7 +4845,7 @@ inline ::CoreML::Specification::MILSpec::ValueType* DictionaryType::mutable_valu } inline ::CoreML::Specification::MILSpec::ValueType* DictionaryType::release_valuetype() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.DictionaryType.valueType) - + ::CoreML::Specification::MILSpec::ValueType* temp = valuetype_; valuetype_ = NULL; return temp; @@ -4702,15 +4854,58 @@ inline void DictionaryType::set_allocated_valuetype(::CoreML::Specification::MIL delete valuetype_; valuetype_ = valuetype; if (valuetype) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.DictionaryType.valueType) } // ------------------------------------------------------------------- +// StateType + +// .CoreML.Specification.MILSpec.ValueType wrappedType = 1; +inline bool StateType::has_wrappedtype() const { + return this != internal_default_instance() && wrappedtype_ != NULL; +} +inline void StateType::clear_wrappedtype() { + if (GetArenaNoVirtual() == NULL && wrappedtype_ != NULL) delete wrappedtype_; + wrappedtype_ = NULL; +} +inline const ::CoreML::Specification::MILSpec::ValueType& StateType::wrappedtype() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.MILSpec.StateType.wrappedType) + return wrappedtype_ != NULL ? *wrappedtype_ + : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); +} +inline ::CoreML::Specification::MILSpec::ValueType* StateType::mutable_wrappedtype() { + + if (wrappedtype_ == NULL) { + wrappedtype_ = new ::CoreML::Specification::MILSpec::ValueType; + } + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.StateType.wrappedType) + return wrappedtype_; +} +inline ::CoreML::Specification::MILSpec::ValueType* StateType::release_wrappedtype() { + // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.StateType.wrappedType) + + ::CoreML::Specification::MILSpec::ValueType* temp = wrappedtype_; + wrappedtype_ = NULL; + return temp; +} +inline void StateType::set_allocated_wrappedtype(::CoreML::Specification::MILSpec::ValueType* wrappedtype) { + delete wrappedtype_; + wrappedtype_ = wrappedtype; + if (wrappedtype) { + + } else { + + } + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.StateType.wrappedType) +} + +// ------------------------------------------------------------------- + // Dimension_ConstantDimension // uint64 size = 1; @@ -4722,7 +4917,7 @@ inline ::google::protobuf::uint64 Dimension_ConstantDimension::size() const { return size_; } inline void Dimension_ConstantDimension::set_size(::google::protobuf::uint64 value) { - + size_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Dimension.ConstantDimension.size) } @@ -4740,7 +4935,7 @@ inline bool Dimension_UnknownDimension::variadic() const { return variadic_; } inline void Dimension_UnknownDimension::set_variadic(bool value) { - + variadic_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Dimension.UnknownDimension.variadic) } @@ -5072,13 +5267,13 @@ inline const ::std::string& Value_BlobFileValue::filename() const { return filename_.GetNoArena(); } inline void Value_BlobFileValue::set_filename(const ::std::string& value) { - + filename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) } #if LANG_CXX11 inline void Value_BlobFileValue::set_filename(::std::string&& value) { - + filename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) @@ -5086,31 +5281,31 @@ inline void Value_BlobFileValue::set_filename(::std::string&& value) { #endif inline void Value_BlobFileValue::set_filename(const char* value) { GOOGLE_DCHECK(value != NULL); - + filename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) } inline void Value_BlobFileValue::set_filename(const char* value, size_t size) { - + filename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) } inline ::std::string* Value_BlobFileValue::mutable_filename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) return filename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Value_BlobFileValue::release_filename() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) - + return filename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Value_BlobFileValue::set_allocated_filename(::std::string* filename) { if (filename != NULL) { - + } else { - + } filename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), filename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Value.BlobFileValue.fileName) @@ -5125,7 +5320,7 @@ inline ::google::protobuf::uint64 Value_BlobFileValue::offset() const { return offset_; } inline void Value_BlobFileValue::set_offset(::google::protobuf::uint64 value) { - + offset_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Value.BlobFileValue.offset) } @@ -5143,13 +5338,13 @@ inline const ::std::string& Value::docstring() const { return docstring_.GetNoArena(); } inline void Value::set_docstring(const ::std::string& value) { - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.Value.docString) } #if LANG_CXX11 inline void Value::set_docstring(::std::string&& value) { - + docstring_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.Value.docString) @@ -5157,31 +5352,31 @@ inline void Value::set_docstring(::std::string&& value) { #endif inline void Value::set_docstring(const char* value) { GOOGLE_DCHECK(value != NULL); - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.Value.docString) } inline void Value::set_docstring(const char* value, size_t size) { - + docstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.Value.docString) } inline ::std::string* Value::mutable_docstring() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.Value.docString) return docstring_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Value::release_docstring() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Value.docString) - + return docstring_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Value::set_allocated_docstring(::std::string* docstring) { if (docstring != NULL) { - + } else { - + } docstring_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), docstring); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Value.docString) @@ -5201,7 +5396,7 @@ inline const ::CoreML::Specification::MILSpec::ValueType& Value::type() const { : *::CoreML::Specification::MILSpec::ValueType::internal_default_instance(); } inline ::CoreML::Specification::MILSpec::ValueType* Value::mutable_type() { - + if (type_ == NULL) { type_ = new ::CoreML::Specification::MILSpec::ValueType; } @@ -5210,7 +5405,7 @@ inline ::CoreML::Specification::MILSpec::ValueType* Value::mutable_type() { } inline ::CoreML::Specification::MILSpec::ValueType* Value::release_type() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.Value.type) - + ::CoreML::Specification::MILSpec::ValueType* temp = type_; type_ = NULL; return temp; @@ -5219,9 +5414,9 @@ inline void Value::set_allocated_type(::CoreML::Specification::MILSpec::ValueTyp delete type_; type_ = type; if (type) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.Value.type) } @@ -5587,13 +5782,13 @@ inline const ::std::string& TensorValue_RepeatedBytes::values() const { return values_.GetNoArena(); } inline void TensorValue_RepeatedBytes::set_values(const ::std::string& value) { - + values_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) } #if LANG_CXX11 inline void TensorValue_RepeatedBytes::set_values(::std::string&& value) { - + values_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) @@ -5601,31 +5796,31 @@ inline void TensorValue_RepeatedBytes::set_values(::std::string&& value) { #endif inline void TensorValue_RepeatedBytes::set_values(const char* value) { GOOGLE_DCHECK(value != NULL); - + values_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) } inline void TensorValue_RepeatedBytes::set_values(const void* value, size_t size) { - + values_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) } inline ::std::string* TensorValue_RepeatedBytes::mutable_values() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) return values_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* TensorValue_RepeatedBytes::release_values() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) - + return values_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void TensorValue_RepeatedBytes::set_allocated_values(::std::string* values) { if (values != NULL) { - + } else { - + } values_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), values); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.TensorValue.RepeatedBytes.values) @@ -6066,7 +6261,7 @@ inline const ::CoreML::Specification::MILSpec::Value& DictionaryValue_KeyValuePa : *::CoreML::Specification::MILSpec::Value::internal_default_instance(); } inline ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::mutable_key() { - + if (key_ == NULL) { key_ = new ::CoreML::Specification::MILSpec::Value; } @@ -6075,7 +6270,7 @@ inline ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::mu } inline ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::release_key() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.key) - + ::CoreML::Specification::MILSpec::Value* temp = key_; key_ = NULL; return temp; @@ -6084,9 +6279,9 @@ inline void DictionaryValue_KeyValuePair::set_allocated_key(::CoreML::Specificat delete key_; key_ = key; if (key) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.key) } @@ -6105,7 +6300,7 @@ inline const ::CoreML::Specification::MILSpec::Value& DictionaryValue_KeyValuePa : *::CoreML::Specification::MILSpec::Value::internal_default_instance(); } inline ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::mutable_value() { - + if (value_ == NULL) { value_ = new ::CoreML::Specification::MILSpec::Value; } @@ -6114,7 +6309,7 @@ inline ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::mu } inline ::CoreML::Specification::MILSpec::Value* DictionaryValue_KeyValuePair::release_value() { // @@protoc_insertion_point(field_release:CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.value) - + ::CoreML::Specification::MILSpec::Value* temp = value_; value_ = NULL; return temp; @@ -6123,9 +6318,9 @@ inline void DictionaryValue_KeyValuePair::set_allocated_value(::CoreML::Specific delete value_; value_ = value; if (value) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MILSpec.DictionaryValue.KeyValuePair.value) } @@ -6239,6 +6434,8 @@ DictionaryValue::values() const { // ------------------------------------------------------------------- +// ------------------------------------------------------------------- + // @@protoc_insertion_point(namespace_scope) diff --git a/mlmodel/build/format/MIL_enums.h b/mlmodel/build/format/MIL_enums.h index 911353d4f..71166c5c6 100644 --- a/mlmodel/build/format/MIL_enums.h +++ b/mlmodel/build/format/MIL_enums.h @@ -4,6 +4,8 @@ enum MLDataType: int { MLDataTypeUNUSED_TYPE = 0, MLDataTypeBOOL = 1, MLDataTypeSTRING = 2, + MLDataTypeFLOAT8E4M3FN = 40, + MLDataTypeFLOAT8E5M2 = 41, MLDataTypeFLOAT16 = 10, MLDataTypeFLOAT32 = 11, MLDataTypeFLOAT64 = 12, @@ -12,10 +14,16 @@ enum MLDataType: int { MLDataTypeINT16 = 22, MLDataTypeINT32 = 23, MLDataTypeINT64 = 24, + MLDataTypeINT4 = 25, MLDataTypeUINT8 = 31, MLDataTypeUINT16 = 32, MLDataTypeUINT32 = 33, MLDataTypeUINT64 = 34, + MLDataTypeUINT4 = 35, + MLDataTypeUINT2 = 36, + MLDataTypeUINT1 = 37, + MLDataTypeUINT6 = 38, + MLDataTypeUINT3 = 39, }; enum MLBindingbinding: int { @@ -42,6 +50,7 @@ enum MLValueTypetype: int { MLValueTypetype_listType = 2, MLValueTypetype_tupleType = 3, MLValueTypetype_dictionaryType = 4, + MLValueTypetype_stateType = 5, MLValueTypetype_NOT_SET = 0, }; @@ -56,6 +65,8 @@ static const char * MLValueTypetype_Name(MLValueTypetype x) { return "MLValueTypetype_tupleType"; case MLValueTypetype_dictionaryType: return "MLValueTypetype_dictionaryType"; + case MLValueTypetype_stateType: + return "MLValueTypetype_stateType"; case MLValueTypetype_NOT_SET: return "INVALID"; } diff --git a/mlmodel/build/format/Model.pb.cc b/mlmodel/build/format/Model.pb.cc index 8ab5c09de..a68528b35 100644 --- a/mlmodel/build/format/Model.pb.cc +++ b/mlmodel/build/format/Model.pb.cc @@ -28,6 +28,8 @@ class Metadata_UserDefinedEntryDefaultTypeInternal : public ::google::protobuf:: } _Metadata_UserDefinedEntry_default_instance_; class MetadataDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { } _Metadata_default_instance_; +class FunctionDescriptionDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { +} _FunctionDescription_default_instance_; class ModelDescriptionDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { } _ModelDescription_default_instance_; class SerializedModelDefaultTypeInternal : public ::google::protobuf::internal::ExplicitlyConstructed { @@ -95,6 +97,7 @@ PROTOBUF_CONSTEXPR_VAR ::google::protobuf::internal::ParseTable const { NULL, NULL, 0, -1, -1, false }, { NULL, NULL, 0, -1, -1, false }, { NULL, NULL, 0, -1, -1, false }, + { NULL, NULL, 0, -1, -1, false }, }; @@ -104,6 +107,7 @@ void TableStruct::Shutdown() { _PipelineRegressor_default_instance_.Shutdown(); _FeatureDescription_default_instance_.Shutdown(); _Metadata_default_instance_.Shutdown(); + _FunctionDescription_default_instance_.Shutdown(); _ModelDescription_default_instance_.Shutdown(); _SerializedModel_default_instance_.Shutdown(); _Model_default_instance_.Shutdown(); @@ -150,6 +154,7 @@ void TableStruct::InitDefaultsImpl() { _FeatureDescription_default_instance_.DefaultConstruct(); _Metadata_UserDefinedEntry_default_instance_.DefaultConstruct(); _Metadata_default_instance_.DefaultConstruct(); + _FunctionDescription_default_instance_.DefaultConstruct(); _ModelDescription_default_instance_.DefaultConstruct(); _SerializedModel_default_instance_.DefaultConstruct(); _Model_default_instance_.DefaultConstruct(); @@ -747,7 +752,7 @@ const ::CoreML::Specification::Pipeline& PipelineClassifier::pipeline() const { : *::CoreML::Specification::Pipeline::internal_default_instance(); } ::CoreML::Specification::Pipeline* PipelineClassifier::mutable_pipeline() { - + if (pipeline_ == NULL) { pipeline_ = new ::CoreML::Specification::Pipeline; } @@ -756,7 +761,7 @@ ::CoreML::Specification::Pipeline* PipelineClassifier::mutable_pipeline() { } ::CoreML::Specification::Pipeline* PipelineClassifier::release_pipeline() { // @@protoc_insertion_point(field_release:CoreML.Specification.PipelineClassifier.pipeline) - + ::CoreML::Specification::Pipeline* temp = pipeline_; pipeline_ = NULL; return temp; @@ -765,9 +770,9 @@ void PipelineClassifier::set_allocated_pipeline(::CoreML::Specification::Pipelin delete pipeline_; pipeline_ = pipeline; if (pipeline) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.PipelineClassifier.pipeline) } @@ -977,7 +982,7 @@ const ::CoreML::Specification::Pipeline& PipelineRegressor::pipeline() const { : *::CoreML::Specification::Pipeline::internal_default_instance(); } ::CoreML::Specification::Pipeline* PipelineRegressor::mutable_pipeline() { - + if (pipeline_ == NULL) { pipeline_ = new ::CoreML::Specification::Pipeline; } @@ -986,7 +991,7 @@ ::CoreML::Specification::Pipeline* PipelineRegressor::mutable_pipeline() { } ::CoreML::Specification::Pipeline* PipelineRegressor::release_pipeline() { // @@protoc_insertion_point(field_release:CoreML.Specification.PipelineRegressor.pipeline) - + ::CoreML::Specification::Pipeline* temp = pipeline_; pipeline_ = NULL; return temp; @@ -995,9 +1000,9 @@ void PipelineRegressor::set_allocated_pipeline(::CoreML::Specification::Pipeline delete pipeline_; pipeline_ = pipeline; if (pipeline) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.PipelineRegressor.pipeline) } @@ -1294,13 +1299,13 @@ const ::std::string& FeatureDescription::name() const { return name_.GetNoArena(); } void FeatureDescription::set_name(const ::std::string& value) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.FeatureDescription.name) } #if LANG_CXX11 void FeatureDescription::set_name(::std::string&& value) { - + name_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.FeatureDescription.name) @@ -1308,31 +1313,31 @@ void FeatureDescription::set_name(::std::string&& value) { #endif void FeatureDescription::set_name(const char* value) { GOOGLE_DCHECK(value != NULL); - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.FeatureDescription.name) } void FeatureDescription::set_name(const char* value, size_t size) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.FeatureDescription.name) } ::std::string* FeatureDescription::mutable_name() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FeatureDescription.name) return name_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* FeatureDescription::release_name() { // @@protoc_insertion_point(field_release:CoreML.Specification.FeatureDescription.name) - + return name_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void FeatureDescription::set_allocated_name(::std::string* name) { if (name != NULL) { - + } else { - + } name_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), name); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FeatureDescription.name) @@ -1347,13 +1352,13 @@ const ::std::string& FeatureDescription::shortdescription() const { return shortdescription_.GetNoArena(); } void FeatureDescription::set_shortdescription(const ::std::string& value) { - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.FeatureDescription.shortDescription) } #if LANG_CXX11 void FeatureDescription::set_shortdescription(::std::string&& value) { - + shortdescription_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.FeatureDescription.shortDescription) @@ -1361,31 +1366,31 @@ void FeatureDescription::set_shortdescription(::std::string&& value) { #endif void FeatureDescription::set_shortdescription(const char* value) { GOOGLE_DCHECK(value != NULL); - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.FeatureDescription.shortDescription) } void FeatureDescription::set_shortdescription(const char* value, size_t size) { - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.FeatureDescription.shortDescription) } ::std::string* FeatureDescription::mutable_shortdescription() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FeatureDescription.shortDescription) return shortdescription_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* FeatureDescription::release_shortdescription() { // @@protoc_insertion_point(field_release:CoreML.Specification.FeatureDescription.shortDescription) - + return shortdescription_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void FeatureDescription::set_allocated_shortdescription(::std::string* shortdescription) { if (shortdescription != NULL) { - + } else { - + } shortdescription_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), shortdescription); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FeatureDescription.shortDescription) @@ -1405,7 +1410,7 @@ const ::CoreML::Specification::FeatureType& FeatureDescription::type() const { : *::CoreML::Specification::FeatureType::internal_default_instance(); } ::CoreML::Specification::FeatureType* FeatureDescription::mutable_type() { - + if (type_ == NULL) { type_ = new ::CoreML::Specification::FeatureType; } @@ -1414,7 +1419,7 @@ ::CoreML::Specification::FeatureType* FeatureDescription::mutable_type() { } ::CoreML::Specification::FeatureType* FeatureDescription::release_type() { // @@protoc_insertion_point(field_release:CoreML.Specification.FeatureDescription.type) - + ::CoreML::Specification::FeatureType* temp = type_; type_ = NULL; return temp; @@ -1423,9 +1428,9 @@ void FeatureDescription::set_allocated_type(::CoreML::Specification::FeatureType delete type_; type_ = type; if (type) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FeatureDescription.type) } @@ -1875,13 +1880,13 @@ const ::std::string& Metadata::shortdescription() const { return shortdescription_.GetNoArena(); } void Metadata::set_shortdescription(const ::std::string& value) { - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.Metadata.shortDescription) } #if LANG_CXX11 void Metadata::set_shortdescription(::std::string&& value) { - + shortdescription_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.Metadata.shortDescription) @@ -1889,31 +1894,31 @@ void Metadata::set_shortdescription(::std::string&& value) { #endif void Metadata::set_shortdescription(const char* value) { GOOGLE_DCHECK(value != NULL); - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.Metadata.shortDescription) } void Metadata::set_shortdescription(const char* value, size_t size) { - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.Metadata.shortDescription) } ::std::string* Metadata::mutable_shortdescription() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.Metadata.shortDescription) return shortdescription_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* Metadata::release_shortdescription() { // @@protoc_insertion_point(field_release:CoreML.Specification.Metadata.shortDescription) - + return shortdescription_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void Metadata::set_allocated_shortdescription(::std::string* shortdescription) { if (shortdescription != NULL) { - + } else { - + } shortdescription_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), shortdescription); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Metadata.shortDescription) @@ -1928,13 +1933,13 @@ const ::std::string& Metadata::versionstring() const { return versionstring_.GetNoArena(); } void Metadata::set_versionstring(const ::std::string& value) { - + versionstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.Metadata.versionString) } #if LANG_CXX11 void Metadata::set_versionstring(::std::string&& value) { - + versionstring_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.Metadata.versionString) @@ -1942,31 +1947,31 @@ void Metadata::set_versionstring(::std::string&& value) { #endif void Metadata::set_versionstring(const char* value) { GOOGLE_DCHECK(value != NULL); - + versionstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.Metadata.versionString) } void Metadata::set_versionstring(const char* value, size_t size) { - + versionstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.Metadata.versionString) } ::std::string* Metadata::mutable_versionstring() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.Metadata.versionString) return versionstring_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* Metadata::release_versionstring() { // @@protoc_insertion_point(field_release:CoreML.Specification.Metadata.versionString) - + return versionstring_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void Metadata::set_allocated_versionstring(::std::string* versionstring) { if (versionstring != NULL) { - + } else { - + } versionstring_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), versionstring); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Metadata.versionString) @@ -1981,13 +1986,13 @@ const ::std::string& Metadata::author() const { return author_.GetNoArena(); } void Metadata::set_author(const ::std::string& value) { - + author_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.Metadata.author) } #if LANG_CXX11 void Metadata::set_author(::std::string&& value) { - + author_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.Metadata.author) @@ -1995,31 +2000,31 @@ void Metadata::set_author(::std::string&& value) { #endif void Metadata::set_author(const char* value) { GOOGLE_DCHECK(value != NULL); - + author_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.Metadata.author) } void Metadata::set_author(const char* value, size_t size) { - + author_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.Metadata.author) } ::std::string* Metadata::mutable_author() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.Metadata.author) return author_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* Metadata::release_author() { // @@protoc_insertion_point(field_release:CoreML.Specification.Metadata.author) - + return author_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void Metadata::set_allocated_author(::std::string* author) { if (author != NULL) { - + } else { - + } author_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), author); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Metadata.author) @@ -2034,13 +2039,13 @@ const ::std::string& Metadata::license() const { return license_.GetNoArena(); } void Metadata::set_license(const ::std::string& value) { - + license_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.Metadata.license) } #if LANG_CXX11 void Metadata::set_license(::std::string&& value) { - + license_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.Metadata.license) @@ -2048,52 +2053,689 @@ void Metadata::set_license(::std::string&& value) { #endif void Metadata::set_license(const char* value) { GOOGLE_DCHECK(value != NULL); - + license_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.Metadata.license) } void Metadata::set_license(const char* value, size_t size) { - + license_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.Metadata.license) } ::std::string* Metadata::mutable_license() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.Metadata.license) return license_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* Metadata::release_license() { // @@protoc_insertion_point(field_release:CoreML.Specification.Metadata.license) - + return license_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } -void Metadata::set_allocated_license(::std::string* license) { - if (license != NULL) { - - } else { - - } - license_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), license); - // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Metadata.license) +void Metadata::set_allocated_license(::std::string* license) { + if (license != NULL) { + + } else { + + } + license_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), license); + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Metadata.license) +} + +// map userDefined = 100; +int Metadata::userdefined_size() const { + return userdefined_.size(); +} +void Metadata::clear_userdefined() { + userdefined_.Clear(); +} + const ::google::protobuf::Map< ::std::string, ::std::string >& +Metadata::userdefined() const { + // @@protoc_insertion_point(field_map:CoreML.Specification.Metadata.userDefined) + return userdefined_.GetMap(); +} + ::google::protobuf::Map< ::std::string, ::std::string >* +Metadata::mutable_userdefined() { + // @@protoc_insertion_point(field_mutable_map:CoreML.Specification.Metadata.userDefined) + return userdefined_.MutableMap(); +} + +#endif // PROTOBUF_INLINE_NOT_IN_HEADERS + +// =================================================================== + +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +const int FunctionDescription::kNameFieldNumber; +const int FunctionDescription::kInputFieldNumber; +const int FunctionDescription::kOutputFieldNumber; +const int FunctionDescription::kStateFieldNumber; +const int FunctionDescription::kPredictedFeatureNameFieldNumber; +const int FunctionDescription::kPredictedProbabilitiesNameFieldNumber; +#endif // !defined(_MSC_VER) || _MSC_VER >= 1900 + +FunctionDescription::FunctionDescription() + : ::google::protobuf::MessageLite(), _internal_metadata_(NULL) { + if (GOOGLE_PREDICT_TRUE(this != internal_default_instance())) { + protobuf_Model_2eproto::InitDefaults(); + } + SharedCtor(); + // @@protoc_insertion_point(constructor:CoreML.Specification.FunctionDescription) +} +FunctionDescription::FunctionDescription(const FunctionDescription& from) + : ::google::protobuf::MessageLite(), + _internal_metadata_(NULL), + input_(from.input_), + output_(from.output_), + state_(from.state_), + _cached_size_(0) { + _internal_metadata_.MergeFrom(from._internal_metadata_); + name_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + if (from.name().size() > 0) { + name_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.name_); + } + predictedfeaturename_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + if (from.predictedfeaturename().size() > 0) { + predictedfeaturename_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.predictedfeaturename_); + } + predictedprobabilitiesname_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + if (from.predictedprobabilitiesname().size() > 0) { + predictedprobabilitiesname_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.predictedprobabilitiesname_); + } + // @@protoc_insertion_point(copy_constructor:CoreML.Specification.FunctionDescription) +} + +void FunctionDescription::SharedCtor() { + name_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + predictedfeaturename_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + predictedprobabilitiesname_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + _cached_size_ = 0; +} + +FunctionDescription::~FunctionDescription() { + // @@protoc_insertion_point(destructor:CoreML.Specification.FunctionDescription) + SharedDtor(); +} + +void FunctionDescription::SharedDtor() { + name_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + predictedfeaturename_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + predictedprobabilitiesname_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} + +void FunctionDescription::SetCachedSize(int size) const { + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); +} +const FunctionDescription& FunctionDescription::default_instance() { + protobuf_Model_2eproto::InitDefaults(); + return *internal_default_instance(); +} + +FunctionDescription* FunctionDescription::New(::google::protobuf::Arena* arena) const { + FunctionDescription* n = new FunctionDescription; + if (arena != NULL) { + arena->Own(n); + } + return n; +} + +void FunctionDescription::Clear() { +// @@protoc_insertion_point(message_clear_start:CoreML.Specification.FunctionDescription) + input_.Clear(); + output_.Clear(); + state_.Clear(); + name_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + predictedfeaturename_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + predictedprobabilitiesname_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} + +bool FunctionDescription::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!GOOGLE_PREDICT_TRUE(EXPRESSION)) goto failure + ::google::protobuf::uint32 tag; + // @@protoc_insertion_point(parse_start:CoreML.Specification.FunctionDescription) + for (;;) { + ::std::pair< ::google::protobuf::uint32, bool> p = input->ReadTagWithCutoffNoLastTag(127u); + tag = p.first; + if (!p.second) goto handle_unusual; + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // string name = 1; + case 1: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(10u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadString( + input, this->mutable_name())); + DO_(::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->name().data(), this->name().length(), + ::google::protobuf::internal::WireFormatLite::PARSE, + "CoreML.Specification.FunctionDescription.name")); + } else { + goto handle_unusual; + } + break; + } + + // repeated .CoreML.Specification.FeatureDescription input = 2; + case 2: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(18u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, add_input())); + } else { + goto handle_unusual; + } + break; + } + + // repeated .CoreML.Specification.FeatureDescription output = 3; + case 3: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(26u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, add_output())); + } else { + goto handle_unusual; + } + break; + } + + // string predictedFeatureName = 4; + case 4: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(34u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadString( + input, this->mutable_predictedfeaturename())); + DO_(::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->predictedfeaturename().data(), this->predictedfeaturename().length(), + ::google::protobuf::internal::WireFormatLite::PARSE, + "CoreML.Specification.FunctionDescription.predictedFeatureName")); + } else { + goto handle_unusual; + } + break; + } + + // string predictedProbabilitiesName = 5; + case 5: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(42u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadString( + input, this->mutable_predictedprobabilitiesname())); + DO_(::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->predictedprobabilitiesname().data(), this->predictedprobabilitiesname().length(), + ::google::protobuf::internal::WireFormatLite::PARSE, + "CoreML.Specification.FunctionDescription.predictedProbabilitiesName")); + } else { + goto handle_unusual; + } + break; + } + + // repeated .CoreML.Specification.FeatureDescription state = 6; + case 6: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(50u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, add_state())); + } else { + goto handle_unusual; + } + break; + } + + default: { + handle_unusual: + if (tag == 0 || + ::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { + goto success; + } + DO_(::google::protobuf::internal::WireFormatLite::SkipField(input, tag)); + break; + } + } + } +success: + // @@protoc_insertion_point(parse_success:CoreML.Specification.FunctionDescription) + return true; +failure: + // @@protoc_insertion_point(parse_failure:CoreML.Specification.FunctionDescription) + return false; +#undef DO_ +} + +void FunctionDescription::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // @@protoc_insertion_point(serialize_start:CoreML.Specification.FunctionDescription) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // string name = 1; + if (this->name().size() > 0) { + ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->name().data(), this->name().length(), + ::google::protobuf::internal::WireFormatLite::SERIALIZE, + "CoreML.Specification.FunctionDescription.name"); + ::google::protobuf::internal::WireFormatLite::WriteStringMaybeAliased( + 1, this->name(), output); + } + + // repeated .CoreML.Specification.FeatureDescription input = 2; + for (unsigned int i = 0, n = this->input_size(); i < n; i++) { + ::google::protobuf::internal::WireFormatLite::WriteMessage( + 2, this->input(i), output); + } + + // repeated .CoreML.Specification.FeatureDescription output = 3; + for (unsigned int i = 0, n = this->output_size(); i < n; i++) { + ::google::protobuf::internal::WireFormatLite::WriteMessage( + 3, this->output(i), output); + } + + // string predictedFeatureName = 4; + if (this->predictedfeaturename().size() > 0) { + ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->predictedfeaturename().data(), this->predictedfeaturename().length(), + ::google::protobuf::internal::WireFormatLite::SERIALIZE, + "CoreML.Specification.FunctionDescription.predictedFeatureName"); + ::google::protobuf::internal::WireFormatLite::WriteStringMaybeAliased( + 4, this->predictedfeaturename(), output); + } + + // string predictedProbabilitiesName = 5; + if (this->predictedprobabilitiesname().size() > 0) { + ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->predictedprobabilitiesname().data(), this->predictedprobabilitiesname().length(), + ::google::protobuf::internal::WireFormatLite::SERIALIZE, + "CoreML.Specification.FunctionDescription.predictedProbabilitiesName"); + ::google::protobuf::internal::WireFormatLite::WriteStringMaybeAliased( + 5, this->predictedprobabilitiesname(), output); + } + + // repeated .CoreML.Specification.FeatureDescription state = 6; + for (unsigned int i = 0, n = this->state_size(); i < n; i++) { + ::google::protobuf::internal::WireFormatLite::WriteMessage( + 6, this->state(i), output); + } + + // @@protoc_insertion_point(serialize_end:CoreML.Specification.FunctionDescription) +} + +size_t FunctionDescription::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:CoreML.Specification.FunctionDescription) + size_t total_size = 0; + + // repeated .CoreML.Specification.FeatureDescription input = 2; + { + unsigned int count = this->input_size(); + total_size += 1UL * count; + for (unsigned int i = 0; i < count; i++) { + total_size += + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + this->input(i)); + } + } + + // repeated .CoreML.Specification.FeatureDescription output = 3; + { + unsigned int count = this->output_size(); + total_size += 1UL * count; + for (unsigned int i = 0; i < count; i++) { + total_size += + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + this->output(i)); + } + } + + // repeated .CoreML.Specification.FeatureDescription state = 6; + { + unsigned int count = this->state_size(); + total_size += 1UL * count; + for (unsigned int i = 0; i < count; i++) { + total_size += + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + this->state(i)); + } + } + + // string name = 1; + if (this->name().size() > 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::StringSize( + this->name()); + } + + // string predictedFeatureName = 4; + if (this->predictedfeaturename().size() > 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::StringSize( + this->predictedfeaturename()); + } + + // string predictedProbabilitiesName = 5; + if (this->predictedprobabilitiesname().size() > 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::StringSize( + this->predictedprobabilitiesname()); + } + + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = cached_size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); + return total_size; +} + +void FunctionDescription::CheckTypeAndMergeFrom( + const ::google::protobuf::MessageLite& from) { + MergeFrom(*::google::protobuf::down_cast(&from)); +} + +void FunctionDescription::MergeFrom(const FunctionDescription& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:CoreML.Specification.FunctionDescription) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom(from._internal_metadata_); + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + input_.MergeFrom(from.input_); + output_.MergeFrom(from.output_); + state_.MergeFrom(from.state_); + if (from.name().size() > 0) { + + name_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.name_); + } + if (from.predictedfeaturename().size() > 0) { + + predictedfeaturename_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.predictedfeaturename_); + } + if (from.predictedprobabilitiesname().size() > 0) { + + predictedprobabilitiesname_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.predictedprobabilitiesname_); + } +} + +void FunctionDescription::CopyFrom(const FunctionDescription& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:CoreML.Specification.FunctionDescription) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool FunctionDescription::IsInitialized() const { + return true; +} + +void FunctionDescription::Swap(FunctionDescription* other) { + if (other == this) return; + InternalSwap(other); +} +void FunctionDescription::InternalSwap(FunctionDescription* other) { + input_.InternalSwap(&other->input_); + output_.InternalSwap(&other->output_); + state_.InternalSwap(&other->state_); + name_.Swap(&other->name_); + predictedfeaturename_.Swap(&other->predictedfeaturename_); + predictedprobabilitiesname_.Swap(&other->predictedprobabilitiesname_); + std::swap(_cached_size_, other->_cached_size_); +} + +::std::string FunctionDescription::GetTypeName() const { + return "CoreML.Specification.FunctionDescription"; +} + +#if PROTOBUF_INLINE_NOT_IN_HEADERS +// FunctionDescription + +// string name = 1; +void FunctionDescription::clear_name() { + name_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +const ::std::string& FunctionDescription::name() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.name) + return name_.GetNoArena(); +} +void FunctionDescription::set_name(const ::std::string& value) { + + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:CoreML.Specification.FunctionDescription.name) +} +#if LANG_CXX11 +void FunctionDescription::set_name(::std::string&& value) { + + name_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.FunctionDescription.name) +} +#endif +void FunctionDescription::set_name(const char* value) { + GOOGLE_DCHECK(value != NULL); + + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:CoreML.Specification.FunctionDescription.name) +} +void FunctionDescription::set_name(const char* value, size_t size) { + + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.FunctionDescription.name) +} +::std::string* FunctionDescription::mutable_name() { + + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.name) + return name_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +::std::string* FunctionDescription::release_name() { + // @@protoc_insertion_point(field_release:CoreML.Specification.FunctionDescription.name) + + return name_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +void FunctionDescription::set_allocated_name(::std::string* name) { + if (name != NULL) { + + } else { + + } + name_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), name); + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FunctionDescription.name) +} + +// repeated .CoreML.Specification.FeatureDescription input = 2; +int FunctionDescription::input_size() const { + return input_.size(); +} +void FunctionDescription::clear_input() { + input_.Clear(); +} +const ::CoreML::Specification::FeatureDescription& FunctionDescription::input(int index) const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.input) + return input_.Get(index); +} +::CoreML::Specification::FeatureDescription* FunctionDescription::mutable_input(int index) { + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.input) + return input_.Mutable(index); +} +::CoreML::Specification::FeatureDescription* FunctionDescription::add_input() { + // @@protoc_insertion_point(field_add:CoreML.Specification.FunctionDescription.input) + return input_.Add(); +} +::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* +FunctionDescription::mutable_input() { + // @@protoc_insertion_point(field_mutable_list:CoreML.Specification.FunctionDescription.input) + return &input_; +} +const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& +FunctionDescription::input() const { + // @@protoc_insertion_point(field_list:CoreML.Specification.FunctionDescription.input) + return input_; +} + +// repeated .CoreML.Specification.FeatureDescription output = 3; +int FunctionDescription::output_size() const { + return output_.size(); +} +void FunctionDescription::clear_output() { + output_.Clear(); +} +const ::CoreML::Specification::FeatureDescription& FunctionDescription::output(int index) const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.output) + return output_.Get(index); +} +::CoreML::Specification::FeatureDescription* FunctionDescription::mutable_output(int index) { + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.output) + return output_.Mutable(index); +} +::CoreML::Specification::FeatureDescription* FunctionDescription::add_output() { + // @@protoc_insertion_point(field_add:CoreML.Specification.FunctionDescription.output) + return output_.Add(); +} +::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* +FunctionDescription::mutable_output() { + // @@protoc_insertion_point(field_mutable_list:CoreML.Specification.FunctionDescription.output) + return &output_; +} +const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& +FunctionDescription::output() const { + // @@protoc_insertion_point(field_list:CoreML.Specification.FunctionDescription.output) + return output_; +} + +// repeated .CoreML.Specification.FeatureDescription state = 6; +int FunctionDescription::state_size() const { + return state_.size(); +} +void FunctionDescription::clear_state() { + state_.Clear(); +} +const ::CoreML::Specification::FeatureDescription& FunctionDescription::state(int index) const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.state) + return state_.Get(index); +} +::CoreML::Specification::FeatureDescription* FunctionDescription::mutable_state(int index) { + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.state) + return state_.Mutable(index); +} +::CoreML::Specification::FeatureDescription* FunctionDescription::add_state() { + // @@protoc_insertion_point(field_add:CoreML.Specification.FunctionDescription.state) + return state_.Add(); +} +::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* +FunctionDescription::mutable_state() { + // @@protoc_insertion_point(field_mutable_list:CoreML.Specification.FunctionDescription.state) + return &state_; +} +const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& +FunctionDescription::state() const { + // @@protoc_insertion_point(field_list:CoreML.Specification.FunctionDescription.state) + return state_; +} + +// string predictedFeatureName = 4; +void FunctionDescription::clear_predictedfeaturename() { + predictedfeaturename_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +const ::std::string& FunctionDescription::predictedfeaturename() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.predictedFeatureName) + return predictedfeaturename_.GetNoArena(); +} +void FunctionDescription::set_predictedfeaturename(const ::std::string& value) { + + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:CoreML.Specification.FunctionDescription.predictedFeatureName) +} +#if LANG_CXX11 +void FunctionDescription::set_predictedfeaturename(::std::string&& value) { + + predictedfeaturename_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.FunctionDescription.predictedFeatureName) +} +#endif +void FunctionDescription::set_predictedfeaturename(const char* value) { + GOOGLE_DCHECK(value != NULL); + + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:CoreML.Specification.FunctionDescription.predictedFeatureName) +} +void FunctionDescription::set_predictedfeaturename(const char* value, size_t size) { + + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.FunctionDescription.predictedFeatureName) +} +::std::string* FunctionDescription::mutable_predictedfeaturename() { + + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.predictedFeatureName) + return predictedfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +::std::string* FunctionDescription::release_predictedfeaturename() { + // @@protoc_insertion_point(field_release:CoreML.Specification.FunctionDescription.predictedFeatureName) + + return predictedfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +void FunctionDescription::set_allocated_predictedfeaturename(::std::string* predictedfeaturename) { + if (predictedfeaturename != NULL) { + + } else { + + } + predictedfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), predictedfeaturename); + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FunctionDescription.predictedFeatureName) +} + +// string predictedProbabilitiesName = 5; +void FunctionDescription::clear_predictedprobabilitiesname() { + predictedprobabilitiesname_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +const ::std::string& FunctionDescription::predictedprobabilitiesname() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) + return predictedprobabilitiesname_.GetNoArena(); +} +void FunctionDescription::set_predictedprobabilitiesname(const ::std::string& value) { + + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) +} +#if LANG_CXX11 +void FunctionDescription::set_predictedprobabilitiesname(::std::string&& value) { + + predictedprobabilitiesname_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) +} +#endif +void FunctionDescription::set_predictedprobabilitiesname(const char* value) { + GOOGLE_DCHECK(value != NULL); + + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) } +void FunctionDescription::set_predictedprobabilitiesname(const char* value, size_t size) { -// map userDefined = 100; -int Metadata::userdefined_size() const { - return userdefined_.size(); + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) } -void Metadata::clear_userdefined() { - userdefined_.Clear(); +::std::string* FunctionDescription::mutable_predictedprobabilitiesname() { + + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) + return predictedprobabilitiesname_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } - const ::google::protobuf::Map< ::std::string, ::std::string >& -Metadata::userdefined() const { - // @@protoc_insertion_point(field_map:CoreML.Specification.Metadata.userDefined) - return userdefined_.GetMap(); +::std::string* FunctionDescription::release_predictedprobabilitiesname() { + // @@protoc_insertion_point(field_release:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) + + return predictedprobabilitiesname_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } - ::google::protobuf::Map< ::std::string, ::std::string >* -Metadata::mutable_userdefined() { - // @@protoc_insertion_point(field_mutable_map:CoreML.Specification.Metadata.userDefined) - return userdefined_.MutableMap(); +void FunctionDescription::set_allocated_predictedprobabilitiesname(::std::string* predictedprobabilitiesname) { + if (predictedprobabilitiesname != NULL) { + + } else { + + } + predictedprobabilitiesname_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), predictedprobabilitiesname); + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) } #endif // PROTOBUF_INLINE_NOT_IN_HEADERS @@ -2101,12 +2743,15 @@ Metadata::mutable_userdefined() { // =================================================================== #if !defined(_MSC_VER) || _MSC_VER >= 1900 +const int ModelDescription::kFunctionsFieldNumber; +const int ModelDescription::kDefaultFunctionNameFieldNumber; +const int ModelDescription::kMetadataFieldNumber; const int ModelDescription::kInputFieldNumber; const int ModelDescription::kOutputFieldNumber; +const int ModelDescription::kStateFieldNumber; const int ModelDescription::kPredictedFeatureNameFieldNumber; const int ModelDescription::kPredictedProbabilitiesNameFieldNumber; const int ModelDescription::kTrainingInputFieldNumber; -const int ModelDescription::kMetadataFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 ModelDescription::ModelDescription() @@ -2122,6 +2767,8 @@ ModelDescription::ModelDescription(const ModelDescription& from) _internal_metadata_(NULL), input_(from.input_), output_(from.output_), + state_(from.state_), + functions_(from.functions_), traininginput_(from.traininginput_), _cached_size_(0) { _internal_metadata_.MergeFrom(from._internal_metadata_); @@ -2133,6 +2780,10 @@ ModelDescription::ModelDescription(const ModelDescription& from) if (from.predictedprobabilitiesname().size() > 0) { predictedprobabilitiesname_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.predictedprobabilitiesname_); } + defaultfunctionname_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + if (from.defaultfunctionname().size() > 0) { + defaultfunctionname_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.defaultfunctionname_); + } if (from.has_metadata()) { metadata_ = new ::CoreML::Specification::Metadata(*from.metadata_); } else { @@ -2144,6 +2795,7 @@ ModelDescription::ModelDescription(const ModelDescription& from) void ModelDescription::SharedCtor() { predictedfeaturename_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); predictedprobabilitiesname_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + defaultfunctionname_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); metadata_ = NULL; _cached_size_ = 0; } @@ -2156,6 +2808,7 @@ ModelDescription::~ModelDescription() { void ModelDescription::SharedDtor() { predictedfeaturename_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); predictedprobabilitiesname_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + defaultfunctionname_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); if (this != internal_default_instance()) { delete metadata_; } @@ -2183,9 +2836,12 @@ void ModelDescription::Clear() { // @@protoc_insertion_point(message_clear_start:CoreML.Specification.ModelDescription) input_.Clear(); output_.Clear(); + state_.Clear(); + functions_.Clear(); traininginput_.Clear(); predictedfeaturename_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); predictedprobabilitiesname_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + defaultfunctionname_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); if (GetArenaNoVirtual() == NULL && metadata_ != NULL) { delete metadata_; } @@ -2258,6 +2914,46 @@ bool ModelDescription::MergePartialFromCodedStream( break; } + // repeated .CoreML.Specification.FeatureDescription state = 13; + case 13: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(106u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, add_state())); + } else { + goto handle_unusual; + } + break; + } + + // repeated .CoreML.Specification.FunctionDescription functions = 20; + case 20: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(162u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( + input, add_functions())); + } else { + goto handle_unusual; + } + break; + } + + // string defaultFunctionName = 21; + case 21: { + if (static_cast< ::google::protobuf::uint8>(tag) == + static_cast< ::google::protobuf::uint8>(170u)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadString( + input, this->mutable_defaultfunctionname())); + DO_(::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->defaultfunctionname().data(), this->defaultfunctionname().length(), + ::google::protobuf::internal::WireFormatLite::PARSE, + "CoreML.Specification.ModelDescription.defaultFunctionName")); + } else { + goto handle_unusual; + } + break; + } + // repeated .CoreML.Specification.FeatureDescription trainingInput = 50; case 50: { if (static_cast< ::google::protobuf::uint8>(tag) == @@ -2341,6 +3037,28 @@ void ModelDescription::SerializeWithCachedSizes( 12, this->predictedprobabilitiesname(), output); } + // repeated .CoreML.Specification.FeatureDescription state = 13; + for (unsigned int i = 0, n = this->state_size(); i < n; i++) { + ::google::protobuf::internal::WireFormatLite::WriteMessage( + 13, this->state(i), output); + } + + // repeated .CoreML.Specification.FunctionDescription functions = 20; + for (unsigned int i = 0, n = this->functions_size(); i < n; i++) { + ::google::protobuf::internal::WireFormatLite::WriteMessage( + 20, this->functions(i), output); + } + + // string defaultFunctionName = 21; + if (this->defaultfunctionname().size() > 0) { + ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->defaultfunctionname().data(), this->defaultfunctionname().length(), + ::google::protobuf::internal::WireFormatLite::SERIALIZE, + "CoreML.Specification.ModelDescription.defaultFunctionName"); + ::google::protobuf::internal::WireFormatLite::WriteStringMaybeAliased( + 21, this->defaultfunctionname(), output); + } + // repeated .CoreML.Specification.FeatureDescription trainingInput = 50; for (unsigned int i = 0, n = this->traininginput_size(); i < n; i++) { ::google::protobuf::internal::WireFormatLite::WriteMessage( @@ -2382,6 +3100,28 @@ size_t ModelDescription::ByteSizeLong() const { } } + // repeated .CoreML.Specification.FeatureDescription state = 13; + { + unsigned int count = this->state_size(); + total_size += 1UL * count; + for (unsigned int i = 0; i < count; i++) { + total_size += + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + this->state(i)); + } + } + + // repeated .CoreML.Specification.FunctionDescription functions = 20; + { + unsigned int count = this->functions_size(); + total_size += 2UL * count; + for (unsigned int i = 0; i < count; i++) { + total_size += + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( + this->functions(i)); + } + } + // repeated .CoreML.Specification.FeatureDescription trainingInput = 50; { unsigned int count = this->traininginput_size(); @@ -2407,6 +3147,13 @@ size_t ModelDescription::ByteSizeLong() const { this->predictedprobabilitiesname()); } + // string defaultFunctionName = 21; + if (this->defaultfunctionname().size() > 0) { + total_size += 2 + + ::google::protobuf::internal::WireFormatLite::StringSize( + this->defaultfunctionname()); + } + // .CoreML.Specification.Metadata metadata = 100; if (this->has_metadata()) { total_size += 2 + @@ -2435,6 +3182,8 @@ void ModelDescription::MergeFrom(const ModelDescription& from) { input_.MergeFrom(from.input_); output_.MergeFrom(from.output_); + state_.MergeFrom(from.state_); + functions_.MergeFrom(from.functions_); traininginput_.MergeFrom(from.traininginput_); if (from.predictedfeaturename().size() > 0) { @@ -2444,6 +3193,10 @@ void ModelDescription::MergeFrom(const ModelDescription& from) { predictedprobabilitiesname_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.predictedprobabilitiesname_); } + if (from.defaultfunctionname().size() > 0) { + + defaultfunctionname_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.defaultfunctionname_); + } if (from.has_metadata()) { mutable_metadata()->::CoreML::Specification::Metadata::MergeFrom(from.metadata()); } @@ -2467,9 +3220,12 @@ void ModelDescription::Swap(ModelDescription* other) { void ModelDescription::InternalSwap(ModelDescription* other) { input_.InternalSwap(&other->input_); output_.InternalSwap(&other->output_); + state_.InternalSwap(&other->state_); + functions_.InternalSwap(&other->functions_); traininginput_.InternalSwap(&other->traininginput_); predictedfeaturename_.Swap(&other->predictedfeaturename_); predictedprobabilitiesname_.Swap(&other->predictedprobabilitiesname_); + defaultfunctionname_.Swap(&other->defaultfunctionname_); std::swap(metadata_, other->metadata_); std::swap(_cached_size_, other->_cached_size_); } @@ -2481,6 +3237,128 @@ ::std::string ModelDescription::GetTypeName() const { #if PROTOBUF_INLINE_NOT_IN_HEADERS // ModelDescription +// repeated .CoreML.Specification.FunctionDescription functions = 20; +int ModelDescription::functions_size() const { + return functions_.size(); +} +void ModelDescription::clear_functions() { + functions_.Clear(); +} +const ::CoreML::Specification::FunctionDescription& ModelDescription::functions(int index) const { + // @@protoc_insertion_point(field_get:CoreML.Specification.ModelDescription.functions) + return functions_.Get(index); +} +::CoreML::Specification::FunctionDescription* ModelDescription::mutable_functions(int index) { + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.functions) + return functions_.Mutable(index); +} +::CoreML::Specification::FunctionDescription* ModelDescription::add_functions() { + // @@protoc_insertion_point(field_add:CoreML.Specification.ModelDescription.functions) + return functions_.Add(); +} +::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FunctionDescription >* +ModelDescription::mutable_functions() { + // @@protoc_insertion_point(field_mutable_list:CoreML.Specification.ModelDescription.functions) + return &functions_; +} +const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FunctionDescription >& +ModelDescription::functions() const { + // @@protoc_insertion_point(field_list:CoreML.Specification.ModelDescription.functions) + return functions_; +} + +// string defaultFunctionName = 21; +void ModelDescription::clear_defaultfunctionname() { + defaultfunctionname_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +const ::std::string& ModelDescription::defaultfunctionname() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.ModelDescription.defaultFunctionName) + return defaultfunctionname_.GetNoArena(); +} +void ModelDescription::set_defaultfunctionname(const ::std::string& value) { + + defaultfunctionname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:CoreML.Specification.ModelDescription.defaultFunctionName) +} +#if LANG_CXX11 +void ModelDescription::set_defaultfunctionname(::std::string&& value) { + + defaultfunctionname_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ModelDescription.defaultFunctionName) +} +#endif +void ModelDescription::set_defaultfunctionname(const char* value) { + GOOGLE_DCHECK(value != NULL); + + defaultfunctionname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:CoreML.Specification.ModelDescription.defaultFunctionName) +} +void ModelDescription::set_defaultfunctionname(const char* value, size_t size) { + + defaultfunctionname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ModelDescription.defaultFunctionName) +} +::std::string* ModelDescription::mutable_defaultfunctionname() { + + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.defaultFunctionName) + return defaultfunctionname_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +::std::string* ModelDescription::release_defaultfunctionname() { + // @@protoc_insertion_point(field_release:CoreML.Specification.ModelDescription.defaultFunctionName) + + return defaultfunctionname_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +void ModelDescription::set_allocated_defaultfunctionname(::std::string* defaultfunctionname) { + if (defaultfunctionname != NULL) { + + } else { + + } + defaultfunctionname_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), defaultfunctionname); + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ModelDescription.defaultFunctionName) +} + +// .CoreML.Specification.Metadata metadata = 100; +bool ModelDescription::has_metadata() const { + return this != internal_default_instance() && metadata_ != NULL; +} +void ModelDescription::clear_metadata() { + if (GetArenaNoVirtual() == NULL && metadata_ != NULL) delete metadata_; + metadata_ = NULL; +} +const ::CoreML::Specification::Metadata& ModelDescription::metadata() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.ModelDescription.metadata) + return metadata_ != NULL ? *metadata_ + : *::CoreML::Specification::Metadata::internal_default_instance(); +} +::CoreML::Specification::Metadata* ModelDescription::mutable_metadata() { + + if (metadata_ == NULL) { + metadata_ = new ::CoreML::Specification::Metadata; + } + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.metadata) + return metadata_; +} +::CoreML::Specification::Metadata* ModelDescription::release_metadata() { + // @@protoc_insertion_point(field_release:CoreML.Specification.ModelDescription.metadata) + + ::CoreML::Specification::Metadata* temp = metadata_; + metadata_ = NULL; + return temp; +} +void ModelDescription::set_allocated_metadata(::CoreML::Specification::Metadata* metadata) { + delete metadata_; + metadata_ = metadata; + if (metadata) { + + } else { + + } + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ModelDescription.metadata) +} + // repeated .CoreML.Specification.FeatureDescription input = 1; int ModelDescription::input_size() const { return input_.size(); @@ -2541,6 +3419,36 @@ ModelDescription::output() const { return output_; } +// repeated .CoreML.Specification.FeatureDescription state = 13; +int ModelDescription::state_size() const { + return state_.size(); +} +void ModelDescription::clear_state() { + state_.Clear(); +} +const ::CoreML::Specification::FeatureDescription& ModelDescription::state(int index) const { + // @@protoc_insertion_point(field_get:CoreML.Specification.ModelDescription.state) + return state_.Get(index); +} +::CoreML::Specification::FeatureDescription* ModelDescription::mutable_state(int index) { + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.state) + return state_.Mutable(index); +} +::CoreML::Specification::FeatureDescription* ModelDescription::add_state() { + // @@protoc_insertion_point(field_add:CoreML.Specification.ModelDescription.state) + return state_.Add(); +} +::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* +ModelDescription::mutable_state() { + // @@protoc_insertion_point(field_mutable_list:CoreML.Specification.ModelDescription.state) + return &state_; +} +const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& +ModelDescription::state() const { + // @@protoc_insertion_point(field_list:CoreML.Specification.ModelDescription.state) + return state_; +} + // string predictedFeatureName = 11; void ModelDescription::clear_predictedfeaturename() { predictedfeaturename_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); @@ -2550,13 +3458,13 @@ const ::std::string& ModelDescription::predictedfeaturename() const { return predictedfeaturename_.GetNoArena(); } void ModelDescription::set_predictedfeaturename(const ::std::string& value) { - + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.ModelDescription.predictedFeatureName) } #if LANG_CXX11 void ModelDescription::set_predictedfeaturename(::std::string&& value) { - + predictedfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ModelDescription.predictedFeatureName) @@ -2564,31 +3472,31 @@ void ModelDescription::set_predictedfeaturename(::std::string&& value) { #endif void ModelDescription::set_predictedfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.ModelDescription.predictedFeatureName) } void ModelDescription::set_predictedfeaturename(const char* value, size_t size) { - + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ModelDescription.predictedFeatureName) } ::std::string* ModelDescription::mutable_predictedfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.predictedFeatureName) return predictedfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* ModelDescription::release_predictedfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.ModelDescription.predictedFeatureName) - + return predictedfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void ModelDescription::set_allocated_predictedfeaturename(::std::string* predictedfeaturename) { if (predictedfeaturename != NULL) { - + } else { - + } predictedfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), predictedfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ModelDescription.predictedFeatureName) @@ -2603,13 +3511,13 @@ const ::std::string& ModelDescription::predictedprobabilitiesname() const { return predictedprobabilitiesname_.GetNoArena(); } void ModelDescription::set_predictedprobabilitiesname(const ::std::string& value) { - + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.ModelDescription.predictedProbabilitiesName) } #if LANG_CXX11 void ModelDescription::set_predictedprobabilitiesname(::std::string&& value) { - + predictedprobabilitiesname_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ModelDescription.predictedProbabilitiesName) @@ -2617,31 +3525,31 @@ void ModelDescription::set_predictedprobabilitiesname(::std::string&& value) { #endif void ModelDescription::set_predictedprobabilitiesname(const char* value) { GOOGLE_DCHECK(value != NULL); - + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.ModelDescription.predictedProbabilitiesName) } void ModelDescription::set_predictedprobabilitiesname(const char* value, size_t size) { - + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ModelDescription.predictedProbabilitiesName) } ::std::string* ModelDescription::mutable_predictedprobabilitiesname() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.predictedProbabilitiesName) return predictedprobabilitiesname_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* ModelDescription::release_predictedprobabilitiesname() { // @@protoc_insertion_point(field_release:CoreML.Specification.ModelDescription.predictedProbabilitiesName) - + return predictedprobabilitiesname_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void ModelDescription::set_allocated_predictedprobabilitiesname(::std::string* predictedprobabilitiesname) { if (predictedprobabilitiesname != NULL) { - + } else { - + } predictedprobabilitiesname_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), predictedprobabilitiesname); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ModelDescription.predictedProbabilitiesName) @@ -2677,45 +3585,6 @@ ModelDescription::traininginput() const { return traininginput_; } -// .CoreML.Specification.Metadata metadata = 100; -bool ModelDescription::has_metadata() const { - return this != internal_default_instance() && metadata_ != NULL; -} -void ModelDescription::clear_metadata() { - if (GetArenaNoVirtual() == NULL && metadata_ != NULL) delete metadata_; - metadata_ = NULL; -} -const ::CoreML::Specification::Metadata& ModelDescription::metadata() const { - // @@protoc_insertion_point(field_get:CoreML.Specification.ModelDescription.metadata) - return metadata_ != NULL ? *metadata_ - : *::CoreML::Specification::Metadata::internal_default_instance(); -} -::CoreML::Specification::Metadata* ModelDescription::mutable_metadata() { - - if (metadata_ == NULL) { - metadata_ = new ::CoreML::Specification::Metadata; - } - // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.metadata) - return metadata_; -} -::CoreML::Specification::Metadata* ModelDescription::release_metadata() { - // @@protoc_insertion_point(field_release:CoreML.Specification.ModelDescription.metadata) - - ::CoreML::Specification::Metadata* temp = metadata_; - metadata_ = NULL; - return temp; -} -void ModelDescription::set_allocated_metadata(::CoreML::Specification::Metadata* metadata) { - delete metadata_; - metadata_ = metadata; - if (metadata) { - - } else { - - } - // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ModelDescription.metadata) -} - #endif // PROTOBUF_INLINE_NOT_IN_HEADERS // =================================================================== @@ -2957,13 +3826,13 @@ const ::std::string& SerializedModel::identifier() const { return identifier_.GetNoArena(); } void SerializedModel::set_identifier(const ::std::string& value) { - + identifier_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.SerializedModel.identifier) } #if LANG_CXX11 void SerializedModel::set_identifier(::std::string&& value) { - + identifier_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.SerializedModel.identifier) @@ -2971,31 +3840,31 @@ void SerializedModel::set_identifier(::std::string&& value) { #endif void SerializedModel::set_identifier(const char* value) { GOOGLE_DCHECK(value != NULL); - + identifier_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.SerializedModel.identifier) } void SerializedModel::set_identifier(const char* value, size_t size) { - + identifier_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.SerializedModel.identifier) } ::std::string* SerializedModel::mutable_identifier() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.SerializedModel.identifier) return identifier_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* SerializedModel::release_identifier() { // @@protoc_insertion_point(field_release:CoreML.Specification.SerializedModel.identifier) - + return identifier_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void SerializedModel::set_allocated_identifier(::std::string* identifier) { if (identifier != NULL) { - + } else { - + } identifier_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), identifier); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SerializedModel.identifier) @@ -3010,13 +3879,13 @@ const ::std::string& SerializedModel::model() const { return model_.GetNoArena(); } void SerializedModel::set_model(const ::std::string& value) { - + model_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.SerializedModel.model) } #if LANG_CXX11 void SerializedModel::set_model(::std::string&& value) { - + model_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.SerializedModel.model) @@ -3024,31 +3893,31 @@ void SerializedModel::set_model(::std::string&& value) { #endif void SerializedModel::set_model(const char* value) { GOOGLE_DCHECK(value != NULL); - + model_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.SerializedModel.model) } void SerializedModel::set_model(const void* value, size_t size) { - + model_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.SerializedModel.model) } ::std::string* SerializedModel::mutable_model() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.SerializedModel.model) return model_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } ::std::string* SerializedModel::release_model() { // @@protoc_insertion_point(field_release:CoreML.Specification.SerializedModel.model) - + return model_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } void SerializedModel::set_allocated_model(::std::string* model) { if (model != NULL) { - + } else { - + } model_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), model); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SerializedModel.model) @@ -4760,7 +5629,7 @@ ::google::protobuf::int32 Model::specificationversion() const { return specificationversion_; } void Model::set_specificationversion(::google::protobuf::int32 value) { - + specificationversion_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Model.specificationVersion) } @@ -4779,7 +5648,7 @@ const ::CoreML::Specification::ModelDescription& Model::description() const { : *::CoreML::Specification::ModelDescription::internal_default_instance(); } ::CoreML::Specification::ModelDescription* Model::mutable_description() { - + if (description_ == NULL) { description_ = new ::CoreML::Specification::ModelDescription; } @@ -4788,7 +5657,7 @@ ::CoreML::Specification::ModelDescription* Model::mutable_description() { } ::CoreML::Specification::ModelDescription* Model::release_description() { // @@protoc_insertion_point(field_release:CoreML.Specification.Model.description) - + ::CoreML::Specification::ModelDescription* temp = description_; description_ = NULL; return temp; @@ -4797,9 +5666,9 @@ void Model::set_allocated_description(::CoreML::Specification::ModelDescription* delete description_; description_ = description; if (description) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Model.description) } @@ -4813,7 +5682,7 @@ bool Model::isupdatable() const { return isupdatable_; } void Model::set_isupdatable(bool value) { - + isupdatable_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Model.isUpdatable) } diff --git a/mlmodel/build/format/Model.pb.h b/mlmodel/build/format/Model.pb.h index 58895f832..7dcf13e84 100644 --- a/mlmodel/build/format/Model.pb.h +++ b/mlmodel/build/format/Model.pb.h @@ -367,6 +367,9 @@ extern FloorDivBroadcastableLayerParamsDefaultTypeInternal _FloorDivBroadcastabl class FloorLayerParams; class FloorLayerParamsDefaultTypeInternal; extern FloorLayerParamsDefaultTypeInternal _FloorLayerParams_default_instance_; +class FunctionDescription; +class FunctionDescriptionDefaultTypeInternal; +extern FunctionDescriptionDefaultTypeInternal _FunctionDescription_default_instance_; class GLMClassifier; class GLMClassifierDefaultTypeInternal; extern GLMClassifierDefaultTypeInternal _GLMClassifier_default_instance_; @@ -880,6 +883,9 @@ extern SqueezeLayerParamsDefaultTypeInternal _SqueezeLayerParams_default_instanc class StackLayerParams; class StackLayerParamsDefaultTypeInternal; extern StackLayerParamsDefaultTypeInternal _StackLayerParams_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -1072,6 +1078,9 @@ extern Program_AttributesEntryDefaultTypeInternal _Program_AttributesEntry_defau class Program_FunctionsEntry; class Program_FunctionsEntryDefaultTypeInternal; extern Program_FunctionsEntryDefaultTypeInternal _Program_FunctionsEntry_default_instance_; +class StateType; +class StateTypeDefaultTypeInternal; +extern StateTypeDefaultTypeInternal _StateType_default_instance_; class TensorType; class TensorTypeDefaultTypeInternal; extern TensorTypeDefaultTypeInternal _TensorType_default_instance_; @@ -1681,6 +1690,161 @@ class Metadata : public ::google::protobuf::MessageLite /* @@protoc_insertion_po }; // ------------------------------------------------------------------- +class FunctionDescription : public ::google::protobuf::MessageLite /* @@protoc_insertion_point(class_definition:CoreML.Specification.FunctionDescription) */ { + public: + FunctionDescription(); + virtual ~FunctionDescription(); + + FunctionDescription(const FunctionDescription& from); + + inline FunctionDescription& operator=(const FunctionDescription& from) { + CopyFrom(from); + return *this; + } + + static const FunctionDescription& default_instance(); + + static inline const FunctionDescription* internal_default_instance() { + return reinterpret_cast( + &_FunctionDescription_default_instance_); + } + static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = + 6; + + void Swap(FunctionDescription* other); + + // implements Message ---------------------------------------------- + + inline FunctionDescription* New() const PROTOBUF_FINAL { return New(NULL); } + + FunctionDescription* New(::google::protobuf::Arena* arena) const PROTOBUF_FINAL; + void CheckTypeAndMergeFrom(const ::google::protobuf::MessageLite& from) + PROTOBUF_FINAL; + void CopyFrom(const FunctionDescription& from); + void MergeFrom(const FunctionDescription& from); + void Clear() PROTOBUF_FINAL; + bool IsInitialized() const PROTOBUF_FINAL; + + size_t ByteSizeLong() const PROTOBUF_FINAL; + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) PROTOBUF_FINAL; + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const PROTOBUF_FINAL; + void DiscardUnknownFields(); + int GetCachedSize() const PROTOBUF_FINAL { return _cached_size_; } + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const; + void InternalSwap(FunctionDescription* other); + private: + inline ::google::protobuf::Arena* GetArenaNoVirtual() const { + return NULL; + } + inline void* MaybeArenaPtr() const { + return NULL; + } + public: + + ::std::string GetTypeName() const PROTOBUF_FINAL; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // repeated .CoreML.Specification.FeatureDescription input = 2; + int input_size() const; + void clear_input(); + static const int kInputFieldNumber = 2; + const ::CoreML::Specification::FeatureDescription& input(int index) const; + ::CoreML::Specification::FeatureDescription* mutable_input(int index); + ::CoreML::Specification::FeatureDescription* add_input(); + ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* + mutable_input(); + const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& + input() const; + + // repeated .CoreML.Specification.FeatureDescription output = 3; + int output_size() const; + void clear_output(); + static const int kOutputFieldNumber = 3; + const ::CoreML::Specification::FeatureDescription& output(int index) const; + ::CoreML::Specification::FeatureDescription* mutable_output(int index); + ::CoreML::Specification::FeatureDescription* add_output(); + ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* + mutable_output(); + const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& + output() const; + + // repeated .CoreML.Specification.FeatureDescription state = 6; + int state_size() const; + void clear_state(); + static const int kStateFieldNumber = 6; + const ::CoreML::Specification::FeatureDescription& state(int index) const; + ::CoreML::Specification::FeatureDescription* mutable_state(int index); + ::CoreML::Specification::FeatureDescription* add_state(); + ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* + mutable_state(); + const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& + state() const; + + // string name = 1; + void clear_name(); + static const int kNameFieldNumber = 1; + const ::std::string& name() const; + void set_name(const ::std::string& value); + #if LANG_CXX11 + void set_name(::std::string&& value); + #endif + void set_name(const char* value); + void set_name(const char* value, size_t size); + ::std::string* mutable_name(); + ::std::string* release_name(); + void set_allocated_name(::std::string* name); + + // string predictedFeatureName = 4; + void clear_predictedfeaturename(); + static const int kPredictedFeatureNameFieldNumber = 4; + const ::std::string& predictedfeaturename() const; + void set_predictedfeaturename(const ::std::string& value); + #if LANG_CXX11 + void set_predictedfeaturename(::std::string&& value); + #endif + void set_predictedfeaturename(const char* value); + void set_predictedfeaturename(const char* value, size_t size); + ::std::string* mutable_predictedfeaturename(); + ::std::string* release_predictedfeaturename(); + void set_allocated_predictedfeaturename(::std::string* predictedfeaturename); + + // string predictedProbabilitiesName = 5; + void clear_predictedprobabilitiesname(); + static const int kPredictedProbabilitiesNameFieldNumber = 5; + const ::std::string& predictedprobabilitiesname() const; + void set_predictedprobabilitiesname(const ::std::string& value); + #if LANG_CXX11 + void set_predictedprobabilitiesname(::std::string&& value); + #endif + void set_predictedprobabilitiesname(const char* value); + void set_predictedprobabilitiesname(const char* value, size_t size); + ::std::string* mutable_predictedprobabilitiesname(); + ::std::string* release_predictedprobabilitiesname(); + void set_allocated_predictedprobabilitiesname(::std::string* predictedprobabilitiesname); + + // @@protoc_insertion_point(class_scope:CoreML.Specification.FunctionDescription) + private: + + ::google::protobuf::internal::InternalMetadataWithArenaLite _internal_metadata_; + ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription > input_; + ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription > output_; + ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription > state_; + ::google::protobuf::internal::ArenaStringPtr name_; + ::google::protobuf::internal::ArenaStringPtr predictedfeaturename_; + ::google::protobuf::internal::ArenaStringPtr predictedprobabilitiesname_; + mutable int _cached_size_; + friend struct protobuf_Model_2eproto::TableStruct; +}; +// ------------------------------------------------------------------- + class ModelDescription : public ::google::protobuf::MessageLite /* @@protoc_insertion_point(class_definition:CoreML.Specification.ModelDescription) */ { public: ModelDescription(); @@ -1700,7 +1864,7 @@ class ModelDescription : public ::google::protobuf::MessageLite /* @@protoc_inse &_ModelDescription_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 6; + 7; void Swap(ModelDescription* other); @@ -1767,6 +1931,30 @@ class ModelDescription : public ::google::protobuf::MessageLite /* @@protoc_inse const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& output() const; + // repeated .CoreML.Specification.FeatureDescription state = 13; + int state_size() const; + void clear_state(); + static const int kStateFieldNumber = 13; + const ::CoreML::Specification::FeatureDescription& state(int index) const; + ::CoreML::Specification::FeatureDescription* mutable_state(int index); + ::CoreML::Specification::FeatureDescription* add_state(); + ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* + mutable_state(); + const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& + state() const; + + // repeated .CoreML.Specification.FunctionDescription functions = 20; + int functions_size() const; + void clear_functions(); + static const int kFunctionsFieldNumber = 20; + const ::CoreML::Specification::FunctionDescription& functions(int index) const; + ::CoreML::Specification::FunctionDescription* mutable_functions(int index); + ::CoreML::Specification::FunctionDescription* add_functions(); + ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FunctionDescription >* + mutable_functions(); + const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FunctionDescription >& + functions() const; + // repeated .CoreML.Specification.FeatureDescription trainingInput = 50; int traininginput_size() const; void clear_traininginput(); @@ -1807,6 +1995,20 @@ class ModelDescription : public ::google::protobuf::MessageLite /* @@protoc_inse ::std::string* release_predictedprobabilitiesname(); void set_allocated_predictedprobabilitiesname(::std::string* predictedprobabilitiesname); + // string defaultFunctionName = 21; + void clear_defaultfunctionname(); + static const int kDefaultFunctionNameFieldNumber = 21; + const ::std::string& defaultfunctionname() const; + void set_defaultfunctionname(const ::std::string& value); + #if LANG_CXX11 + void set_defaultfunctionname(::std::string&& value); + #endif + void set_defaultfunctionname(const char* value); + void set_defaultfunctionname(const char* value, size_t size); + ::std::string* mutable_defaultfunctionname(); + ::std::string* release_defaultfunctionname(); + void set_allocated_defaultfunctionname(::std::string* defaultfunctionname); + // .CoreML.Specification.Metadata metadata = 100; bool has_metadata() const; void clear_metadata(); @@ -1822,9 +2024,12 @@ class ModelDescription : public ::google::protobuf::MessageLite /* @@protoc_inse ::google::protobuf::internal::InternalMetadataWithArenaLite _internal_metadata_; ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription > input_; ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription > output_; + ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription > state_; + ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FunctionDescription > functions_; ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription > traininginput_; ::google::protobuf::internal::ArenaStringPtr predictedfeaturename_; ::google::protobuf::internal::ArenaStringPtr predictedprobabilitiesname_; + ::google::protobuf::internal::ArenaStringPtr defaultfunctionname_; ::CoreML::Specification::Metadata* metadata_; mutable int _cached_size_; friend struct protobuf_Model_2eproto::TableStruct; @@ -1850,7 +2055,7 @@ class SerializedModel : public ::google::protobuf::MessageLite /* @@protoc_inser &_SerializedModel_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 7; + 8; void Swap(SerializedModel* other); @@ -1992,7 +2197,7 @@ class Model : public ::google::protobuf::MessageLite /* @@protoc_insertion_point &_Model_default_instance_); } static PROTOBUF_CONSTEXPR int const kIndexInFileMessages = - 8; + 9; void Swap(Model* other); @@ -2608,7 +2813,7 @@ inline const ::CoreML::Specification::Pipeline& PipelineClassifier::pipeline() c : *::CoreML::Specification::Pipeline::internal_default_instance(); } inline ::CoreML::Specification::Pipeline* PipelineClassifier::mutable_pipeline() { - + if (pipeline_ == NULL) { pipeline_ = new ::CoreML::Specification::Pipeline; } @@ -2617,7 +2822,7 @@ inline ::CoreML::Specification::Pipeline* PipelineClassifier::mutable_pipeline() } inline ::CoreML::Specification::Pipeline* PipelineClassifier::release_pipeline() { // @@protoc_insertion_point(field_release:CoreML.Specification.PipelineClassifier.pipeline) - + ::CoreML::Specification::Pipeline* temp = pipeline_; pipeline_ = NULL; return temp; @@ -2626,9 +2831,9 @@ inline void PipelineClassifier::set_allocated_pipeline(::CoreML::Specification:: delete pipeline_; pipeline_ = pipeline; if (pipeline) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.PipelineClassifier.pipeline) } @@ -2651,7 +2856,7 @@ inline const ::CoreML::Specification::Pipeline& PipelineRegressor::pipeline() co : *::CoreML::Specification::Pipeline::internal_default_instance(); } inline ::CoreML::Specification::Pipeline* PipelineRegressor::mutable_pipeline() { - + if (pipeline_ == NULL) { pipeline_ = new ::CoreML::Specification::Pipeline; } @@ -2660,7 +2865,7 @@ inline ::CoreML::Specification::Pipeline* PipelineRegressor::mutable_pipeline() } inline ::CoreML::Specification::Pipeline* PipelineRegressor::release_pipeline() { // @@protoc_insertion_point(field_release:CoreML.Specification.PipelineRegressor.pipeline) - + ::CoreML::Specification::Pipeline* temp = pipeline_; pipeline_ = NULL; return temp; @@ -2669,9 +2874,9 @@ inline void PipelineRegressor::set_allocated_pipeline(::CoreML::Specification::P delete pipeline_; pipeline_ = pipeline; if (pipeline) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.PipelineRegressor.pipeline) } @@ -2689,13 +2894,13 @@ inline const ::std::string& FeatureDescription::name() const { return name_.GetNoArena(); } inline void FeatureDescription::set_name(const ::std::string& value) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.FeatureDescription.name) } #if LANG_CXX11 inline void FeatureDescription::set_name(::std::string&& value) { - + name_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.FeatureDescription.name) @@ -2703,31 +2908,31 @@ inline void FeatureDescription::set_name(::std::string&& value) { #endif inline void FeatureDescription::set_name(const char* value) { GOOGLE_DCHECK(value != NULL); - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.FeatureDescription.name) } inline void FeatureDescription::set_name(const char* value, size_t size) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.FeatureDescription.name) } inline ::std::string* FeatureDescription::mutable_name() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FeatureDescription.name) return name_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* FeatureDescription::release_name() { // @@protoc_insertion_point(field_release:CoreML.Specification.FeatureDescription.name) - + return name_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void FeatureDescription::set_allocated_name(::std::string* name) { if (name != NULL) { - + } else { - + } name_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), name); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FeatureDescription.name) @@ -2742,13 +2947,13 @@ inline const ::std::string& FeatureDescription::shortdescription() const { return shortdescription_.GetNoArena(); } inline void FeatureDescription::set_shortdescription(const ::std::string& value) { - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.FeatureDescription.shortDescription) } #if LANG_CXX11 inline void FeatureDescription::set_shortdescription(::std::string&& value) { - + shortdescription_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.FeatureDescription.shortDescription) @@ -2756,31 +2961,31 @@ inline void FeatureDescription::set_shortdescription(::std::string&& value) { #endif inline void FeatureDescription::set_shortdescription(const char* value) { GOOGLE_DCHECK(value != NULL); - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.FeatureDescription.shortDescription) } inline void FeatureDescription::set_shortdescription(const char* value, size_t size) { - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.FeatureDescription.shortDescription) } inline ::std::string* FeatureDescription::mutable_shortdescription() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FeatureDescription.shortDescription) return shortdescription_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* FeatureDescription::release_shortdescription() { // @@protoc_insertion_point(field_release:CoreML.Specification.FeatureDescription.shortDescription) - + return shortdescription_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void FeatureDescription::set_allocated_shortdescription(::std::string* shortdescription) { if (shortdescription != NULL) { - + } else { - + } shortdescription_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), shortdescription); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FeatureDescription.shortDescription) @@ -2800,7 +3005,7 @@ inline const ::CoreML::Specification::FeatureType& FeatureDescription::type() co : *::CoreML::Specification::FeatureType::internal_default_instance(); } inline ::CoreML::Specification::FeatureType* FeatureDescription::mutable_type() { - + if (type_ == NULL) { type_ = new ::CoreML::Specification::FeatureType; } @@ -2809,7 +3014,7 @@ inline ::CoreML::Specification::FeatureType* FeatureDescription::mutable_type() } inline ::CoreML::Specification::FeatureType* FeatureDescription::release_type() { // @@protoc_insertion_point(field_release:CoreML.Specification.FeatureDescription.type) - + ::CoreML::Specification::FeatureType* temp = type_; type_ = NULL; return temp; @@ -2818,9 +3023,9 @@ inline void FeatureDescription::set_allocated_type(::CoreML::Specification::Feat delete type_; type_ = type; if (type) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FeatureDescription.type) } @@ -2840,13 +3045,13 @@ inline const ::std::string& Metadata::shortdescription() const { return shortdescription_.GetNoArena(); } inline void Metadata::set_shortdescription(const ::std::string& value) { - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.Metadata.shortDescription) } #if LANG_CXX11 inline void Metadata::set_shortdescription(::std::string&& value) { - + shortdescription_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.Metadata.shortDescription) @@ -2854,31 +3059,31 @@ inline void Metadata::set_shortdescription(::std::string&& value) { #endif inline void Metadata::set_shortdescription(const char* value) { GOOGLE_DCHECK(value != NULL); - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.Metadata.shortDescription) } inline void Metadata::set_shortdescription(const char* value, size_t size) { - + shortdescription_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.Metadata.shortDescription) } inline ::std::string* Metadata::mutable_shortdescription() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.Metadata.shortDescription) return shortdescription_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Metadata::release_shortdescription() { // @@protoc_insertion_point(field_release:CoreML.Specification.Metadata.shortDescription) - + return shortdescription_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Metadata::set_allocated_shortdescription(::std::string* shortdescription) { if (shortdescription != NULL) { - + } else { - + } shortdescription_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), shortdescription); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Metadata.shortDescription) @@ -2893,13 +3098,13 @@ inline const ::std::string& Metadata::versionstring() const { return versionstring_.GetNoArena(); } inline void Metadata::set_versionstring(const ::std::string& value) { - + versionstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.Metadata.versionString) } #if LANG_CXX11 inline void Metadata::set_versionstring(::std::string&& value) { - + versionstring_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.Metadata.versionString) @@ -2907,31 +3112,31 @@ inline void Metadata::set_versionstring(::std::string&& value) { #endif inline void Metadata::set_versionstring(const char* value) { GOOGLE_DCHECK(value != NULL); - + versionstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.Metadata.versionString) } inline void Metadata::set_versionstring(const char* value, size_t size) { - + versionstring_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.Metadata.versionString) } inline ::std::string* Metadata::mutable_versionstring() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.Metadata.versionString) return versionstring_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Metadata::release_versionstring() { // @@protoc_insertion_point(field_release:CoreML.Specification.Metadata.versionString) - + return versionstring_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Metadata::set_allocated_versionstring(::std::string* versionstring) { if (versionstring != NULL) { - + } else { - + } versionstring_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), versionstring); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Metadata.versionString) @@ -2946,13 +3151,13 @@ inline const ::std::string& Metadata::author() const { return author_.GetNoArena(); } inline void Metadata::set_author(const ::std::string& value) { - + author_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.Metadata.author) } #if LANG_CXX11 inline void Metadata::set_author(::std::string&& value) { - + author_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.Metadata.author) @@ -2960,31 +3165,31 @@ inline void Metadata::set_author(::std::string&& value) { #endif inline void Metadata::set_author(const char* value) { GOOGLE_DCHECK(value != NULL); - + author_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.Metadata.author) } inline void Metadata::set_author(const char* value, size_t size) { - + author_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.Metadata.author) } inline ::std::string* Metadata::mutable_author() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.Metadata.author) return author_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Metadata::release_author() { // @@protoc_insertion_point(field_release:CoreML.Specification.Metadata.author) - + return author_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Metadata::set_allocated_author(::std::string* author) { if (author != NULL) { - + } else { - + } author_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), author); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Metadata.author) @@ -2999,13 +3204,13 @@ inline const ::std::string& Metadata::license() const { return license_.GetNoArena(); } inline void Metadata::set_license(const ::std::string& value) { - + license_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.Metadata.license) } #if LANG_CXX11 inline void Metadata::set_license(::std::string&& value) { - + license_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.Metadata.license) @@ -3013,31 +3218,31 @@ inline void Metadata::set_license(::std::string&& value) { #endif inline void Metadata::set_license(const char* value) { GOOGLE_DCHECK(value != NULL); - + license_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.Metadata.license) } inline void Metadata::set_license(const char* value, size_t size) { - + license_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.Metadata.license) } inline ::std::string* Metadata::mutable_license() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.Metadata.license) return license_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* Metadata::release_license() { // @@protoc_insertion_point(field_release:CoreML.Specification.Metadata.license) - + return license_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void Metadata::set_allocated_license(::std::string* license) { if (license != NULL) { - + } else { - + } license_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), license); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Metadata.license) @@ -3063,8 +3268,383 @@ Metadata::mutable_userdefined() { // ------------------------------------------------------------------- +// FunctionDescription + +// string name = 1; +inline void FunctionDescription::clear_name() { + name_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline const ::std::string& FunctionDescription::name() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.name) + return name_.GetNoArena(); +} +inline void FunctionDescription::set_name(const ::std::string& value) { + + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:CoreML.Specification.FunctionDescription.name) +} +#if LANG_CXX11 +inline void FunctionDescription::set_name(::std::string&& value) { + + name_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.FunctionDescription.name) +} +#endif +inline void FunctionDescription::set_name(const char* value) { + GOOGLE_DCHECK(value != NULL); + + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:CoreML.Specification.FunctionDescription.name) +} +inline void FunctionDescription::set_name(const char* value, size_t size) { + + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.FunctionDescription.name) +} +inline ::std::string* FunctionDescription::mutable_name() { + + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.name) + return name_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline ::std::string* FunctionDescription::release_name() { + // @@protoc_insertion_point(field_release:CoreML.Specification.FunctionDescription.name) + + return name_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline void FunctionDescription::set_allocated_name(::std::string* name) { + if (name != NULL) { + + } else { + + } + name_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), name); + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FunctionDescription.name) +} + +// repeated .CoreML.Specification.FeatureDescription input = 2; +inline int FunctionDescription::input_size() const { + return input_.size(); +} +inline void FunctionDescription::clear_input() { + input_.Clear(); +} +inline const ::CoreML::Specification::FeatureDescription& FunctionDescription::input(int index) const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.input) + return input_.Get(index); +} +inline ::CoreML::Specification::FeatureDescription* FunctionDescription::mutable_input(int index) { + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.input) + return input_.Mutable(index); +} +inline ::CoreML::Specification::FeatureDescription* FunctionDescription::add_input() { + // @@protoc_insertion_point(field_add:CoreML.Specification.FunctionDescription.input) + return input_.Add(); +} +inline ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* +FunctionDescription::mutable_input() { + // @@protoc_insertion_point(field_mutable_list:CoreML.Specification.FunctionDescription.input) + return &input_; +} +inline const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& +FunctionDescription::input() const { + // @@protoc_insertion_point(field_list:CoreML.Specification.FunctionDescription.input) + return input_; +} + +// repeated .CoreML.Specification.FeatureDescription output = 3; +inline int FunctionDescription::output_size() const { + return output_.size(); +} +inline void FunctionDescription::clear_output() { + output_.Clear(); +} +inline const ::CoreML::Specification::FeatureDescription& FunctionDescription::output(int index) const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.output) + return output_.Get(index); +} +inline ::CoreML::Specification::FeatureDescription* FunctionDescription::mutable_output(int index) { + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.output) + return output_.Mutable(index); +} +inline ::CoreML::Specification::FeatureDescription* FunctionDescription::add_output() { + // @@protoc_insertion_point(field_add:CoreML.Specification.FunctionDescription.output) + return output_.Add(); +} +inline ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* +FunctionDescription::mutable_output() { + // @@protoc_insertion_point(field_mutable_list:CoreML.Specification.FunctionDescription.output) + return &output_; +} +inline const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& +FunctionDescription::output() const { + // @@protoc_insertion_point(field_list:CoreML.Specification.FunctionDescription.output) + return output_; +} + +// repeated .CoreML.Specification.FeatureDescription state = 6; +inline int FunctionDescription::state_size() const { + return state_.size(); +} +inline void FunctionDescription::clear_state() { + state_.Clear(); +} +inline const ::CoreML::Specification::FeatureDescription& FunctionDescription::state(int index) const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.state) + return state_.Get(index); +} +inline ::CoreML::Specification::FeatureDescription* FunctionDescription::mutable_state(int index) { + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.state) + return state_.Mutable(index); +} +inline ::CoreML::Specification::FeatureDescription* FunctionDescription::add_state() { + // @@protoc_insertion_point(field_add:CoreML.Specification.FunctionDescription.state) + return state_.Add(); +} +inline ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* +FunctionDescription::mutable_state() { + // @@protoc_insertion_point(field_mutable_list:CoreML.Specification.FunctionDescription.state) + return &state_; +} +inline const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& +FunctionDescription::state() const { + // @@protoc_insertion_point(field_list:CoreML.Specification.FunctionDescription.state) + return state_; +} + +// string predictedFeatureName = 4; +inline void FunctionDescription::clear_predictedfeaturename() { + predictedfeaturename_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline const ::std::string& FunctionDescription::predictedfeaturename() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.predictedFeatureName) + return predictedfeaturename_.GetNoArena(); +} +inline void FunctionDescription::set_predictedfeaturename(const ::std::string& value) { + + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:CoreML.Specification.FunctionDescription.predictedFeatureName) +} +#if LANG_CXX11 +inline void FunctionDescription::set_predictedfeaturename(::std::string&& value) { + + predictedfeaturename_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.FunctionDescription.predictedFeatureName) +} +#endif +inline void FunctionDescription::set_predictedfeaturename(const char* value) { + GOOGLE_DCHECK(value != NULL); + + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:CoreML.Specification.FunctionDescription.predictedFeatureName) +} +inline void FunctionDescription::set_predictedfeaturename(const char* value, size_t size) { + + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.FunctionDescription.predictedFeatureName) +} +inline ::std::string* FunctionDescription::mutable_predictedfeaturename() { + + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.predictedFeatureName) + return predictedfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline ::std::string* FunctionDescription::release_predictedfeaturename() { + // @@protoc_insertion_point(field_release:CoreML.Specification.FunctionDescription.predictedFeatureName) + + return predictedfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline void FunctionDescription::set_allocated_predictedfeaturename(::std::string* predictedfeaturename) { + if (predictedfeaturename != NULL) { + + } else { + + } + predictedfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), predictedfeaturename); + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FunctionDescription.predictedFeatureName) +} + +// string predictedProbabilitiesName = 5; +inline void FunctionDescription::clear_predictedprobabilitiesname() { + predictedprobabilitiesname_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline const ::std::string& FunctionDescription::predictedprobabilitiesname() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) + return predictedprobabilitiesname_.GetNoArena(); +} +inline void FunctionDescription::set_predictedprobabilitiesname(const ::std::string& value) { + + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) +} +#if LANG_CXX11 +inline void FunctionDescription::set_predictedprobabilitiesname(::std::string&& value) { + + predictedprobabilitiesname_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) +} +#endif +inline void FunctionDescription::set_predictedprobabilitiesname(const char* value) { + GOOGLE_DCHECK(value != NULL); + + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) +} +inline void FunctionDescription::set_predictedprobabilitiesname(const char* value, size_t size) { + + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) +} +inline ::std::string* FunctionDescription::mutable_predictedprobabilitiesname() { + + // @@protoc_insertion_point(field_mutable:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) + return predictedprobabilitiesname_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline ::std::string* FunctionDescription::release_predictedprobabilitiesname() { + // @@protoc_insertion_point(field_release:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) + + return predictedprobabilitiesname_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline void FunctionDescription::set_allocated_predictedprobabilitiesname(::std::string* predictedprobabilitiesname) { + if (predictedprobabilitiesname != NULL) { + + } else { + + } + predictedprobabilitiesname_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), predictedprobabilitiesname); + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.FunctionDescription.predictedProbabilitiesName) +} + +// ------------------------------------------------------------------- + // ModelDescription +// repeated .CoreML.Specification.FunctionDescription functions = 20; +inline int ModelDescription::functions_size() const { + return functions_.size(); +} +inline void ModelDescription::clear_functions() { + functions_.Clear(); +} +inline const ::CoreML::Specification::FunctionDescription& ModelDescription::functions(int index) const { + // @@protoc_insertion_point(field_get:CoreML.Specification.ModelDescription.functions) + return functions_.Get(index); +} +inline ::CoreML::Specification::FunctionDescription* ModelDescription::mutable_functions(int index) { + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.functions) + return functions_.Mutable(index); +} +inline ::CoreML::Specification::FunctionDescription* ModelDescription::add_functions() { + // @@protoc_insertion_point(field_add:CoreML.Specification.ModelDescription.functions) + return functions_.Add(); +} +inline ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FunctionDescription >* +ModelDescription::mutable_functions() { + // @@protoc_insertion_point(field_mutable_list:CoreML.Specification.ModelDescription.functions) + return &functions_; +} +inline const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FunctionDescription >& +ModelDescription::functions() const { + // @@protoc_insertion_point(field_list:CoreML.Specification.ModelDescription.functions) + return functions_; +} + +// string defaultFunctionName = 21; +inline void ModelDescription::clear_defaultfunctionname() { + defaultfunctionname_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline const ::std::string& ModelDescription::defaultfunctionname() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.ModelDescription.defaultFunctionName) + return defaultfunctionname_.GetNoArena(); +} +inline void ModelDescription::set_defaultfunctionname(const ::std::string& value) { + + defaultfunctionname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:CoreML.Specification.ModelDescription.defaultFunctionName) +} +#if LANG_CXX11 +inline void ModelDescription::set_defaultfunctionname(::std::string&& value) { + + defaultfunctionname_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ModelDescription.defaultFunctionName) +} +#endif +inline void ModelDescription::set_defaultfunctionname(const char* value) { + GOOGLE_DCHECK(value != NULL); + + defaultfunctionname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:CoreML.Specification.ModelDescription.defaultFunctionName) +} +inline void ModelDescription::set_defaultfunctionname(const char* value, size_t size) { + + defaultfunctionname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ModelDescription.defaultFunctionName) +} +inline ::std::string* ModelDescription::mutable_defaultfunctionname() { + + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.defaultFunctionName) + return defaultfunctionname_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline ::std::string* ModelDescription::release_defaultfunctionname() { + // @@protoc_insertion_point(field_release:CoreML.Specification.ModelDescription.defaultFunctionName) + + return defaultfunctionname_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline void ModelDescription::set_allocated_defaultfunctionname(::std::string* defaultfunctionname) { + if (defaultfunctionname != NULL) { + + } else { + + } + defaultfunctionname_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), defaultfunctionname); + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ModelDescription.defaultFunctionName) +} + +// .CoreML.Specification.Metadata metadata = 100; +inline bool ModelDescription::has_metadata() const { + return this != internal_default_instance() && metadata_ != NULL; +} +inline void ModelDescription::clear_metadata() { + if (GetArenaNoVirtual() == NULL && metadata_ != NULL) delete metadata_; + metadata_ = NULL; +} +inline const ::CoreML::Specification::Metadata& ModelDescription::metadata() const { + // @@protoc_insertion_point(field_get:CoreML.Specification.ModelDescription.metadata) + return metadata_ != NULL ? *metadata_ + : *::CoreML::Specification::Metadata::internal_default_instance(); +} +inline ::CoreML::Specification::Metadata* ModelDescription::mutable_metadata() { + + if (metadata_ == NULL) { + metadata_ = new ::CoreML::Specification::Metadata; + } + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.metadata) + return metadata_; +} +inline ::CoreML::Specification::Metadata* ModelDescription::release_metadata() { + // @@protoc_insertion_point(field_release:CoreML.Specification.ModelDescription.metadata) + + ::CoreML::Specification::Metadata* temp = metadata_; + metadata_ = NULL; + return temp; +} +inline void ModelDescription::set_allocated_metadata(::CoreML::Specification::Metadata* metadata) { + delete metadata_; + metadata_ = metadata; + if (metadata) { + + } else { + + } + // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ModelDescription.metadata) +} + // repeated .CoreML.Specification.FeatureDescription input = 1; inline int ModelDescription::input_size() const { return input_.size(); @@ -3125,6 +3705,36 @@ ModelDescription::output() const { return output_; } +// repeated .CoreML.Specification.FeatureDescription state = 13; +inline int ModelDescription::state_size() const { + return state_.size(); +} +inline void ModelDescription::clear_state() { + state_.Clear(); +} +inline const ::CoreML::Specification::FeatureDescription& ModelDescription::state(int index) const { + // @@protoc_insertion_point(field_get:CoreML.Specification.ModelDescription.state) + return state_.Get(index); +} +inline ::CoreML::Specification::FeatureDescription* ModelDescription::mutable_state(int index) { + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.state) + return state_.Mutable(index); +} +inline ::CoreML::Specification::FeatureDescription* ModelDescription::add_state() { + // @@protoc_insertion_point(field_add:CoreML.Specification.ModelDescription.state) + return state_.Add(); +} +inline ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >* +ModelDescription::mutable_state() { + // @@protoc_insertion_point(field_mutable_list:CoreML.Specification.ModelDescription.state) + return &state_; +} +inline const ::google::protobuf::RepeatedPtrField< ::CoreML::Specification::FeatureDescription >& +ModelDescription::state() const { + // @@protoc_insertion_point(field_list:CoreML.Specification.ModelDescription.state) + return state_; +} + // string predictedFeatureName = 11; inline void ModelDescription::clear_predictedfeaturename() { predictedfeaturename_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); @@ -3134,13 +3744,13 @@ inline const ::std::string& ModelDescription::predictedfeaturename() const { return predictedfeaturename_.GetNoArena(); } inline void ModelDescription::set_predictedfeaturename(const ::std::string& value) { - + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.ModelDescription.predictedFeatureName) } #if LANG_CXX11 inline void ModelDescription::set_predictedfeaturename(::std::string&& value) { - + predictedfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ModelDescription.predictedFeatureName) @@ -3148,31 +3758,31 @@ inline void ModelDescription::set_predictedfeaturename(::std::string&& value) { #endif inline void ModelDescription::set_predictedfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.ModelDescription.predictedFeatureName) } inline void ModelDescription::set_predictedfeaturename(const char* value, size_t size) { - + predictedfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ModelDescription.predictedFeatureName) } inline ::std::string* ModelDescription::mutable_predictedfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.predictedFeatureName) return predictedfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* ModelDescription::release_predictedfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.ModelDescription.predictedFeatureName) - + return predictedfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void ModelDescription::set_allocated_predictedfeaturename(::std::string* predictedfeaturename) { if (predictedfeaturename != NULL) { - + } else { - + } predictedfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), predictedfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ModelDescription.predictedFeatureName) @@ -3187,13 +3797,13 @@ inline const ::std::string& ModelDescription::predictedprobabilitiesname() const return predictedprobabilitiesname_.GetNoArena(); } inline void ModelDescription::set_predictedprobabilitiesname(const ::std::string& value) { - + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.ModelDescription.predictedProbabilitiesName) } #if LANG_CXX11 inline void ModelDescription::set_predictedprobabilitiesname(::std::string&& value) { - + predictedprobabilitiesname_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.ModelDescription.predictedProbabilitiesName) @@ -3201,31 +3811,31 @@ inline void ModelDescription::set_predictedprobabilitiesname(::std::string&& val #endif inline void ModelDescription::set_predictedprobabilitiesname(const char* value) { GOOGLE_DCHECK(value != NULL); - + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.ModelDescription.predictedProbabilitiesName) } inline void ModelDescription::set_predictedprobabilitiesname(const char* value, size_t size) { - + predictedprobabilitiesname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.ModelDescription.predictedProbabilitiesName) } inline ::std::string* ModelDescription::mutable_predictedprobabilitiesname() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.predictedProbabilitiesName) return predictedprobabilitiesname_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* ModelDescription::release_predictedprobabilitiesname() { // @@protoc_insertion_point(field_release:CoreML.Specification.ModelDescription.predictedProbabilitiesName) - + return predictedprobabilitiesname_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void ModelDescription::set_allocated_predictedprobabilitiesname(::std::string* predictedprobabilitiesname) { if (predictedprobabilitiesname != NULL) { - + } else { - + } predictedprobabilitiesname_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), predictedprobabilitiesname); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ModelDescription.predictedProbabilitiesName) @@ -3261,45 +3871,6 @@ ModelDescription::traininginput() const { return traininginput_; } -// .CoreML.Specification.Metadata metadata = 100; -inline bool ModelDescription::has_metadata() const { - return this != internal_default_instance() && metadata_ != NULL; -} -inline void ModelDescription::clear_metadata() { - if (GetArenaNoVirtual() == NULL && metadata_ != NULL) delete metadata_; - metadata_ = NULL; -} -inline const ::CoreML::Specification::Metadata& ModelDescription::metadata() const { - // @@protoc_insertion_point(field_get:CoreML.Specification.ModelDescription.metadata) - return metadata_ != NULL ? *metadata_ - : *::CoreML::Specification::Metadata::internal_default_instance(); -} -inline ::CoreML::Specification::Metadata* ModelDescription::mutable_metadata() { - - if (metadata_ == NULL) { - metadata_ = new ::CoreML::Specification::Metadata; - } - // @@protoc_insertion_point(field_mutable:CoreML.Specification.ModelDescription.metadata) - return metadata_; -} -inline ::CoreML::Specification::Metadata* ModelDescription::release_metadata() { - // @@protoc_insertion_point(field_release:CoreML.Specification.ModelDescription.metadata) - - ::CoreML::Specification::Metadata* temp = metadata_; - metadata_ = NULL; - return temp; -} -inline void ModelDescription::set_allocated_metadata(::CoreML::Specification::Metadata* metadata) { - delete metadata_; - metadata_ = metadata; - if (metadata) { - - } else { - - } - // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ModelDescription.metadata) -} - // ------------------------------------------------------------------- // SerializedModel @@ -3313,13 +3884,13 @@ inline const ::std::string& SerializedModel::identifier() const { return identifier_.GetNoArena(); } inline void SerializedModel::set_identifier(const ::std::string& value) { - + identifier_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.SerializedModel.identifier) } #if LANG_CXX11 inline void SerializedModel::set_identifier(::std::string&& value) { - + identifier_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.SerializedModel.identifier) @@ -3327,31 +3898,31 @@ inline void SerializedModel::set_identifier(::std::string&& value) { #endif inline void SerializedModel::set_identifier(const char* value) { GOOGLE_DCHECK(value != NULL); - + identifier_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.SerializedModel.identifier) } inline void SerializedModel::set_identifier(const char* value, size_t size) { - + identifier_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.SerializedModel.identifier) } inline ::std::string* SerializedModel::mutable_identifier() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.SerializedModel.identifier) return identifier_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* SerializedModel::release_identifier() { // @@protoc_insertion_point(field_release:CoreML.Specification.SerializedModel.identifier) - + return identifier_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void SerializedModel::set_allocated_identifier(::std::string* identifier) { if (identifier != NULL) { - + } else { - + } identifier_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), identifier); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SerializedModel.identifier) @@ -3366,13 +3937,13 @@ inline const ::std::string& SerializedModel::model() const { return model_.GetNoArena(); } inline void SerializedModel::set_model(const ::std::string& value) { - + model_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.SerializedModel.model) } #if LANG_CXX11 inline void SerializedModel::set_model(::std::string&& value) { - + model_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.SerializedModel.model) @@ -3380,31 +3951,31 @@ inline void SerializedModel::set_model(::std::string&& value) { #endif inline void SerializedModel::set_model(const char* value) { GOOGLE_DCHECK(value != NULL); - + model_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.SerializedModel.model) } inline void SerializedModel::set_model(const void* value, size_t size) { - + model_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.SerializedModel.model) } inline ::std::string* SerializedModel::mutable_model() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.SerializedModel.model) return model_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* SerializedModel::release_model() { // @@protoc_insertion_point(field_release:CoreML.Specification.SerializedModel.model) - + return model_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void SerializedModel::set_allocated_model(::std::string* model) { if (model != NULL) { - + } else { - + } model_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), model); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SerializedModel.model) @@ -3423,7 +3994,7 @@ inline ::google::protobuf::int32 Model::specificationversion() const { return specificationversion_; } inline void Model::set_specificationversion(::google::protobuf::int32 value) { - + specificationversion_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Model.specificationVersion) } @@ -3442,7 +4013,7 @@ inline const ::CoreML::Specification::ModelDescription& Model::description() con : *::CoreML::Specification::ModelDescription::internal_default_instance(); } inline ::CoreML::Specification::ModelDescription* Model::mutable_description() { - + if (description_ == NULL) { description_ = new ::CoreML::Specification::ModelDescription; } @@ -3451,7 +4022,7 @@ inline ::CoreML::Specification::ModelDescription* Model::mutable_description() { } inline ::CoreML::Specification::ModelDescription* Model::release_description() { // @@protoc_insertion_point(field_release:CoreML.Specification.Model.description) - + ::CoreML::Specification::ModelDescription* temp = description_; description_ = NULL; return temp; @@ -3460,9 +4031,9 @@ inline void Model::set_allocated_description(::CoreML::Specification::ModelDescr delete description_; description_ = description; if (description) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Model.description) } @@ -3476,7 +4047,7 @@ inline bool Model::isupdatable() const { return isupdatable_; } inline void Model::set_isupdatable(bool value) { - + isupdatable_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Model.isUpdatable) } @@ -5283,6 +5854,8 @@ inline Model::TypeCase Model::Type_case() const { // ------------------------------------------------------------------- +// ------------------------------------------------------------------- + // @@protoc_insertion_point(namespace_scope) diff --git a/mlmodel/build/format/NearestNeighbors.pb.h b/mlmodel/build/format/NearestNeighbors.pb.h index b0990b8f3..694a2de19 100644 --- a/mlmodel/build/format/NearestNeighbors.pb.h +++ b/mlmodel/build/format/NearestNeighbors.pb.h @@ -135,6 +135,9 @@ extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; class SquaredEuclideanDistance; class SquaredEuclideanDistanceDefaultTypeInternal; extern SquaredEuclideanDistanceDefaultTypeInternal _SquaredEuclideanDistance_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -925,7 +928,7 @@ inline const ::CoreML::Specification::NearestNeighborsIndex& KNearestNeighborsCl : *::CoreML::Specification::NearestNeighborsIndex::internal_default_instance(); } inline ::CoreML::Specification::NearestNeighborsIndex* KNearestNeighborsClassifier::mutable_nearestneighborsindex() { - + if (nearestneighborsindex_ == NULL) { nearestneighborsindex_ = new ::CoreML::Specification::NearestNeighborsIndex; } @@ -934,7 +937,7 @@ inline ::CoreML::Specification::NearestNeighborsIndex* KNearestNeighborsClassifi } inline ::CoreML::Specification::NearestNeighborsIndex* KNearestNeighborsClassifier::release_nearestneighborsindex() { // @@protoc_insertion_point(field_release:CoreML.Specification.KNearestNeighborsClassifier.nearestNeighborsIndex) - + ::CoreML::Specification::NearestNeighborsIndex* temp = nearestneighborsindex_; nearestneighborsindex_ = NULL; return temp; @@ -943,9 +946,9 @@ inline void KNearestNeighborsClassifier::set_allocated_nearestneighborsindex(::C delete nearestneighborsindex_; nearestneighborsindex_ = nearestneighborsindex; if (nearestneighborsindex) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.KNearestNeighborsClassifier.nearestNeighborsIndex) } @@ -964,7 +967,7 @@ inline const ::CoreML::Specification::Int64Parameter& KNearestNeighborsClassifie : *::CoreML::Specification::Int64Parameter::internal_default_instance(); } inline ::CoreML::Specification::Int64Parameter* KNearestNeighborsClassifier::mutable_numberofneighbors() { - + if (numberofneighbors_ == NULL) { numberofneighbors_ = new ::CoreML::Specification::Int64Parameter; } @@ -973,7 +976,7 @@ inline ::CoreML::Specification::Int64Parameter* KNearestNeighborsClassifier::mut } inline ::CoreML::Specification::Int64Parameter* KNearestNeighborsClassifier::release_numberofneighbors() { // @@protoc_insertion_point(field_release:CoreML.Specification.KNearestNeighborsClassifier.numberOfNeighbors) - + ::CoreML::Specification::Int64Parameter* temp = numberofneighbors_; numberofneighbors_ = NULL; return temp; @@ -982,9 +985,9 @@ inline void KNearestNeighborsClassifier::set_allocated_numberofneighbors(::CoreM delete numberofneighbors_; numberofneighbors_ = numberofneighbors; if (numberofneighbors) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.KNearestNeighborsClassifier.numberOfNeighbors) } @@ -1345,7 +1348,7 @@ inline ::google::protobuf::int32 NearestNeighborsIndex::numberofdimensions() con return numberofdimensions_; } inline void NearestNeighborsIndex::set_numberofdimensions(::google::protobuf::int32 value) { - + numberofdimensions_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NearestNeighborsIndex.numberOfDimensions) } @@ -1567,7 +1570,7 @@ inline ::google::protobuf::int32 SingleKdTreeIndex::leafsize() const { return leafsize_; } inline void SingleKdTreeIndex::set_leafsize(::google::protobuf::int32 value) { - + leafsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SingleKdTreeIndex.leafSize) } diff --git a/mlmodel/build/format/NeuralNetwork.pb.h b/mlmodel/build/format/NeuralNetwork.pb.h index fabdf7005..eae1fe889 100644 --- a/mlmodel/build/format/NeuralNetwork.pb.h +++ b/mlmodel/build/format/NeuralNetwork.pb.h @@ -687,6 +687,9 @@ extern SqueezeLayerParamsDefaultTypeInternal _SqueezeLayerParams_default_instanc class StackLayerParams; class StackLayerParamsDefaultTypeInternal; extern StackLayerParamsDefaultTypeInternal _StackLayerParams_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -22424,7 +22427,7 @@ inline ::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping NeuralNetwor return static_cast< ::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping >(arrayinputshapemapping_); } inline void NeuralNetwork::set_arrayinputshapemapping(::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping value) { - + arrayinputshapemapping_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetwork.arrayInputShapeMapping) } @@ -22438,7 +22441,7 @@ inline ::CoreML::Specification::NeuralNetworkImageShapeMapping NeuralNetwork::im return static_cast< ::CoreML::Specification::NeuralNetworkImageShapeMapping >(imageinputshapemapping_); } inline void NeuralNetwork::set_imageinputshapemapping(::CoreML::Specification::NeuralNetworkImageShapeMapping value) { - + imageinputshapemapping_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetwork.imageInputShapeMapping) } @@ -22457,7 +22460,7 @@ inline const ::CoreML::Specification::NetworkUpdateParameters& NeuralNetwork::up : *::CoreML::Specification::NetworkUpdateParameters::internal_default_instance(); } inline ::CoreML::Specification::NetworkUpdateParameters* NeuralNetwork::mutable_updateparams() { - + if (updateparams_ == NULL) { updateparams_ = new ::CoreML::Specification::NetworkUpdateParameters; } @@ -22466,7 +22469,7 @@ inline ::CoreML::Specification::NetworkUpdateParameters* NeuralNetwork::mutable_ } inline ::CoreML::Specification::NetworkUpdateParameters* NeuralNetwork::release_updateparams() { // @@protoc_insertion_point(field_release:CoreML.Specification.NeuralNetwork.updateParams) - + ::CoreML::Specification::NetworkUpdateParameters* temp = updateparams_; updateparams_ = NULL; return temp; @@ -22475,9 +22478,9 @@ inline void NeuralNetwork::set_allocated_updateparams(::CoreML::Specification::N delete updateparams_; updateparams_ = updateparams; if (updateparams) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NeuralNetwork.updateParams) } @@ -22495,7 +22498,7 @@ inline float NeuralNetworkImageScaler::channelscale() const { return channelscale_; } inline void NeuralNetworkImageScaler::set_channelscale(float value) { - + channelscale_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkImageScaler.channelScale) } @@ -22509,7 +22512,7 @@ inline float NeuralNetworkImageScaler::bluebias() const { return bluebias_; } inline void NeuralNetworkImageScaler::set_bluebias(float value) { - + bluebias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkImageScaler.blueBias) } @@ -22523,7 +22526,7 @@ inline float NeuralNetworkImageScaler::greenbias() const { return greenbias_; } inline void NeuralNetworkImageScaler::set_greenbias(float value) { - + greenbias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkImageScaler.greenBias) } @@ -22537,7 +22540,7 @@ inline float NeuralNetworkImageScaler::redbias() const { return redbias_; } inline void NeuralNetworkImageScaler::set_redbias(float value) { - + redbias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkImageScaler.redBias) } @@ -22551,7 +22554,7 @@ inline float NeuralNetworkImageScaler::graybias() const { return graybias_; } inline void NeuralNetworkImageScaler::set_graybias(float value) { - + graybias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkImageScaler.grayBias) } @@ -22603,13 +22606,13 @@ inline const ::std::string& NeuralNetworkPreprocessing::featurename() const { return featurename_.GetNoArena(); } inline void NeuralNetworkPreprocessing::set_featurename(const ::std::string& value) { - + featurename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkPreprocessing.featureName) } #if LANG_CXX11 inline void NeuralNetworkPreprocessing::set_featurename(::std::string&& value) { - + featurename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.NeuralNetworkPreprocessing.featureName) @@ -22617,31 +22620,31 @@ inline void NeuralNetworkPreprocessing::set_featurename(::std::string&& value) { #endif inline void NeuralNetworkPreprocessing::set_featurename(const char* value) { GOOGLE_DCHECK(value != NULL); - + featurename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.NeuralNetworkPreprocessing.featureName) } inline void NeuralNetworkPreprocessing::set_featurename(const char* value, size_t size) { - + featurename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.NeuralNetworkPreprocessing.featureName) } inline ::std::string* NeuralNetworkPreprocessing::mutable_featurename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.NeuralNetworkPreprocessing.featureName) return featurename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* NeuralNetworkPreprocessing::release_featurename() { // @@protoc_insertion_point(field_release:CoreML.Specification.NeuralNetworkPreprocessing.featureName) - + return featurename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void NeuralNetworkPreprocessing::set_allocated_featurename(::std::string* featurename) { if (featurename != NULL) { - + } else { - + } featurename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), featurename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NeuralNetworkPreprocessing.featureName) @@ -22769,7 +22772,7 @@ inline float ActivationLeakyReLU::alpha() const { return alpha_; } inline void ActivationLeakyReLU::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ActivationLeakyReLU.alpha) } @@ -22791,7 +22794,7 @@ inline float ActivationScaledTanh::alpha() const { return alpha_; } inline void ActivationScaledTanh::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ActivationScaledTanh.alpha) } @@ -22805,7 +22808,7 @@ inline float ActivationScaledTanh::beta() const { return beta_; } inline void ActivationScaledTanh::set_beta(float value) { - + beta_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ActivationScaledTanh.beta) } @@ -22827,7 +22830,7 @@ inline float ActivationLinear::alpha() const { return alpha_; } inline void ActivationLinear::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ActivationLinear.alpha) } @@ -22841,7 +22844,7 @@ inline float ActivationLinear::beta() const { return beta_; } inline void ActivationLinear::set_beta(float value) { - + beta_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ActivationLinear.beta) } @@ -22859,7 +22862,7 @@ inline float ActivationSigmoidHard::alpha() const { return alpha_; } inline void ActivationSigmoidHard::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ActivationSigmoidHard.alpha) } @@ -22873,7 +22876,7 @@ inline float ActivationSigmoidHard::beta() const { return beta_; } inline void ActivationSigmoidHard::set_beta(float value) { - + beta_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ActivationSigmoidHard.beta) } @@ -22896,7 +22899,7 @@ inline const ::CoreML::Specification::WeightParams& ActivationPReLU::alpha() con : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* ActivationPReLU::mutable_alpha() { - + if (alpha_ == NULL) { alpha_ = new ::CoreML::Specification::WeightParams; } @@ -22905,7 +22908,7 @@ inline ::CoreML::Specification::WeightParams* ActivationPReLU::mutable_alpha() { } inline ::CoreML::Specification::WeightParams* ActivationPReLU::release_alpha() { // @@protoc_insertion_point(field_release:CoreML.Specification.ActivationPReLU.alpha) - + ::CoreML::Specification::WeightParams* temp = alpha_; alpha_ = NULL; return temp; @@ -22914,9 +22917,9 @@ inline void ActivationPReLU::set_allocated_alpha(::CoreML::Specification::Weight delete alpha_; alpha_ = alpha; if (alpha) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ActivationPReLU.alpha) } @@ -22934,7 +22937,7 @@ inline float ActivationELU::alpha() const { return alpha_; } inline void ActivationELU::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ActivationELU.alpha) } @@ -22952,7 +22955,7 @@ inline float ActivationThresholdedReLU::alpha() const { return alpha_; } inline void ActivationThresholdedReLU::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ActivationThresholdedReLU.alpha) } @@ -22983,7 +22986,7 @@ inline const ::CoreML::Specification::WeightParams& ActivationParametricSoftplus : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* ActivationParametricSoftplus::mutable_alpha() { - + if (alpha_ == NULL) { alpha_ = new ::CoreML::Specification::WeightParams; } @@ -22992,7 +22995,7 @@ inline ::CoreML::Specification::WeightParams* ActivationParametricSoftplus::muta } inline ::CoreML::Specification::WeightParams* ActivationParametricSoftplus::release_alpha() { // @@protoc_insertion_point(field_release:CoreML.Specification.ActivationParametricSoftplus.alpha) - + ::CoreML::Specification::WeightParams* temp = alpha_; alpha_ = NULL; return temp; @@ -23001,9 +23004,9 @@ inline void ActivationParametricSoftplus::set_allocated_alpha(::CoreML::Specific delete alpha_; alpha_ = alpha; if (alpha) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ActivationParametricSoftplus.alpha) } @@ -23022,7 +23025,7 @@ inline const ::CoreML::Specification::WeightParams& ActivationParametricSoftplus : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* ActivationParametricSoftplus::mutable_beta() { - + if (beta_ == NULL) { beta_ = new ::CoreML::Specification::WeightParams; } @@ -23031,7 +23034,7 @@ inline ::CoreML::Specification::WeightParams* ActivationParametricSoftplus::muta } inline ::CoreML::Specification::WeightParams* ActivationParametricSoftplus::release_beta() { // @@protoc_insertion_point(field_release:CoreML.Specification.ActivationParametricSoftplus.beta) - + ::CoreML::Specification::WeightParams* temp = beta_; beta_ = NULL; return temp; @@ -23040,9 +23043,9 @@ inline void ActivationParametricSoftplus::set_allocated_beta(::CoreML::Specifica delete beta_; beta_ = beta; if (beta) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ActivationParametricSoftplus.beta) } @@ -23697,7 +23700,7 @@ inline ::google::protobuf::uint32 Tensor::rank() const { return rank_; } inline void Tensor::set_rank(::google::protobuf::uint32 value) { - + rank_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Tensor.rank) } @@ -23745,13 +23748,13 @@ inline const ::std::string& NeuralNetworkLayer::name() const { return name_.GetNoArena(); } inline void NeuralNetworkLayer::set_name(const ::std::string& value) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkLayer.name) } #if LANG_CXX11 inline void NeuralNetworkLayer::set_name(::std::string&& value) { - + name_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.NeuralNetworkLayer.name) @@ -23759,31 +23762,31 @@ inline void NeuralNetworkLayer::set_name(::std::string&& value) { #endif inline void NeuralNetworkLayer::set_name(const char* value) { GOOGLE_DCHECK(value != NULL); - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.NeuralNetworkLayer.name) } inline void NeuralNetworkLayer::set_name(const char* value, size_t size) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.NeuralNetworkLayer.name) } inline ::std::string* NeuralNetworkLayer::mutable_name() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.NeuralNetworkLayer.name) return name_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* NeuralNetworkLayer::release_name() { // @@protoc_insertion_point(field_release:CoreML.Specification.NeuralNetworkLayer.name) - + return name_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void NeuralNetworkLayer::set_allocated_name(::std::string* name) { if (name != NULL) { - + } else { - + } name_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), name); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NeuralNetworkLayer.name) @@ -23996,7 +23999,7 @@ inline bool NeuralNetworkLayer::isupdatable() const { return isupdatable_; } inline void NeuralNetworkLayer::set_isupdatable(bool value) { - + isupdatable_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkLayer.isUpdatable) } @@ -31612,7 +31615,7 @@ inline const ::CoreML::Specification::NeuralNetwork& BranchLayerParams::ifbranch : *::CoreML::Specification::NeuralNetwork::internal_default_instance(); } inline ::CoreML::Specification::NeuralNetwork* BranchLayerParams::mutable_ifbranch() { - + if (ifbranch_ == NULL) { ifbranch_ = new ::CoreML::Specification::NeuralNetwork; } @@ -31621,7 +31624,7 @@ inline ::CoreML::Specification::NeuralNetwork* BranchLayerParams::mutable_ifbran } inline ::CoreML::Specification::NeuralNetwork* BranchLayerParams::release_ifbranch() { // @@protoc_insertion_point(field_release:CoreML.Specification.BranchLayerParams.ifBranch) - + ::CoreML::Specification::NeuralNetwork* temp = ifbranch_; ifbranch_ = NULL; return temp; @@ -31630,9 +31633,9 @@ inline void BranchLayerParams::set_allocated_ifbranch(::CoreML::Specification::N delete ifbranch_; ifbranch_ = ifbranch; if (ifbranch) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.BranchLayerParams.ifBranch) } @@ -31651,7 +31654,7 @@ inline const ::CoreML::Specification::NeuralNetwork& BranchLayerParams::elsebran : *::CoreML::Specification::NeuralNetwork::internal_default_instance(); } inline ::CoreML::Specification::NeuralNetwork* BranchLayerParams::mutable_elsebranch() { - + if (elsebranch_ == NULL) { elsebranch_ = new ::CoreML::Specification::NeuralNetwork; } @@ -31660,7 +31663,7 @@ inline ::CoreML::Specification::NeuralNetwork* BranchLayerParams::mutable_elsebr } inline ::CoreML::Specification::NeuralNetwork* BranchLayerParams::release_elsebranch() { // @@protoc_insertion_point(field_release:CoreML.Specification.BranchLayerParams.elseBranch) - + ::CoreML::Specification::NeuralNetwork* temp = elsebranch_; elsebranch_ = NULL; return temp; @@ -31669,9 +31672,9 @@ inline void BranchLayerParams::set_allocated_elsebranch(::CoreML::Specification: delete elsebranch_; elsebranch_ = elsebranch; if (elsebranch) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.BranchLayerParams.elseBranch) } @@ -31689,7 +31692,7 @@ inline ::google::protobuf::uint64 LoopLayerParams::maxloopiterations() const { return maxloopiterations_; } inline void LoopLayerParams::set_maxloopiterations(::google::protobuf::uint64 value) { - + maxloopiterations_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LoopLayerParams.maxLoopIterations) } @@ -31703,13 +31706,13 @@ inline const ::std::string& LoopLayerParams::conditionvar() const { return conditionvar_.GetNoArena(); } inline void LoopLayerParams::set_conditionvar(const ::std::string& value) { - + conditionvar_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.LoopLayerParams.conditionVar) } #if LANG_CXX11 inline void LoopLayerParams::set_conditionvar(::std::string&& value) { - + conditionvar_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.LoopLayerParams.conditionVar) @@ -31717,31 +31720,31 @@ inline void LoopLayerParams::set_conditionvar(::std::string&& value) { #endif inline void LoopLayerParams::set_conditionvar(const char* value) { GOOGLE_DCHECK(value != NULL); - + conditionvar_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.LoopLayerParams.conditionVar) } inline void LoopLayerParams::set_conditionvar(const char* value, size_t size) { - + conditionvar_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.LoopLayerParams.conditionVar) } inline ::std::string* LoopLayerParams::mutable_conditionvar() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.LoopLayerParams.conditionVar) return conditionvar_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* LoopLayerParams::release_conditionvar() { // @@protoc_insertion_point(field_release:CoreML.Specification.LoopLayerParams.conditionVar) - + return conditionvar_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void LoopLayerParams::set_allocated_conditionvar(::std::string* conditionvar) { if (conditionvar != NULL) { - + } else { - + } conditionvar_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), conditionvar); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LoopLayerParams.conditionVar) @@ -31761,7 +31764,7 @@ inline const ::CoreML::Specification::NeuralNetwork& LoopLayerParams::conditionn : *::CoreML::Specification::NeuralNetwork::internal_default_instance(); } inline ::CoreML::Specification::NeuralNetwork* LoopLayerParams::mutable_conditionnetwork() { - + if (conditionnetwork_ == NULL) { conditionnetwork_ = new ::CoreML::Specification::NeuralNetwork; } @@ -31770,7 +31773,7 @@ inline ::CoreML::Specification::NeuralNetwork* LoopLayerParams::mutable_conditio } inline ::CoreML::Specification::NeuralNetwork* LoopLayerParams::release_conditionnetwork() { // @@protoc_insertion_point(field_release:CoreML.Specification.LoopLayerParams.conditionNetwork) - + ::CoreML::Specification::NeuralNetwork* temp = conditionnetwork_; conditionnetwork_ = NULL; return temp; @@ -31779,9 +31782,9 @@ inline void LoopLayerParams::set_allocated_conditionnetwork(::CoreML::Specificat delete conditionnetwork_; conditionnetwork_ = conditionnetwork; if (conditionnetwork) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LoopLayerParams.conditionNetwork) } @@ -31800,7 +31803,7 @@ inline const ::CoreML::Specification::NeuralNetwork& LoopLayerParams::bodynetwor : *::CoreML::Specification::NeuralNetwork::internal_default_instance(); } inline ::CoreML::Specification::NeuralNetwork* LoopLayerParams::mutable_bodynetwork() { - + if (bodynetwork_ == NULL) { bodynetwork_ = new ::CoreML::Specification::NeuralNetwork; } @@ -31809,7 +31812,7 @@ inline ::CoreML::Specification::NeuralNetwork* LoopLayerParams::mutable_bodynetw } inline ::CoreML::Specification::NeuralNetwork* LoopLayerParams::release_bodynetwork() { // @@protoc_insertion_point(field_release:CoreML.Specification.LoopLayerParams.bodyNetwork) - + ::CoreML::Specification::NeuralNetwork* temp = bodynetwork_; bodynetwork_ = NULL; return temp; @@ -31818,9 +31821,9 @@ inline void LoopLayerParams::set_allocated_bodynetwork(::CoreML::Specification:: delete bodynetwork_; bodynetwork_ = bodynetwork; if (bodynetwork) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LoopLayerParams.bodyNetwork) } @@ -31850,7 +31853,7 @@ inline float GreaterThanLayerParams::alpha() const { return alpha_; } inline void GreaterThanLayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GreaterThanLayerParams.alpha) } @@ -31868,7 +31871,7 @@ inline float GreaterEqualLayerParams::alpha() const { return alpha_; } inline void GreaterEqualLayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GreaterEqualLayerParams.alpha) } @@ -31886,7 +31889,7 @@ inline float LessThanLayerParams::alpha() const { return alpha_; } inline void LessThanLayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LessThanLayerParams.alpha) } @@ -31904,7 +31907,7 @@ inline float LessEqualLayerParams::alpha() const { return alpha_; } inline void LessEqualLayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LessEqualLayerParams.alpha) } @@ -31922,7 +31925,7 @@ inline float EqualLayerParams::alpha() const { return alpha_; } inline void EqualLayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.EqualLayerParams.alpha) } @@ -31940,7 +31943,7 @@ inline float NotEqualLayerParams::alpha() const { return alpha_; } inline void NotEqualLayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NotEqualLayerParams.alpha) } @@ -31974,7 +31977,7 @@ inline ::google::protobuf::uint64 BorderAmounts_EdgeSizes::startedgesize() const return startedgesize_; } inline void BorderAmounts_EdgeSizes::set_startedgesize(::google::protobuf::uint64 value) { - + startedgesize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BorderAmounts.EdgeSizes.startEdgeSize) } @@ -31988,7 +31991,7 @@ inline ::google::protobuf::uint64 BorderAmounts_EdgeSizes::endedgesize() const { return endedgesize_; } inline void BorderAmounts_EdgeSizes::set_endedgesize(::google::protobuf::uint64 value) { - + endedgesize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BorderAmounts.EdgeSizes.endEdgeSize) } @@ -32045,7 +32048,7 @@ inline const ::CoreML::Specification::BorderAmounts& ValidPadding::paddingamount : *::CoreML::Specification::BorderAmounts::internal_default_instance(); } inline ::CoreML::Specification::BorderAmounts* ValidPadding::mutable_paddingamounts() { - + if (paddingamounts_ == NULL) { paddingamounts_ = new ::CoreML::Specification::BorderAmounts; } @@ -32054,7 +32057,7 @@ inline ::CoreML::Specification::BorderAmounts* ValidPadding::mutable_paddingamou } inline ::CoreML::Specification::BorderAmounts* ValidPadding::release_paddingamounts() { // @@protoc_insertion_point(field_release:CoreML.Specification.ValidPadding.paddingAmounts) - + ::CoreML::Specification::BorderAmounts* temp = paddingamounts_; paddingamounts_ = NULL; return temp; @@ -32063,9 +32066,9 @@ inline void ValidPadding::set_allocated_paddingamounts(::CoreML::Specification:: delete paddingamounts_; paddingamounts_ = paddingamounts; if (paddingamounts) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ValidPadding.paddingAmounts) } @@ -32083,7 +32086,7 @@ inline ::CoreML::Specification::SamePadding_SamePaddingMode SamePadding::asymmet return static_cast< ::CoreML::Specification::SamePadding_SamePaddingMode >(asymmetrymode_); } inline void SamePadding::set_asymmetrymode(::CoreML::Specification::SamePadding_SamePaddingMode value) { - + asymmetrymode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SamePadding.asymmetryMode) } @@ -32101,7 +32104,7 @@ inline ::CoreML::Specification::SamplingMode_Method SamplingMode::samplingmethod return static_cast< ::CoreML::Specification::SamplingMode_Method >(samplingmethod_); } inline void SamplingMode::set_samplingmethod(::CoreML::Specification::SamplingMode_Method value) { - + samplingmethod_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SamplingMode.samplingMethod) } @@ -32119,7 +32122,7 @@ inline ::CoreML::Specification::BoxCoordinatesMode_Coordinates BoxCoordinatesMod return static_cast< ::CoreML::Specification::BoxCoordinatesMode_Coordinates >(boxmode_); } inline void BoxCoordinatesMode::set_boxmode(::CoreML::Specification::BoxCoordinatesMode_Coordinates value) { - + boxmode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BoxCoordinatesMode.boxMode) } @@ -32167,13 +32170,13 @@ inline const ::std::string& WeightParams::float16value() const { return float16value_.GetNoArena(); } inline void WeightParams::set_float16value(const ::std::string& value) { - + float16value_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.WeightParams.float16Value) } #if LANG_CXX11 inline void WeightParams::set_float16value(::std::string&& value) { - + float16value_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.WeightParams.float16Value) @@ -32181,31 +32184,31 @@ inline void WeightParams::set_float16value(::std::string&& value) { #endif inline void WeightParams::set_float16value(const char* value) { GOOGLE_DCHECK(value != NULL); - + float16value_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.WeightParams.float16Value) } inline void WeightParams::set_float16value(const void* value, size_t size) { - + float16value_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.WeightParams.float16Value) } inline ::std::string* WeightParams::mutable_float16value() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.WeightParams.float16Value) return float16value_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WeightParams::release_float16value() { // @@protoc_insertion_point(field_release:CoreML.Specification.WeightParams.float16Value) - + return float16value_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WeightParams::set_allocated_float16value(::std::string* float16value) { if (float16value != NULL) { - + } else { - + } float16value_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), float16value); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.WeightParams.float16Value) @@ -32220,13 +32223,13 @@ inline const ::std::string& WeightParams::rawvalue() const { return rawvalue_.GetNoArena(); } inline void WeightParams::set_rawvalue(const ::std::string& value) { - + rawvalue_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.WeightParams.rawValue) } #if LANG_CXX11 inline void WeightParams::set_rawvalue(::std::string&& value) { - + rawvalue_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.WeightParams.rawValue) @@ -32234,31 +32237,31 @@ inline void WeightParams::set_rawvalue(::std::string&& value) { #endif inline void WeightParams::set_rawvalue(const char* value) { GOOGLE_DCHECK(value != NULL); - + rawvalue_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.WeightParams.rawValue) } inline void WeightParams::set_rawvalue(const void* value, size_t size) { - + rawvalue_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.WeightParams.rawValue) } inline ::std::string* WeightParams::mutable_rawvalue() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.WeightParams.rawValue) return rawvalue_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WeightParams::release_rawvalue() { // @@protoc_insertion_point(field_release:CoreML.Specification.WeightParams.rawValue) - + return rawvalue_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WeightParams::set_allocated_rawvalue(::std::string* rawvalue) { if (rawvalue != NULL) { - + } else { - + } rawvalue_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), rawvalue); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.WeightParams.rawValue) @@ -32273,13 +32276,13 @@ inline const ::std::string& WeightParams::int8rawvalue() const { return int8rawvalue_.GetNoArena(); } inline void WeightParams::set_int8rawvalue(const ::std::string& value) { - + int8rawvalue_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.WeightParams.int8RawValue) } #if LANG_CXX11 inline void WeightParams::set_int8rawvalue(::std::string&& value) { - + int8rawvalue_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.WeightParams.int8RawValue) @@ -32287,31 +32290,31 @@ inline void WeightParams::set_int8rawvalue(::std::string&& value) { #endif inline void WeightParams::set_int8rawvalue(const char* value) { GOOGLE_DCHECK(value != NULL); - + int8rawvalue_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.WeightParams.int8RawValue) } inline void WeightParams::set_int8rawvalue(const void* value, size_t size) { - + int8rawvalue_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.WeightParams.int8RawValue) } inline ::std::string* WeightParams::mutable_int8rawvalue() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.WeightParams.int8RawValue) return int8rawvalue_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WeightParams::release_int8rawvalue() { // @@protoc_insertion_point(field_release:CoreML.Specification.WeightParams.int8RawValue) - + return int8rawvalue_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WeightParams::set_allocated_int8rawvalue(::std::string* int8rawvalue) { if (int8rawvalue != NULL) { - + } else { - + } int8rawvalue_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), int8rawvalue); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.WeightParams.int8RawValue) @@ -32331,7 +32334,7 @@ inline const ::CoreML::Specification::QuantizationParams& WeightParams::quantiza : *::CoreML::Specification::QuantizationParams::internal_default_instance(); } inline ::CoreML::Specification::QuantizationParams* WeightParams::mutable_quantization() { - + if (quantization_ == NULL) { quantization_ = new ::CoreML::Specification::QuantizationParams; } @@ -32340,7 +32343,7 @@ inline ::CoreML::Specification::QuantizationParams* WeightParams::mutable_quanti } inline ::CoreML::Specification::QuantizationParams* WeightParams::release_quantization() { // @@protoc_insertion_point(field_release:CoreML.Specification.WeightParams.quantization) - + ::CoreML::Specification::QuantizationParams* temp = quantization_; quantization_ = NULL; return temp; @@ -32349,9 +32352,9 @@ inline void WeightParams::set_allocated_quantization(::CoreML::Specification::Qu delete quantization_; quantization_ = quantization; if (quantization) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.WeightParams.quantization) } @@ -32365,7 +32368,7 @@ inline bool WeightParams::isupdatable() const { return isupdatable_; } inline void WeightParams::set_isupdatable(bool value) { - + isupdatable_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.WeightParams.isUpdatable) } @@ -32383,7 +32386,7 @@ inline ::google::protobuf::uint64 QuantizationParams::numberofbits() const { return numberofbits_; } inline void QuantizationParams::set_numberofbits(::google::protobuf::uint64 value) { - + numberofbits_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.QuantizationParams.numberOfBits) } @@ -32604,7 +32607,7 @@ inline ::google::protobuf::uint64 ConvolutionLayerParams::outputchannels() const return outputchannels_; } inline void ConvolutionLayerParams::set_outputchannels(::google::protobuf::uint64 value) { - + outputchannels_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ConvolutionLayerParams.outputChannels) } @@ -32618,7 +32621,7 @@ inline ::google::protobuf::uint64 ConvolutionLayerParams::kernelchannels() const return kernelchannels_; } inline void ConvolutionLayerParams::set_kernelchannels(::google::protobuf::uint64 value) { - + kernelchannels_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ConvolutionLayerParams.kernelChannels) } @@ -32632,7 +32635,7 @@ inline ::google::protobuf::uint64 ConvolutionLayerParams::ngroups() const { return ngroups_; } inline void ConvolutionLayerParams::set_ngroups(::google::protobuf::uint64 value) { - + ngroups_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ConvolutionLayerParams.nGroups) } @@ -32832,7 +32835,7 @@ inline bool ConvolutionLayerParams::isdeconvolution() const { return isdeconvolution_; } inline void ConvolutionLayerParams::set_isdeconvolution(bool value) { - + isdeconvolution_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ConvolutionLayerParams.isDeconvolution) } @@ -32846,7 +32849,7 @@ inline bool ConvolutionLayerParams::hasbias() const { return hasbias_; } inline void ConvolutionLayerParams::set_hasbias(bool value) { - + hasbias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ConvolutionLayerParams.hasBias) } @@ -32865,7 +32868,7 @@ inline const ::CoreML::Specification::WeightParams& ConvolutionLayerParams::weig : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* ConvolutionLayerParams::mutable_weights() { - + if (weights_ == NULL) { weights_ = new ::CoreML::Specification::WeightParams; } @@ -32874,7 +32877,7 @@ inline ::CoreML::Specification::WeightParams* ConvolutionLayerParams::mutable_we } inline ::CoreML::Specification::WeightParams* ConvolutionLayerParams::release_weights() { // @@protoc_insertion_point(field_release:CoreML.Specification.ConvolutionLayerParams.weights) - + ::CoreML::Specification::WeightParams* temp = weights_; weights_ = NULL; return temp; @@ -32883,9 +32886,9 @@ inline void ConvolutionLayerParams::set_allocated_weights(::CoreML::Specificatio delete weights_; weights_ = weights; if (weights) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ConvolutionLayerParams.weights) } @@ -32904,7 +32907,7 @@ inline const ::CoreML::Specification::WeightParams& ConvolutionLayerParams::bias : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* ConvolutionLayerParams::mutable_bias() { - + if (bias_ == NULL) { bias_ = new ::CoreML::Specification::WeightParams; } @@ -32913,7 +32916,7 @@ inline ::CoreML::Specification::WeightParams* ConvolutionLayerParams::mutable_bi } inline ::CoreML::Specification::WeightParams* ConvolutionLayerParams::release_bias() { // @@protoc_insertion_point(field_release:CoreML.Specification.ConvolutionLayerParams.bias) - + ::CoreML::Specification::WeightParams* temp = bias_; bias_ = NULL; return temp; @@ -32922,9 +32925,9 @@ inline void ConvolutionLayerParams::set_allocated_bias(::CoreML::Specification:: delete bias_; bias_ = bias; if (bias) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ConvolutionLayerParams.bias) } @@ -32981,7 +32984,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::outputchannels() cons return outputchannels_; } inline void Convolution3DLayerParams::set_outputchannels(::google::protobuf::int32 value) { - + outputchannels_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.outputChannels) } @@ -32995,7 +32998,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::inputchannels() const return inputchannels_; } inline void Convolution3DLayerParams::set_inputchannels(::google::protobuf::int32 value) { - + inputchannels_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.inputChannels) } @@ -33009,7 +33012,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::ngroups() const { return ngroups_; } inline void Convolution3DLayerParams::set_ngroups(::google::protobuf::int32 value) { - + ngroups_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.nGroups) } @@ -33023,7 +33026,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::kerneldepth() const { return kerneldepth_; } inline void Convolution3DLayerParams::set_kerneldepth(::google::protobuf::int32 value) { - + kerneldepth_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.kernelDepth) } @@ -33037,7 +33040,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::kernelheight() const return kernelheight_; } inline void Convolution3DLayerParams::set_kernelheight(::google::protobuf::int32 value) { - + kernelheight_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.kernelHeight) } @@ -33051,7 +33054,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::kernelwidth() const { return kernelwidth_; } inline void Convolution3DLayerParams::set_kernelwidth(::google::protobuf::int32 value) { - + kernelwidth_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.kernelWidth) } @@ -33065,7 +33068,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::stridedepth() const { return stridedepth_; } inline void Convolution3DLayerParams::set_stridedepth(::google::protobuf::int32 value) { - + stridedepth_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.strideDepth) } @@ -33079,7 +33082,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::strideheight() const return strideheight_; } inline void Convolution3DLayerParams::set_strideheight(::google::protobuf::int32 value) { - + strideheight_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.strideHeight) } @@ -33093,7 +33096,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::stridewidth() const { return stridewidth_; } inline void Convolution3DLayerParams::set_stridewidth(::google::protobuf::int32 value) { - + stridewidth_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.strideWidth) } @@ -33107,7 +33110,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::dilationdepth() const return dilationdepth_; } inline void Convolution3DLayerParams::set_dilationdepth(::google::protobuf::int32 value) { - + dilationdepth_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.dilationDepth) } @@ -33121,7 +33124,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::dilationheight() cons return dilationheight_; } inline void Convolution3DLayerParams::set_dilationheight(::google::protobuf::int32 value) { - + dilationheight_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.dilationHeight) } @@ -33135,7 +33138,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::dilationwidth() const return dilationwidth_; } inline void Convolution3DLayerParams::set_dilationwidth(::google::protobuf::int32 value) { - + dilationwidth_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.dilationWidth) } @@ -33149,7 +33152,7 @@ inline bool Convolution3DLayerParams::hasbias() const { return hasbias_; } inline void Convolution3DLayerParams::set_hasbias(bool value) { - + hasbias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.hasBias) } @@ -33168,7 +33171,7 @@ inline const ::CoreML::Specification::WeightParams& Convolution3DLayerParams::we : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* Convolution3DLayerParams::mutable_weights() { - + if (weights_ == NULL) { weights_ = new ::CoreML::Specification::WeightParams; } @@ -33177,7 +33180,7 @@ inline ::CoreML::Specification::WeightParams* Convolution3DLayerParams::mutable_ } inline ::CoreML::Specification::WeightParams* Convolution3DLayerParams::release_weights() { // @@protoc_insertion_point(field_release:CoreML.Specification.Convolution3DLayerParams.weights) - + ::CoreML::Specification::WeightParams* temp = weights_; weights_ = NULL; return temp; @@ -33186,9 +33189,9 @@ inline void Convolution3DLayerParams::set_allocated_weights(::CoreML::Specificat delete weights_; weights_ = weights; if (weights) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Convolution3DLayerParams.weights) } @@ -33207,7 +33210,7 @@ inline const ::CoreML::Specification::WeightParams& Convolution3DLayerParams::bi : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* Convolution3DLayerParams::mutable_bias() { - + if (bias_ == NULL) { bias_ = new ::CoreML::Specification::WeightParams; } @@ -33216,7 +33219,7 @@ inline ::CoreML::Specification::WeightParams* Convolution3DLayerParams::mutable_ } inline ::CoreML::Specification::WeightParams* Convolution3DLayerParams::release_bias() { // @@protoc_insertion_point(field_release:CoreML.Specification.Convolution3DLayerParams.bias) - + ::CoreML::Specification::WeightParams* temp = bias_; bias_ = NULL; return temp; @@ -33225,9 +33228,9 @@ inline void Convolution3DLayerParams::set_allocated_bias(::CoreML::Specification delete bias_; bias_ = bias; if (bias) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.Convolution3DLayerParams.bias) } @@ -33241,7 +33244,7 @@ inline ::CoreML::Specification::Convolution3DLayerParams_PaddingType Convolution return static_cast< ::CoreML::Specification::Convolution3DLayerParams_PaddingType >(paddingtype_); } inline void Convolution3DLayerParams::set_paddingtype(::CoreML::Specification::Convolution3DLayerParams_PaddingType value) { - + paddingtype_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.paddingType) } @@ -33255,7 +33258,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::custompaddingfront() return custompaddingfront_; } inline void Convolution3DLayerParams::set_custompaddingfront(::google::protobuf::int32 value) { - + custompaddingfront_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.customPaddingFront) } @@ -33269,7 +33272,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::custompaddingback() c return custompaddingback_; } inline void Convolution3DLayerParams::set_custompaddingback(::google::protobuf::int32 value) { - + custompaddingback_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.customPaddingBack) } @@ -33283,7 +33286,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::custompaddingtop() co return custompaddingtop_; } inline void Convolution3DLayerParams::set_custompaddingtop(::google::protobuf::int32 value) { - + custompaddingtop_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.customPaddingTop) } @@ -33297,7 +33300,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::custompaddingbottom() return custompaddingbottom_; } inline void Convolution3DLayerParams::set_custompaddingbottom(::google::protobuf::int32 value) { - + custompaddingbottom_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.customPaddingBottom) } @@ -33311,7 +33314,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::custompaddingleft() c return custompaddingleft_; } inline void Convolution3DLayerParams::set_custompaddingleft(::google::protobuf::int32 value) { - + custompaddingleft_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.customPaddingLeft) } @@ -33325,7 +33328,7 @@ inline ::google::protobuf::int32 Convolution3DLayerParams::custompaddingright() return custompaddingright_; } inline void Convolution3DLayerParams::set_custompaddingright(::google::protobuf::int32 value) { - + custompaddingright_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.customPaddingRight) } @@ -33339,7 +33342,7 @@ inline bool Convolution3DLayerParams::isdeconvolution() const { return isdeconvolution_; } inline void Convolution3DLayerParams::set_isdeconvolution(bool value) { - + isdeconvolution_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Convolution3DLayerParams.isDeconvolution) } @@ -33387,7 +33390,7 @@ inline ::google::protobuf::uint64 InnerProductLayerParams::inputchannels() const return inputchannels_; } inline void InnerProductLayerParams::set_inputchannels(::google::protobuf::uint64 value) { - + inputchannels_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.InnerProductLayerParams.inputChannels) } @@ -33401,7 +33404,7 @@ inline ::google::protobuf::uint64 InnerProductLayerParams::outputchannels() cons return outputchannels_; } inline void InnerProductLayerParams::set_outputchannels(::google::protobuf::uint64 value) { - + outputchannels_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.InnerProductLayerParams.outputChannels) } @@ -33415,7 +33418,7 @@ inline bool InnerProductLayerParams::hasbias() const { return hasbias_; } inline void InnerProductLayerParams::set_hasbias(bool value) { - + hasbias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.InnerProductLayerParams.hasBias) } @@ -33434,7 +33437,7 @@ inline const ::CoreML::Specification::WeightParams& InnerProductLayerParams::wei : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* InnerProductLayerParams::mutable_weights() { - + if (weights_ == NULL) { weights_ = new ::CoreML::Specification::WeightParams; } @@ -33443,7 +33446,7 @@ inline ::CoreML::Specification::WeightParams* InnerProductLayerParams::mutable_w } inline ::CoreML::Specification::WeightParams* InnerProductLayerParams::release_weights() { // @@protoc_insertion_point(field_release:CoreML.Specification.InnerProductLayerParams.weights) - + ::CoreML::Specification::WeightParams* temp = weights_; weights_ = NULL; return temp; @@ -33452,9 +33455,9 @@ inline void InnerProductLayerParams::set_allocated_weights(::CoreML::Specificati delete weights_; weights_ = weights; if (weights) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.InnerProductLayerParams.weights) } @@ -33473,7 +33476,7 @@ inline const ::CoreML::Specification::WeightParams& InnerProductLayerParams::bia : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* InnerProductLayerParams::mutable_bias() { - + if (bias_ == NULL) { bias_ = new ::CoreML::Specification::WeightParams; } @@ -33482,7 +33485,7 @@ inline ::CoreML::Specification::WeightParams* InnerProductLayerParams::mutable_b } inline ::CoreML::Specification::WeightParams* InnerProductLayerParams::release_bias() { // @@protoc_insertion_point(field_release:CoreML.Specification.InnerProductLayerParams.bias) - + ::CoreML::Specification::WeightParams* temp = bias_; bias_ = NULL; return temp; @@ -33491,9 +33494,9 @@ inline void InnerProductLayerParams::set_allocated_bias(::CoreML::Specification: delete bias_; bias_ = bias; if (bias) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.InnerProductLayerParams.bias) } @@ -33507,7 +33510,7 @@ inline bool InnerProductLayerParams::int8dynamicquantize() const { return int8dynamicquantize_; } inline void InnerProductLayerParams::set_int8dynamicquantize(bool value) { - + int8dynamicquantize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.InnerProductLayerParams.int8DynamicQuantize) } @@ -33525,7 +33528,7 @@ inline ::google::protobuf::uint64 EmbeddingLayerParams::inputdim() const { return inputdim_; } inline void EmbeddingLayerParams::set_inputdim(::google::protobuf::uint64 value) { - + inputdim_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.EmbeddingLayerParams.inputDim) } @@ -33539,7 +33542,7 @@ inline ::google::protobuf::uint64 EmbeddingLayerParams::outputchannels() const { return outputchannels_; } inline void EmbeddingLayerParams::set_outputchannels(::google::protobuf::uint64 value) { - + outputchannels_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.EmbeddingLayerParams.outputChannels) } @@ -33553,7 +33556,7 @@ inline bool EmbeddingLayerParams::hasbias() const { return hasbias_; } inline void EmbeddingLayerParams::set_hasbias(bool value) { - + hasbias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.EmbeddingLayerParams.hasBias) } @@ -33572,7 +33575,7 @@ inline const ::CoreML::Specification::WeightParams& EmbeddingLayerParams::weight : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* EmbeddingLayerParams::mutable_weights() { - + if (weights_ == NULL) { weights_ = new ::CoreML::Specification::WeightParams; } @@ -33581,7 +33584,7 @@ inline ::CoreML::Specification::WeightParams* EmbeddingLayerParams::mutable_weig } inline ::CoreML::Specification::WeightParams* EmbeddingLayerParams::release_weights() { // @@protoc_insertion_point(field_release:CoreML.Specification.EmbeddingLayerParams.weights) - + ::CoreML::Specification::WeightParams* temp = weights_; weights_ = NULL; return temp; @@ -33590,9 +33593,9 @@ inline void EmbeddingLayerParams::set_allocated_weights(::CoreML::Specification: delete weights_; weights_ = weights; if (weights) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.EmbeddingLayerParams.weights) } @@ -33611,7 +33614,7 @@ inline const ::CoreML::Specification::WeightParams& EmbeddingLayerParams::bias() : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* EmbeddingLayerParams::mutable_bias() { - + if (bias_ == NULL) { bias_ = new ::CoreML::Specification::WeightParams; } @@ -33620,7 +33623,7 @@ inline ::CoreML::Specification::WeightParams* EmbeddingLayerParams::mutable_bias } inline ::CoreML::Specification::WeightParams* EmbeddingLayerParams::release_bias() { // @@protoc_insertion_point(field_release:CoreML.Specification.EmbeddingLayerParams.bias) - + ::CoreML::Specification::WeightParams* temp = bias_; bias_ = NULL; return temp; @@ -33629,9 +33632,9 @@ inline void EmbeddingLayerParams::set_allocated_bias(::CoreML::Specification::We delete bias_; bias_ = bias; if (bias) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.EmbeddingLayerParams.bias) } @@ -33649,7 +33652,7 @@ inline ::google::protobuf::uint64 EmbeddingNDLayerParams::vocabsize() const { return vocabsize_; } inline void EmbeddingNDLayerParams::set_vocabsize(::google::protobuf::uint64 value) { - + vocabsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.EmbeddingNDLayerParams.vocabSize) } @@ -33663,7 +33666,7 @@ inline ::google::protobuf::uint64 EmbeddingNDLayerParams::embeddingsize() const return embeddingsize_; } inline void EmbeddingNDLayerParams::set_embeddingsize(::google::protobuf::uint64 value) { - + embeddingsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.EmbeddingNDLayerParams.embeddingSize) } @@ -33677,7 +33680,7 @@ inline bool EmbeddingNDLayerParams::hasbias() const { return hasbias_; } inline void EmbeddingNDLayerParams::set_hasbias(bool value) { - + hasbias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.EmbeddingNDLayerParams.hasBias) } @@ -33696,7 +33699,7 @@ inline const ::CoreML::Specification::WeightParams& EmbeddingNDLayerParams::weig : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* EmbeddingNDLayerParams::mutable_weights() { - + if (weights_ == NULL) { weights_ = new ::CoreML::Specification::WeightParams; } @@ -33705,7 +33708,7 @@ inline ::CoreML::Specification::WeightParams* EmbeddingNDLayerParams::mutable_we } inline ::CoreML::Specification::WeightParams* EmbeddingNDLayerParams::release_weights() { // @@protoc_insertion_point(field_release:CoreML.Specification.EmbeddingNDLayerParams.weights) - + ::CoreML::Specification::WeightParams* temp = weights_; weights_ = NULL; return temp; @@ -33714,9 +33717,9 @@ inline void EmbeddingNDLayerParams::set_allocated_weights(::CoreML::Specificatio delete weights_; weights_ = weights; if (weights) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.EmbeddingNDLayerParams.weights) } @@ -33735,7 +33738,7 @@ inline const ::CoreML::Specification::WeightParams& EmbeddingNDLayerParams::bias : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* EmbeddingNDLayerParams::mutable_bias() { - + if (bias_ == NULL) { bias_ = new ::CoreML::Specification::WeightParams; } @@ -33744,7 +33747,7 @@ inline ::CoreML::Specification::WeightParams* EmbeddingNDLayerParams::mutable_bi } inline ::CoreML::Specification::WeightParams* EmbeddingNDLayerParams::release_bias() { // @@protoc_insertion_point(field_release:CoreML.Specification.EmbeddingNDLayerParams.bias) - + ::CoreML::Specification::WeightParams* temp = bias_; bias_ = NULL; return temp; @@ -33753,9 +33756,9 @@ inline void EmbeddingNDLayerParams::set_allocated_bias(::CoreML::Specification:: delete bias_; bias_ = bias; if (bias) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.EmbeddingNDLayerParams.bias) } @@ -33773,7 +33776,7 @@ inline ::google::protobuf::uint64 BatchnormLayerParams::channels() const { return channels_; } inline void BatchnormLayerParams::set_channels(::google::protobuf::uint64 value) { - + channels_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BatchnormLayerParams.channels) } @@ -33787,7 +33790,7 @@ inline bool BatchnormLayerParams::computemeanvar() const { return computemeanvar_; } inline void BatchnormLayerParams::set_computemeanvar(bool value) { - + computemeanvar_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BatchnormLayerParams.computeMeanVar) } @@ -33801,7 +33804,7 @@ inline bool BatchnormLayerParams::instancenormalization() const { return instancenormalization_; } inline void BatchnormLayerParams::set_instancenormalization(bool value) { - + instancenormalization_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BatchnormLayerParams.instanceNormalization) } @@ -33815,7 +33818,7 @@ inline float BatchnormLayerParams::epsilon() const { return epsilon_; } inline void BatchnormLayerParams::set_epsilon(float value) { - + epsilon_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BatchnormLayerParams.epsilon) } @@ -33834,7 +33837,7 @@ inline const ::CoreML::Specification::WeightParams& BatchnormLayerParams::gamma( : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::mutable_gamma() { - + if (gamma_ == NULL) { gamma_ = new ::CoreML::Specification::WeightParams; } @@ -33843,7 +33846,7 @@ inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::mutable_gamm } inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::release_gamma() { // @@protoc_insertion_point(field_release:CoreML.Specification.BatchnormLayerParams.gamma) - + ::CoreML::Specification::WeightParams* temp = gamma_; gamma_ = NULL; return temp; @@ -33852,9 +33855,9 @@ inline void BatchnormLayerParams::set_allocated_gamma(::CoreML::Specification::W delete gamma_; gamma_ = gamma; if (gamma) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.BatchnormLayerParams.gamma) } @@ -33873,7 +33876,7 @@ inline const ::CoreML::Specification::WeightParams& BatchnormLayerParams::beta() : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::mutable_beta() { - + if (beta_ == NULL) { beta_ = new ::CoreML::Specification::WeightParams; } @@ -33882,7 +33885,7 @@ inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::mutable_beta } inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::release_beta() { // @@protoc_insertion_point(field_release:CoreML.Specification.BatchnormLayerParams.beta) - + ::CoreML::Specification::WeightParams* temp = beta_; beta_ = NULL; return temp; @@ -33891,9 +33894,9 @@ inline void BatchnormLayerParams::set_allocated_beta(::CoreML::Specification::We delete beta_; beta_ = beta; if (beta) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.BatchnormLayerParams.beta) } @@ -33912,7 +33915,7 @@ inline const ::CoreML::Specification::WeightParams& BatchnormLayerParams::mean() : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::mutable_mean() { - + if (mean_ == NULL) { mean_ = new ::CoreML::Specification::WeightParams; } @@ -33921,7 +33924,7 @@ inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::mutable_mean } inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::release_mean() { // @@protoc_insertion_point(field_release:CoreML.Specification.BatchnormLayerParams.mean) - + ::CoreML::Specification::WeightParams* temp = mean_; mean_ = NULL; return temp; @@ -33930,9 +33933,9 @@ inline void BatchnormLayerParams::set_allocated_mean(::CoreML::Specification::We delete mean_; mean_ = mean; if (mean) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.BatchnormLayerParams.mean) } @@ -33951,7 +33954,7 @@ inline const ::CoreML::Specification::WeightParams& BatchnormLayerParams::varian : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::mutable_variance() { - + if (variance_ == NULL) { variance_ = new ::CoreML::Specification::WeightParams; } @@ -33960,7 +33963,7 @@ inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::mutable_vari } inline ::CoreML::Specification::WeightParams* BatchnormLayerParams::release_variance() { // @@protoc_insertion_point(field_release:CoreML.Specification.BatchnormLayerParams.variance) - + ::CoreML::Specification::WeightParams* temp = variance_; variance_ = NULL; return temp; @@ -33969,9 +33972,9 @@ inline void BatchnormLayerParams::set_allocated_variance(::CoreML::Specification delete variance_; variance_ = variance; if (variance) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.BatchnormLayerParams.variance) } @@ -34023,7 +34026,7 @@ inline ::CoreML::Specification::PoolingLayerParams_PoolingType PoolingLayerParam return static_cast< ::CoreML::Specification::PoolingLayerParams_PoolingType >(type_); } inline void PoolingLayerParams::set_type(::CoreML::Specification::PoolingLayerParams_PoolingType value) { - + type_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.PoolingLayerParams.type) } @@ -34241,7 +34244,7 @@ inline bool PoolingLayerParams::avgpoolexcludepadding() const { return avgpoolexcludepadding_; } inline void PoolingLayerParams::set_avgpoolexcludepadding(bool value) { - + avgpoolexcludepadding_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.PoolingLayerParams.avgPoolExcludePadding) } @@ -34255,7 +34258,7 @@ inline bool PoolingLayerParams::globalpooling() const { return globalpooling_; } inline void PoolingLayerParams::set_globalpooling(bool value) { - + globalpooling_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.PoolingLayerParams.globalPooling) } @@ -34282,7 +34285,7 @@ inline ::CoreML::Specification::Pooling3DLayerParams_PoolingType3D Pooling3DLaye return static_cast< ::CoreML::Specification::Pooling3DLayerParams_PoolingType3D >(type_); } inline void Pooling3DLayerParams::set_type(::CoreML::Specification::Pooling3DLayerParams_PoolingType3D value) { - + type_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.type) } @@ -34296,7 +34299,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::kerneldepth() const { return kerneldepth_; } inline void Pooling3DLayerParams::set_kerneldepth(::google::protobuf::int32 value) { - + kerneldepth_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.kernelDepth) } @@ -34310,7 +34313,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::kernelheight() const { return kernelheight_; } inline void Pooling3DLayerParams::set_kernelheight(::google::protobuf::int32 value) { - + kernelheight_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.kernelHeight) } @@ -34324,7 +34327,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::kernelwidth() const { return kernelwidth_; } inline void Pooling3DLayerParams::set_kernelwidth(::google::protobuf::int32 value) { - + kernelwidth_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.kernelWidth) } @@ -34338,7 +34341,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::stridedepth() const { return stridedepth_; } inline void Pooling3DLayerParams::set_stridedepth(::google::protobuf::int32 value) { - + stridedepth_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.strideDepth) } @@ -34352,7 +34355,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::strideheight() const { return strideheight_; } inline void Pooling3DLayerParams::set_strideheight(::google::protobuf::int32 value) { - + strideheight_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.strideHeight) } @@ -34366,7 +34369,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::stridewidth() const { return stridewidth_; } inline void Pooling3DLayerParams::set_stridewidth(::google::protobuf::int32 value) { - + stridewidth_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.strideWidth) } @@ -34380,7 +34383,7 @@ inline ::CoreML::Specification::Pooling3DLayerParams_Pooling3DPaddingType Poolin return static_cast< ::CoreML::Specification::Pooling3DLayerParams_Pooling3DPaddingType >(paddingtype_); } inline void Pooling3DLayerParams::set_paddingtype(::CoreML::Specification::Pooling3DLayerParams_Pooling3DPaddingType value) { - + paddingtype_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.paddingType) } @@ -34394,7 +34397,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::custompaddingfront() cons return custompaddingfront_; } inline void Pooling3DLayerParams::set_custompaddingfront(::google::protobuf::int32 value) { - + custompaddingfront_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.customPaddingFront) } @@ -34408,7 +34411,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::custompaddingback() const return custompaddingback_; } inline void Pooling3DLayerParams::set_custompaddingback(::google::protobuf::int32 value) { - + custompaddingback_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.customPaddingBack) } @@ -34422,7 +34425,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::custompaddingtop() const return custompaddingtop_; } inline void Pooling3DLayerParams::set_custompaddingtop(::google::protobuf::int32 value) { - + custompaddingtop_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.customPaddingTop) } @@ -34436,7 +34439,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::custompaddingbottom() con return custompaddingbottom_; } inline void Pooling3DLayerParams::set_custompaddingbottom(::google::protobuf::int32 value) { - + custompaddingbottom_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.customPaddingBottom) } @@ -34450,7 +34453,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::custompaddingleft() const return custompaddingleft_; } inline void Pooling3DLayerParams::set_custompaddingleft(::google::protobuf::int32 value) { - + custompaddingleft_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.customPaddingLeft) } @@ -34464,7 +34467,7 @@ inline ::google::protobuf::int32 Pooling3DLayerParams::custompaddingright() cons return custompaddingright_; } inline void Pooling3DLayerParams::set_custompaddingright(::google::protobuf::int32 value) { - + custompaddingright_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.customPaddingRight) } @@ -34478,7 +34481,7 @@ inline bool Pooling3DLayerParams::countexcludepadding() const { return countexcludepadding_; } inline void Pooling3DLayerParams::set_countexcludepadding(bool value) { - + countexcludepadding_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Pooling3DLayerParams.countExcludePadding) } @@ -34496,7 +34499,7 @@ inline ::CoreML::Specification::GlobalPooling3DLayerParams_GlobalPoolingType3D G return static_cast< ::CoreML::Specification::GlobalPooling3DLayerParams_GlobalPoolingType3D >(type_); } inline void GlobalPooling3DLayerParams::set_type(::CoreML::Specification::GlobalPooling3DLayerParams_GlobalPoolingType3D value) { - + type_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GlobalPooling3DLayerParams.type) } @@ -34514,7 +34517,7 @@ inline float PaddingLayerParams_PaddingConstant::value() const { return value_; } inline void PaddingLayerParams_PaddingConstant::set_value(float value) { - + value_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.PaddingLayerParams.PaddingConstant.value) } @@ -34689,7 +34692,7 @@ inline const ::CoreML::Specification::BorderAmounts& PaddingLayerParams::padding : *::CoreML::Specification::BorderAmounts::internal_default_instance(); } inline ::CoreML::Specification::BorderAmounts* PaddingLayerParams::mutable_paddingamounts() { - + if (paddingamounts_ == NULL) { paddingamounts_ = new ::CoreML::Specification::BorderAmounts; } @@ -34698,7 +34701,7 @@ inline ::CoreML::Specification::BorderAmounts* PaddingLayerParams::mutable_paddi } inline ::CoreML::Specification::BorderAmounts* PaddingLayerParams::release_paddingamounts() { // @@protoc_insertion_point(field_release:CoreML.Specification.PaddingLayerParams.paddingAmounts) - + ::CoreML::Specification::BorderAmounts* temp = paddingamounts_; paddingamounts_ = NULL; return temp; @@ -34707,9 +34710,9 @@ inline void PaddingLayerParams::set_allocated_paddingamounts(::CoreML::Specifica delete paddingamounts_; paddingamounts_ = paddingamounts; if (paddingamounts) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.PaddingLayerParams.paddingAmounts) } @@ -34736,7 +34739,7 @@ inline bool ConcatLayerParams::sequenceconcat() const { return sequenceconcat_; } inline void ConcatLayerParams::set_sequenceconcat(bool value) { - + sequenceconcat_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ConcatLayerParams.sequenceConcat) } @@ -34754,7 +34757,7 @@ inline float LRNLayerParams::alpha() const { return alpha_; } inline void LRNLayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LRNLayerParams.alpha) } @@ -34768,7 +34771,7 @@ inline float LRNLayerParams::beta() const { return beta_; } inline void LRNLayerParams::set_beta(float value) { - + beta_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LRNLayerParams.beta) } @@ -34782,7 +34785,7 @@ inline ::google::protobuf::uint64 LRNLayerParams::localsize() const { return localsize_; } inline void LRNLayerParams::set_localsize(::google::protobuf::uint64 value) { - + localsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LRNLayerParams.localSize) } @@ -34796,7 +34799,7 @@ inline float LRNLayerParams::k() const { return k_; } inline void LRNLayerParams::set_k(float value) { - + k_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LRNLayerParams.k) } @@ -34818,7 +34821,7 @@ inline ::google::protobuf::uint64 SplitLayerParams::noutputs() const { return noutputs_; } inline void SplitLayerParams::set_noutputs(::google::protobuf::uint64 value) { - + noutputs_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SplitLayerParams.nOutputs) } @@ -34836,7 +34839,7 @@ inline float AddLayerParams::alpha() const { return alpha_; } inline void AddLayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.AddLayerParams.alpha) } @@ -34854,7 +34857,7 @@ inline float MultiplyLayerParams::alpha() const { return alpha_; } inline void MultiplyLayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MultiplyLayerParams.alpha) } @@ -34872,7 +34875,7 @@ inline ::CoreML::Specification::UnaryFunctionLayerParams_Operation UnaryFunction return static_cast< ::CoreML::Specification::UnaryFunctionLayerParams_Operation >(type_); } inline void UnaryFunctionLayerParams::set_type(::CoreML::Specification::UnaryFunctionLayerParams_Operation value) { - + type_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UnaryFunctionLayerParams.type) } @@ -34886,7 +34889,7 @@ inline float UnaryFunctionLayerParams::alpha() const { return alpha_; } inline void UnaryFunctionLayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UnaryFunctionLayerParams.alpha) } @@ -34900,7 +34903,7 @@ inline float UnaryFunctionLayerParams::epsilon() const { return epsilon_; } inline void UnaryFunctionLayerParams::set_epsilon(float value) { - + epsilon_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UnaryFunctionLayerParams.epsilon) } @@ -34914,7 +34917,7 @@ inline float UnaryFunctionLayerParams::shift() const { return shift_; } inline void UnaryFunctionLayerParams::set_shift(float value) { - + shift_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UnaryFunctionLayerParams.shift) } @@ -34928,7 +34931,7 @@ inline float UnaryFunctionLayerParams::scale() const { return scale_; } inline void UnaryFunctionLayerParams::set_scale(float value) { - + scale_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UnaryFunctionLayerParams.scale) } @@ -35006,7 +35009,7 @@ inline ::CoreML::Specification::UpsampleLayerParams_InterpolationMode UpsampleLa return static_cast< ::CoreML::Specification::UpsampleLayerParams_InterpolationMode >(mode_); } inline void UpsampleLayerParams::set_mode(::CoreML::Specification::UpsampleLayerParams_InterpolationMode value) { - + mode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UpsampleLayerParams.mode) } @@ -35020,7 +35023,7 @@ inline ::CoreML::Specification::UpsampleLayerParams_LinearUpsampleMode UpsampleL return static_cast< ::CoreML::Specification::UpsampleLayerParams_LinearUpsampleMode >(linearupsamplemode_); } inline void UpsampleLayerParams::set_linearupsamplemode(::CoreML::Specification::UpsampleLayerParams_LinearUpsampleMode value) { - + linearupsamplemode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UpsampleLayerParams.linearUpsampleMode) } @@ -35073,7 +35076,7 @@ inline const ::CoreML::Specification::SamplingMode& ResizeBilinearLayerParams::m : *::CoreML::Specification::SamplingMode::internal_default_instance(); } inline ::CoreML::Specification::SamplingMode* ResizeBilinearLayerParams::mutable_mode() { - + if (mode_ == NULL) { mode_ = new ::CoreML::Specification::SamplingMode; } @@ -35082,7 +35085,7 @@ inline ::CoreML::Specification::SamplingMode* ResizeBilinearLayerParams::mutable } inline ::CoreML::Specification::SamplingMode* ResizeBilinearLayerParams::release_mode() { // @@protoc_insertion_point(field_release:CoreML.Specification.ResizeBilinearLayerParams.mode) - + ::CoreML::Specification::SamplingMode* temp = mode_; mode_ = NULL; return temp; @@ -35091,9 +35094,9 @@ inline void ResizeBilinearLayerParams::set_allocated_mode(::CoreML::Specificatio delete mode_; mode_ = mode; if (mode) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ResizeBilinearLayerParams.mode) } @@ -35141,7 +35144,7 @@ inline bool CropResizeLayerParams::normalizedcoordinates() const { return normalizedcoordinates_; } inline void CropResizeLayerParams::set_normalizedcoordinates(bool value) { - + normalizedcoordinates_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CropResizeLayerParams.normalizedCoordinates) } @@ -35160,7 +35163,7 @@ inline const ::CoreML::Specification::SamplingMode& CropResizeLayerParams::mode( : *::CoreML::Specification::SamplingMode::internal_default_instance(); } inline ::CoreML::Specification::SamplingMode* CropResizeLayerParams::mutable_mode() { - + if (mode_ == NULL) { mode_ = new ::CoreML::Specification::SamplingMode; } @@ -35169,7 +35172,7 @@ inline ::CoreML::Specification::SamplingMode* CropResizeLayerParams::mutable_mod } inline ::CoreML::Specification::SamplingMode* CropResizeLayerParams::release_mode() { // @@protoc_insertion_point(field_release:CoreML.Specification.CropResizeLayerParams.mode) - + ::CoreML::Specification::SamplingMode* temp = mode_; mode_ = NULL; return temp; @@ -35178,9 +35181,9 @@ inline void CropResizeLayerParams::set_allocated_mode(::CoreML::Specification::S delete mode_; mode_ = mode; if (mode) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CropResizeLayerParams.mode) } @@ -35199,7 +35202,7 @@ inline const ::CoreML::Specification::BoxCoordinatesMode& CropResizeLayerParams: : *::CoreML::Specification::BoxCoordinatesMode::internal_default_instance(); } inline ::CoreML::Specification::BoxCoordinatesMode* CropResizeLayerParams::mutable_boxindicesmode() { - + if (boxindicesmode_ == NULL) { boxindicesmode_ = new ::CoreML::Specification::BoxCoordinatesMode; } @@ -35208,7 +35211,7 @@ inline ::CoreML::Specification::BoxCoordinatesMode* CropResizeLayerParams::mutab } inline ::CoreML::Specification::BoxCoordinatesMode* CropResizeLayerParams::release_boxindicesmode() { // @@protoc_insertion_point(field_release:CoreML.Specification.CropResizeLayerParams.boxIndicesMode) - + ::CoreML::Specification::BoxCoordinatesMode* temp = boxindicesmode_; boxindicesmode_ = NULL; return temp; @@ -35217,9 +35220,9 @@ inline void CropResizeLayerParams::set_allocated_boxindicesmode(::CoreML::Specif delete boxindicesmode_; boxindicesmode_ = boxindicesmode; if (boxindicesmode) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CropResizeLayerParams.boxIndicesMode) } @@ -35233,7 +35236,7 @@ inline float CropResizeLayerParams::spatialscale() const { return spatialscale_; } inline void CropResizeLayerParams::set_spatialscale(float value) { - + spatialscale_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CropResizeLayerParams.spatialScale) } @@ -35286,7 +35289,7 @@ inline const ::CoreML::Specification::WeightParams& BiasLayerParams::bias() cons : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* BiasLayerParams::mutable_bias() { - + if (bias_ == NULL) { bias_ = new ::CoreML::Specification::WeightParams; } @@ -35295,7 +35298,7 @@ inline ::CoreML::Specification::WeightParams* BiasLayerParams::mutable_bias() { } inline ::CoreML::Specification::WeightParams* BiasLayerParams::release_bias() { // @@protoc_insertion_point(field_release:CoreML.Specification.BiasLayerParams.bias) - + ::CoreML::Specification::WeightParams* temp = bias_; bias_ = NULL; return temp; @@ -35304,9 +35307,9 @@ inline void BiasLayerParams::set_allocated_bias(::CoreML::Specification::WeightP delete bias_; bias_ = bias; if (bias) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.BiasLayerParams.bias) } @@ -35359,7 +35362,7 @@ inline const ::CoreML::Specification::WeightParams& ScaleLayerParams::scale() co : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* ScaleLayerParams::mutable_scale() { - + if (scale_ == NULL) { scale_ = new ::CoreML::Specification::WeightParams; } @@ -35368,7 +35371,7 @@ inline ::CoreML::Specification::WeightParams* ScaleLayerParams::mutable_scale() } inline ::CoreML::Specification::WeightParams* ScaleLayerParams::release_scale() { // @@protoc_insertion_point(field_release:CoreML.Specification.ScaleLayerParams.scale) - + ::CoreML::Specification::WeightParams* temp = scale_; scale_ = NULL; return temp; @@ -35377,9 +35380,9 @@ inline void ScaleLayerParams::set_allocated_scale(::CoreML::Specification::Weigh delete scale_; scale_ = scale; if (scale) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ScaleLayerParams.scale) } @@ -35393,7 +35396,7 @@ inline bool ScaleLayerParams::hasbias() const { return hasbias_; } inline void ScaleLayerParams::set_hasbias(bool value) { - + hasbias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ScaleLayerParams.hasBias) } @@ -35442,7 +35445,7 @@ inline const ::CoreML::Specification::WeightParams& ScaleLayerParams::bias() con : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* ScaleLayerParams::mutable_bias() { - + if (bias_ == NULL) { bias_ = new ::CoreML::Specification::WeightParams; } @@ -35451,7 +35454,7 @@ inline ::CoreML::Specification::WeightParams* ScaleLayerParams::mutable_bias() { } inline ::CoreML::Specification::WeightParams* ScaleLayerParams::release_bias() { // @@protoc_insertion_point(field_release:CoreML.Specification.ScaleLayerParams.bias) - + ::CoreML::Specification::WeightParams* temp = bias_; bias_ = NULL; return temp; @@ -35460,9 +35463,9 @@ inline void ScaleLayerParams::set_allocated_bias(::CoreML::Specification::Weight delete bias_; bias_ = bias; if (bias) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.ScaleLayerParams.bias) } @@ -35515,7 +35518,7 @@ inline const ::CoreML::Specification::WeightParams& LoadConstantLayerParams::dat : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LoadConstantLayerParams::mutable_data() { - + if (data_ == NULL) { data_ = new ::CoreML::Specification::WeightParams; } @@ -35524,7 +35527,7 @@ inline ::CoreML::Specification::WeightParams* LoadConstantLayerParams::mutable_d } inline ::CoreML::Specification::WeightParams* LoadConstantLayerParams::release_data() { // @@protoc_insertion_point(field_release:CoreML.Specification.LoadConstantLayerParams.data) - + ::CoreML::Specification::WeightParams* temp = data_; data_ = NULL; return temp; @@ -35533,9 +35536,9 @@ inline void LoadConstantLayerParams::set_allocated_data(::CoreML::Specification: delete data_; data_ = data; if (data) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LoadConstantLayerParams.data) } @@ -35553,7 +35556,7 @@ inline float L2NormalizeLayerParams::epsilon() const { return epsilon_; } inline void L2NormalizeLayerParams::set_epsilon(float value) { - + epsilon_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.L2NormalizeLayerParams.epsilon) } @@ -35571,7 +35574,7 @@ inline ::CoreML::Specification::FlattenLayerParams_FlattenOrder FlattenLayerPara return static_cast< ::CoreML::Specification::FlattenLayerParams_FlattenOrder >(mode_); } inline void FlattenLayerParams::set_mode(::CoreML::Specification::FlattenLayerParams_FlattenOrder value) { - + mode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.FlattenLayerParams.mode) } @@ -35619,7 +35622,7 @@ inline ::CoreML::Specification::ReshapeLayerParams_ReshapeOrder ReshapeLayerPara return static_cast< ::CoreML::Specification::ReshapeLayerParams_ReshapeOrder >(mode_); } inline void ReshapeLayerParams::set_mode(::CoreML::Specification::ReshapeLayerParams_ReshapeOrder value) { - + mode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReshapeLayerParams.mode) } @@ -35671,7 +35674,7 @@ inline ::CoreML::Specification::ReorganizeDataLayerParams_ReorganizationType Reo return static_cast< ::CoreML::Specification::ReorganizeDataLayerParams_ReorganizationType >(mode_); } inline void ReorganizeDataLayerParams::set_mode(::CoreML::Specification::ReorganizeDataLayerParams_ReorganizationType value) { - + mode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReorganizeDataLayerParams.mode) } @@ -35685,7 +35688,7 @@ inline ::google::protobuf::uint64 ReorganizeDataLayerParams::blocksize() const { return blocksize_; } inline void ReorganizeDataLayerParams::set_blocksize(::google::protobuf::uint64 value) { - + blocksize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReorganizeDataLayerParams.blockSize) } @@ -35703,7 +35706,7 @@ inline ::google::protobuf::int64 SliceLayerParams::startindex() const { return startindex_; } inline void SliceLayerParams::set_startindex(::google::protobuf::int64 value) { - + startindex_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SliceLayerParams.startIndex) } @@ -35717,7 +35720,7 @@ inline ::google::protobuf::int64 SliceLayerParams::endindex() const { return endindex_; } inline void SliceLayerParams::set_endindex(::google::protobuf::int64 value) { - + endindex_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SliceLayerParams.endIndex) } @@ -35731,7 +35734,7 @@ inline ::google::protobuf::uint64 SliceLayerParams::stride() const { return stride_; } inline void SliceLayerParams::set_stride(::google::protobuf::uint64 value) { - + stride_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SliceLayerParams.stride) } @@ -35745,7 +35748,7 @@ inline ::CoreML::Specification::SliceLayerParams_SliceAxis SliceLayerParams::axi return static_cast< ::CoreML::Specification::SliceLayerParams_SliceAxis >(axis_); } inline void SliceLayerParams::set_axis(::CoreML::Specification::SliceLayerParams_SliceAxis value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SliceLayerParams.axis) } @@ -35763,7 +35766,7 @@ inline ::CoreML::Specification::ReduceLayerParams_ReduceOperation ReduceLayerPar return static_cast< ::CoreML::Specification::ReduceLayerParams_ReduceOperation >(mode_); } inline void ReduceLayerParams::set_mode(::CoreML::Specification::ReduceLayerParams_ReduceOperation value) { - + mode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceLayerParams.mode) } @@ -35777,7 +35780,7 @@ inline float ReduceLayerParams::epsilon() const { return epsilon_; } inline void ReduceLayerParams::set_epsilon(float value) { - + epsilon_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceLayerParams.epsilon) } @@ -35791,7 +35794,7 @@ inline ::CoreML::Specification::ReduceLayerParams_ReduceAxis ReduceLayerParams:: return static_cast< ::CoreML::Specification::ReduceLayerParams_ReduceAxis >(axis_); } inline void ReduceLayerParams::set_axis(::CoreML::Specification::ReduceLayerParams_ReduceAxis value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceLayerParams.axis) } @@ -35814,7 +35817,7 @@ inline const ::CoreML::Specification::BorderAmounts& CropLayerParams::cropamount : *::CoreML::Specification::BorderAmounts::internal_default_instance(); } inline ::CoreML::Specification::BorderAmounts* CropLayerParams::mutable_cropamounts() { - + if (cropamounts_ == NULL) { cropamounts_ = new ::CoreML::Specification::BorderAmounts; } @@ -35823,7 +35826,7 @@ inline ::CoreML::Specification::BorderAmounts* CropLayerParams::mutable_cropamou } inline ::CoreML::Specification::BorderAmounts* CropLayerParams::release_cropamounts() { // @@protoc_insertion_point(field_release:CoreML.Specification.CropLayerParams.cropAmounts) - + ::CoreML::Specification::BorderAmounts* temp = cropamounts_; cropamounts_ = NULL; return temp; @@ -35832,9 +35835,9 @@ inline void CropLayerParams::set_allocated_cropamounts(::CoreML::Specification:: delete cropamounts_; cropamounts_ = cropamounts; if (cropamounts) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CropLayerParams.cropAmounts) } @@ -35894,7 +35897,7 @@ inline bool DotProductLayerParams::cosinesimilarity() const { return cosinesimilarity_; } inline void DotProductLayerParams::set_cosinesimilarity(bool value) { - + cosinesimilarity_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.DotProductLayerParams.cosineSimilarity) } @@ -35912,7 +35915,7 @@ inline bool MeanVarianceNormalizeLayerParams::acrosschannels() const { return acrosschannels_; } inline void MeanVarianceNormalizeLayerParams::set_acrosschannels(bool value) { - + acrosschannels_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MeanVarianceNormalizeLayerParams.acrossChannels) } @@ -35926,7 +35929,7 @@ inline bool MeanVarianceNormalizeLayerParams::normalizevariance() const { return normalizevariance_; } inline void MeanVarianceNormalizeLayerParams::set_normalizevariance(bool value) { - + normalizevariance_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MeanVarianceNormalizeLayerParams.normalizeVariance) } @@ -35940,7 +35943,7 @@ inline float MeanVarianceNormalizeLayerParams::epsilon() const { return epsilon_; } inline void MeanVarianceNormalizeLayerParams::set_epsilon(float value) { - + epsilon_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MeanVarianceNormalizeLayerParams.epsilon) } @@ -35958,7 +35961,7 @@ inline ::google::protobuf::uint64 SequenceRepeatLayerParams::nrepetitions() cons return nrepetitions_; } inline void SequenceRepeatLayerParams::set_nrepetitions(::google::protobuf::uint64 value) { - + nrepetitions_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SequenceRepeatLayerParams.nRepetitions) } @@ -35976,7 +35979,7 @@ inline ::google::protobuf::uint64 SimpleRecurrentLayerParams::inputvectorsize() return inputvectorsize_; } inline void SimpleRecurrentLayerParams::set_inputvectorsize(::google::protobuf::uint64 value) { - + inputvectorsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SimpleRecurrentLayerParams.inputVectorSize) } @@ -35990,7 +35993,7 @@ inline ::google::protobuf::uint64 SimpleRecurrentLayerParams::outputvectorsize() return outputvectorsize_; } inline void SimpleRecurrentLayerParams::set_outputvectorsize(::google::protobuf::uint64 value) { - + outputvectorsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SimpleRecurrentLayerParams.outputVectorSize) } @@ -36009,7 +36012,7 @@ inline const ::CoreML::Specification::ActivationParams& SimpleRecurrentLayerPara : *::CoreML::Specification::ActivationParams::internal_default_instance(); } inline ::CoreML::Specification::ActivationParams* SimpleRecurrentLayerParams::mutable_activation() { - + if (activation_ == NULL) { activation_ = new ::CoreML::Specification::ActivationParams; } @@ -36018,7 +36021,7 @@ inline ::CoreML::Specification::ActivationParams* SimpleRecurrentLayerParams::mu } inline ::CoreML::Specification::ActivationParams* SimpleRecurrentLayerParams::release_activation() { // @@protoc_insertion_point(field_release:CoreML.Specification.SimpleRecurrentLayerParams.activation) - + ::CoreML::Specification::ActivationParams* temp = activation_; activation_ = NULL; return temp; @@ -36027,9 +36030,9 @@ inline void SimpleRecurrentLayerParams::set_allocated_activation(::CoreML::Speci delete activation_; activation_ = activation; if (activation) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SimpleRecurrentLayerParams.activation) } @@ -36043,7 +36046,7 @@ inline bool SimpleRecurrentLayerParams::sequenceoutput() const { return sequenceoutput_; } inline void SimpleRecurrentLayerParams::set_sequenceoutput(bool value) { - + sequenceoutput_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SimpleRecurrentLayerParams.sequenceOutput) } @@ -36057,7 +36060,7 @@ inline bool SimpleRecurrentLayerParams::hasbiasvector() const { return hasbiasvector_; } inline void SimpleRecurrentLayerParams::set_hasbiasvector(bool value) { - + hasbiasvector_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SimpleRecurrentLayerParams.hasBiasVector) } @@ -36076,7 +36079,7 @@ inline const ::CoreML::Specification::WeightParams& SimpleRecurrentLayerParams:: : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* SimpleRecurrentLayerParams::mutable_weightmatrix() { - + if (weightmatrix_ == NULL) { weightmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36085,7 +36088,7 @@ inline ::CoreML::Specification::WeightParams* SimpleRecurrentLayerParams::mutabl } inline ::CoreML::Specification::WeightParams* SimpleRecurrentLayerParams::release_weightmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.SimpleRecurrentLayerParams.weightMatrix) - + ::CoreML::Specification::WeightParams* temp = weightmatrix_; weightmatrix_ = NULL; return temp; @@ -36094,9 +36097,9 @@ inline void SimpleRecurrentLayerParams::set_allocated_weightmatrix(::CoreML::Spe delete weightmatrix_; weightmatrix_ = weightmatrix; if (weightmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SimpleRecurrentLayerParams.weightMatrix) } @@ -36115,7 +36118,7 @@ inline const ::CoreML::Specification::WeightParams& SimpleRecurrentLayerParams:: : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* SimpleRecurrentLayerParams::mutable_recursionmatrix() { - + if (recursionmatrix_ == NULL) { recursionmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36124,7 +36127,7 @@ inline ::CoreML::Specification::WeightParams* SimpleRecurrentLayerParams::mutabl } inline ::CoreML::Specification::WeightParams* SimpleRecurrentLayerParams::release_recursionmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.SimpleRecurrentLayerParams.recursionMatrix) - + ::CoreML::Specification::WeightParams* temp = recursionmatrix_; recursionmatrix_ = NULL; return temp; @@ -36133,9 +36136,9 @@ inline void SimpleRecurrentLayerParams::set_allocated_recursionmatrix(::CoreML:: delete recursionmatrix_; recursionmatrix_ = recursionmatrix; if (recursionmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SimpleRecurrentLayerParams.recursionMatrix) } @@ -36154,7 +36157,7 @@ inline const ::CoreML::Specification::WeightParams& SimpleRecurrentLayerParams:: : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* SimpleRecurrentLayerParams::mutable_biasvector() { - + if (biasvector_ == NULL) { biasvector_ = new ::CoreML::Specification::WeightParams; } @@ -36163,7 +36166,7 @@ inline ::CoreML::Specification::WeightParams* SimpleRecurrentLayerParams::mutabl } inline ::CoreML::Specification::WeightParams* SimpleRecurrentLayerParams::release_biasvector() { // @@protoc_insertion_point(field_release:CoreML.Specification.SimpleRecurrentLayerParams.biasVector) - + ::CoreML::Specification::WeightParams* temp = biasvector_; biasvector_ = NULL; return temp; @@ -36172,9 +36175,9 @@ inline void SimpleRecurrentLayerParams::set_allocated_biasvector(::CoreML::Speci delete biasvector_; biasvector_ = biasvector; if (biasvector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SimpleRecurrentLayerParams.biasVector) } @@ -36188,7 +36191,7 @@ inline bool SimpleRecurrentLayerParams::reverseinput() const { return reverseinput_; } inline void SimpleRecurrentLayerParams::set_reverseinput(bool value) { - + reverseinput_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SimpleRecurrentLayerParams.reverseInput) } @@ -36206,7 +36209,7 @@ inline ::google::protobuf::uint64 GRULayerParams::inputvectorsize() const { return inputvectorsize_; } inline void GRULayerParams::set_inputvectorsize(::google::protobuf::uint64 value) { - + inputvectorsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GRULayerParams.inputVectorSize) } @@ -36220,7 +36223,7 @@ inline ::google::protobuf::uint64 GRULayerParams::outputvectorsize() const { return outputvectorsize_; } inline void GRULayerParams::set_outputvectorsize(::google::protobuf::uint64 value) { - + outputvectorsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GRULayerParams.outputVectorSize) } @@ -36264,7 +36267,7 @@ inline bool GRULayerParams::sequenceoutput() const { return sequenceoutput_; } inline void GRULayerParams::set_sequenceoutput(bool value) { - + sequenceoutput_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GRULayerParams.sequenceOutput) } @@ -36278,7 +36281,7 @@ inline bool GRULayerParams::hasbiasvectors() const { return hasbiasvectors_; } inline void GRULayerParams::set_hasbiasvectors(bool value) { - + hasbiasvectors_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GRULayerParams.hasBiasVectors) } @@ -36297,7 +36300,7 @@ inline const ::CoreML::Specification::WeightParams& GRULayerParams::updategatewe : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_updategateweightmatrix() { - + if (updategateweightmatrix_ == NULL) { updategateweightmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36306,7 +36309,7 @@ inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_updategate } inline ::CoreML::Specification::WeightParams* GRULayerParams::release_updategateweightmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.GRULayerParams.updateGateWeightMatrix) - + ::CoreML::Specification::WeightParams* temp = updategateweightmatrix_; updategateweightmatrix_ = NULL; return temp; @@ -36315,9 +36318,9 @@ inline void GRULayerParams::set_allocated_updategateweightmatrix(::CoreML::Speci delete updategateweightmatrix_; updategateweightmatrix_ = updategateweightmatrix; if (updategateweightmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.GRULayerParams.updateGateWeightMatrix) } @@ -36336,7 +36339,7 @@ inline const ::CoreML::Specification::WeightParams& GRULayerParams::resetgatewei : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_resetgateweightmatrix() { - + if (resetgateweightmatrix_ == NULL) { resetgateweightmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36345,7 +36348,7 @@ inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_resetgatew } inline ::CoreML::Specification::WeightParams* GRULayerParams::release_resetgateweightmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.GRULayerParams.resetGateWeightMatrix) - + ::CoreML::Specification::WeightParams* temp = resetgateweightmatrix_; resetgateweightmatrix_ = NULL; return temp; @@ -36354,9 +36357,9 @@ inline void GRULayerParams::set_allocated_resetgateweightmatrix(::CoreML::Specif delete resetgateweightmatrix_; resetgateweightmatrix_ = resetgateweightmatrix; if (resetgateweightmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.GRULayerParams.resetGateWeightMatrix) } @@ -36375,7 +36378,7 @@ inline const ::CoreML::Specification::WeightParams& GRULayerParams::outputgatewe : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_outputgateweightmatrix() { - + if (outputgateweightmatrix_ == NULL) { outputgateweightmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36384,7 +36387,7 @@ inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_outputgate } inline ::CoreML::Specification::WeightParams* GRULayerParams::release_outputgateweightmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.GRULayerParams.outputGateWeightMatrix) - + ::CoreML::Specification::WeightParams* temp = outputgateweightmatrix_; outputgateweightmatrix_ = NULL; return temp; @@ -36393,9 +36396,9 @@ inline void GRULayerParams::set_allocated_outputgateweightmatrix(::CoreML::Speci delete outputgateweightmatrix_; outputgateweightmatrix_ = outputgateweightmatrix; if (outputgateweightmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.GRULayerParams.outputGateWeightMatrix) } @@ -36414,7 +36417,7 @@ inline const ::CoreML::Specification::WeightParams& GRULayerParams::updategatere : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_updategaterecursionmatrix() { - + if (updategaterecursionmatrix_ == NULL) { updategaterecursionmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36423,7 +36426,7 @@ inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_updategate } inline ::CoreML::Specification::WeightParams* GRULayerParams::release_updategaterecursionmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.GRULayerParams.updateGateRecursionMatrix) - + ::CoreML::Specification::WeightParams* temp = updategaterecursionmatrix_; updategaterecursionmatrix_ = NULL; return temp; @@ -36432,9 +36435,9 @@ inline void GRULayerParams::set_allocated_updategaterecursionmatrix(::CoreML::Sp delete updategaterecursionmatrix_; updategaterecursionmatrix_ = updategaterecursionmatrix; if (updategaterecursionmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.GRULayerParams.updateGateRecursionMatrix) } @@ -36453,7 +36456,7 @@ inline const ::CoreML::Specification::WeightParams& GRULayerParams::resetgaterec : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_resetgaterecursionmatrix() { - + if (resetgaterecursionmatrix_ == NULL) { resetgaterecursionmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36462,7 +36465,7 @@ inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_resetgater } inline ::CoreML::Specification::WeightParams* GRULayerParams::release_resetgaterecursionmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.GRULayerParams.resetGateRecursionMatrix) - + ::CoreML::Specification::WeightParams* temp = resetgaterecursionmatrix_; resetgaterecursionmatrix_ = NULL; return temp; @@ -36471,9 +36474,9 @@ inline void GRULayerParams::set_allocated_resetgaterecursionmatrix(::CoreML::Spe delete resetgaterecursionmatrix_; resetgaterecursionmatrix_ = resetgaterecursionmatrix; if (resetgaterecursionmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.GRULayerParams.resetGateRecursionMatrix) } @@ -36492,7 +36495,7 @@ inline const ::CoreML::Specification::WeightParams& GRULayerParams::outputgatere : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_outputgaterecursionmatrix() { - + if (outputgaterecursionmatrix_ == NULL) { outputgaterecursionmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36501,7 +36504,7 @@ inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_outputgate } inline ::CoreML::Specification::WeightParams* GRULayerParams::release_outputgaterecursionmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.GRULayerParams.outputGateRecursionMatrix) - + ::CoreML::Specification::WeightParams* temp = outputgaterecursionmatrix_; outputgaterecursionmatrix_ = NULL; return temp; @@ -36510,9 +36513,9 @@ inline void GRULayerParams::set_allocated_outputgaterecursionmatrix(::CoreML::Sp delete outputgaterecursionmatrix_; outputgaterecursionmatrix_ = outputgaterecursionmatrix; if (outputgaterecursionmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.GRULayerParams.outputGateRecursionMatrix) } @@ -36531,7 +36534,7 @@ inline const ::CoreML::Specification::WeightParams& GRULayerParams::updategatebi : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_updategatebiasvector() { - + if (updategatebiasvector_ == NULL) { updategatebiasvector_ = new ::CoreML::Specification::WeightParams; } @@ -36540,7 +36543,7 @@ inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_updategate } inline ::CoreML::Specification::WeightParams* GRULayerParams::release_updategatebiasvector() { // @@protoc_insertion_point(field_release:CoreML.Specification.GRULayerParams.updateGateBiasVector) - + ::CoreML::Specification::WeightParams* temp = updategatebiasvector_; updategatebiasvector_ = NULL; return temp; @@ -36549,9 +36552,9 @@ inline void GRULayerParams::set_allocated_updategatebiasvector(::CoreML::Specifi delete updategatebiasvector_; updategatebiasvector_ = updategatebiasvector; if (updategatebiasvector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.GRULayerParams.updateGateBiasVector) } @@ -36570,7 +36573,7 @@ inline const ::CoreML::Specification::WeightParams& GRULayerParams::resetgatebia : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_resetgatebiasvector() { - + if (resetgatebiasvector_ == NULL) { resetgatebiasvector_ = new ::CoreML::Specification::WeightParams; } @@ -36579,7 +36582,7 @@ inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_resetgateb } inline ::CoreML::Specification::WeightParams* GRULayerParams::release_resetgatebiasvector() { // @@protoc_insertion_point(field_release:CoreML.Specification.GRULayerParams.resetGateBiasVector) - + ::CoreML::Specification::WeightParams* temp = resetgatebiasvector_; resetgatebiasvector_ = NULL; return temp; @@ -36588,9 +36591,9 @@ inline void GRULayerParams::set_allocated_resetgatebiasvector(::CoreML::Specific delete resetgatebiasvector_; resetgatebiasvector_ = resetgatebiasvector; if (resetgatebiasvector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.GRULayerParams.resetGateBiasVector) } @@ -36609,7 +36612,7 @@ inline const ::CoreML::Specification::WeightParams& GRULayerParams::outputgatebi : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_outputgatebiasvector() { - + if (outputgatebiasvector_ == NULL) { outputgatebiasvector_ = new ::CoreML::Specification::WeightParams; } @@ -36618,7 +36621,7 @@ inline ::CoreML::Specification::WeightParams* GRULayerParams::mutable_outputgate } inline ::CoreML::Specification::WeightParams* GRULayerParams::release_outputgatebiasvector() { // @@protoc_insertion_point(field_release:CoreML.Specification.GRULayerParams.outputGateBiasVector) - + ::CoreML::Specification::WeightParams* temp = outputgatebiasvector_; outputgatebiasvector_ = NULL; return temp; @@ -36627,9 +36630,9 @@ inline void GRULayerParams::set_allocated_outputgatebiasvector(::CoreML::Specifi delete outputgatebiasvector_; outputgatebiasvector_ = outputgatebiasvector; if (outputgatebiasvector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.GRULayerParams.outputGateBiasVector) } @@ -36643,7 +36646,7 @@ inline bool GRULayerParams::reverseinput() const { return reverseinput_; } inline void GRULayerParams::set_reverseinput(bool value) { - + reverseinput_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GRULayerParams.reverseInput) } @@ -36661,7 +36664,7 @@ inline bool LSTMParams::sequenceoutput() const { return sequenceoutput_; } inline void LSTMParams::set_sequenceoutput(bool value) { - + sequenceoutput_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LSTMParams.sequenceOutput) } @@ -36675,7 +36678,7 @@ inline bool LSTMParams::hasbiasvectors() const { return hasbiasvectors_; } inline void LSTMParams::set_hasbiasvectors(bool value) { - + hasbiasvectors_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LSTMParams.hasBiasVectors) } @@ -36689,7 +36692,7 @@ inline bool LSTMParams::forgetbias() const { return forgetbias_; } inline void LSTMParams::set_forgetbias(bool value) { - + forgetbias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LSTMParams.forgetBias) } @@ -36703,7 +36706,7 @@ inline bool LSTMParams::haspeepholevectors() const { return haspeepholevectors_; } inline void LSTMParams::set_haspeepholevectors(bool value) { - + haspeepholevectors_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LSTMParams.hasPeepholeVectors) } @@ -36717,7 +36720,7 @@ inline bool LSTMParams::coupledinputandforgetgate() const { return coupledinputandforgetgate_; } inline void LSTMParams::set_coupledinputandforgetgate(bool value) { - + coupledinputandforgetgate_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LSTMParams.coupledInputAndForgetGate) } @@ -36731,7 +36734,7 @@ inline float LSTMParams::cellclipthreshold() const { return cellclipthreshold_; } inline void LSTMParams::set_cellclipthreshold(float value) { - + cellclipthreshold_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LSTMParams.cellClipThreshold) } @@ -36754,7 +36757,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::inputgatew : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_inputgateweightmatrix() { - + if (inputgateweightmatrix_ == NULL) { inputgateweightmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36763,7 +36766,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_inputgat } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_inputgateweightmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.inputGateWeightMatrix) - + ::CoreML::Specification::WeightParams* temp = inputgateweightmatrix_; inputgateweightmatrix_ = NULL; return temp; @@ -36772,9 +36775,9 @@ inline void LSTMWeightParams::set_allocated_inputgateweightmatrix(::CoreML::Spec delete inputgateweightmatrix_; inputgateweightmatrix_ = inputgateweightmatrix; if (inputgateweightmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.inputGateWeightMatrix) } @@ -36793,7 +36796,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::forgetgate : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_forgetgateweightmatrix() { - + if (forgetgateweightmatrix_ == NULL) { forgetgateweightmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36802,7 +36805,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_forgetga } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_forgetgateweightmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.forgetGateWeightMatrix) - + ::CoreML::Specification::WeightParams* temp = forgetgateweightmatrix_; forgetgateweightmatrix_ = NULL; return temp; @@ -36811,9 +36814,9 @@ inline void LSTMWeightParams::set_allocated_forgetgateweightmatrix(::CoreML::Spe delete forgetgateweightmatrix_; forgetgateweightmatrix_ = forgetgateweightmatrix; if (forgetgateweightmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.forgetGateWeightMatrix) } @@ -36832,7 +36835,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::blockinput : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_blockinputweightmatrix() { - + if (blockinputweightmatrix_ == NULL) { blockinputweightmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36841,7 +36844,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_blockinp } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_blockinputweightmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.blockInputWeightMatrix) - + ::CoreML::Specification::WeightParams* temp = blockinputweightmatrix_; blockinputweightmatrix_ = NULL; return temp; @@ -36850,9 +36853,9 @@ inline void LSTMWeightParams::set_allocated_blockinputweightmatrix(::CoreML::Spe delete blockinputweightmatrix_; blockinputweightmatrix_ = blockinputweightmatrix; if (blockinputweightmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.blockInputWeightMatrix) } @@ -36871,7 +36874,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::outputgate : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_outputgateweightmatrix() { - + if (outputgateweightmatrix_ == NULL) { outputgateweightmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36880,7 +36883,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_outputga } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_outputgateweightmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.outputGateWeightMatrix) - + ::CoreML::Specification::WeightParams* temp = outputgateweightmatrix_; outputgateweightmatrix_ = NULL; return temp; @@ -36889,9 +36892,9 @@ inline void LSTMWeightParams::set_allocated_outputgateweightmatrix(::CoreML::Spe delete outputgateweightmatrix_; outputgateweightmatrix_ = outputgateweightmatrix; if (outputgateweightmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.outputGateWeightMatrix) } @@ -36910,7 +36913,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::inputgater : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_inputgaterecursionmatrix() { - + if (inputgaterecursionmatrix_ == NULL) { inputgaterecursionmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36919,7 +36922,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_inputgat } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_inputgaterecursionmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.inputGateRecursionMatrix) - + ::CoreML::Specification::WeightParams* temp = inputgaterecursionmatrix_; inputgaterecursionmatrix_ = NULL; return temp; @@ -36928,9 +36931,9 @@ inline void LSTMWeightParams::set_allocated_inputgaterecursionmatrix(::CoreML::S delete inputgaterecursionmatrix_; inputgaterecursionmatrix_ = inputgaterecursionmatrix; if (inputgaterecursionmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.inputGateRecursionMatrix) } @@ -36949,7 +36952,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::forgetgate : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_forgetgaterecursionmatrix() { - + if (forgetgaterecursionmatrix_ == NULL) { forgetgaterecursionmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36958,7 +36961,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_forgetga } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_forgetgaterecursionmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.forgetGateRecursionMatrix) - + ::CoreML::Specification::WeightParams* temp = forgetgaterecursionmatrix_; forgetgaterecursionmatrix_ = NULL; return temp; @@ -36967,9 +36970,9 @@ inline void LSTMWeightParams::set_allocated_forgetgaterecursionmatrix(::CoreML:: delete forgetgaterecursionmatrix_; forgetgaterecursionmatrix_ = forgetgaterecursionmatrix; if (forgetgaterecursionmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.forgetGateRecursionMatrix) } @@ -36988,7 +36991,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::blockinput : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_blockinputrecursionmatrix() { - + if (blockinputrecursionmatrix_ == NULL) { blockinputrecursionmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -36997,7 +37000,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_blockinp } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_blockinputrecursionmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.blockInputRecursionMatrix) - + ::CoreML::Specification::WeightParams* temp = blockinputrecursionmatrix_; blockinputrecursionmatrix_ = NULL; return temp; @@ -37006,9 +37009,9 @@ inline void LSTMWeightParams::set_allocated_blockinputrecursionmatrix(::CoreML:: delete blockinputrecursionmatrix_; blockinputrecursionmatrix_ = blockinputrecursionmatrix; if (blockinputrecursionmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.blockInputRecursionMatrix) } @@ -37027,7 +37030,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::outputgate : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_outputgaterecursionmatrix() { - + if (outputgaterecursionmatrix_ == NULL) { outputgaterecursionmatrix_ = new ::CoreML::Specification::WeightParams; } @@ -37036,7 +37039,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_outputga } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_outputgaterecursionmatrix() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.outputGateRecursionMatrix) - + ::CoreML::Specification::WeightParams* temp = outputgaterecursionmatrix_; outputgaterecursionmatrix_ = NULL; return temp; @@ -37045,9 +37048,9 @@ inline void LSTMWeightParams::set_allocated_outputgaterecursionmatrix(::CoreML:: delete outputgaterecursionmatrix_; outputgaterecursionmatrix_ = outputgaterecursionmatrix; if (outputgaterecursionmatrix) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.outputGateRecursionMatrix) } @@ -37066,7 +37069,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::inputgateb : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_inputgatebiasvector() { - + if (inputgatebiasvector_ == NULL) { inputgatebiasvector_ = new ::CoreML::Specification::WeightParams; } @@ -37075,7 +37078,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_inputgat } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_inputgatebiasvector() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.inputGateBiasVector) - + ::CoreML::Specification::WeightParams* temp = inputgatebiasvector_; inputgatebiasvector_ = NULL; return temp; @@ -37084,9 +37087,9 @@ inline void LSTMWeightParams::set_allocated_inputgatebiasvector(::CoreML::Specif delete inputgatebiasvector_; inputgatebiasvector_ = inputgatebiasvector; if (inputgatebiasvector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.inputGateBiasVector) } @@ -37105,7 +37108,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::forgetgate : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_forgetgatebiasvector() { - + if (forgetgatebiasvector_ == NULL) { forgetgatebiasvector_ = new ::CoreML::Specification::WeightParams; } @@ -37114,7 +37117,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_forgetga } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_forgetgatebiasvector() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.forgetGateBiasVector) - + ::CoreML::Specification::WeightParams* temp = forgetgatebiasvector_; forgetgatebiasvector_ = NULL; return temp; @@ -37123,9 +37126,9 @@ inline void LSTMWeightParams::set_allocated_forgetgatebiasvector(::CoreML::Speci delete forgetgatebiasvector_; forgetgatebiasvector_ = forgetgatebiasvector; if (forgetgatebiasvector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.forgetGateBiasVector) } @@ -37144,7 +37147,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::blockinput : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_blockinputbiasvector() { - + if (blockinputbiasvector_ == NULL) { blockinputbiasvector_ = new ::CoreML::Specification::WeightParams; } @@ -37153,7 +37156,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_blockinp } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_blockinputbiasvector() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.blockInputBiasVector) - + ::CoreML::Specification::WeightParams* temp = blockinputbiasvector_; blockinputbiasvector_ = NULL; return temp; @@ -37162,9 +37165,9 @@ inline void LSTMWeightParams::set_allocated_blockinputbiasvector(::CoreML::Speci delete blockinputbiasvector_; blockinputbiasvector_ = blockinputbiasvector; if (blockinputbiasvector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.blockInputBiasVector) } @@ -37183,7 +37186,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::outputgate : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_outputgatebiasvector() { - + if (outputgatebiasvector_ == NULL) { outputgatebiasvector_ = new ::CoreML::Specification::WeightParams; } @@ -37192,7 +37195,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_outputga } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_outputgatebiasvector() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.outputGateBiasVector) - + ::CoreML::Specification::WeightParams* temp = outputgatebiasvector_; outputgatebiasvector_ = NULL; return temp; @@ -37201,9 +37204,9 @@ inline void LSTMWeightParams::set_allocated_outputgatebiasvector(::CoreML::Speci delete outputgatebiasvector_; outputgatebiasvector_ = outputgatebiasvector; if (outputgatebiasvector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.outputGateBiasVector) } @@ -37222,7 +37225,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::inputgatep : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_inputgatepeepholevector() { - + if (inputgatepeepholevector_ == NULL) { inputgatepeepholevector_ = new ::CoreML::Specification::WeightParams; } @@ -37231,7 +37234,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_inputgat } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_inputgatepeepholevector() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.inputGatePeepholeVector) - + ::CoreML::Specification::WeightParams* temp = inputgatepeepholevector_; inputgatepeepholevector_ = NULL; return temp; @@ -37240,9 +37243,9 @@ inline void LSTMWeightParams::set_allocated_inputgatepeepholevector(::CoreML::Sp delete inputgatepeepholevector_; inputgatepeepholevector_ = inputgatepeepholevector; if (inputgatepeepholevector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.inputGatePeepholeVector) } @@ -37261,7 +37264,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::forgetgate : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_forgetgatepeepholevector() { - + if (forgetgatepeepholevector_ == NULL) { forgetgatepeepholevector_ = new ::CoreML::Specification::WeightParams; } @@ -37270,7 +37273,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_forgetga } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_forgetgatepeepholevector() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.forgetGatePeepholeVector) - + ::CoreML::Specification::WeightParams* temp = forgetgatepeepholevector_; forgetgatepeepholevector_ = NULL; return temp; @@ -37279,9 +37282,9 @@ inline void LSTMWeightParams::set_allocated_forgetgatepeepholevector(::CoreML::S delete forgetgatepeepholevector_; forgetgatepeepholevector_ = forgetgatepeepholevector; if (forgetgatepeepholevector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.forgetGatePeepholeVector) } @@ -37300,7 +37303,7 @@ inline const ::CoreML::Specification::WeightParams& LSTMWeightParams::outputgate : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_outputgatepeepholevector() { - + if (outputgatepeepholevector_ == NULL) { outputgatepeepholevector_ = new ::CoreML::Specification::WeightParams; } @@ -37309,7 +37312,7 @@ inline ::CoreML::Specification::WeightParams* LSTMWeightParams::mutable_outputga } inline ::CoreML::Specification::WeightParams* LSTMWeightParams::release_outputgatepeepholevector() { // @@protoc_insertion_point(field_release:CoreML.Specification.LSTMWeightParams.outputGatePeepholeVector) - + ::CoreML::Specification::WeightParams* temp = outputgatepeepholevector_; outputgatepeepholevector_ = NULL; return temp; @@ -37318,9 +37321,9 @@ inline void LSTMWeightParams::set_allocated_outputgatepeepholevector(::CoreML::S delete outputgatepeepholevector_; outputgatepeepholevector_ = outputgatepeepholevector; if (outputgatepeepholevector) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LSTMWeightParams.outputGatePeepholeVector) } @@ -37338,7 +37341,7 @@ inline ::google::protobuf::uint64 UniDirectionalLSTMLayerParams::inputvectorsize return inputvectorsize_; } inline void UniDirectionalLSTMLayerParams::set_inputvectorsize(::google::protobuf::uint64 value) { - + inputvectorsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UniDirectionalLSTMLayerParams.inputVectorSize) } @@ -37352,7 +37355,7 @@ inline ::google::protobuf::uint64 UniDirectionalLSTMLayerParams::outputvectorsiz return outputvectorsize_; } inline void UniDirectionalLSTMLayerParams::set_outputvectorsize(::google::protobuf::uint64 value) { - + outputvectorsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UniDirectionalLSTMLayerParams.outputVectorSize) } @@ -37401,7 +37404,7 @@ inline const ::CoreML::Specification::LSTMParams& UniDirectionalLSTMLayerParams: : *::CoreML::Specification::LSTMParams::internal_default_instance(); } inline ::CoreML::Specification::LSTMParams* UniDirectionalLSTMLayerParams::mutable_params() { - + if (params_ == NULL) { params_ = new ::CoreML::Specification::LSTMParams; } @@ -37410,7 +37413,7 @@ inline ::CoreML::Specification::LSTMParams* UniDirectionalLSTMLayerParams::mutab } inline ::CoreML::Specification::LSTMParams* UniDirectionalLSTMLayerParams::release_params() { // @@protoc_insertion_point(field_release:CoreML.Specification.UniDirectionalLSTMLayerParams.params) - + ::CoreML::Specification::LSTMParams* temp = params_; params_ = NULL; return temp; @@ -37419,9 +37422,9 @@ inline void UniDirectionalLSTMLayerParams::set_allocated_params(::CoreML::Specif delete params_; params_ = params; if (params) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.UniDirectionalLSTMLayerParams.params) } @@ -37440,7 +37443,7 @@ inline const ::CoreML::Specification::LSTMWeightParams& UniDirectionalLSTMLayerP : *::CoreML::Specification::LSTMWeightParams::internal_default_instance(); } inline ::CoreML::Specification::LSTMWeightParams* UniDirectionalLSTMLayerParams::mutable_weightparams() { - + if (weightparams_ == NULL) { weightparams_ = new ::CoreML::Specification::LSTMWeightParams; } @@ -37449,7 +37452,7 @@ inline ::CoreML::Specification::LSTMWeightParams* UniDirectionalLSTMLayerParams: } inline ::CoreML::Specification::LSTMWeightParams* UniDirectionalLSTMLayerParams::release_weightparams() { // @@protoc_insertion_point(field_release:CoreML.Specification.UniDirectionalLSTMLayerParams.weightParams) - + ::CoreML::Specification::LSTMWeightParams* temp = weightparams_; weightparams_ = NULL; return temp; @@ -37458,9 +37461,9 @@ inline void UniDirectionalLSTMLayerParams::set_allocated_weightparams(::CoreML:: delete weightparams_; weightparams_ = weightparams; if (weightparams) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.UniDirectionalLSTMLayerParams.weightParams) } @@ -37474,7 +37477,7 @@ inline bool UniDirectionalLSTMLayerParams::reverseinput() const { return reverseinput_; } inline void UniDirectionalLSTMLayerParams::set_reverseinput(bool value) { - + reverseinput_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UniDirectionalLSTMLayerParams.reverseInput) } @@ -37492,7 +37495,7 @@ inline ::google::protobuf::uint64 BiDirectionalLSTMLayerParams::inputvectorsize( return inputvectorsize_; } inline void BiDirectionalLSTMLayerParams::set_inputvectorsize(::google::protobuf::uint64 value) { - + inputvectorsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BiDirectionalLSTMLayerParams.inputVectorSize) } @@ -37506,7 +37509,7 @@ inline ::google::protobuf::uint64 BiDirectionalLSTMLayerParams::outputvectorsize return outputvectorsize_; } inline void BiDirectionalLSTMLayerParams::set_outputvectorsize(::google::protobuf::uint64 value) { - + outputvectorsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BiDirectionalLSTMLayerParams.outputVectorSize) } @@ -37585,7 +37588,7 @@ inline const ::CoreML::Specification::LSTMParams& BiDirectionalLSTMLayerParams:: : *::CoreML::Specification::LSTMParams::internal_default_instance(); } inline ::CoreML::Specification::LSTMParams* BiDirectionalLSTMLayerParams::mutable_params() { - + if (params_ == NULL) { params_ = new ::CoreML::Specification::LSTMParams; } @@ -37594,7 +37597,7 @@ inline ::CoreML::Specification::LSTMParams* BiDirectionalLSTMLayerParams::mutabl } inline ::CoreML::Specification::LSTMParams* BiDirectionalLSTMLayerParams::release_params() { // @@protoc_insertion_point(field_release:CoreML.Specification.BiDirectionalLSTMLayerParams.params) - + ::CoreML::Specification::LSTMParams* temp = params_; params_ = NULL; return temp; @@ -37603,9 +37606,9 @@ inline void BiDirectionalLSTMLayerParams::set_allocated_params(::CoreML::Specifi delete params_; params_ = params; if (params) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.BiDirectionalLSTMLayerParams.params) } @@ -37879,13 +37882,13 @@ inline const ::std::string& CustomLayerParams::classname() const { return classname_.GetNoArena(); } inline void CustomLayerParams::set_classname(const ::std::string& value) { - + classname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CustomLayerParams.className) } #if LANG_CXX11 inline void CustomLayerParams::set_classname(::std::string&& value) { - + classname_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CustomLayerParams.className) @@ -37893,31 +37896,31 @@ inline void CustomLayerParams::set_classname(::std::string&& value) { #endif inline void CustomLayerParams::set_classname(const char* value) { GOOGLE_DCHECK(value != NULL); - + classname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CustomLayerParams.className) } inline void CustomLayerParams::set_classname(const char* value, size_t size) { - + classname_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CustomLayerParams.className) } inline ::std::string* CustomLayerParams::mutable_classname() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CustomLayerParams.className) return classname_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* CustomLayerParams::release_classname() { // @@protoc_insertion_point(field_release:CoreML.Specification.CustomLayerParams.className) - + return classname_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void CustomLayerParams::set_allocated_classname(::std::string* classname) { if (classname != NULL) { - + } else { - + } classname_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), classname); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CustomLayerParams.className) @@ -37980,13 +37983,13 @@ inline const ::std::string& CustomLayerParams::description() const { return description_.GetNoArena(); } inline void CustomLayerParams::set_description(const ::std::string& value) { - + description_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CustomLayerParams.description) } #if LANG_CXX11 inline void CustomLayerParams::set_description(::std::string&& value) { - + description_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CustomLayerParams.description) @@ -37994,31 +37997,31 @@ inline void CustomLayerParams::set_description(::std::string&& value) { #endif inline void CustomLayerParams::set_description(const char* value) { GOOGLE_DCHECK(value != NULL); - + description_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CustomLayerParams.description) } inline void CustomLayerParams::set_description(const char* value, size_t size) { - + description_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CustomLayerParams.description) } inline ::std::string* CustomLayerParams::mutable_description() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CustomLayerParams.description) return description_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* CustomLayerParams::release_description() { // @@protoc_insertion_point(field_release:CoreML.Specification.CustomLayerParams.description) - + return description_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void CustomLayerParams::set_allocated_description(::std::string* description) { if (description != NULL) { - + } else { - + } description_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), description); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CustomLayerParams.description) @@ -38071,7 +38074,7 @@ inline bool BatchedMatMulLayerParams::transposea() const { return transposea_; } inline void BatchedMatMulLayerParams::set_transposea(bool value) { - + transposea_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BatchedMatMulLayerParams.transposeA) } @@ -38085,7 +38088,7 @@ inline bool BatchedMatMulLayerParams::transposeb() const { return transposeb_; } inline void BatchedMatMulLayerParams::set_transposeb(bool value) { - + transposeb_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BatchedMatMulLayerParams.transposeB) } @@ -38099,7 +38102,7 @@ inline ::google::protobuf::uint64 BatchedMatMulLayerParams::weightmatrixfirstdim return weightmatrixfirstdimension_; } inline void BatchedMatMulLayerParams::set_weightmatrixfirstdimension(::google::protobuf::uint64 value) { - + weightmatrixfirstdimension_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BatchedMatMulLayerParams.weightMatrixFirstDimension) } @@ -38113,7 +38116,7 @@ inline ::google::protobuf::uint64 BatchedMatMulLayerParams::weightmatrixseconddi return weightmatrixseconddimension_; } inline void BatchedMatMulLayerParams::set_weightmatrixseconddimension(::google::protobuf::uint64 value) { - + weightmatrixseconddimension_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BatchedMatMulLayerParams.weightMatrixSecondDimension) } @@ -38127,7 +38130,7 @@ inline bool BatchedMatMulLayerParams::hasbias() const { return hasbias_; } inline void BatchedMatMulLayerParams::set_hasbias(bool value) { - + hasbias_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BatchedMatMulLayerParams.hasBias) } @@ -38146,7 +38149,7 @@ inline const ::CoreML::Specification::WeightParams& BatchedMatMulLayerParams::we : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* BatchedMatMulLayerParams::mutable_weights() { - + if (weights_ == NULL) { weights_ = new ::CoreML::Specification::WeightParams; } @@ -38155,7 +38158,7 @@ inline ::CoreML::Specification::WeightParams* BatchedMatMulLayerParams::mutable_ } inline ::CoreML::Specification::WeightParams* BatchedMatMulLayerParams::release_weights() { // @@protoc_insertion_point(field_release:CoreML.Specification.BatchedMatMulLayerParams.weights) - + ::CoreML::Specification::WeightParams* temp = weights_; weights_ = NULL; return temp; @@ -38164,9 +38167,9 @@ inline void BatchedMatMulLayerParams::set_allocated_weights(::CoreML::Specificat delete weights_; weights_ = weights; if (weights) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.BatchedMatMulLayerParams.weights) } @@ -38185,7 +38188,7 @@ inline const ::CoreML::Specification::WeightParams& BatchedMatMulLayerParams::bi : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* BatchedMatMulLayerParams::mutable_bias() { - + if (bias_ == NULL) { bias_ = new ::CoreML::Specification::WeightParams; } @@ -38194,7 +38197,7 @@ inline ::CoreML::Specification::WeightParams* BatchedMatMulLayerParams::mutable_ } inline ::CoreML::Specification::WeightParams* BatchedMatMulLayerParams::release_bias() { // @@protoc_insertion_point(field_release:CoreML.Specification.BatchedMatMulLayerParams.bias) - + ::CoreML::Specification::WeightParams* temp = bias_; bias_ = NULL; return temp; @@ -38203,9 +38206,9 @@ inline void BatchedMatMulLayerParams::set_allocated_bias(::CoreML::Specification delete bias_; bias_ = bias; if (bias) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.BatchedMatMulLayerParams.bias) } @@ -38219,7 +38222,7 @@ inline bool BatchedMatMulLayerParams::int8dynamicquantize() const { return int8dynamicquantize_; } inline void BatchedMatMulLayerParams::set_int8dynamicquantize(bool value) { - + int8dynamicquantize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BatchedMatMulLayerParams.int8DynamicQuantize) } @@ -38237,7 +38240,7 @@ inline ::google::protobuf::int64 ConcatNDLayerParams::axis() const { return axis_; } inline void ConcatNDLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ConcatNDLayerParams.axis) } @@ -38251,7 +38254,7 @@ inline bool ConcatNDLayerParams::interleave() const { return interleave_; } inline void ConcatNDLayerParams::set_interleave(bool value) { - + interleave_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ConcatNDLayerParams.interleave) } @@ -38269,7 +38272,7 @@ inline ::google::protobuf::int64 SoftmaxNDLayerParams::axis() const { return axis_; } inline void SoftmaxNDLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SoftmaxNDLayerParams.axis) } @@ -38321,7 +38324,7 @@ inline ::google::protobuf::int64 ReverseSeqLayerParams::batchaxis() const { return batchaxis_; } inline void ReverseSeqLayerParams::set_batchaxis(::google::protobuf::int64 value) { - + batchaxis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReverseSeqLayerParams.batchAxis) } @@ -38335,7 +38338,7 @@ inline ::google::protobuf::int64 ReverseSeqLayerParams::sequenceaxis() const { return sequenceaxis_; } inline void ReverseSeqLayerParams::set_sequenceaxis(::google::protobuf::int64 value) { - + sequenceaxis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReverseSeqLayerParams.sequenceAxis) } @@ -38388,7 +38391,7 @@ inline const ::CoreML::Specification::WeightParams& LoadConstantNDLayerParams::d : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LoadConstantNDLayerParams::mutable_data() { - + if (data_ == NULL) { data_ = new ::CoreML::Specification::WeightParams; } @@ -38397,7 +38400,7 @@ inline ::CoreML::Specification::WeightParams* LoadConstantNDLayerParams::mutable } inline ::CoreML::Specification::WeightParams* LoadConstantNDLayerParams::release_data() { // @@protoc_insertion_point(field_release:CoreML.Specification.LoadConstantNDLayerParams.data) - + ::CoreML::Specification::WeightParams* temp = data_; data_ = NULL; return temp; @@ -38406,9 +38409,9 @@ inline void LoadConstantNDLayerParams::set_allocated_data(::CoreML::Specificatio delete data_; data_ = data; if (data) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LoadConstantNDLayerParams.data) } @@ -38426,7 +38429,7 @@ inline float FillLikeLayerParams::value() const { return value_; } inline void FillLikeLayerParams::set_value(float value) { - + value_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.FillLikeLayerParams.value) } @@ -38444,7 +38447,7 @@ inline float FillStaticLayerParams::value() const { return value_; } inline void FillStaticLayerParams::set_value(float value) { - + value_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.FillStaticLayerParams.value) } @@ -38492,7 +38495,7 @@ inline float FillDynamicLayerParams::value() const { return value_; } inline void FillDynamicLayerParams::set_value(float value) { - + value_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.FillDynamicLayerParams.value) } @@ -38574,7 +38577,7 @@ inline ::google::protobuf::int64 MatrixBandPartLayerParams::numlower() const { return numlower_; } inline void MatrixBandPartLayerParams::set_numlower(::google::protobuf::int64 value) { - + numlower_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MatrixBandPartLayerParams.numLower) } @@ -38588,7 +38591,7 @@ inline ::google::protobuf::int64 MatrixBandPartLayerParams::numupper() const { return numupper_; } inline void MatrixBandPartLayerParams::set_numupper(::google::protobuf::int64 value) { - + numupper_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.MatrixBandPartLayerParams.numUpper) } @@ -38606,7 +38609,7 @@ inline ::google::protobuf::int64 UpperTriangularLayerParams::k() const { return k_; } inline void UpperTriangularLayerParams::set_k(::google::protobuf::int64 value) { - + k_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.UpperTriangularLayerParams.k) } @@ -38624,7 +38627,7 @@ inline ::google::protobuf::int64 LowerTriangularLayerParams::k() const { return k_; } inline void LowerTriangularLayerParams::set_k(::google::protobuf::int64 value) { - + k_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LowerTriangularLayerParams.k) } @@ -38716,7 +38719,7 @@ inline ::google::protobuf::int64 GatherLayerParams::axis() const { return axis_; } inline void GatherLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GatherLayerParams.axis) } @@ -38734,7 +38737,7 @@ inline ::google::protobuf::int64 ScatterLayerParams::axis() const { return axis_; } inline void ScatterLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ScatterLayerParams.axis) } @@ -38748,7 +38751,7 @@ inline ::CoreML::Specification::ScatterMode ScatterLayerParams::mode() const { return static_cast< ::CoreML::Specification::ScatterMode >(mode_); } inline void ScatterLayerParams::set_mode(::CoreML::Specification::ScatterMode value) { - + mode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ScatterLayerParams.mode) } @@ -38770,7 +38773,7 @@ inline ::CoreML::Specification::ScatterMode ScatterNDLayerParams::mode() const { return static_cast< ::CoreML::Specification::ScatterMode >(mode_); } inline void ScatterNDLayerParams::set_mode(::CoreML::Specification::ScatterMode value) { - + mode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ScatterNDLayerParams.mode) } @@ -38788,7 +38791,7 @@ inline ::google::protobuf::int64 GatherAlongAxisLayerParams::axis() const { return axis_; } inline void GatherAlongAxisLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GatherAlongAxisLayerParams.axis) } @@ -38806,7 +38809,7 @@ inline ::google::protobuf::int64 ScatterAlongAxisLayerParams::axis() const { return axis_; } inline void ScatterAlongAxisLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ScatterAlongAxisLayerParams.axis) } @@ -38820,7 +38823,7 @@ inline ::CoreML::Specification::ScatterMode ScatterAlongAxisLayerParams::mode() return static_cast< ::CoreML::Specification::ScatterMode >(mode_); } inline void ScatterAlongAxisLayerParams::set_mode(::CoreML::Specification::ScatterMode value) { - + mode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ScatterAlongAxisLayerParams.mode) } @@ -38838,7 +38841,7 @@ inline ::google::protobuf::int64 StackLayerParams::axis() const { return axis_; } inline void StackLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.StackLayerParams.axis) } @@ -38890,7 +38893,7 @@ inline float ConstantPaddingLayerParams::value() const { return value_; } inline void ConstantPaddingLayerParams::set_value(float value) { - + value_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ConstantPaddingLayerParams.value) } @@ -38934,7 +38937,7 @@ inline bool ConstantPaddingLayerParams::padtogivenoutputsizemode() const { return padtogivenoutputsizemode_; } inline void ConstantPaddingLayerParams::set_padtogivenoutputsizemode(bool value) { - + padtogivenoutputsizemode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ConstantPaddingLayerParams.padToGivenOutputSizeMode) } @@ -38952,7 +38955,7 @@ inline ::google::protobuf::int64 RandomNormalLikeLayerParams::seed() const { return seed_; } inline void RandomNormalLikeLayerParams::set_seed(::google::protobuf::int64 value) { - + seed_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomNormalLikeLayerParams.seed) } @@ -38966,7 +38969,7 @@ inline float RandomNormalLikeLayerParams::mean() const { return mean_; } inline void RandomNormalLikeLayerParams::set_mean(float value) { - + mean_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomNormalLikeLayerParams.mean) } @@ -38980,7 +38983,7 @@ inline float RandomNormalLikeLayerParams::stddev() const { return stddev_; } inline void RandomNormalLikeLayerParams::set_stddev(float value) { - + stddev_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomNormalLikeLayerParams.stdDev) } @@ -38998,7 +39001,7 @@ inline ::google::protobuf::int64 RandomNormalStaticLayerParams::seed() const { return seed_; } inline void RandomNormalStaticLayerParams::set_seed(::google::protobuf::int64 value) { - + seed_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomNormalStaticLayerParams.seed) } @@ -39012,7 +39015,7 @@ inline float RandomNormalStaticLayerParams::mean() const { return mean_; } inline void RandomNormalStaticLayerParams::set_mean(float value) { - + mean_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomNormalStaticLayerParams.mean) } @@ -39026,7 +39029,7 @@ inline float RandomNormalStaticLayerParams::stddev() const { return stddev_; } inline void RandomNormalStaticLayerParams::set_stddev(float value) { - + stddev_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomNormalStaticLayerParams.stdDev) } @@ -39074,7 +39077,7 @@ inline ::google::protobuf::int64 RandomNormalDynamicLayerParams::seed() const { return seed_; } inline void RandomNormalDynamicLayerParams::set_seed(::google::protobuf::int64 value) { - + seed_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomNormalDynamicLayerParams.seed) } @@ -39088,7 +39091,7 @@ inline float RandomNormalDynamicLayerParams::mean() const { return mean_; } inline void RandomNormalDynamicLayerParams::set_mean(float value) { - + mean_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomNormalDynamicLayerParams.mean) } @@ -39102,7 +39105,7 @@ inline float RandomNormalDynamicLayerParams::stddev() const { return stddev_; } inline void RandomNormalDynamicLayerParams::set_stddev(float value) { - + stddev_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomNormalDynamicLayerParams.stdDev) } @@ -39120,7 +39123,7 @@ inline ::google::protobuf::int64 RandomUniformLikeLayerParams::seed() const { return seed_; } inline void RandomUniformLikeLayerParams::set_seed(::google::protobuf::int64 value) { - + seed_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomUniformLikeLayerParams.seed) } @@ -39134,7 +39137,7 @@ inline float RandomUniformLikeLayerParams::minval() const { return minval_; } inline void RandomUniformLikeLayerParams::set_minval(float value) { - + minval_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomUniformLikeLayerParams.minVal) } @@ -39148,7 +39151,7 @@ inline float RandomUniformLikeLayerParams::maxval() const { return maxval_; } inline void RandomUniformLikeLayerParams::set_maxval(float value) { - + maxval_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomUniformLikeLayerParams.maxVal) } @@ -39166,7 +39169,7 @@ inline ::google::protobuf::int64 RandomUniformStaticLayerParams::seed() const { return seed_; } inline void RandomUniformStaticLayerParams::set_seed(::google::protobuf::int64 value) { - + seed_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomUniformStaticLayerParams.seed) } @@ -39180,7 +39183,7 @@ inline float RandomUniformStaticLayerParams::minval() const { return minval_; } inline void RandomUniformStaticLayerParams::set_minval(float value) { - + minval_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomUniformStaticLayerParams.minVal) } @@ -39194,7 +39197,7 @@ inline float RandomUniformStaticLayerParams::maxval() const { return maxval_; } inline void RandomUniformStaticLayerParams::set_maxval(float value) { - + maxval_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomUniformStaticLayerParams.maxVal) } @@ -39242,7 +39245,7 @@ inline ::google::protobuf::int64 RandomUniformDynamicLayerParams::seed() const { return seed_; } inline void RandomUniformDynamicLayerParams::set_seed(::google::protobuf::int64 value) { - + seed_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomUniformDynamicLayerParams.seed) } @@ -39256,7 +39259,7 @@ inline float RandomUniformDynamicLayerParams::minval() const { return minval_; } inline void RandomUniformDynamicLayerParams::set_minval(float value) { - + minval_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomUniformDynamicLayerParams.minVal) } @@ -39270,7 +39273,7 @@ inline float RandomUniformDynamicLayerParams::maxval() const { return maxval_; } inline void RandomUniformDynamicLayerParams::set_maxval(float value) { - + maxval_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomUniformDynamicLayerParams.maxVal) } @@ -39288,7 +39291,7 @@ inline ::google::protobuf::int64 RandomBernoulliLikeLayerParams::seed() const { return seed_; } inline void RandomBernoulliLikeLayerParams::set_seed(::google::protobuf::int64 value) { - + seed_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomBernoulliLikeLayerParams.seed) } @@ -39302,7 +39305,7 @@ inline float RandomBernoulliLikeLayerParams::prob() const { return prob_; } inline void RandomBernoulliLikeLayerParams::set_prob(float value) { - + prob_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomBernoulliLikeLayerParams.prob) } @@ -39320,7 +39323,7 @@ inline ::google::protobuf::int64 RandomBernoulliStaticLayerParams::seed() const return seed_; } inline void RandomBernoulliStaticLayerParams::set_seed(::google::protobuf::int64 value) { - + seed_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomBernoulliStaticLayerParams.seed) } @@ -39334,7 +39337,7 @@ inline float RandomBernoulliStaticLayerParams::prob() const { return prob_; } inline void RandomBernoulliStaticLayerParams::set_prob(float value) { - + prob_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomBernoulliStaticLayerParams.prob) } @@ -39382,7 +39385,7 @@ inline ::google::protobuf::int64 RandomBernoulliDynamicLayerParams::seed() const return seed_; } inline void RandomBernoulliDynamicLayerParams::set_seed(::google::protobuf::int64 value) { - + seed_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomBernoulliDynamicLayerParams.seed) } @@ -39396,7 +39399,7 @@ inline float RandomBernoulliDynamicLayerParams::prob() const { return prob_; } inline void RandomBernoulliDynamicLayerParams::set_prob(float value) { - + prob_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RandomBernoulliDynamicLayerParams.prob) } @@ -39414,7 +39417,7 @@ inline ::google::protobuf::int64 CategoricalDistributionLayerParams::seed() cons return seed_; } inline void CategoricalDistributionLayerParams::set_seed(::google::protobuf::int64 value) { - + seed_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CategoricalDistributionLayerParams.seed) } @@ -39428,7 +39431,7 @@ inline ::google::protobuf::int64 CategoricalDistributionLayerParams::numsamples( return numsamples_; } inline void CategoricalDistributionLayerParams::set_numsamples(::google::protobuf::int64 value) { - + numsamples_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CategoricalDistributionLayerParams.numSamples) } @@ -39442,7 +39445,7 @@ inline bool CategoricalDistributionLayerParams::islogits() const { return islogits_; } inline void CategoricalDistributionLayerParams::set_islogits(bool value) { - + islogits_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CategoricalDistributionLayerParams.isLogits) } @@ -39456,7 +39459,7 @@ inline float CategoricalDistributionLayerParams::eps() const { return eps_; } inline void CategoricalDistributionLayerParams::set_eps(float value) { - + eps_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CategoricalDistributionLayerParams.eps) } @@ -39470,7 +39473,7 @@ inline float CategoricalDistributionLayerParams::temperature() const { return temperature_; } inline void CategoricalDistributionLayerParams::set_temperature(float value) { - + temperature_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CategoricalDistributionLayerParams.temperature) } @@ -39518,7 +39521,7 @@ inline bool ReduceL1LayerParams::keepdims() const { return keepdims_; } inline void ReduceL1LayerParams::set_keepdims(bool value) { - + keepdims_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceL1LayerParams.keepDims) } @@ -39532,7 +39535,7 @@ inline bool ReduceL1LayerParams::reduceall() const { return reduceall_; } inline void ReduceL1LayerParams::set_reduceall(bool value) { - + reduceall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceL1LayerParams.reduceAll) } @@ -39580,7 +39583,7 @@ inline bool ReduceL2LayerParams::keepdims() const { return keepdims_; } inline void ReduceL2LayerParams::set_keepdims(bool value) { - + keepdims_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceL2LayerParams.keepDims) } @@ -39594,7 +39597,7 @@ inline bool ReduceL2LayerParams::reduceall() const { return reduceall_; } inline void ReduceL2LayerParams::set_reduceall(bool value) { - + reduceall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceL2LayerParams.reduceAll) } @@ -39642,7 +39645,7 @@ inline bool ReduceMaxLayerParams::keepdims() const { return keepdims_; } inline void ReduceMaxLayerParams::set_keepdims(bool value) { - + keepdims_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceMaxLayerParams.keepDims) } @@ -39656,7 +39659,7 @@ inline bool ReduceMaxLayerParams::reduceall() const { return reduceall_; } inline void ReduceMaxLayerParams::set_reduceall(bool value) { - + reduceall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceMaxLayerParams.reduceAll) } @@ -39704,7 +39707,7 @@ inline bool ReduceMinLayerParams::keepdims() const { return keepdims_; } inline void ReduceMinLayerParams::set_keepdims(bool value) { - + keepdims_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceMinLayerParams.keepDims) } @@ -39718,7 +39721,7 @@ inline bool ReduceMinLayerParams::reduceall() const { return reduceall_; } inline void ReduceMinLayerParams::set_reduceall(bool value) { - + reduceall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceMinLayerParams.reduceAll) } @@ -39766,7 +39769,7 @@ inline bool ReduceSumLayerParams::keepdims() const { return keepdims_; } inline void ReduceSumLayerParams::set_keepdims(bool value) { - + keepdims_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceSumLayerParams.keepDims) } @@ -39780,7 +39783,7 @@ inline bool ReduceSumLayerParams::reduceall() const { return reduceall_; } inline void ReduceSumLayerParams::set_reduceall(bool value) { - + reduceall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceSumLayerParams.reduceAll) } @@ -39828,7 +39831,7 @@ inline bool ReduceProdLayerParams::keepdims() const { return keepdims_; } inline void ReduceProdLayerParams::set_keepdims(bool value) { - + keepdims_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceProdLayerParams.keepDims) } @@ -39842,7 +39845,7 @@ inline bool ReduceProdLayerParams::reduceall() const { return reduceall_; } inline void ReduceProdLayerParams::set_reduceall(bool value) { - + reduceall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceProdLayerParams.reduceAll) } @@ -39890,7 +39893,7 @@ inline bool ReduceMeanLayerParams::keepdims() const { return keepdims_; } inline void ReduceMeanLayerParams::set_keepdims(bool value) { - + keepdims_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceMeanLayerParams.keepDims) } @@ -39904,7 +39907,7 @@ inline bool ReduceMeanLayerParams::reduceall() const { return reduceall_; } inline void ReduceMeanLayerParams::set_reduceall(bool value) { - + reduceall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceMeanLayerParams.reduceAll) } @@ -39952,7 +39955,7 @@ inline bool ReduceLogSumLayerParams::keepdims() const { return keepdims_; } inline void ReduceLogSumLayerParams::set_keepdims(bool value) { - + keepdims_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceLogSumLayerParams.keepDims) } @@ -39966,7 +39969,7 @@ inline bool ReduceLogSumLayerParams::reduceall() const { return reduceall_; } inline void ReduceLogSumLayerParams::set_reduceall(bool value) { - + reduceall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceLogSumLayerParams.reduceAll) } @@ -40014,7 +40017,7 @@ inline bool ReduceSumSquareLayerParams::keepdims() const { return keepdims_; } inline void ReduceSumSquareLayerParams::set_keepdims(bool value) { - + keepdims_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceSumSquareLayerParams.keepDims) } @@ -40028,7 +40031,7 @@ inline bool ReduceSumSquareLayerParams::reduceall() const { return reduceall_; } inline void ReduceSumSquareLayerParams::set_reduceall(bool value) { - + reduceall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceSumSquareLayerParams.reduceAll) } @@ -40076,7 +40079,7 @@ inline bool ReduceLogSumExpLayerParams::keepdims() const { return keepdims_; } inline void ReduceLogSumExpLayerParams::set_keepdims(bool value) { - + keepdims_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceLogSumExpLayerParams.keepDims) } @@ -40090,7 +40093,7 @@ inline bool ReduceLogSumExpLayerParams::reduceall() const { return reduceall_; } inline void ReduceLogSumExpLayerParams::set_reduceall(bool value) { - + reduceall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ReduceLogSumExpLayerParams.reduceAll) } @@ -40142,7 +40145,7 @@ inline ::google::protobuf::int64 FlattenTo2DLayerParams::axis() const { return axis_; } inline void FlattenTo2DLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.FlattenTo2DLayerParams.axis) } @@ -40232,7 +40235,7 @@ inline bool SqueezeLayerParams::squeezeall() const { return squeezeall_; } inline void SqueezeLayerParams::set_squeezeall(bool value) { - + squeezeall_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SqueezeLayerParams.squeezeAll) } @@ -40250,7 +40253,7 @@ inline ::google::protobuf::int64 TopKLayerParams::axis() const { return axis_; } inline void TopKLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TopKLayerParams.axis) } @@ -40264,7 +40267,7 @@ inline ::google::protobuf::uint64 TopKLayerParams::k() const { return k_; } inline void TopKLayerParams::set_k(::google::protobuf::uint64 value) { - + k_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TopKLayerParams.K) } @@ -40278,7 +40281,7 @@ inline bool TopKLayerParams::usebottomk() const { return usebottomk_; } inline void TopKLayerParams::set_usebottomk(bool value) { - + usebottomk_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TopKLayerParams.useBottomK) } @@ -40296,7 +40299,7 @@ inline ::google::protobuf::int64 ArgMaxLayerParams::axis() const { return axis_; } inline void ArgMaxLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ArgMaxLayerParams.axis) } @@ -40310,7 +40313,7 @@ inline bool ArgMaxLayerParams::removedim() const { return removedim_; } inline void ArgMaxLayerParams::set_removedim(bool value) { - + removedim_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ArgMaxLayerParams.removeDim) } @@ -40328,7 +40331,7 @@ inline ::google::protobuf::int64 ArgMinLayerParams::axis() const { return axis_; } inline void ArgMinLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ArgMinLayerParams.axis) } @@ -40342,7 +40345,7 @@ inline bool ArgMinLayerParams::removedim() const { return removedim_; } inline void ArgMinLayerParams::set_removedim(bool value) { - + removedim_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ArgMinLayerParams.removeDim) } @@ -40360,7 +40363,7 @@ inline ::google::protobuf::int64 SplitNDLayerParams::axis() const { return axis_; } inline void SplitNDLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SplitNDLayerParams.axis) } @@ -40374,7 +40377,7 @@ inline ::google::protobuf::uint64 SplitNDLayerParams::numsplits() const { return numsplits_; } inline void SplitNDLayerParams::set_numsplits(::google::protobuf::uint64 value) { - + numsplits_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SplitNDLayerParams.numSplits) } @@ -40438,7 +40441,7 @@ inline float ClipLayerParams::minval() const { return minval_; } inline void ClipLayerParams::set_minval(float value) { - + minval_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ClipLayerParams.minVal) } @@ -40452,7 +40455,7 @@ inline float ClipLayerParams::maxval() const { return maxval_; } inline void ClipLayerParams::set_maxval(float value) { - + maxval_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ClipLayerParams.maxVal) } @@ -40850,7 +40853,7 @@ inline ::CoreML::Specification::GeluLayerParams_GeluMode GeluLayerParams::mode() return static_cast< ::CoreML::Specification::GeluLayerParams_GeluMode >(mode_); } inline void GeluLayerParams::set_mode(::CoreML::Specification::GeluLayerParams_GeluMode value) { - + mode_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.GeluLayerParams.mode) } @@ -40868,7 +40871,7 @@ inline float RangeStaticLayerParams::endvalue() const { return endvalue_; } inline void RangeStaticLayerParams::set_endvalue(float value) { - + endvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RangeStaticLayerParams.endValue) } @@ -40882,7 +40885,7 @@ inline float RangeStaticLayerParams::startvalue() const { return startvalue_; } inline void RangeStaticLayerParams::set_startvalue(float value) { - + startvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RangeStaticLayerParams.startValue) } @@ -40896,7 +40899,7 @@ inline float RangeStaticLayerParams::stepsizevalue() const { return stepsizevalue_; } inline void RangeStaticLayerParams::set_stepsizevalue(float value) { - + stepsizevalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RangeStaticLayerParams.stepSizeValue) } @@ -40914,7 +40917,7 @@ inline float RangeDynamicLayerParams::startvalue() const { return startvalue_; } inline void RangeDynamicLayerParams::set_startvalue(float value) { - + startvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RangeDynamicLayerParams.startValue) } @@ -40928,7 +40931,7 @@ inline float RangeDynamicLayerParams::stepsizevalue() const { return stepsizevalue_; } inline void RangeDynamicLayerParams::set_stepsizevalue(float value) { - + stepsizevalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RangeDynamicLayerParams.stepSizeValue) } @@ -40946,7 +40949,7 @@ inline ::google::protobuf::int64 SlidingWindowsLayerParams::axis() const { return axis_; } inline void SlidingWindowsLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SlidingWindowsLayerParams.axis) } @@ -40960,7 +40963,7 @@ inline ::google::protobuf::uint64 SlidingWindowsLayerParams::windowsize() const return windowsize_; } inline void SlidingWindowsLayerParams::set_windowsize(::google::protobuf::uint64 value) { - + windowsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SlidingWindowsLayerParams.windowSize) } @@ -40974,7 +40977,7 @@ inline ::google::protobuf::uint64 SlidingWindowsLayerParams::step() const { return step_; } inline void SlidingWindowsLayerParams::set_step(::google::protobuf::uint64 value) { - + step_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SlidingWindowsLayerParams.step) } @@ -41022,7 +41025,7 @@ inline float LayerNormalizationLayerParams::eps() const { return eps_; } inline void LayerNormalizationLayerParams::set_eps(float value) { - + eps_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.LayerNormalizationLayerParams.eps) } @@ -41041,7 +41044,7 @@ inline const ::CoreML::Specification::WeightParams& LayerNormalizationLayerParam : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LayerNormalizationLayerParams::mutable_gamma() { - + if (gamma_ == NULL) { gamma_ = new ::CoreML::Specification::WeightParams; } @@ -41050,7 +41053,7 @@ inline ::CoreML::Specification::WeightParams* LayerNormalizationLayerParams::mut } inline ::CoreML::Specification::WeightParams* LayerNormalizationLayerParams::release_gamma() { // @@protoc_insertion_point(field_release:CoreML.Specification.LayerNormalizationLayerParams.gamma) - + ::CoreML::Specification::WeightParams* temp = gamma_; gamma_ = NULL; return temp; @@ -41059,9 +41062,9 @@ inline void LayerNormalizationLayerParams::set_allocated_gamma(::CoreML::Specifi delete gamma_; gamma_ = gamma; if (gamma) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LayerNormalizationLayerParams.gamma) } @@ -41080,7 +41083,7 @@ inline const ::CoreML::Specification::WeightParams& LayerNormalizationLayerParam : *::CoreML::Specification::WeightParams::internal_default_instance(); } inline ::CoreML::Specification::WeightParams* LayerNormalizationLayerParams::mutable_beta() { - + if (beta_ == NULL) { beta_ = new ::CoreML::Specification::WeightParams; } @@ -41089,7 +41092,7 @@ inline ::CoreML::Specification::WeightParams* LayerNormalizationLayerParams::mut } inline ::CoreML::Specification::WeightParams* LayerNormalizationLayerParams::release_beta() { // @@protoc_insertion_point(field_release:CoreML.Specification.LayerNormalizationLayerParams.beta) - + ::CoreML::Specification::WeightParams* temp = beta_; beta_ = NULL; return temp; @@ -41098,9 +41101,9 @@ inline void LayerNormalizationLayerParams::set_allocated_beta(::CoreML::Specific delete beta_; beta_ = beta; if (beta) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LayerNormalizationLayerParams.beta) } @@ -41118,7 +41121,7 @@ inline float NonMaximumSuppressionLayerParams::iouthreshold() const { return iouthreshold_; } inline void NonMaximumSuppressionLayerParams::set_iouthreshold(float value) { - + iouthreshold_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppressionLayerParams.iouThreshold) } @@ -41132,7 +41135,7 @@ inline float NonMaximumSuppressionLayerParams::scorethreshold() const { return scorethreshold_; } inline void NonMaximumSuppressionLayerParams::set_scorethreshold(float value) { - + scorethreshold_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppressionLayerParams.scoreThreshold) } @@ -41146,7 +41149,7 @@ inline ::google::protobuf::uint64 NonMaximumSuppressionLayerParams::maxboxes() c return maxboxes_; } inline void NonMaximumSuppressionLayerParams::set_maxboxes(::google::protobuf::uint64 value) { - + maxboxes_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppressionLayerParams.maxBoxes) } @@ -41160,7 +41163,7 @@ inline bool NonMaximumSuppressionLayerParams::perclasssuppression() const { return perclasssuppression_; } inline void NonMaximumSuppressionLayerParams::set_perclasssuppression(bool value) { - + perclasssuppression_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppressionLayerParams.perClassSuppression) } @@ -41178,7 +41181,7 @@ inline float ClampedReLULayerParams::alpha() const { return alpha_; } inline void ClampedReLULayerParams::set_alpha(float value) { - + alpha_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ClampedReLULayerParams.alpha) } @@ -41192,7 +41195,7 @@ inline float ClampedReLULayerParams::beta() const { return beta_; } inline void ClampedReLULayerParams::set_beta(float value) { - + beta_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ClampedReLULayerParams.beta) } @@ -41210,7 +41213,7 @@ inline ::google::protobuf::int64 ArgSortLayerParams::axis() const { return axis_; } inline void ArgSortLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ArgSortLayerParams.axis) } @@ -41224,7 +41227,7 @@ inline bool ArgSortLayerParams::descending() const { return descending_; } inline void ArgSortLayerParams::set_descending(bool value) { - + descending_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.ArgSortLayerParams.descending) } @@ -41242,7 +41245,7 @@ inline ::google::protobuf::int64 SliceBySizeLayerParams::size() const { return size_; } inline void SliceBySizeLayerParams::set_size(::google::protobuf::int64 value) { - + size_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SliceBySizeLayerParams.size) } @@ -41256,7 +41259,7 @@ inline ::google::protobuf::int64 SliceBySizeLayerParams::axis() const { return axis_; } inline void SliceBySizeLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SliceBySizeLayerParams.axis) } @@ -41334,7 +41337,7 @@ inline ::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping NeuralNetwor return static_cast< ::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping >(arrayinputshapemapping_); } inline void NeuralNetworkClassifier::set_arrayinputshapemapping(::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping value) { - + arrayinputshapemapping_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkClassifier.arrayInputShapeMapping) } @@ -41348,7 +41351,7 @@ inline ::CoreML::Specification::NeuralNetworkImageShapeMapping NeuralNetworkClas return static_cast< ::CoreML::Specification::NeuralNetworkImageShapeMapping >(imageinputshapemapping_); } inline void NeuralNetworkClassifier::set_imageinputshapemapping(::CoreML::Specification::NeuralNetworkImageShapeMapping value) { - + imageinputshapemapping_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkClassifier.imageInputShapeMapping) } @@ -41367,7 +41370,7 @@ inline const ::CoreML::Specification::NetworkUpdateParameters& NeuralNetworkClas : *::CoreML::Specification::NetworkUpdateParameters::internal_default_instance(); } inline ::CoreML::Specification::NetworkUpdateParameters* NeuralNetworkClassifier::mutable_updateparams() { - + if (updateparams_ == NULL) { updateparams_ = new ::CoreML::Specification::NetworkUpdateParameters; } @@ -41376,7 +41379,7 @@ inline ::CoreML::Specification::NetworkUpdateParameters* NeuralNetworkClassifier } inline ::CoreML::Specification::NetworkUpdateParameters* NeuralNetworkClassifier::release_updateparams() { // @@protoc_insertion_point(field_release:CoreML.Specification.NeuralNetworkClassifier.updateParams) - + ::CoreML::Specification::NetworkUpdateParameters* temp = updateparams_; updateparams_ = NULL; return temp; @@ -41385,9 +41388,9 @@ inline void NeuralNetworkClassifier::set_allocated_updateparams(::CoreML::Specif delete updateparams_; updateparams_ = updateparams; if (updateparams) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NeuralNetworkClassifier.updateParams) } @@ -41497,13 +41500,13 @@ inline const ::std::string& NeuralNetworkClassifier::labelprobabilitylayername() return labelprobabilitylayername_.GetNoArena(); } inline void NeuralNetworkClassifier::set_labelprobabilitylayername(const ::std::string& value) { - + labelprobabilitylayername_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkClassifier.labelProbabilityLayerName) } #if LANG_CXX11 inline void NeuralNetworkClassifier::set_labelprobabilitylayername(::std::string&& value) { - + labelprobabilitylayername_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.NeuralNetworkClassifier.labelProbabilityLayerName) @@ -41511,31 +41514,31 @@ inline void NeuralNetworkClassifier::set_labelprobabilitylayername(::std::string #endif inline void NeuralNetworkClassifier::set_labelprobabilitylayername(const char* value) { GOOGLE_DCHECK(value != NULL); - + labelprobabilitylayername_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.NeuralNetworkClassifier.labelProbabilityLayerName) } inline void NeuralNetworkClassifier::set_labelprobabilitylayername(const char* value, size_t size) { - + labelprobabilitylayername_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.NeuralNetworkClassifier.labelProbabilityLayerName) } inline ::std::string* NeuralNetworkClassifier::mutable_labelprobabilitylayername() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.NeuralNetworkClassifier.labelProbabilityLayerName) return labelprobabilitylayername_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* NeuralNetworkClassifier::release_labelprobabilitylayername() { // @@protoc_insertion_point(field_release:CoreML.Specification.NeuralNetworkClassifier.labelProbabilityLayerName) - + return labelprobabilitylayername_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void NeuralNetworkClassifier::set_allocated_labelprobabilitylayername(::std::string* labelprobabilitylayername) { if (labelprobabilitylayername != NULL) { - + } else { - + } labelprobabilitylayername_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), labelprobabilitylayername); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NeuralNetworkClassifier.labelProbabilityLayerName) @@ -41563,7 +41566,7 @@ inline ::google::protobuf::uint64 OneHotLayerParams::onehotvectorsize() const { return onehotvectorsize_; } inline void OneHotLayerParams::set_onehotvectorsize(::google::protobuf::uint64 value) { - + onehotvectorsize_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.OneHotLayerParams.oneHotVectorSize) } @@ -41577,7 +41580,7 @@ inline ::google::protobuf::int64 OneHotLayerParams::axis() const { return axis_; } inline void OneHotLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.OneHotLayerParams.axis) } @@ -41591,7 +41594,7 @@ inline float OneHotLayerParams::onvalue() const { return onvalue_; } inline void OneHotLayerParams::set_onvalue(float value) { - + onvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.OneHotLayerParams.onValue) } @@ -41605,7 +41608,7 @@ inline float OneHotLayerParams::offvalue() const { return offvalue_; } inline void OneHotLayerParams::set_offvalue(float value) { - + offvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.OneHotLayerParams.offValue) } @@ -41623,7 +41626,7 @@ inline ::google::protobuf::int64 CumSumLayerParams::axis() const { return axis_; } inline void CumSumLayerParams::set_axis(::google::protobuf::int64 value) { - + axis_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CumSumLayerParams.axis) } @@ -41637,7 +41640,7 @@ inline bool CumSumLayerParams::excludefinalsum() const { return excludefinalsum_; } inline void CumSumLayerParams::set_excludefinalsum(bool value) { - + excludefinalsum_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CumSumLayerParams.excludeFinalSum) } @@ -41651,7 +41654,7 @@ inline bool CumSumLayerParams::reverse() const { return reverse_; } inline void CumSumLayerParams::set_reverse(bool value) { - + reverse_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CumSumLayerParams.reverse) } @@ -41729,7 +41732,7 @@ inline ::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping NeuralNetwor return static_cast< ::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping >(arrayinputshapemapping_); } inline void NeuralNetworkRegressor::set_arrayinputshapemapping(::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping value) { - + arrayinputshapemapping_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkRegressor.arrayInputShapeMapping) } @@ -41743,7 +41746,7 @@ inline ::CoreML::Specification::NeuralNetworkImageShapeMapping NeuralNetworkRegr return static_cast< ::CoreML::Specification::NeuralNetworkImageShapeMapping >(imageinputshapemapping_); } inline void NeuralNetworkRegressor::set_imageinputshapemapping(::CoreML::Specification::NeuralNetworkImageShapeMapping value) { - + imageinputshapemapping_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NeuralNetworkRegressor.imageInputShapeMapping) } @@ -41762,7 +41765,7 @@ inline const ::CoreML::Specification::NetworkUpdateParameters& NeuralNetworkRegr : *::CoreML::Specification::NetworkUpdateParameters::internal_default_instance(); } inline ::CoreML::Specification::NetworkUpdateParameters* NeuralNetworkRegressor::mutable_updateparams() { - + if (updateparams_ == NULL) { updateparams_ = new ::CoreML::Specification::NetworkUpdateParameters; } @@ -41771,7 +41774,7 @@ inline ::CoreML::Specification::NetworkUpdateParameters* NeuralNetworkRegressor: } inline ::CoreML::Specification::NetworkUpdateParameters* NeuralNetworkRegressor::release_updateparams() { // @@protoc_insertion_point(field_release:CoreML.Specification.NeuralNetworkRegressor.updateParams) - + ::CoreML::Specification::NetworkUpdateParameters* temp = updateparams_; updateparams_ = NULL; return temp; @@ -41780,9 +41783,9 @@ inline void NeuralNetworkRegressor::set_allocated_updateparams(::CoreML::Specifi delete updateparams_; updateparams_ = updateparams; if (updateparams) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NeuralNetworkRegressor.updateParams) } @@ -41835,7 +41838,7 @@ inline const ::CoreML::Specification::Optimizer& NetworkUpdateParameters::optimi : *::CoreML::Specification::Optimizer::internal_default_instance(); } inline ::CoreML::Specification::Optimizer* NetworkUpdateParameters::mutable_optimizer() { - + if (optimizer_ == NULL) { optimizer_ = new ::CoreML::Specification::Optimizer; } @@ -41844,7 +41847,7 @@ inline ::CoreML::Specification::Optimizer* NetworkUpdateParameters::mutable_opti } inline ::CoreML::Specification::Optimizer* NetworkUpdateParameters::release_optimizer() { // @@protoc_insertion_point(field_release:CoreML.Specification.NetworkUpdateParameters.optimizer) - + ::CoreML::Specification::Optimizer* temp = optimizer_; optimizer_ = NULL; return temp; @@ -41853,9 +41856,9 @@ inline void NetworkUpdateParameters::set_allocated_optimizer(::CoreML::Specifica delete optimizer_; optimizer_ = optimizer; if (optimizer) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NetworkUpdateParameters.optimizer) } @@ -41874,7 +41877,7 @@ inline const ::CoreML::Specification::Int64Parameter& NetworkUpdateParameters::e : *::CoreML::Specification::Int64Parameter::internal_default_instance(); } inline ::CoreML::Specification::Int64Parameter* NetworkUpdateParameters::mutable_epochs() { - + if (epochs_ == NULL) { epochs_ = new ::CoreML::Specification::Int64Parameter; } @@ -41883,7 +41886,7 @@ inline ::CoreML::Specification::Int64Parameter* NetworkUpdateParameters::mutable } inline ::CoreML::Specification::Int64Parameter* NetworkUpdateParameters::release_epochs() { // @@protoc_insertion_point(field_release:CoreML.Specification.NetworkUpdateParameters.epochs) - + ::CoreML::Specification::Int64Parameter* temp = epochs_; epochs_ = NULL; return temp; @@ -41892,9 +41895,9 @@ inline void NetworkUpdateParameters::set_allocated_epochs(::CoreML::Specificatio delete epochs_; epochs_ = epochs; if (epochs) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NetworkUpdateParameters.epochs) } @@ -41913,7 +41916,7 @@ inline const ::CoreML::Specification::BoolParameter& NetworkUpdateParameters::sh : *::CoreML::Specification::BoolParameter::internal_default_instance(); } inline ::CoreML::Specification::BoolParameter* NetworkUpdateParameters::mutable_shuffle() { - + if (shuffle_ == NULL) { shuffle_ = new ::CoreML::Specification::BoolParameter; } @@ -41922,7 +41925,7 @@ inline ::CoreML::Specification::BoolParameter* NetworkUpdateParameters::mutable_ } inline ::CoreML::Specification::BoolParameter* NetworkUpdateParameters::release_shuffle() { // @@protoc_insertion_point(field_release:CoreML.Specification.NetworkUpdateParameters.shuffle) - + ::CoreML::Specification::BoolParameter* temp = shuffle_; shuffle_ = NULL; return temp; @@ -41931,9 +41934,9 @@ inline void NetworkUpdateParameters::set_allocated_shuffle(::CoreML::Specificati delete shuffle_; shuffle_ = shuffle; if (shuffle) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NetworkUpdateParameters.shuffle) } @@ -41952,7 +41955,7 @@ inline const ::CoreML::Specification::Int64Parameter& NetworkUpdateParameters::s : *::CoreML::Specification::Int64Parameter::internal_default_instance(); } inline ::CoreML::Specification::Int64Parameter* NetworkUpdateParameters::mutable_seed() { - + if (seed_ == NULL) { seed_ = new ::CoreML::Specification::Int64Parameter; } @@ -41961,7 +41964,7 @@ inline ::CoreML::Specification::Int64Parameter* NetworkUpdateParameters::mutable } inline ::CoreML::Specification::Int64Parameter* NetworkUpdateParameters::release_seed() { // @@protoc_insertion_point(field_release:CoreML.Specification.NetworkUpdateParameters.seed) - + ::CoreML::Specification::Int64Parameter* temp = seed_; seed_ = NULL; return temp; @@ -41970,9 +41973,9 @@ inline void NetworkUpdateParameters::set_allocated_seed(::CoreML::Specification: delete seed_; seed_ = seed; if (seed) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NetworkUpdateParameters.seed) } @@ -41990,13 +41993,13 @@ inline const ::std::string& LossLayer::name() const { return name_.GetNoArena(); } inline void LossLayer::set_name(const ::std::string& value) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.LossLayer.name) } #if LANG_CXX11 inline void LossLayer::set_name(::std::string&& value) { - + name_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.LossLayer.name) @@ -42004,31 +42007,31 @@ inline void LossLayer::set_name(::std::string&& value) { #endif inline void LossLayer::set_name(const char* value) { GOOGLE_DCHECK(value != NULL); - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.LossLayer.name) } inline void LossLayer::set_name(const char* value, size_t size) { - + name_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.LossLayer.name) } inline ::std::string* LossLayer::mutable_name() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.LossLayer.name) return name_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* LossLayer::release_name() { // @@protoc_insertion_point(field_release:CoreML.Specification.LossLayer.name) - + return name_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void LossLayer::set_allocated_name(::std::string* name) { if (name != NULL) { - + } else { - + } name_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), name); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.LossLayer.name) @@ -42152,13 +42155,13 @@ inline const ::std::string& CategoricalCrossEntropyLossLayer::input() const { return input_.GetNoArena(); } inline void CategoricalCrossEntropyLossLayer::set_input(const ::std::string& value) { - + input_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CategoricalCrossEntropyLossLayer.input) } #if LANG_CXX11 inline void CategoricalCrossEntropyLossLayer::set_input(::std::string&& value) { - + input_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CategoricalCrossEntropyLossLayer.input) @@ -42166,31 +42169,31 @@ inline void CategoricalCrossEntropyLossLayer::set_input(::std::string&& value) { #endif inline void CategoricalCrossEntropyLossLayer::set_input(const char* value) { GOOGLE_DCHECK(value != NULL); - + input_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CategoricalCrossEntropyLossLayer.input) } inline void CategoricalCrossEntropyLossLayer::set_input(const char* value, size_t size) { - + input_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CategoricalCrossEntropyLossLayer.input) } inline ::std::string* CategoricalCrossEntropyLossLayer::mutable_input() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CategoricalCrossEntropyLossLayer.input) return input_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* CategoricalCrossEntropyLossLayer::release_input() { // @@protoc_insertion_point(field_release:CoreML.Specification.CategoricalCrossEntropyLossLayer.input) - + return input_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void CategoricalCrossEntropyLossLayer::set_allocated_input(::std::string* input) { if (input != NULL) { - + } else { - + } input_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), input); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CategoricalCrossEntropyLossLayer.input) @@ -42205,13 +42208,13 @@ inline const ::std::string& CategoricalCrossEntropyLossLayer::target() const { return target_.GetNoArena(); } inline void CategoricalCrossEntropyLossLayer::set_target(const ::std::string& value) { - + target_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CategoricalCrossEntropyLossLayer.target) } #if LANG_CXX11 inline void CategoricalCrossEntropyLossLayer::set_target(::std::string&& value) { - + target_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CategoricalCrossEntropyLossLayer.target) @@ -42219,31 +42222,31 @@ inline void CategoricalCrossEntropyLossLayer::set_target(::std::string&& value) #endif inline void CategoricalCrossEntropyLossLayer::set_target(const char* value) { GOOGLE_DCHECK(value != NULL); - + target_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CategoricalCrossEntropyLossLayer.target) } inline void CategoricalCrossEntropyLossLayer::set_target(const char* value, size_t size) { - + target_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CategoricalCrossEntropyLossLayer.target) } inline ::std::string* CategoricalCrossEntropyLossLayer::mutable_target() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CategoricalCrossEntropyLossLayer.target) return target_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* CategoricalCrossEntropyLossLayer::release_target() { // @@protoc_insertion_point(field_release:CoreML.Specification.CategoricalCrossEntropyLossLayer.target) - + return target_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void CategoricalCrossEntropyLossLayer::set_allocated_target(::std::string* target) { if (target != NULL) { - + } else { - + } target_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), target); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CategoricalCrossEntropyLossLayer.target) @@ -42262,13 +42265,13 @@ inline const ::std::string& MeanSquaredErrorLossLayer::input() const { return input_.GetNoArena(); } inline void MeanSquaredErrorLossLayer::set_input(const ::std::string& value) { - + input_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MeanSquaredErrorLossLayer.input) } #if LANG_CXX11 inline void MeanSquaredErrorLossLayer::set_input(::std::string&& value) { - + input_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MeanSquaredErrorLossLayer.input) @@ -42276,31 +42279,31 @@ inline void MeanSquaredErrorLossLayer::set_input(::std::string&& value) { #endif inline void MeanSquaredErrorLossLayer::set_input(const char* value) { GOOGLE_DCHECK(value != NULL); - + input_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MeanSquaredErrorLossLayer.input) } inline void MeanSquaredErrorLossLayer::set_input(const char* value, size_t size) { - + input_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MeanSquaredErrorLossLayer.input) } inline ::std::string* MeanSquaredErrorLossLayer::mutable_input() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MeanSquaredErrorLossLayer.input) return input_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* MeanSquaredErrorLossLayer::release_input() { // @@protoc_insertion_point(field_release:CoreML.Specification.MeanSquaredErrorLossLayer.input) - + return input_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void MeanSquaredErrorLossLayer::set_allocated_input(::std::string* input) { if (input != NULL) { - + } else { - + } input_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), input); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MeanSquaredErrorLossLayer.input) @@ -42315,13 +42318,13 @@ inline const ::std::string& MeanSquaredErrorLossLayer::target() const { return target_.GetNoArena(); } inline void MeanSquaredErrorLossLayer::set_target(const ::std::string& value) { - + target_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.MeanSquaredErrorLossLayer.target) } #if LANG_CXX11 inline void MeanSquaredErrorLossLayer::set_target(::std::string&& value) { - + target_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.MeanSquaredErrorLossLayer.target) @@ -42329,31 +42332,31 @@ inline void MeanSquaredErrorLossLayer::set_target(::std::string&& value) { #endif inline void MeanSquaredErrorLossLayer::set_target(const char* value) { GOOGLE_DCHECK(value != NULL); - + target_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.MeanSquaredErrorLossLayer.target) } inline void MeanSquaredErrorLossLayer::set_target(const char* value, size_t size) { - + target_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.MeanSquaredErrorLossLayer.target) } inline ::std::string* MeanSquaredErrorLossLayer::mutable_target() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.MeanSquaredErrorLossLayer.target) return target_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* MeanSquaredErrorLossLayer::release_target() { // @@protoc_insertion_point(field_release:CoreML.Specification.MeanSquaredErrorLossLayer.target) - + return target_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void MeanSquaredErrorLossLayer::set_allocated_target(::std::string* target) { if (target != NULL) { - + } else { - + } target_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), target); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.MeanSquaredErrorLossLayer.target) @@ -42486,7 +42489,7 @@ inline const ::CoreML::Specification::DoubleParameter& SGDOptimizer::learningrat : *::CoreML::Specification::DoubleParameter::internal_default_instance(); } inline ::CoreML::Specification::DoubleParameter* SGDOptimizer::mutable_learningrate() { - + if (learningrate_ == NULL) { learningrate_ = new ::CoreML::Specification::DoubleParameter; } @@ -42495,7 +42498,7 @@ inline ::CoreML::Specification::DoubleParameter* SGDOptimizer::mutable_learningr } inline ::CoreML::Specification::DoubleParameter* SGDOptimizer::release_learningrate() { // @@protoc_insertion_point(field_release:CoreML.Specification.SGDOptimizer.learningRate) - + ::CoreML::Specification::DoubleParameter* temp = learningrate_; learningrate_ = NULL; return temp; @@ -42504,9 +42507,9 @@ inline void SGDOptimizer::set_allocated_learningrate(::CoreML::Specification::Do delete learningrate_; learningrate_ = learningrate; if (learningrate) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SGDOptimizer.learningRate) } @@ -42525,7 +42528,7 @@ inline const ::CoreML::Specification::Int64Parameter& SGDOptimizer::minibatchsiz : *::CoreML::Specification::Int64Parameter::internal_default_instance(); } inline ::CoreML::Specification::Int64Parameter* SGDOptimizer::mutable_minibatchsize() { - + if (minibatchsize_ == NULL) { minibatchsize_ = new ::CoreML::Specification::Int64Parameter; } @@ -42534,7 +42537,7 @@ inline ::CoreML::Specification::Int64Parameter* SGDOptimizer::mutable_minibatchs } inline ::CoreML::Specification::Int64Parameter* SGDOptimizer::release_minibatchsize() { // @@protoc_insertion_point(field_release:CoreML.Specification.SGDOptimizer.miniBatchSize) - + ::CoreML::Specification::Int64Parameter* temp = minibatchsize_; minibatchsize_ = NULL; return temp; @@ -42543,9 +42546,9 @@ inline void SGDOptimizer::set_allocated_minibatchsize(::CoreML::Specification::I delete minibatchsize_; minibatchsize_ = minibatchsize; if (minibatchsize) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SGDOptimizer.miniBatchSize) } @@ -42564,7 +42567,7 @@ inline const ::CoreML::Specification::DoubleParameter& SGDOptimizer::momentum() : *::CoreML::Specification::DoubleParameter::internal_default_instance(); } inline ::CoreML::Specification::DoubleParameter* SGDOptimizer::mutable_momentum() { - + if (momentum_ == NULL) { momentum_ = new ::CoreML::Specification::DoubleParameter; } @@ -42573,7 +42576,7 @@ inline ::CoreML::Specification::DoubleParameter* SGDOptimizer::mutable_momentum( } inline ::CoreML::Specification::DoubleParameter* SGDOptimizer::release_momentum() { // @@protoc_insertion_point(field_release:CoreML.Specification.SGDOptimizer.momentum) - + ::CoreML::Specification::DoubleParameter* temp = momentum_; momentum_ = NULL; return temp; @@ -42582,9 +42585,9 @@ inline void SGDOptimizer::set_allocated_momentum(::CoreML::Specification::Double delete momentum_; momentum_ = momentum; if (momentum) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SGDOptimizer.momentum) } @@ -42607,7 +42610,7 @@ inline const ::CoreML::Specification::DoubleParameter& AdamOptimizer::learningra : *::CoreML::Specification::DoubleParameter::internal_default_instance(); } inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::mutable_learningrate() { - + if (learningrate_ == NULL) { learningrate_ = new ::CoreML::Specification::DoubleParameter; } @@ -42616,7 +42619,7 @@ inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::mutable_learning } inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::release_learningrate() { // @@protoc_insertion_point(field_release:CoreML.Specification.AdamOptimizer.learningRate) - + ::CoreML::Specification::DoubleParameter* temp = learningrate_; learningrate_ = NULL; return temp; @@ -42625,9 +42628,9 @@ inline void AdamOptimizer::set_allocated_learningrate(::CoreML::Specification::D delete learningrate_; learningrate_ = learningrate; if (learningrate) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.AdamOptimizer.learningRate) } @@ -42646,7 +42649,7 @@ inline const ::CoreML::Specification::Int64Parameter& AdamOptimizer::minibatchsi : *::CoreML::Specification::Int64Parameter::internal_default_instance(); } inline ::CoreML::Specification::Int64Parameter* AdamOptimizer::mutable_minibatchsize() { - + if (minibatchsize_ == NULL) { minibatchsize_ = new ::CoreML::Specification::Int64Parameter; } @@ -42655,7 +42658,7 @@ inline ::CoreML::Specification::Int64Parameter* AdamOptimizer::mutable_minibatch } inline ::CoreML::Specification::Int64Parameter* AdamOptimizer::release_minibatchsize() { // @@protoc_insertion_point(field_release:CoreML.Specification.AdamOptimizer.miniBatchSize) - + ::CoreML::Specification::Int64Parameter* temp = minibatchsize_; minibatchsize_ = NULL; return temp; @@ -42664,9 +42667,9 @@ inline void AdamOptimizer::set_allocated_minibatchsize(::CoreML::Specification:: delete minibatchsize_; minibatchsize_ = minibatchsize; if (minibatchsize) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.AdamOptimizer.miniBatchSize) } @@ -42685,7 +42688,7 @@ inline const ::CoreML::Specification::DoubleParameter& AdamOptimizer::beta1() co : *::CoreML::Specification::DoubleParameter::internal_default_instance(); } inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::mutable_beta1() { - + if (beta1_ == NULL) { beta1_ = new ::CoreML::Specification::DoubleParameter; } @@ -42694,7 +42697,7 @@ inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::mutable_beta1() } inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::release_beta1() { // @@protoc_insertion_point(field_release:CoreML.Specification.AdamOptimizer.beta1) - + ::CoreML::Specification::DoubleParameter* temp = beta1_; beta1_ = NULL; return temp; @@ -42703,9 +42706,9 @@ inline void AdamOptimizer::set_allocated_beta1(::CoreML::Specification::DoublePa delete beta1_; beta1_ = beta1; if (beta1) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.AdamOptimizer.beta1) } @@ -42724,7 +42727,7 @@ inline const ::CoreML::Specification::DoubleParameter& AdamOptimizer::beta2() co : *::CoreML::Specification::DoubleParameter::internal_default_instance(); } inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::mutable_beta2() { - + if (beta2_ == NULL) { beta2_ = new ::CoreML::Specification::DoubleParameter; } @@ -42733,7 +42736,7 @@ inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::mutable_beta2() } inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::release_beta2() { // @@protoc_insertion_point(field_release:CoreML.Specification.AdamOptimizer.beta2) - + ::CoreML::Specification::DoubleParameter* temp = beta2_; beta2_ = NULL; return temp; @@ -42742,9 +42745,9 @@ inline void AdamOptimizer::set_allocated_beta2(::CoreML::Specification::DoublePa delete beta2_; beta2_ = beta2; if (beta2) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.AdamOptimizer.beta2) } @@ -42763,7 +42766,7 @@ inline const ::CoreML::Specification::DoubleParameter& AdamOptimizer::eps() cons : *::CoreML::Specification::DoubleParameter::internal_default_instance(); } inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::mutable_eps() { - + if (eps_ == NULL) { eps_ = new ::CoreML::Specification::DoubleParameter; } @@ -42772,7 +42775,7 @@ inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::mutable_eps() { } inline ::CoreML::Specification::DoubleParameter* AdamOptimizer::release_eps() { // @@protoc_insertion_point(field_release:CoreML.Specification.AdamOptimizer.eps) - + ::CoreML::Specification::DoubleParameter* temp = eps_; eps_ = NULL; return temp; @@ -42781,9 +42784,9 @@ inline void AdamOptimizer::set_allocated_eps(::CoreML::Specification::DoublePara delete eps_; eps_ = eps; if (eps) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.AdamOptimizer.eps) } diff --git a/mlmodel/build/format/NonMaximumSuppression.pb.h b/mlmodel/build/format/NonMaximumSuppression.pb.h index 30b7d5e73..3f53bd1e8 100644 --- a/mlmodel/build/format/NonMaximumSuppression.pb.h +++ b/mlmodel/build/format/NonMaximumSuppression.pb.h @@ -113,6 +113,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -485,7 +488,7 @@ inline bool NonMaximumSuppression_PickTop::perclass() const { return perclass_; } inline void NonMaximumSuppression_PickTop::set_perclass(bool value) { - + perclass_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppression.PickTop.perClass) } @@ -647,7 +650,7 @@ inline double NonMaximumSuppression::iouthreshold() const { return iouthreshold_; } inline void NonMaximumSuppression::set_iouthreshold(double value) { - + iouthreshold_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppression.iouThreshold) } @@ -661,7 +664,7 @@ inline double NonMaximumSuppression::confidencethreshold() const { return confidencethreshold_; } inline void NonMaximumSuppression::set_confidencethreshold(double value) { - + confidencethreshold_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppression.confidenceThreshold) } @@ -675,13 +678,13 @@ inline const ::std::string& NonMaximumSuppression::confidenceinputfeaturename() return confidenceinputfeaturename_.GetNoArena(); } inline void NonMaximumSuppression::set_confidenceinputfeaturename(const ::std::string& value) { - + confidenceinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppression.confidenceInputFeatureName) } #if LANG_CXX11 inline void NonMaximumSuppression::set_confidenceinputfeaturename(::std::string&& value) { - + confidenceinputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.NonMaximumSuppression.confidenceInputFeatureName) @@ -689,31 +692,31 @@ inline void NonMaximumSuppression::set_confidenceinputfeaturename(::std::string& #endif inline void NonMaximumSuppression::set_confidenceinputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + confidenceinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.NonMaximumSuppression.confidenceInputFeatureName) } inline void NonMaximumSuppression::set_confidenceinputfeaturename(const char* value, size_t size) { - + confidenceinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.NonMaximumSuppression.confidenceInputFeatureName) } inline ::std::string* NonMaximumSuppression::mutable_confidenceinputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.NonMaximumSuppression.confidenceInputFeatureName) return confidenceinputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* NonMaximumSuppression::release_confidenceinputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.NonMaximumSuppression.confidenceInputFeatureName) - + return confidenceinputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void NonMaximumSuppression::set_allocated_confidenceinputfeaturename(::std::string* confidenceinputfeaturename) { if (confidenceinputfeaturename != NULL) { - + } else { - + } confidenceinputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), confidenceinputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NonMaximumSuppression.confidenceInputFeatureName) @@ -728,13 +731,13 @@ inline const ::std::string& NonMaximumSuppression::coordinatesinputfeaturename() return coordinatesinputfeaturename_.GetNoArena(); } inline void NonMaximumSuppression::set_coordinatesinputfeaturename(const ::std::string& value) { - + coordinatesinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppression.coordinatesInputFeatureName) } #if LANG_CXX11 inline void NonMaximumSuppression::set_coordinatesinputfeaturename(::std::string&& value) { - + coordinatesinputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.NonMaximumSuppression.coordinatesInputFeatureName) @@ -742,31 +745,31 @@ inline void NonMaximumSuppression::set_coordinatesinputfeaturename(::std::string #endif inline void NonMaximumSuppression::set_coordinatesinputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + coordinatesinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.NonMaximumSuppression.coordinatesInputFeatureName) } inline void NonMaximumSuppression::set_coordinatesinputfeaturename(const char* value, size_t size) { - + coordinatesinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.NonMaximumSuppression.coordinatesInputFeatureName) } inline ::std::string* NonMaximumSuppression::mutable_coordinatesinputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.NonMaximumSuppression.coordinatesInputFeatureName) return coordinatesinputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* NonMaximumSuppression::release_coordinatesinputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.NonMaximumSuppression.coordinatesInputFeatureName) - + return coordinatesinputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void NonMaximumSuppression::set_allocated_coordinatesinputfeaturename(::std::string* coordinatesinputfeaturename) { if (coordinatesinputfeaturename != NULL) { - + } else { - + } coordinatesinputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), coordinatesinputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NonMaximumSuppression.coordinatesInputFeatureName) @@ -781,13 +784,13 @@ inline const ::std::string& NonMaximumSuppression::iouthresholdinputfeaturename( return iouthresholdinputfeaturename_.GetNoArena(); } inline void NonMaximumSuppression::set_iouthresholdinputfeaturename(const ::std::string& value) { - + iouthresholdinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppression.iouThresholdInputFeatureName) } #if LANG_CXX11 inline void NonMaximumSuppression::set_iouthresholdinputfeaturename(::std::string&& value) { - + iouthresholdinputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.NonMaximumSuppression.iouThresholdInputFeatureName) @@ -795,31 +798,31 @@ inline void NonMaximumSuppression::set_iouthresholdinputfeaturename(::std::strin #endif inline void NonMaximumSuppression::set_iouthresholdinputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + iouthresholdinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.NonMaximumSuppression.iouThresholdInputFeatureName) } inline void NonMaximumSuppression::set_iouthresholdinputfeaturename(const char* value, size_t size) { - + iouthresholdinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.NonMaximumSuppression.iouThresholdInputFeatureName) } inline ::std::string* NonMaximumSuppression::mutable_iouthresholdinputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.NonMaximumSuppression.iouThresholdInputFeatureName) return iouthresholdinputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* NonMaximumSuppression::release_iouthresholdinputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.NonMaximumSuppression.iouThresholdInputFeatureName) - + return iouthresholdinputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void NonMaximumSuppression::set_allocated_iouthresholdinputfeaturename(::std::string* iouthresholdinputfeaturename) { if (iouthresholdinputfeaturename != NULL) { - + } else { - + } iouthresholdinputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), iouthresholdinputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NonMaximumSuppression.iouThresholdInputFeatureName) @@ -834,13 +837,13 @@ inline const ::std::string& NonMaximumSuppression::confidencethresholdinputfeatu return confidencethresholdinputfeaturename_.GetNoArena(); } inline void NonMaximumSuppression::set_confidencethresholdinputfeaturename(const ::std::string& value) { - + confidencethresholdinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppression.confidenceThresholdInputFeatureName) } #if LANG_CXX11 inline void NonMaximumSuppression::set_confidencethresholdinputfeaturename(::std::string&& value) { - + confidencethresholdinputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.NonMaximumSuppression.confidenceThresholdInputFeatureName) @@ -848,31 +851,31 @@ inline void NonMaximumSuppression::set_confidencethresholdinputfeaturename(::std #endif inline void NonMaximumSuppression::set_confidencethresholdinputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + confidencethresholdinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.NonMaximumSuppression.confidenceThresholdInputFeatureName) } inline void NonMaximumSuppression::set_confidencethresholdinputfeaturename(const char* value, size_t size) { - + confidencethresholdinputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.NonMaximumSuppression.confidenceThresholdInputFeatureName) } inline ::std::string* NonMaximumSuppression::mutable_confidencethresholdinputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.NonMaximumSuppression.confidenceThresholdInputFeatureName) return confidencethresholdinputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* NonMaximumSuppression::release_confidencethresholdinputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.NonMaximumSuppression.confidenceThresholdInputFeatureName) - + return confidencethresholdinputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void NonMaximumSuppression::set_allocated_confidencethresholdinputfeaturename(::std::string* confidencethresholdinputfeaturename) { if (confidencethresholdinputfeaturename != NULL) { - + } else { - + } confidencethresholdinputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), confidencethresholdinputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NonMaximumSuppression.confidenceThresholdInputFeatureName) @@ -887,13 +890,13 @@ inline const ::std::string& NonMaximumSuppression::confidenceoutputfeaturename() return confidenceoutputfeaturename_.GetNoArena(); } inline void NonMaximumSuppression::set_confidenceoutputfeaturename(const ::std::string& value) { - + confidenceoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppression.confidenceOutputFeatureName) } #if LANG_CXX11 inline void NonMaximumSuppression::set_confidenceoutputfeaturename(::std::string&& value) { - + confidenceoutputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.NonMaximumSuppression.confidenceOutputFeatureName) @@ -901,31 +904,31 @@ inline void NonMaximumSuppression::set_confidenceoutputfeaturename(::std::string #endif inline void NonMaximumSuppression::set_confidenceoutputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + confidenceoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.NonMaximumSuppression.confidenceOutputFeatureName) } inline void NonMaximumSuppression::set_confidenceoutputfeaturename(const char* value, size_t size) { - + confidenceoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.NonMaximumSuppression.confidenceOutputFeatureName) } inline ::std::string* NonMaximumSuppression::mutable_confidenceoutputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.NonMaximumSuppression.confidenceOutputFeatureName) return confidenceoutputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* NonMaximumSuppression::release_confidenceoutputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.NonMaximumSuppression.confidenceOutputFeatureName) - + return confidenceoutputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void NonMaximumSuppression::set_allocated_confidenceoutputfeaturename(::std::string* confidenceoutputfeaturename) { if (confidenceoutputfeaturename != NULL) { - + } else { - + } confidenceoutputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), confidenceoutputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NonMaximumSuppression.confidenceOutputFeatureName) @@ -940,13 +943,13 @@ inline const ::std::string& NonMaximumSuppression::coordinatesoutputfeaturename( return coordinatesoutputfeaturename_.GetNoArena(); } inline void NonMaximumSuppression::set_coordinatesoutputfeaturename(const ::std::string& value) { - + coordinatesoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.NonMaximumSuppression.coordinatesOutputFeatureName) } #if LANG_CXX11 inline void NonMaximumSuppression::set_coordinatesoutputfeaturename(::std::string&& value) { - + coordinatesoutputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.NonMaximumSuppression.coordinatesOutputFeatureName) @@ -954,31 +957,31 @@ inline void NonMaximumSuppression::set_coordinatesoutputfeaturename(::std::strin #endif inline void NonMaximumSuppression::set_coordinatesoutputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + coordinatesoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.NonMaximumSuppression.coordinatesOutputFeatureName) } inline void NonMaximumSuppression::set_coordinatesoutputfeaturename(const char* value, size_t size) { - + coordinatesoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.NonMaximumSuppression.coordinatesOutputFeatureName) } inline ::std::string* NonMaximumSuppression::mutable_coordinatesoutputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.NonMaximumSuppression.coordinatesOutputFeatureName) return coordinatesoutputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* NonMaximumSuppression::release_coordinatesoutputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.NonMaximumSuppression.coordinatesOutputFeatureName) - + return coordinatesoutputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void NonMaximumSuppression::set_allocated_coordinatesoutputfeaturename(::std::string* coordinatesoutputfeaturename) { if (coordinatesoutputfeaturename != NULL) { - + } else { - + } coordinatesoutputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), coordinatesoutputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.NonMaximumSuppression.coordinatesOutputFeatureName) diff --git a/mlmodel/build/format/OneHotEncoder.pb.h b/mlmodel/build/format/OneHotEncoder.pb.h index 2986e1748..1d4534f20 100644 --- a/mlmodel/build/format/OneHotEncoder.pb.h +++ b/mlmodel/build/format/OneHotEncoder.pb.h @@ -111,6 +111,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -411,7 +414,7 @@ inline bool OneHotEncoder::outputsparse() const { return outputsparse_; } inline void OneHotEncoder::set_outputsparse(bool value) { - + outputsparse_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.OneHotEncoder.outputSparse) } @@ -425,7 +428,7 @@ inline ::CoreML::Specification::OneHotEncoder_HandleUnknown OneHotEncoder::handl return static_cast< ::CoreML::Specification::OneHotEncoder_HandleUnknown >(handleunknown_); } inline void OneHotEncoder::set_handleunknown(::CoreML::Specification::OneHotEncoder_HandleUnknown value) { - + handleunknown_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.OneHotEncoder.handleUnknown) } diff --git a/mlmodel/build/format/Parameters.pb.h b/mlmodel/build/format/Parameters.pb.h index 0e708b014..0fe4aa206 100644 --- a/mlmodel/build/format/Parameters.pb.h +++ b/mlmodel/build/format/Parameters.pb.h @@ -116,6 +116,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -558,7 +561,7 @@ inline ::google::protobuf::int64 Int64Parameter::defaultvalue() const { return defaultvalue_; } inline void Int64Parameter::set_defaultvalue(::google::protobuf::int64 value) { - + defaultvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.Int64Parameter.defaultValue) } @@ -681,7 +684,7 @@ inline double DoubleParameter::defaultvalue() const { return defaultvalue_; } inline void DoubleParameter::set_defaultvalue(double value) { - + defaultvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.DoubleParameter.defaultValue) } @@ -756,13 +759,13 @@ inline const ::std::string& StringParameter::defaultvalue() const { return defaultvalue_.GetNoArena(); } inline void StringParameter::set_defaultvalue(const ::std::string& value) { - + defaultvalue_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.StringParameter.defaultValue) } #if LANG_CXX11 inline void StringParameter::set_defaultvalue(::std::string&& value) { - + defaultvalue_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.StringParameter.defaultValue) @@ -770,31 +773,31 @@ inline void StringParameter::set_defaultvalue(::std::string&& value) { #endif inline void StringParameter::set_defaultvalue(const char* value) { GOOGLE_DCHECK(value != NULL); - + defaultvalue_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.StringParameter.defaultValue) } inline void StringParameter::set_defaultvalue(const char* value, size_t size) { - + defaultvalue_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.StringParameter.defaultValue) } inline ::std::string* StringParameter::mutable_defaultvalue() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.StringParameter.defaultValue) return defaultvalue_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* StringParameter::release_defaultvalue() { // @@protoc_insertion_point(field_release:CoreML.Specification.StringParameter.defaultValue) - + return defaultvalue_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void StringParameter::set_allocated_defaultvalue(::std::string* defaultvalue) { if (defaultvalue != NULL) { - + } else { - + } defaultvalue_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), defaultvalue); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.StringParameter.defaultValue) @@ -813,7 +816,7 @@ inline bool BoolParameter::defaultvalue() const { return defaultvalue_; } inline void BoolParameter::set_defaultvalue(bool value) { - + defaultvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.BoolParameter.defaultValue) } diff --git a/mlmodel/build/format/SVM.pb.h b/mlmodel/build/format/SVM.pb.h index 9be742495..1009309e6 100644 --- a/mlmodel/build/format/SVM.pb.h +++ b/mlmodel/build/format/SVM.pb.h @@ -140,6 +140,9 @@ extern SparseSupportVectorsDefaultTypeInternal _SparseSupportVectors_default_ins class SparseVector; class SparseVectorDefaultTypeInternal; extern SparseVectorDefaultTypeInternal _SparseVector_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -1531,7 +1534,7 @@ inline double RBFKernel::gamma() const { return gamma_; } inline void RBFKernel::set_gamma(double value) { - + gamma_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.RBFKernel.gamma) } @@ -1549,7 +1552,7 @@ inline ::google::protobuf::int32 PolyKernel::degree() const { return degree_; } inline void PolyKernel::set_degree(::google::protobuf::int32 value) { - + degree_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.PolyKernel.degree) } @@ -1563,7 +1566,7 @@ inline double PolyKernel::c() const { return c_; } inline void PolyKernel::set_c(double value) { - + c_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.PolyKernel.c) } @@ -1577,7 +1580,7 @@ inline double PolyKernel::gamma() const { return gamma_; } inline void PolyKernel::set_gamma(double value) { - + gamma_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.PolyKernel.gamma) } @@ -1595,7 +1598,7 @@ inline double SigmoidKernel::gamma() const { return gamma_; } inline void SigmoidKernel::set_gamma(double value) { - + gamma_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SigmoidKernel.gamma) } @@ -1609,7 +1612,7 @@ inline double SigmoidKernel::c() const { return c_; } inline void SigmoidKernel::set_c(double value) { - + c_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SigmoidKernel.c) } @@ -1832,7 +1835,7 @@ inline ::google::protobuf::int32 SparseNode::index() const { return index_; } inline void SparseNode::set_index(::google::protobuf::int32 value) { - + index_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SparseNode.index) } @@ -1846,7 +1849,7 @@ inline double SparseNode::value() const { return value_; } inline void SparseNode::set_value(double value) { - + value_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SparseNode.value) } @@ -2039,7 +2042,7 @@ inline const ::CoreML::Specification::Kernel& SupportVectorRegressor::kernel() c : *::CoreML::Specification::Kernel::internal_default_instance(); } inline ::CoreML::Specification::Kernel* SupportVectorRegressor::mutable_kernel() { - + if (kernel_ == NULL) { kernel_ = new ::CoreML::Specification::Kernel; } @@ -2048,7 +2051,7 @@ inline ::CoreML::Specification::Kernel* SupportVectorRegressor::mutable_kernel() } inline ::CoreML::Specification::Kernel* SupportVectorRegressor::release_kernel() { // @@protoc_insertion_point(field_release:CoreML.Specification.SupportVectorRegressor.kernel) - + ::CoreML::Specification::Kernel* temp = kernel_; kernel_ = NULL; return temp; @@ -2057,9 +2060,9 @@ inline void SupportVectorRegressor::set_allocated_kernel(::CoreML::Specification delete kernel_; kernel_ = kernel; if (kernel) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SupportVectorRegressor.kernel) } @@ -2174,7 +2177,7 @@ inline const ::CoreML::Specification::Coefficients& SupportVectorRegressor::coef : *::CoreML::Specification::Coefficients::internal_default_instance(); } inline ::CoreML::Specification::Coefficients* SupportVectorRegressor::mutable_coefficients() { - + if (coefficients_ == NULL) { coefficients_ = new ::CoreML::Specification::Coefficients; } @@ -2183,7 +2186,7 @@ inline ::CoreML::Specification::Coefficients* SupportVectorRegressor::mutable_co } inline ::CoreML::Specification::Coefficients* SupportVectorRegressor::release_coefficients() { // @@protoc_insertion_point(field_release:CoreML.Specification.SupportVectorRegressor.coefficients) - + ::CoreML::Specification::Coefficients* temp = coefficients_; coefficients_ = NULL; return temp; @@ -2192,9 +2195,9 @@ inline void SupportVectorRegressor::set_allocated_coefficients(::CoreML::Specifi delete coefficients_; coefficients_ = coefficients; if (coefficients) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SupportVectorRegressor.coefficients) } @@ -2208,7 +2211,7 @@ inline double SupportVectorRegressor::rho() const { return rho_; } inline void SupportVectorRegressor::set_rho(double value) { - + rho_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.SupportVectorRegressor.rho) } @@ -2240,7 +2243,7 @@ inline const ::CoreML::Specification::Kernel& SupportVectorClassifier::kernel() : *::CoreML::Specification::Kernel::internal_default_instance(); } inline ::CoreML::Specification::Kernel* SupportVectorClassifier::mutable_kernel() { - + if (kernel_ == NULL) { kernel_ = new ::CoreML::Specification::Kernel; } @@ -2249,7 +2252,7 @@ inline ::CoreML::Specification::Kernel* SupportVectorClassifier::mutable_kernel( } inline ::CoreML::Specification::Kernel* SupportVectorClassifier::release_kernel() { // @@protoc_insertion_point(field_release:CoreML.Specification.SupportVectorClassifier.kernel) - + ::CoreML::Specification::Kernel* temp = kernel_; kernel_ = NULL; return temp; @@ -2258,9 +2261,9 @@ inline void SupportVectorClassifier::set_allocated_kernel(::CoreML::Specificatio delete kernel_; kernel_ = kernel; if (kernel) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.SupportVectorClassifier.kernel) } diff --git a/mlmodel/build/format/TextClassifier.pb.h b/mlmodel/build/format/TextClassifier.pb.h index 1634d956d..024d61de9 100644 --- a/mlmodel/build/format/TextClassifier.pb.h +++ b/mlmodel/build/format/TextClassifier.pb.h @@ -107,6 +107,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -302,7 +305,7 @@ inline ::google::protobuf::uint32 TextClassifier::revision() const { return revision_; } inline void TextClassifier::set_revision(::google::protobuf::uint32 value) { - + revision_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.TextClassifier.revision) } @@ -316,13 +319,13 @@ inline const ::std::string& TextClassifier::language() const { return language_.GetNoArena(); } inline void TextClassifier::set_language(const ::std::string& value) { - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.TextClassifier.language) } #if LANG_CXX11 inline void TextClassifier::set_language(::std::string&& value) { - + language_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.TextClassifier.language) @@ -330,31 +333,31 @@ inline void TextClassifier::set_language(::std::string&& value) { #endif inline void TextClassifier::set_language(const char* value) { GOOGLE_DCHECK(value != NULL); - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.TextClassifier.language) } inline void TextClassifier::set_language(const char* value, size_t size) { - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.TextClassifier.language) } inline ::std::string* TextClassifier::mutable_language() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.TextClassifier.language) return language_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* TextClassifier::release_language() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.TextClassifier.language) - + return language_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void TextClassifier::set_allocated_language(::std::string* language) { if (language != NULL) { - + } else { - + } language_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), language); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.TextClassifier.language) @@ -369,13 +372,13 @@ inline const ::std::string& TextClassifier::modelparameterdata() const { return modelparameterdata_.GetNoArena(); } inline void TextClassifier::set_modelparameterdata(const ::std::string& value) { - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.TextClassifier.modelParameterData) } #if LANG_CXX11 inline void TextClassifier::set_modelparameterdata(::std::string&& value) { - + modelparameterdata_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.TextClassifier.modelParameterData) @@ -383,31 +386,31 @@ inline void TextClassifier::set_modelparameterdata(::std::string&& value) { #endif inline void TextClassifier::set_modelparameterdata(const char* value) { GOOGLE_DCHECK(value != NULL); - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.TextClassifier.modelParameterData) } inline void TextClassifier::set_modelparameterdata(const void* value, size_t size) { - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.TextClassifier.modelParameterData) } inline ::std::string* TextClassifier::mutable_modelparameterdata() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.TextClassifier.modelParameterData) return modelparameterdata_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* TextClassifier::release_modelparameterdata() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.TextClassifier.modelParameterData) - + return modelparameterdata_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void TextClassifier::set_allocated_modelparameterdata(::std::string* modelparameterdata) { if (modelparameterdata != NULL) { - + } else { - + } modelparameterdata_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), modelparameterdata); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.TextClassifier.modelParameterData) diff --git a/mlmodel/build/format/TreeEnsemble.pb.h b/mlmodel/build/format/TreeEnsemble.pb.h index a4e647a7d..b81e38576 100644 --- a/mlmodel/build/format/TreeEnsemble.pb.h +++ b/mlmodel/build/format/TreeEnsemble.pb.h @@ -108,6 +108,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -787,7 +790,7 @@ inline ::google::protobuf::uint64 TreeEnsembleParameters_TreeNode_EvaluationInfo return evaluationindex_; } inline void TreeEnsembleParameters_TreeNode_EvaluationInfo::set_evaluationindex(::google::protobuf::uint64 value) { - + evaluationindex_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.EvaluationInfo.evaluationIndex) } @@ -801,7 +804,7 @@ inline double TreeEnsembleParameters_TreeNode_EvaluationInfo::evaluationvalue() return evaluationvalue_; } inline void TreeEnsembleParameters_TreeNode_EvaluationInfo::set_evaluationvalue(double value) { - + evaluationvalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.EvaluationInfo.evaluationValue) } @@ -819,7 +822,7 @@ inline ::google::protobuf::uint64 TreeEnsembleParameters_TreeNode::treeid() cons return treeid_; } inline void TreeEnsembleParameters_TreeNode::set_treeid(::google::protobuf::uint64 value) { - + treeid_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.treeId) } @@ -833,7 +836,7 @@ inline ::google::protobuf::uint64 TreeEnsembleParameters_TreeNode::nodeid() cons return nodeid_; } inline void TreeEnsembleParameters_TreeNode::set_nodeid(::google::protobuf::uint64 value) { - + nodeid_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.nodeId) } @@ -847,7 +850,7 @@ inline ::CoreML::Specification::TreeEnsembleParameters_TreeNode_TreeNodeBehavior return static_cast< ::CoreML::Specification::TreeEnsembleParameters_TreeNode_TreeNodeBehavior >(nodebehavior_); } inline void TreeEnsembleParameters_TreeNode::set_nodebehavior(::CoreML::Specification::TreeEnsembleParameters_TreeNode_TreeNodeBehavior value) { - + nodebehavior_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.nodeBehavior) } @@ -861,7 +864,7 @@ inline ::google::protobuf::uint64 TreeEnsembleParameters_TreeNode::branchfeature return branchfeatureindex_; } inline void TreeEnsembleParameters_TreeNode::set_branchfeatureindex(::google::protobuf::uint64 value) { - + branchfeatureindex_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.branchFeatureIndex) } @@ -875,7 +878,7 @@ inline double TreeEnsembleParameters_TreeNode::branchfeaturevalue() const { return branchfeaturevalue_; } inline void TreeEnsembleParameters_TreeNode::set_branchfeaturevalue(double value) { - + branchfeaturevalue_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.branchFeatureValue) } @@ -889,7 +892,7 @@ inline ::google::protobuf::uint64 TreeEnsembleParameters_TreeNode::truechildnode return truechildnodeid_; } inline void TreeEnsembleParameters_TreeNode::set_truechildnodeid(::google::protobuf::uint64 value) { - + truechildnodeid_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.trueChildNodeId) } @@ -903,7 +906,7 @@ inline ::google::protobuf::uint64 TreeEnsembleParameters_TreeNode::falsechildnod return falsechildnodeid_; } inline void TreeEnsembleParameters_TreeNode::set_falsechildnodeid(::google::protobuf::uint64 value) { - + falsechildnodeid_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.falseChildNodeId) } @@ -917,7 +920,7 @@ inline bool TreeEnsembleParameters_TreeNode::missingvaluetrackstruechild() const return missingvaluetrackstruechild_; } inline void TreeEnsembleParameters_TreeNode::set_missingvaluetrackstruechild(bool value) { - + missingvaluetrackstruechild_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.missingValueTracksTrueChild) } @@ -961,7 +964,7 @@ inline double TreeEnsembleParameters_TreeNode::relativehitrate() const { return relativehitrate_; } inline void TreeEnsembleParameters_TreeNode::set_relativehitrate(double value) { - + relativehitrate_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.TreeNode.relativeHitRate) } @@ -1009,7 +1012,7 @@ inline ::google::protobuf::uint64 TreeEnsembleParameters::numpredictiondimension return numpredictiondimensions_; } inline void TreeEnsembleParameters::set_numpredictiondimensions(::google::protobuf::uint64 value) { - + numpredictiondimensions_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleParameters.numPredictionDimensions) } @@ -1062,7 +1065,7 @@ inline const ::CoreML::Specification::TreeEnsembleParameters& TreeEnsembleClassi : *::CoreML::Specification::TreeEnsembleParameters::internal_default_instance(); } inline ::CoreML::Specification::TreeEnsembleParameters* TreeEnsembleClassifier::mutable_treeensemble() { - + if (treeensemble_ == NULL) { treeensemble_ = new ::CoreML::Specification::TreeEnsembleParameters; } @@ -1071,7 +1074,7 @@ inline ::CoreML::Specification::TreeEnsembleParameters* TreeEnsembleClassifier:: } inline ::CoreML::Specification::TreeEnsembleParameters* TreeEnsembleClassifier::release_treeensemble() { // @@protoc_insertion_point(field_release:CoreML.Specification.TreeEnsembleClassifier.treeEnsemble) - + ::CoreML::Specification::TreeEnsembleParameters* temp = treeensemble_; treeensemble_ = NULL; return temp; @@ -1080,9 +1083,9 @@ inline void TreeEnsembleClassifier::set_allocated_treeensemble(::CoreML::Specifi delete treeensemble_; treeensemble_ = treeensemble; if (treeensemble) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.TreeEnsembleClassifier.treeEnsemble) } @@ -1096,7 +1099,7 @@ inline ::CoreML::Specification::TreeEnsemblePostEvaluationTransform TreeEnsemble return static_cast< ::CoreML::Specification::TreeEnsemblePostEvaluationTransform >(postevaluationtransform_); } inline void TreeEnsembleClassifier::set_postevaluationtransform(::CoreML::Specification::TreeEnsemblePostEvaluationTransform value) { - + postevaluationtransform_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleClassifier.postEvaluationTransform) } @@ -1224,7 +1227,7 @@ inline const ::CoreML::Specification::TreeEnsembleParameters& TreeEnsembleRegres : *::CoreML::Specification::TreeEnsembleParameters::internal_default_instance(); } inline ::CoreML::Specification::TreeEnsembleParameters* TreeEnsembleRegressor::mutable_treeensemble() { - + if (treeensemble_ == NULL) { treeensemble_ = new ::CoreML::Specification::TreeEnsembleParameters; } @@ -1233,7 +1236,7 @@ inline ::CoreML::Specification::TreeEnsembleParameters* TreeEnsembleRegressor::m } inline ::CoreML::Specification::TreeEnsembleParameters* TreeEnsembleRegressor::release_treeensemble() { // @@protoc_insertion_point(field_release:CoreML.Specification.TreeEnsembleRegressor.treeEnsemble) - + ::CoreML::Specification::TreeEnsembleParameters* temp = treeensemble_; treeensemble_ = NULL; return temp; @@ -1242,9 +1245,9 @@ inline void TreeEnsembleRegressor::set_allocated_treeensemble(::CoreML::Specific delete treeensemble_; treeensemble_ = treeensemble; if (treeensemble) { - + } else { - + } // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.TreeEnsembleRegressor.treeEnsemble) } @@ -1258,7 +1261,7 @@ inline ::CoreML::Specification::TreeEnsemblePostEvaluationTransform TreeEnsemble return static_cast< ::CoreML::Specification::TreeEnsemblePostEvaluationTransform >(postevaluationtransform_); } inline void TreeEnsembleRegressor::set_postevaluationtransform(::CoreML::Specification::TreeEnsemblePostEvaluationTransform value) { - + postevaluationtransform_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.TreeEnsembleRegressor.postEvaluationTransform) } diff --git a/mlmodel/build/format/WordEmbedding.pb.h b/mlmodel/build/format/WordEmbedding.pb.h index e9a3e74ce..bb7383442 100644 --- a/mlmodel/build/format/WordEmbedding.pb.h +++ b/mlmodel/build/format/WordEmbedding.pb.h @@ -107,6 +107,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -276,7 +279,7 @@ inline ::google::protobuf::uint32 WordEmbedding::revision() const { return revision_; } inline void WordEmbedding::set_revision(::google::protobuf::uint32 value) { - + revision_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.WordEmbedding.revision) } @@ -290,13 +293,13 @@ inline const ::std::string& WordEmbedding::language() const { return language_.GetNoArena(); } inline void WordEmbedding::set_language(const ::std::string& value) { - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.WordEmbedding.language) } #if LANG_CXX11 inline void WordEmbedding::set_language(::std::string&& value) { - + language_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.WordEmbedding.language) @@ -304,31 +307,31 @@ inline void WordEmbedding::set_language(::std::string&& value) { #endif inline void WordEmbedding::set_language(const char* value) { GOOGLE_DCHECK(value != NULL); - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.WordEmbedding.language) } inline void WordEmbedding::set_language(const char* value, size_t size) { - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.WordEmbedding.language) } inline ::std::string* WordEmbedding::mutable_language() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.WordEmbedding.language) return language_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WordEmbedding::release_language() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.WordEmbedding.language) - + return language_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WordEmbedding::set_allocated_language(::std::string* language) { if (language != NULL) { - + } else { - + } language_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), language); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.WordEmbedding.language) @@ -343,13 +346,13 @@ inline const ::std::string& WordEmbedding::modelparameterdata() const { return modelparameterdata_.GetNoArena(); } inline void WordEmbedding::set_modelparameterdata(const ::std::string& value) { - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.WordEmbedding.modelParameterData) } #if LANG_CXX11 inline void WordEmbedding::set_modelparameterdata(::std::string&& value) { - + modelparameterdata_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.WordEmbedding.modelParameterData) @@ -357,31 +360,31 @@ inline void WordEmbedding::set_modelparameterdata(::std::string&& value) { #endif inline void WordEmbedding::set_modelparameterdata(const char* value) { GOOGLE_DCHECK(value != NULL); - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.WordEmbedding.modelParameterData) } inline void WordEmbedding::set_modelparameterdata(const void* value, size_t size) { - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.WordEmbedding.modelParameterData) } inline ::std::string* WordEmbedding::mutable_modelparameterdata() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.WordEmbedding.modelParameterData) return modelparameterdata_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WordEmbedding::release_modelparameterdata() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.WordEmbedding.modelParameterData) - + return modelparameterdata_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WordEmbedding::set_allocated_modelparameterdata(::std::string* modelparameterdata) { if (modelparameterdata != NULL) { - + } else { - + } modelparameterdata_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), modelparameterdata); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.WordEmbedding.modelParameterData) diff --git a/mlmodel/build/format/WordTagger.pb.h b/mlmodel/build/format/WordTagger.pb.h index 64bc24642..b9a51acc2 100644 --- a/mlmodel/build/format/WordTagger.pb.h +++ b/mlmodel/build/format/WordTagger.pb.h @@ -107,6 +107,9 @@ extern SequenceFeatureTypeDefaultTypeInternal _SequenceFeatureType_default_insta class SizeRange; class SizeRangeDefaultTypeInternal; extern SizeRangeDefaultTypeInternal _SizeRange_default_instance_; +class StateFeatureType; +class StateFeatureTypeDefaultTypeInternal; +extern StateFeatureTypeDefaultTypeInternal _StateFeatureType_default_instance_; class StringFeatureType; class StringFeatureTypeDefaultTypeInternal; extern StringFeatureTypeDefaultTypeInternal _StringFeatureType_default_instance_; @@ -362,7 +365,7 @@ inline ::google::protobuf::uint32 WordTagger::revision() const { return revision_; } inline void WordTagger::set_revision(::google::protobuf::uint32 value) { - + revision_ = value; // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.WordTagger.revision) } @@ -376,13 +379,13 @@ inline const ::std::string& WordTagger::language() const { return language_.GetNoArena(); } inline void WordTagger::set_language(const ::std::string& value) { - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.WordTagger.language) } #if LANG_CXX11 inline void WordTagger::set_language(::std::string&& value) { - + language_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.WordTagger.language) @@ -390,31 +393,31 @@ inline void WordTagger::set_language(::std::string&& value) { #endif inline void WordTagger::set_language(const char* value) { GOOGLE_DCHECK(value != NULL); - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.WordTagger.language) } inline void WordTagger::set_language(const char* value, size_t size) { - + language_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.WordTagger.language) } inline ::std::string* WordTagger::mutable_language() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.WordTagger.language) return language_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WordTagger::release_language() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.WordTagger.language) - + return language_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WordTagger::set_allocated_language(::std::string* language) { if (language != NULL) { - + } else { - + } language_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), language); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.WordTagger.language) @@ -429,13 +432,13 @@ inline const ::std::string& WordTagger::tokensoutputfeaturename() const { return tokensoutputfeaturename_.GetNoArena(); } inline void WordTagger::set_tokensoutputfeaturename(const ::std::string& value) { - + tokensoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.WordTagger.tokensOutputFeatureName) } #if LANG_CXX11 inline void WordTagger::set_tokensoutputfeaturename(::std::string&& value) { - + tokensoutputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.WordTagger.tokensOutputFeatureName) @@ -443,31 +446,31 @@ inline void WordTagger::set_tokensoutputfeaturename(::std::string&& value) { #endif inline void WordTagger::set_tokensoutputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + tokensoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.WordTagger.tokensOutputFeatureName) } inline void WordTagger::set_tokensoutputfeaturename(const char* value, size_t size) { - + tokensoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.WordTagger.tokensOutputFeatureName) } inline ::std::string* WordTagger::mutable_tokensoutputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.WordTagger.tokensOutputFeatureName) return tokensoutputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WordTagger::release_tokensoutputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.WordTagger.tokensOutputFeatureName) - + return tokensoutputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WordTagger::set_allocated_tokensoutputfeaturename(::std::string* tokensoutputfeaturename) { if (tokensoutputfeaturename != NULL) { - + } else { - + } tokensoutputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), tokensoutputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.WordTagger.tokensOutputFeatureName) @@ -482,13 +485,13 @@ inline const ::std::string& WordTagger::tokentagsoutputfeaturename() const { return tokentagsoutputfeaturename_.GetNoArena(); } inline void WordTagger::set_tokentagsoutputfeaturename(const ::std::string& value) { - + tokentagsoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.WordTagger.tokenTagsOutputFeatureName) } #if LANG_CXX11 inline void WordTagger::set_tokentagsoutputfeaturename(::std::string&& value) { - + tokentagsoutputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.WordTagger.tokenTagsOutputFeatureName) @@ -496,31 +499,31 @@ inline void WordTagger::set_tokentagsoutputfeaturename(::std::string&& value) { #endif inline void WordTagger::set_tokentagsoutputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + tokentagsoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.WordTagger.tokenTagsOutputFeatureName) } inline void WordTagger::set_tokentagsoutputfeaturename(const char* value, size_t size) { - + tokentagsoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.WordTagger.tokenTagsOutputFeatureName) } inline ::std::string* WordTagger::mutable_tokentagsoutputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.WordTagger.tokenTagsOutputFeatureName) return tokentagsoutputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WordTagger::release_tokentagsoutputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.WordTagger.tokenTagsOutputFeatureName) - + return tokentagsoutputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WordTagger::set_allocated_tokentagsoutputfeaturename(::std::string* tokentagsoutputfeaturename) { if (tokentagsoutputfeaturename != NULL) { - + } else { - + } tokentagsoutputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), tokentagsoutputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.WordTagger.tokenTagsOutputFeatureName) @@ -535,13 +538,13 @@ inline const ::std::string& WordTagger::tokenlocationsoutputfeaturename() const return tokenlocationsoutputfeaturename_.GetNoArena(); } inline void WordTagger::set_tokenlocationsoutputfeaturename(const ::std::string& value) { - + tokenlocationsoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.WordTagger.tokenLocationsOutputFeatureName) } #if LANG_CXX11 inline void WordTagger::set_tokenlocationsoutputfeaturename(::std::string&& value) { - + tokenlocationsoutputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.WordTagger.tokenLocationsOutputFeatureName) @@ -549,31 +552,31 @@ inline void WordTagger::set_tokenlocationsoutputfeaturename(::std::string&& valu #endif inline void WordTagger::set_tokenlocationsoutputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + tokenlocationsoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.WordTagger.tokenLocationsOutputFeatureName) } inline void WordTagger::set_tokenlocationsoutputfeaturename(const char* value, size_t size) { - + tokenlocationsoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.WordTagger.tokenLocationsOutputFeatureName) } inline ::std::string* WordTagger::mutable_tokenlocationsoutputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.WordTagger.tokenLocationsOutputFeatureName) return tokenlocationsoutputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WordTagger::release_tokenlocationsoutputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.WordTagger.tokenLocationsOutputFeatureName) - + return tokenlocationsoutputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WordTagger::set_allocated_tokenlocationsoutputfeaturename(::std::string* tokenlocationsoutputfeaturename) { if (tokenlocationsoutputfeaturename != NULL) { - + } else { - + } tokenlocationsoutputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), tokenlocationsoutputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.WordTagger.tokenLocationsOutputFeatureName) @@ -588,13 +591,13 @@ inline const ::std::string& WordTagger::tokenlengthsoutputfeaturename() const { return tokenlengthsoutputfeaturename_.GetNoArena(); } inline void WordTagger::set_tokenlengthsoutputfeaturename(const ::std::string& value) { - + tokenlengthsoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.WordTagger.tokenLengthsOutputFeatureName) } #if LANG_CXX11 inline void WordTagger::set_tokenlengthsoutputfeaturename(::std::string&& value) { - + tokenlengthsoutputfeaturename_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.WordTagger.tokenLengthsOutputFeatureName) @@ -602,31 +605,31 @@ inline void WordTagger::set_tokenlengthsoutputfeaturename(::std::string&& value) #endif inline void WordTagger::set_tokenlengthsoutputfeaturename(const char* value) { GOOGLE_DCHECK(value != NULL); - + tokenlengthsoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.WordTagger.tokenLengthsOutputFeatureName) } inline void WordTagger::set_tokenlengthsoutputfeaturename(const char* value, size_t size) { - + tokenlengthsoutputfeaturename_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.WordTagger.tokenLengthsOutputFeatureName) } inline ::std::string* WordTagger::mutable_tokenlengthsoutputfeaturename() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.WordTagger.tokenLengthsOutputFeatureName) return tokenlengthsoutputfeaturename_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WordTagger::release_tokenlengthsoutputfeaturename() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.WordTagger.tokenLengthsOutputFeatureName) - + return tokenlengthsoutputfeaturename_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WordTagger::set_allocated_tokenlengthsoutputfeaturename(::std::string* tokenlengthsoutputfeaturename) { if (tokenlengthsoutputfeaturename != NULL) { - + } else { - + } tokenlengthsoutputfeaturename_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), tokenlengthsoutputfeaturename); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.WordTagger.tokenLengthsOutputFeatureName) @@ -641,13 +644,13 @@ inline const ::std::string& WordTagger::modelparameterdata() const { return modelparameterdata_.GetNoArena(); } inline void WordTagger::set_modelparameterdata(const ::std::string& value) { - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); // @@protoc_insertion_point(field_set:CoreML.Specification.CoreMLModels.WordTagger.modelParameterData) } #if LANG_CXX11 inline void WordTagger::set_modelparameterdata(::std::string&& value) { - + modelparameterdata_.SetNoArena( &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); // @@protoc_insertion_point(field_set_rvalue:CoreML.Specification.CoreMLModels.WordTagger.modelParameterData) @@ -655,31 +658,31 @@ inline void WordTagger::set_modelparameterdata(::std::string&& value) { #endif inline void WordTagger::set_modelparameterdata(const char* value) { GOOGLE_DCHECK(value != NULL); - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); // @@protoc_insertion_point(field_set_char:CoreML.Specification.CoreMLModels.WordTagger.modelParameterData) } inline void WordTagger::set_modelparameterdata(const void* value, size_t size) { - + modelparameterdata_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(reinterpret_cast(value), size)); // @@protoc_insertion_point(field_set_pointer:CoreML.Specification.CoreMLModels.WordTagger.modelParameterData) } inline ::std::string* WordTagger::mutable_modelparameterdata() { - + // @@protoc_insertion_point(field_mutable:CoreML.Specification.CoreMLModels.WordTagger.modelParameterData) return modelparameterdata_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline ::std::string* WordTagger::release_modelparameterdata() { // @@protoc_insertion_point(field_release:CoreML.Specification.CoreMLModels.WordTagger.modelParameterData) - + return modelparameterdata_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); } inline void WordTagger::set_allocated_modelparameterdata(::std::string* modelparameterdata) { if (modelparameterdata != NULL) { - + } else { - + } modelparameterdata_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), modelparameterdata); // @@protoc_insertion_point(field_set_allocated:CoreML.Specification.CoreMLModels.WordTagger.modelParameterData) diff --git a/mlmodel/format/DataStructures.proto b/mlmodel/format/DataStructures.proto index a373ff8e2..6cd2d1ee6 100644 --- a/mlmodel/format/DataStructures.proto +++ b/mlmodel/format/DataStructures.proto @@ -116,7 +116,7 @@ message DoubleRange { * * The table must have only valid values; do not use `NaN`, `+/- INF`, * or negative values. The application is responsible for inter/extrapolating - * appropriate confidence threshold based on the application's specific need. + * approprate confidence threshold based on the application's specific need. */ message PrecisionRecallCurve { FloatVector precisionValues = 1; diff --git a/mlmodel/format/FeatureTypes.proto b/mlmodel/format/FeatureTypes.proto index 382143923..0a16f240e 100644 --- a/mlmodel/format/FeatureTypes.proto +++ b/mlmodel/format/FeatureTypes.proto @@ -207,6 +207,12 @@ message SequenceFeatureType { SizeRange sizeRange = 101; } +message StateFeatureType { + oneof Type { + ArrayFeatureType arrayType = 1; + } +} + /* * A feature, which may be optional. */ @@ -219,8 +225,8 @@ message FeatureType { ArrayFeatureType multiArrayType = 5; DictionaryFeatureType dictionaryType = 6; SequenceFeatureType sequenceType = 7; + StateFeatureType stateType = 8; } bool isOptional = 1000; } - diff --git a/mlmodel/format/Gazetteer.proto b/mlmodel/format/Gazetteer.proto index c0cc1d9d1..8dac370e7 100644 --- a/mlmodel/format/Gazetteer.proto +++ b/mlmodel/format/Gazetteer.proto @@ -21,7 +21,7 @@ message Gazetteer { * iOS, tvOS 13.0+, macOS 10.15+ */ uint32 revision = 1; - + /* * Stores the language of the model, as specified in BCP-47 format, * e.g. "en-US". See https://tools.ietf.org/html/bcp47 @@ -32,12 +32,12 @@ message Gazetteer { * Natural Language framework's efficient representation of a gazetter. */ bytes modelParameterData = 100; - + /* * Stores the set of output class labels */ oneof ClassLabels { StringVector stringClassLabels = 200; } - + } diff --git a/mlmodel/format/LinkedModel.proto b/mlmodel/format/LinkedModel.proto index 4ff873613..7b5263c3a 100644 --- a/mlmodel/format/LinkedModel.proto +++ b/mlmodel/format/LinkedModel.proto @@ -38,5 +38,3 @@ message LinkedModelFile { // - $BUNDLE_IDENTIFIER(identifier) - Looks in Bunde with given identifier StringParameter linkedModelSearchPath = 2; } - - diff --git a/mlmodel/format/MIL.proto b/mlmodel/format/MIL.proto index 78ed2de48..af7c3e004 100644 --- a/mlmodel/format/MIL.proto +++ b/mlmodel/format/MIL.proto @@ -184,6 +184,7 @@ message ValueType { ListType listType = 2; TupleType tupleType = 3; DictionaryType dictionaryType = 4; + StateType stateType = 5; } } @@ -199,6 +200,8 @@ enum DataType { STRING = 2; // arbitrary sequence of bytes // Floats + FLOAT8E4M3FN = 40; + FLOAT8E5M2 = 41; FLOAT16 = 10; FLOAT32 = 11; FLOAT64 = 12; @@ -209,12 +212,19 @@ enum DataType { INT16 = 22; INT32 = 23; INT64 = 24; + INT4 = 25; // UInts UINT8 = 31; UINT16 = 32; UINT32 = 33; UINT64 = 34; + + UINT4 = 35; + UINT2 = 36; + UINT1 = 37; + UINT6 = 38; + UINT3 = 39; } message TensorType { @@ -252,6 +262,10 @@ message DictionaryType { ValueType valueType = 2; } +message StateType { + ValueType wrappedType = 1; +} + message Dimension { oneof dimension { ConstantDimension constant = 1; diff --git a/mlmodel/format/Model.proto b/mlmodel/format/Model.proto index ce934cb96..46014470a 100644 --- a/mlmodel/format/Model.proto +++ b/mlmodel/format/Model.proto @@ -9,8 +9,7 @@ * and can be any one of the following types: * * Neural Networks - * - ``MILSpec.Program`` - * - ``NeuralNetwork`` + * - `NeuralNetwork` * * Regressors * - ``GLMRegressor`` @@ -20,42 +19,42 @@ * - ``BayesianProbitRegressor`` * * Classifiers - * - ``NeuralNetworkClassifier`` - * - ``TreeEnsembleClassifier`` - * - ``GLMClassifier`` - * - ``SupportVectorClassifier`` - * - ``KNearestNeighborsClassifier`` + * - `NeuralNetworkClassifier` + * - `TreeEnsembleClassifier` + * - `GLMClassifier` + * - `SupportVectorClassifier` + * - `KNearestNeighborsClassifier` * * Other models - * - ``CustomModel`` - * - ``TextClassifier`` - * - ``WordTagger`` - * - ``Gazetteer`` - * - ``WordEmbedding`` - * - ``VisionFeaturePrint`` - * - ``LinkedModel`` - * - ``SoundAnalysisPreprocessing`` - * - ``ItemSimilarityRecommender`` - * - ``ClassConfidenceThresholding`` + * - `CustomModel` + * - `TextClassifier` + * - `WordTagger` + * - `Gazetteer` + * - `WordEmbedding` + * - `VisionFeaturePrint` + * - `LinkedModel` + * - `SoundAnalysisPreprocessing` + * - `ItemSimilarityRecommender` + * - `ClassConfidenceThresholding` * * Feature Engineering - * - ``Imputer`` - * - ``Scaler`` - * - ``Normalizer`` - * - ``OneHotEncoder`` - * - ``CategoricalMapping`` - * - ``FeatureVectorizer`` - * - ``DictVectorizer`` - * - ``ArrayFeatureExtractor`` - * - ``NonMaximumSuppression`` + * - `Imputer` + * - `Scaler` + * - `Normalizer` + * - `OneHotEncoder` + * - `CategoricalMapping` + * - `FeatureVectorizer` + * - `DictVectorizer` + * - `ArrayFeatureExtractor` + * - `NonMaximumSuppression` * * Pipelines - * - ``PipelineClassifier`` - * - ``PipelineRegressor`` - * - ``Pipeline`` + * - `PipelineClassifier` + * - `PipelineRegressor` + * - `Pipeline` * * Simple Mathematical Functions - * - ``Identity`` + * - `Identity` */ syntax = "proto3"; @@ -122,7 +121,7 @@ message PipelineRegressor { } /* - * A feature description + * A feature description * consisting of a name, short description, and type. */ message FeatureDescription { @@ -145,6 +144,35 @@ message Metadata { map userDefined = 100; } +/* + * A description of a function. + */ +message FunctionDescription { + // The function name. + string name = 1; + + // Input feature descriptions for the function. + repeated FeatureDescription input = 2; + + // Output feature descriptions for the function. + repeated FeatureDescription output = 3; + + // State feature descriptions for the function. + // + // The `type` of each feature description must be `StateFeatureType`. + repeated FeatureDescription state = 6; + + // [Required for regressor and classifier functions]: the name + // to give to an output feature containing the prediction. + string predictedFeatureName = 4; + + // [Optional for classifier functions]: the name to give to an + // output feature containing a dictionary mapping class + // labels to their predicted probabilities. If not specified, + // the dictionary will not be returned by the model. + string predictedProbabilitiesName = 5; +} + /* * A description of a model, * consisting of descriptions of its input and output features. @@ -155,9 +183,48 @@ message Metadata { * (``predictedProbabilitiesName``). */ message ModelDescription { + // Functions in the model. + // + // Some model types (e.g. ML Program) support multiple functions. For + // example, a large language model might have "prompt" and "extend" + // functions. Each has a different input and output behavior, but + // they are in a same model and share resources. + // + // If the model has more than one function, use the multiple + // function configuration and declare the feature descriptions and + // associated properties at function level. + // + // If the model type doesn't support multiple functions or the + // model has just "main" function, declare the feature + // descriptions and associated properties at the model level. + // + // Note: feature descriptions and associated properties mentioned + // above include input, output, state, predictedFeatureName, + // predictedProbabilitiesName, and trainingInput fields. + repeated FunctionDescription functions = 20; + + // The default function. + // + // The default function is the one that is automatically used when + // one doesn't explicitly specify. + // + // The value must be one of the names in `functions` message + // above. If `functions` is empty, this field must not be present. + string defaultFunctionName = 21; + + // The metadata (e.g. author, licence, etc) of the model. + Metadata metadata = 100; + + // Use these fields below only when `functions` above is empty. + repeated FeatureDescription input = 1; repeated FeatureDescription output = 10; + // State feature descriptions for the function. + // + // The `type` of each feature description must be `StateFeatureType`. + repeated FeatureDescription state = 13; + // [Required for regressor and classifier models]: the name // to give to an output feature containing the prediction. string predictedFeatureName = 11; @@ -169,8 +236,6 @@ message ModelDescription { string predictedProbabilitiesName = 12; repeated FeatureDescription trainingInput = 50; - - Metadata metadata = 100; } message SerializedModel { @@ -262,11 +327,14 @@ message SerializedModel { * - iOS 17 ops * - Scene print v2 * - ClassConfidenceThresholding model + * + * 9 : iOS 18, macOS 15, tvOS 18, watchOS 11 (Core ML 8) + * - multiple functions */ message Model { int32 specificationVersion = 1; ModelDescription description = 2; - + /* * Following model types support on-device update: * @@ -276,7 +344,7 @@ message Model { * - KNearestNeighborsClassifier */ bool isUpdatable = 10; - + // start at 200 here // model specific parameters: oneof Type { @@ -310,7 +378,7 @@ message Model { // Precision Recall Curve 'container'' ClassConfidenceThresholding classConfidenceThresholding = 560; - + // feature engineering starts at 600 OneHotEncoder oneHotEncoder = 600; Imputer imputer = 601; @@ -336,7 +404,7 @@ message Model { CoreMLModels.Gazetteer gazetteer = 2004; CoreMLModels.WordEmbedding wordEmbedding = 2005; CoreMLModels.AudioFeaturePrint audioFeaturePrint = 2006; - + // Reserved private messages start at 3000 // These messages are subject to change with no notice or support. SerializedModel serializedModel = 3000; diff --git a/mlmodel/format/NeuralNetwork.proto b/mlmodel/format/NeuralNetwork.proto index d6261d271..f2bdb68c0 100644 --- a/mlmodel/format/NeuralNetwork.proto +++ b/mlmodel/format/NeuralNetwork.proto @@ -1551,7 +1551,7 @@ message ConvolutionLayerParams { * * Output * A blob of rank 5. - * The output blob's shape is ``[batch, channelsOut, depthOut, heightOut, without]``. + * The output blob's shape is ``[batch, channelsOut, depthOut, heightOut, widthOut]``. * * Type of padding can be ``custom``, ``valid``, or ``same``. Padded values are all zeros. * Output spatial dimensions depend on the the type of padding. For details, refer to the @@ -1701,11 +1701,11 @@ message Convolution3DLayerParams { * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. */ int32 customPaddingRight = 85; - + /* Flag to specify if this is Convolution Transpose or not. */ bool isDeconvolution = 86; - + /* * The output shape, which has length 3 ``[D_out, H_out, W_out]``. * This is used only for deconvolution (``isDeconvolution == true``). @@ -2017,33 +2017,33 @@ message PoolingLayerParams { * +----+----+ */ message Pooling3DLayerParams { - + enum PoolingType3D { MAX = 0; AVERAGE = 1; } - + // Whether to use Max or Average PoolingType3D type = 1; - + // Depth of the pooling region. int32 kernelDepth = 2; - + // Height of the pooling region. int32 kernelHeight = 3; - + // Width of the pooling region. int32 kernelWidth = 4; - + // Stride along the depth direction int32 strideDepth = 5; - + // Stride along the height direction int32 strideHeight = 6; - + // Stride along the width direction int32 strideWidth = 7; - + /* * The type of padding. * All padding types pad the input shape with zeros. @@ -2065,25 +2065,25 @@ message Pooling3DLayerParams { SAME = 2; } Pooling3DPaddingType paddingType = 15; - + // Padding before the input in the depth direction. int32 customPaddingFront = 8; - + // Padding after the input in the depth direction. int32 customPaddingBack = 9; - + // Padding before the input in the height direction. int32 customPaddingTop = 10; - + // Padding after the input in the height direction. int32 customPaddingBottom = 11; - + // Padding before the input in the width direction. int32 customPaddingLeft = 12; - + // Padding after the input in the width direction. int32 customPaddingRight = 13; - + // If true, exclude zeros from padding in Average pooling. Meaningless in Max Pooling. bool countExcludePadding = 14; } @@ -2124,12 +2124,12 @@ message Pooling3DLayerParams { * +----+ */ message GlobalPooling3DLayerParams { - + enum GlobalPoolingType3D { MAX = 0; AVERAGE = 1; } - + // Whether to use Max or Average GlobalPoolingType3D type = 1; } @@ -4029,14 +4029,14 @@ message ConcatNDLayerParams { * Dimension along which to concatenate. Supports negative values of the parameter 'axis'. */ int64 axis = 1; - + /* * (Only available in Core ML Specification >= 5 (iOS >= 14, macOS >= 11.0) * Interleave option. If True, concatenation is done via interleaving the inputs. * This requires all inputs to have the exact same shape. */ bool interleave = 2; - + } @@ -6264,7 +6264,7 @@ message ArgSortLayerParams { } /* - * A layer that does slice operation by providing size to be extracted + * A layer that does slice operation by providing size to be extracted * from the given input tensor. * * Requires 2 inputs and produces 1 output. diff --git a/mlmodel/format/NonMaximumSuppression.proto b/mlmodel/format/NonMaximumSuppression.proto index 4cc2dce72..047f74bdb 100644 --- a/mlmodel/format/NonMaximumSuppression.proto +++ b/mlmodel/format/NonMaximumSuppression.proto @@ -89,7 +89,7 @@ message NonMaximumSuppression { * larger than boxes after suppression, the unused boxes are filled with * zero confidence. If the prediction is handled by Core Vision, it is also * important that confidences are defined with the following semantics: - * + * * 1. Confidences should be between 0 and 1 * 2. The sum of the confidences for a prediction should not exceed 1, but is * allowed to be less than 1 diff --git a/mlmodel/format/README.rst b/mlmodel/format/README.rst index 97ea0b1da..e55b0c3a1 100644 --- a/mlmodel/format/README.rst +++ b/mlmodel/format/README.rst @@ -2,9 +2,9 @@ Core ML Model Format Specification ================================== -This directory contains the protobuf message definitions +This directory contains the protobuf message definitions that comprise the Core ML model document (``.mlmodel``) format. The top-level message is ``Model``, which is defined in ``Model.proto``. -Other message types describe data structures, feature types, -feature engineering model types, and predictive model types. \ No newline at end of file +Other message types describe data structures, feature types, +feature engineering model types, and predictive model types. diff --git a/mlmodel/src/Comparison.cpp b/mlmodel/src/Comparison.cpp index 2ff7567b2..ad24ae789 100644 --- a/mlmodel/src/Comparison.cpp +++ b/mlmodel/src/Comparison.cpp @@ -10,11 +10,11 @@ #include namespace CoreML { - + namespace Specification { - + #pragma mark Model container and metadata/interface - + bool operator==(const Model& a, const Model& b) { if (a.specificationversion() != b.specificationversion()) { return false; @@ -25,7 +25,7 @@ namespace CoreML { if (a.Type_case() != b.Type_case()) { return false; } - + // if everything else matches, check the model-specific parameters switch (a.Type_case()) { case Model::kPipelineClassifier: @@ -113,7 +113,7 @@ namespace CoreML { return true; } } - + bool operator==(const Metadata& a, const Metadata& b) { if (a.shortdescription() != b.shortdescription()) { @@ -130,7 +130,7 @@ namespace CoreML { } return true; } - + bool operator==(const ModelDescription& a, const ModelDescription& b) { if (a.input() != b.input()) { @@ -150,7 +150,7 @@ namespace CoreML { } return true; } - + bool operator==(const FeatureDescription& a, const FeatureDescription& b) { if (a.name() != b.name()) { @@ -164,7 +164,7 @@ namespace CoreML { } return true; } - + bool isEquivalent(const FeatureDescription& a, const FeatureDescription& b) { if (a.name() != b.name()) { @@ -212,7 +212,7 @@ namespace CoreML { } return true; } - + static inline bool compareDictionaryTypes(const Specification::FeatureType& x, const Specification::FeatureType& y) { const auto& xp = x.dictionarytype(); @@ -222,7 +222,7 @@ namespace CoreML { } return true; } - + static inline bool compareImageTypes(const Specification::FeatureType& x, const Specification::FeatureType& y) { const auto& xp = x.imagetype(); @@ -249,28 +249,48 @@ namespace CoreML { } // TODO: Compare sizes - + return true; } + static inline bool compareStateTypes(const Specification::FeatureType& x, + const Specification::FeatureType& y) { + const auto& xp = x.statetype().arraytype(); + const auto& yp = y.statetype().arraytype(); + if (xp.datatype() != yp.datatype()) { + return false; + } + + if (rankOf(xp) != rankOf(yp)) { + return false; + } + + for (int i=0; i optional should be allowed, but // it's using this operator== to test that, and failing. // We should eventually fix that by making a notion of "is valid as type" method // for FeatureType, such that T is always a valid optional but not the other // way around. - + /* if (a.isoptional() != b.isoptional()) { return false; } */ - - + + if (a.Type_case() != b.Type_case()) { return false; } @@ -288,13 +308,15 @@ namespace CoreML { return compareImageTypes(a, b); case Specification::FeatureType::kSequenceType: return compareSequenceTypes(a,b); + case Specification::FeatureType::kStateType: + return compareStateTypes(a,b); case Specification::FeatureType::TYPE_NOT_SET: return true; } } - + #pragma mark Pipelines - + bool operator==(const Pipeline& a, const Pipeline& b) { if (a.models_size() != b.models_size()) { @@ -307,19 +329,19 @@ namespace CoreML { } return true; } - + bool operator==(const PipelineClassifier& a, const PipelineClassifier& b) { return a.pipeline() == b.pipeline(); } - + bool operator==(const PipelineRegressor& a, const PipelineRegressor& b) { return a.pipeline() == b.pipeline(); } - + #pragma mark Regressors - + bool operator==(const GLMRegressor& a, const GLMRegressor& b) { if (a.weights() != b.weights()) { @@ -333,12 +355,12 @@ namespace CoreML { } return true; } - + bool operator==(const GLMRegressor_DoubleArray& a, const GLMRegressor_DoubleArray& b) { return a.value() == b.value(); } - + bool operator==(const SupportVectorRegressor& a, const SupportVectorRegressor& b) { if (a.kernel() != b.kernel()) { @@ -362,7 +384,7 @@ namespace CoreML { return false; } } - + bool operator==(const Kernel& a, const Kernel& b) { if (a.kernel_case() != b.kernel_case()) { @@ -381,12 +403,12 @@ namespace CoreML { return false; } } - + bool operator==(const RBFKernel& a, const RBFKernel& b) { return a.gamma() == b.gamma(); } - + bool operator==(const PolyKernel& a, const PolyKernel& b) { if (a.degree() != b.degree()) { @@ -400,7 +422,7 @@ namespace CoreML { } return true; } - + bool operator==(const SigmoidKernel& a, const SigmoidKernel& b) { if (a.gamma() != b.gamma()) { @@ -411,17 +433,17 @@ namespace CoreML { } return true; } - + bool operator==(const Coefficients& a, const Coefficients& b) { return a.alpha() == b.alpha(); } - + bool operator==(const SparseVector& a, const SparseVector& b) { return a.nodes() == b.nodes(); } - + bool operator==(const SparseNode& a, const SparseNode& b) { if (a.index() != b.index()) { @@ -432,12 +454,12 @@ namespace CoreML { } return true; } - + bool operator==(const DenseVector& a, const DenseVector& b) { return a.values() == b.values(); } - + bool operator==(const TreeEnsembleRegressor& a, const TreeEnsembleRegressor& b) { if (a.postevaluationtransform() != b.postevaluationtransform()) { @@ -445,7 +467,7 @@ namespace CoreML { } return a.treeensemble() == b.treeensemble(); } - + bool operator==(const TreeEnsembleParameters& a, const TreeEnsembleParameters& b) { if (a.nodes() != b.nodes()) { @@ -459,7 +481,7 @@ namespace CoreML { } return true; } - + bool operator==(const TreeEnsembleParameters_TreeNode& a, const TreeEnsembleParameters_TreeNode& b) { if (a.treeid() != b.treeid()) { @@ -494,7 +516,7 @@ namespace CoreML { } return true; } - + bool operator==(const TreeEnsembleParameters_TreeNode_EvaluationInfo& a, const TreeEnsembleParameters_TreeNode_EvaluationInfo& b) { if (a.evaluationindex() != b.evaluationindex()) { @@ -505,7 +527,7 @@ namespace CoreML { } return true; } - + bool operator==(const NeuralNetworkRegressor& a, const NeuralNetworkRegressor& b) { if (a.layers() != b.layers()) { @@ -516,19 +538,19 @@ namespace CoreML { } return true; } - + bool operator==(const NeuralNetworkLayer& a, const NeuralNetworkLayer& b) { if (a.layer_case() != b.layer_case()) { return false; } - + // TODO -- not implemented! // for now, all neural network layers are not equal. assert(false); return false; } - + bool operator==(const NeuralNetworkPreprocessing& a, const NeuralNetworkPreprocessing& b) { if (a.featurename() != b.featurename()) { @@ -546,7 +568,7 @@ namespace CoreML { return false; } } - + bool operator==(const NeuralNetworkImageScaler& a, const NeuralNetworkImageScaler& b) { if (a.redbias() != b.redbias()) { @@ -566,7 +588,7 @@ namespace CoreML { } return true; } - + bool operator==(const NeuralNetworkMeanImage& a, const NeuralNetworkMeanImage& b) { if (a.meanimage() != b.meanimage()) { @@ -574,7 +596,7 @@ namespace CoreML { } return true; } - + bool operator==(const BayesianProbitRegressor& a, const BayesianProbitRegressor& b) { #pragma unused(a) @@ -584,7 +606,7 @@ namespace CoreML { } #pragma mark Classifiers - + bool operator==(const GLMClassifier& a, const GLMClassifier& b) { if (a.weights() != b.weights()) { @@ -611,12 +633,12 @@ namespace CoreML { return true; } } - + bool operator==(const GLMClassifier_DoubleArray& a, const GLMClassifier_DoubleArray& b) { return a.value() == b.value(); } - + bool operator==(const SupportVectorClassifier& a, const SupportVectorClassifier& b) { if (a.kernel() != b.kernel()) { @@ -666,7 +688,7 @@ namespace CoreML { return true; } } - + bool operator==(const TreeEnsembleClassifier& a, const TreeEnsembleClassifier& b) { if (a.treeensemble() != b.treeensemble()) { @@ -688,7 +710,7 @@ namespace CoreML { } return true; } - + bool operator==(const NeuralNetworkClassifier& a, const NeuralNetworkClassifier& b) { if (a.layers() != b.layers()) { @@ -709,7 +731,7 @@ namespace CoreML { return true; } } - + bool operator==(const KNearestNeighborsClassifier& a, const KNearestNeighborsClassifier& b) { auto aIndex = a.nearestneighborsindex(); @@ -779,7 +801,7 @@ namespace CoreML { return false; } } - + #pragma mark Generic models bool operator==(const NeuralNetwork& a, @@ -847,35 +869,35 @@ namespace CoreML { bool operator==(const CoreMLModels::WordTagger& a, const CoreMLModels::WordTagger& b) { - + if (a.revision()!= b.revision()) { return false; } - + if (a.language()!= b.language()) { return false; } - + if (a.tokensoutputfeaturename() != b.tokensoutputfeaturename()) { return false; } - + if (a.tokentagsoutputfeaturename() != b.tokentagsoutputfeaturename()) { return false; } - + if (a.tokenlocationsoutputfeaturename() != b.tokenlocationsoutputfeaturename()) { return false; } - + if (a.tokenlengthsoutputfeaturename() != b.tokenlengthsoutputfeaturename()) { return false; } - + if (a.Tags_case()!= b.Tags_case()) { return false; } - + switch (a.Tags_case()) { case CoreMLModels::WordTagger::kStringTags: if (a.stringtags() != b.stringtags()) { @@ -885,36 +907,36 @@ namespace CoreML { case CoreMLModels::WordTagger::TAGS_NOT_SET: break; } - + if (a.modelparameterdata().size() != b.modelparameterdata().size()) { return false; } - + size_t s = a.modelparameterdata().size(); if (s > 0) { if (memcmp(&a.modelparameterdata()[0], &b.modelparameterdata()[0], s)) { return false; } } - + return true; } bool operator==(const CoreMLModels::TextClassifier& a, const CoreMLModels::TextClassifier& b) { - + if (a.revision()!= b.revision()) { return false; } - + if (a.language()!= b.language()) { return false; } - + if (a.ClassLabels_case()!= b.ClassLabels_case()) { return false; } - + switch (a.ClassLabels_case()) { case CoreMLModels::TextClassifier::kStringClassLabels: if (a.stringclasslabels()!= b.stringclasslabels()) { @@ -924,36 +946,36 @@ namespace CoreML { case CoreMLModels::TextClassifier::CLASSLABELS_NOT_SET: break; } - + if (a.modelparameterdata().size() != b.modelparameterdata().size()) { return false; } - + size_t s = a.modelparameterdata().size(); if (s > 0) { if (memcmp(&a.modelparameterdata()[0], &b.modelparameterdata()[0], s)) { return false; } } - + return true; } - + bool operator==(const CoreMLModels::Gazetteer& a, const CoreMLModels::Gazetteer& b) { - + if (a.revision()!= b.revision()) { return false; } - + if (a.language()!= b.language()) { return false; } - + if (a.ClassLabels_case()!= b.ClassLabels_case()) { return false; } - + switch (a.ClassLabels_case()) { case CoreMLModels::Gazetteer::kStringClassLabels: if (a.stringclasslabels()!= b.stringclasslabels()) { @@ -963,43 +985,43 @@ namespace CoreML { case CoreMLModels::Gazetteer::CLASSLABELS_NOT_SET: break; } - + if (a.modelparameterdata().size() != b.modelparameterdata().size()) { return false; } - + size_t s = a.modelparameterdata().size(); if (s > 0) { if (memcmp(&a.modelparameterdata()[0], &b.modelparameterdata()[0], s)) { return false; } } - + return true; } bool operator==(const CoreMLModels::WordEmbedding& a, const CoreMLModels::WordEmbedding& b) { - + if (a.revision()!= b.revision()) { return false; } - + if (a.language()!= b.language()) { return false; } - + if (a.modelparameterdata().size() != b.modelparameterdata().size()) { return false; } - + size_t s = a.modelparameterdata().size(); if (s > 0) { if (memcmp(&a.modelparameterdata()[0], &b.modelparameterdata()[0], s)) { return false; } } - + return true; } @@ -1009,7 +1031,7 @@ namespace CoreML { if (a.VisionFeaturePrintType_case() != b.VisionFeaturePrintType_case()) { return false; } - + switch (a.VisionFeaturePrintType_case()) { case CoreMLModels::VisionFeaturePrint::kScene: if (a.scene().version() != b.scene().version()) { @@ -1032,7 +1054,7 @@ namespace CoreML { case CoreMLModels::VisionFeaturePrint::VISIONFEATUREPRINTTYPE_NOT_SET: break; } - + return true; } @@ -1042,7 +1064,7 @@ namespace CoreML { if (a.AudioFeaturePrintType_case() != b.AudioFeaturePrintType_case()) { return false; } - + switch (a.AudioFeaturePrintType_case()) { case CoreMLModels::AudioFeaturePrint::kSound: if (a.sound().version() != b.sound().version()) { @@ -1052,7 +1074,7 @@ namespace CoreML { case CoreMLModels::AudioFeaturePrint::AUDIOFEATUREPRINTTYPE_NOT_SET: break; } - + return true; } @@ -1095,12 +1117,12 @@ namespace CoreML { } return true; } - + bool operator==(const Imputer& a, const Imputer& b) { if (a.ImputedValue_case() != b.ImputedValue_case()) { return false; } - + switch (a.ImputedValue_case()) { case Imputer::kImputedDoubleValue: if( !(a.imputeddoublevalue() == b.imputeddoublevalue())) { @@ -1140,12 +1162,12 @@ namespace CoreML { // OK to return here, as this just means it's uninitialized. return true; } - + // Now test the replacement value. if (a.ReplaceValue_case() != b.ReplaceValue_case()) { return false; } - + switch(a.ReplaceValue_case()) { case Imputer::kReplaceDoubleValue: { if( ! ( (std::isnan(a.replacedoublevalue()) && std::isnan(b.replacedoublevalue())) @@ -1168,16 +1190,16 @@ namespace CoreML { break; } } - + // Done testing all of this. return true; } - + bool operator==(const FeatureVectorizer& a, const FeatureVectorizer& b) { return a.inputlist() == b.inputlist(); } - + bool operator==(const FeatureVectorizer_InputColumn& a, const FeatureVectorizer_InputColumn& b) { if (a.inputcolumn() != b.inputcolumn()) { @@ -1188,7 +1210,7 @@ namespace CoreML { } return true; } - + bool operator==(const DictVectorizer& a, const DictVectorizer& b) { if (a.Map_case() != b.Map_case()) { @@ -1203,7 +1225,7 @@ namespace CoreML { return true; } } - + bool operator==(const Scaler& a, const Scaler& b) { if (a.shiftvalue() != b.shiftvalue()) { @@ -1214,7 +1236,7 @@ namespace CoreML { } return true; } - + bool operator==(const NonMaximumSuppression& a, const NonMaximumSuppression& b) { // Parameters @@ -1224,7 +1246,7 @@ namespace CoreML { if (a.confidencethreshold() != b.confidencethreshold()) { return false; } - + // Input and outputs feature names if (a.confidenceinputfeaturename() != b.confidenceinputfeaturename()) { return false; @@ -1244,12 +1266,12 @@ namespace CoreML { if (a.coordinatesoutputfeaturename() != b.coordinatesoutputfeaturename()) { return false; } - + // Same suppression method if (a.SuppressionMethod_case() != b.SuppressionMethod_case()) { return false; } - + // Method-specific parameters if (a.SuppressionMethod_case() == NonMaximumSuppression::SuppressionMethodCase::kPickTop) { if (a.picktop().perclass() != b.picktop().perclass()) { @@ -1258,7 +1280,7 @@ namespace CoreML { } return true; } - + bool operator==(const CategoricalMapping& a, const CategoricalMapping& b) { if (a.MappingType_case() != b.MappingType_case()) { @@ -1273,18 +1295,18 @@ namespace CoreML { return true; } } - + bool operator==(const Normalizer& a, const Normalizer& b) { return a.normtype() == b.normtype(); } - + bool operator==(const ArrayFeatureExtractor& a, const ArrayFeatureExtractor& b) { return a.extractindex() == b.extractindex(); } #pragma mark Recommenders - + bool operator==(const ItemSimilarityRecommender& a, const ItemSimilarityRecommender& b) { try { @@ -1297,7 +1319,7 @@ namespace CoreML { #pragma mark Data structures - + template bool vectorsEqual(const T& a, const T& b) { if (a.vector_size() != b.vector_size()) { @@ -1310,22 +1332,22 @@ namespace CoreML { } return true; } - + bool operator==(const Int64Vector& a, const Int64Vector& b) { return vectorsEqual(a, b); } - + bool operator==(const StringVector& a, const StringVector& b) { return vectorsEqual(a, b); } - + bool operator==(const DoubleVector& a, const DoubleVector& b) { return vectorsEqual(a, b); } - + template bool mapsEqual(const T& a, const T& b) { if (a.map_size() != b.map_size()) { @@ -1338,12 +1360,12 @@ namespace CoreML { } return true; } - + bool operator==(const StringToInt64Map& a, const StringToInt64Map& b) { return mapsEqual(a, b); } - + bool operator==(const Int64ToStringMap& a, const Int64ToStringMap& b) { return mapsEqual(a, b); @@ -1352,12 +1374,12 @@ namespace CoreML { const StringToDoubleMap& b) { return mapsEqual(a, b); } - + bool operator==(const Int64ToDoubleMap& a, const Int64ToDoubleMap& b) { return mapsEqual(a, b); } - + #pragma mark not-equal // // Type-specific operator!= overloads necessitated by rdar://98724060 ([libcxx release blocker] Failed to build CoreML_tests (unconstrained operator!=)) diff --git a/mlmodel/src/DataType.cpp b/mlmodel/src/DataType.cpp index 3e9647876..451d52755 100644 --- a/mlmodel/src/DataType.cpp +++ b/mlmodel/src/DataType.cpp @@ -44,21 +44,21 @@ namespace CoreML { FeatureType::FeatureType(const Specification::FeatureType& wrapped) : m_type(std::make_shared(wrapped)) { } - + // simple types #define WRAP_SIMPLE_TYPE(T, U) \ FeatureType FeatureType::T() { return FeatureType(U); } - + WRAP_SIMPLE_TYPE(Int64, MLFeatureTypeType_int64Type) WRAP_SIMPLE_TYPE(String, MLFeatureTypeType_stringType) WRAP_SIMPLE_TYPE(Image, MLFeatureTypeType_imageType) /* TODO image is not simple type */ WRAP_SIMPLE_TYPE(Double, MLFeatureTypeType_doubleType) - + // parametric types FeatureType FeatureType::Array(const std::vector shape, MLArrayDataType dataType) { FeatureType out(MLFeatureTypeType_multiArrayType); Specification::ArrayFeatureType *params = out->mutable_multiarraytype(); - + for (int64_t s : shape) { params->add_shape(s); } @@ -69,12 +69,12 @@ FeatureType FeatureType::T() { return FeatureType(U); } FeatureType FeatureType::Array(const std::vector shape) { return Array(shape,MLArrayDataTypeDOUBLE); } - + FeatureType FeatureType::Dictionary(MLDictionaryFeatureTypeKeyType keyType) { FeatureType out(MLFeatureTypeType_dictionaryType); Specification::DictionaryFeatureType *params = out->mutable_dictionarytype(); - + switch (keyType) { case MLDictionaryFeatureTypeKeyType_int64KeyType: params->mutable_int64keytype(); @@ -88,28 +88,28 @@ FeatureType FeatureType::T() { return FeatureType(U); } return out; } - + // operators const Specification::FeatureType& FeatureType::operator*() const { return *m_type; } - + Specification::FeatureType& FeatureType::operator*() { return *m_type; } - + const Specification::FeatureType* FeatureType::operator->() const { return m_type.get(); } - + Specification::FeatureType* FeatureType::operator->() { return m_type.get(); } - + bool FeatureType::operator==(const FeatureType& other) const { return *m_type == *other.m_type; } - + bool FeatureType::operator!=(const FeatureType& other) const { return !(*this == other); } @@ -130,11 +130,13 @@ FeatureType FeatureType::T() { return FeatureType(U); } return "String"; case Specification::FeatureType::kSequenceType: return "Sequence"; + case Specification::FeatureType::kStateType: + return "State"; case Specification::FeatureType::TYPE_NOT_SET: return "Invalid"; } } - + static std::string keyTypeToString(Specification::DictionaryFeatureType::KeyTypeCase tag) { switch (tag) { case Specification::DictionaryFeatureType::kInt64KeyType: @@ -416,6 +418,18 @@ FeatureType FeatureType::T() { return FeatureType(U); } ss << ")"; break; } + case Specification::FeatureType::kStateType: + { + const Specification::ArrayFeatureType& params = m_type->statetype().arraytype(); + ss << " (" << dataTypeToString(params.datatype()); + std::vector shape = defaultShapeOf(params); + if (shape.size() > 0) { + ss << " "; + ss << dimensionsToString(shape); + } + ss << ")"; + break; + } default: break; } @@ -494,13 +508,20 @@ FeatureType FeatureType::T() { return FeatureType(U); } dict["sizeRange"] = rangeToString((int64_t)params.sizerange().lowerbound(), params.sizerange().upperbound(),true); break; } + case Specification::FeatureType::kStateType: + { + const Specification::ArrayFeatureType& params = m_type->statetype().arraytype(); + dict["dataType"] = dataTypeToString(params.datatype()); + dict["shape"] = dimensionsToString(defaultShapeOf(params),true); + break; + } default: break; } return dict; } - + Specification::FeatureType* FeatureType::allocateCopy() { // we call new here, but don't free! // this method should only be called immediately prior to passing the diff --git a/mlmodel/src/Globals.hpp b/mlmodel/src/Globals.hpp index ca0cf87a6..e76ac8d56 100644 --- a/mlmodel/src/Globals.hpp +++ b/mlmodel/src/Globals.hpp @@ -62,7 +62,10 @@ namespace CoreML { // version 8: static const int32_t MLMODEL_SPECIFICATION_VERSION_IOS17 = 8; - static const int32_t MLMODEL_SPECIFICATION_VERSION_NEWEST = MLMODEL_SPECIFICATION_VERSION_IOS17; + // version 9: + static const int32_t MLMODEL_SPECIFICATION_VERSION_IOS18 = 9; + + static const int32_t MLMODEL_SPECIFICATION_VERSION_NEWEST = MLMODEL_SPECIFICATION_VERSION_IOS18; } diff --git a/mlmodel/src/MILBlob/Blob/BlobDataType.hpp b/mlmodel/src/MILBlob/Blob/BlobDataType.hpp index 1db5587c3..4dee4cc06 100644 --- a/mlmodel/src/MILBlob/Blob/BlobDataType.hpp +++ b/mlmodel/src/MILBlob/Blob/BlobDataType.hpp @@ -7,19 +7,19 @@ #include "MILBlob/Bf16.hpp" #include "MILBlob/Fp16.hpp" +#include "MILBlob/Fp8.hpp" +#include "MILBlob/SubByteTypes.hpp" namespace MILBlob { namespace Blob { -enum class BlobDataType : uint32_t -{ +enum class BlobDataType : uint32_t { // *** WARNING *** - // for binary compatibility, values should ONLY be added at the end. + // For binary compatibility, values should ONLY be added at the end. // // this file needs to remain in sync across multiple repos. // please be cognizant of that when making changes to the // format. - Float16 = 1, Float32 = 2, UInt8 = 3, @@ -27,6 +27,16 @@ enum class BlobDataType : uint32_t BFloat16 = 5, Int16 = 6, UInt16 = 7, + Int4 = 8, + UInt1 = 9, + UInt2 = 10, + UInt4 = 11, + UInt3 = 12, + UInt6 = 13, + Int32 = 14, + UInt32 = 15, + Float8E4M3FN = 16, + Float8E5M2 = 17, }; template @@ -42,6 +52,16 @@ struct BlobDataTypeTraits { static constexpr BlobDataType DataType = BlobDataType::Float16; }; +template <> +struct BlobDataTypeTraits { + static constexpr BlobDataType DataType = BlobDataType::Float8E4M3FN; +}; + +template <> +struct BlobDataTypeTraits { + static constexpr BlobDataType DataType = BlobDataType::Float8E5M2; +}; + template <> struct BlobDataTypeTraits { static constexpr BlobDataType DataType = BlobDataType::BFloat16; @@ -67,5 +87,45 @@ struct BlobDataTypeTraits { static constexpr BlobDataType DataType = BlobDataType::UInt16; }; +template <> +struct BlobDataTypeTraits { + static constexpr BlobDataType DataType = BlobDataType::Int32; +}; + +template <> +struct BlobDataTypeTraits { + static constexpr BlobDataType DataType = BlobDataType::UInt32; +}; + +template <> +struct BlobDataTypeTraits { + static constexpr BlobDataType DataType = BlobDataType::Int4; +}; + +template <> +struct BlobDataTypeTraits { + static constexpr BlobDataType DataType = BlobDataType::UInt6; +}; + +template <> +struct BlobDataTypeTraits { + static constexpr BlobDataType DataType = BlobDataType::UInt4; +}; + +template <> +struct BlobDataTypeTraits { + static constexpr BlobDataType DataType = BlobDataType::UInt3; +}; + +template <> +struct BlobDataTypeTraits { + static constexpr BlobDataType DataType = BlobDataType::UInt2; +}; + +template <> +struct BlobDataTypeTraits { + static constexpr BlobDataType DataType = BlobDataType::UInt1; +}; + } // namespace Blob } // namespace MILBlob diff --git a/mlmodel/src/MILBlob/Blob/StorageFormat.hpp b/mlmodel/src/MILBlob/Blob/StorageFormat.hpp index 224ad1371..135669b97 100644 --- a/mlmodel/src/MILBlob/Blob/StorageFormat.hpp +++ b/mlmodel/src/MILBlob/Blob/StorageFormat.hpp @@ -14,7 +14,7 @@ namespace Blob { // ---: Blob Storage File Format :--- // Default file format for CoreML (iOS15 onwards) // -// ---: File structure :--- +// ---: File sturcture :--- // File is structured as below: // 1. Storage header: `struct storage_header` // 2. Followed by pair: `struct blob_metadata` and `raw_data` @@ -42,20 +42,25 @@ constexpr uint64_t BlobMetadataSentinel = 0xDEADBEEF; /** * blob_metadata: stores information of blob present in weight file + * + * Before ios18, the reserved fields were uninitialized and could have any values if not specified. + * From ios18 on, the reserved fields are initialized to 0 by default. + * To extend the format, make sure to bump the version number in storage_header. */ struct blob_metadata { uint32_t sentinel = BlobMetadataSentinel; // for validating correctness of metadata. - BlobDataType mil_dtype; // data type of the blob data. - uint64_t sizeInBytes; // size of the blob data in bytes. - uint64_t offset; // offset in file for blob data. - + BlobDataType mil_dtype = BlobDataType::Float16; // data type of the blob data. + uint64_t sizeInBytes = 0; // size of the blob data in bytes. + uint64_t offset = 0; // offset in file for blob data. + uint64_t padding_size_in_bits = 0; // describes the number of unused bits in this blob, + // required to calculate the actual size for spans of + // sub-btye-sized types. Unused otherwise // Reserve fields - uint64_t reserved_0; - uint64_t reserved_1; - uint64_t reserved_2; - uint64_t reserved_3; - uint64_t reserved_4; + uint64_t reserved_1 = 0; + uint64_t reserved_2 = 0; + uint64_t reserved_3 = 0; + uint64_t reserved_4 = 0; }; /** diff --git a/mlmodel/src/MILBlob/Blob/StorageReader.cpp b/mlmodel/src/MILBlob/Blob/StorageReader.cpp index 65ede7742..3f4147035 100644 --- a/mlmodel/src/MILBlob/Blob/StorageReader.cpp +++ b/mlmodel/src/MILBlob/Blob/StorageReader.cpp @@ -9,6 +9,7 @@ #include "MILBlob/Blob/StorageFormat.hpp" #include "MILBlob/Blob/StorageReader.hpp" #include "MILBlob/Fp16.hpp" +#include "MILBlob/Fp8.hpp" #include "MILBlob/Util/SpanCast.hpp" #include @@ -54,7 +55,49 @@ class StorageReader::Impl final { } template - Util::Span GetDataView(uint64_t offset) const; + Util::Span GetDataViewForByteAligned(uint64_t offset) const + { + auto metadata = GetAndCheckMetadata(offset, BlobDataTypeTraits::DataType); + + return Util::SpanCast(m_reader->ReadData(metadata.offset, metadata.sizeInBytes)); + } + + template + Util::Span GetDataViewForSubByteSized(uint64_t offset) const + { + auto metadata = GetAndCheckMetadata(offset, BlobDataTypeTraits::DataType); + + Util::Span rawSpan = m_reader->ReadData(metadata.offset, metadata.sizeInBytes); + + MILVerifyIsTrue(metadata.padding_size_in_bits < 8, + std::runtime_error, + "8 or more bits of padding for sub-byte sized data is incorrect"); + + if constexpr (MILBlob::SubByteIsByteAligned()) { + MILVerifyIsTrue(metadata.padding_size_in_bits % T::SizeInBits == 0, + std::runtime_error, + "Invalid padding for byte-aligned sub-byte-sized type"); + } + + // metadata.sizeInBytes includes the padding to make the data byte aligned + + size_t numBits = metadata.sizeInBytes * 8; + numBits -= metadata.padding_size_in_bits; + MILVerifyIsTrue(numBits % T::SizeInBits == 0, std::runtime_error, "Invalid padding for blob"); + size_t numElements = numBits / T::SizeInBits; + + return Util::CastToBitSpan(rawSpan, numElements); + } + + template + Util::Span GetDataView(uint64_t offset) const + { + if constexpr (MILBlob::IsSubByteSized::value) { + return this->GetDataViewForSubByteSized(offset); + } else { + return this->GetDataViewForByteAligned(offset); + } + } uint64_t GetDataOffset(uint64_t offset) const { @@ -62,6 +105,12 @@ class StorageReader::Impl final { return metadata.offset; } + uint64_t GetDataPaddingInBits(uint64_t offset) const + { + auto metadata = GetMetadata(offset); + return metadata.padding_size_in_bits; + } + uint64_t GetDataSize(uint64_t metadataOffset) const { auto metadata = GetMetadata(metadataOffset); @@ -118,23 +167,23 @@ class StorageReader::Impl final { std::call_once(m_loadedFlag, [&load]() { load(); }); } + blob_metadata GetAndCheckMetadata(uint64_t offset, MILBlob::Blob::BlobDataType blobDType) const + { + auto metadata = GetMetadata(offset); + + MILVerifyIsTrue(metadata.mil_dtype == blobDType, + std::runtime_error, + "Metadata data type does not match requested type."); + + return metadata; + } + const std::string m_filePath; mutable std::once_flag m_loadedFlag; mutable std::unique_ptr m_reader; }; -template -Util::Span StorageReader::Impl::GetDataView(uint64_t offset) const -{ - auto metadata = GetMetadata(offset); - - MILVerifyIsTrue(metadata.mil_dtype == BlobDataTypeTraits::DataType, - std::runtime_error, - "Metadata data type does not match requested type."); - return Util::SpanCast(m_reader->ReadData(metadata.offset, metadata.sizeInBytes)); -} - // -------------------------------------------------------------------------------------- StorageReader::~StorageReader() = default; @@ -152,6 +201,18 @@ Util::Span StorageReader::GetDataView(uint64_t offset) con return m_impl->GetDataView(offset); } +// StorageReader::GetDataView specializations for sub byte types +#define DECLARE_SUB_BYTE_TYPE(TYPE_NAME) \ + template <> \ + Util::Span StorageReader::GetDataView(uint64_t offset) const \ + { \ + return m_impl->GetDataView(offset); \ + } + +#include "MILBlob/SubByteTypeList.hpp" + +#undef DECLARE_SUB_BYTE_TYPE + template <> Util::Span StorageReader::GetDataView(uint64_t offset) const { @@ -164,6 +225,18 @@ Util::Span StorageReader::GetDataView(uint64_t offset) const return m_impl->GetDataView(offset); } +template <> +Util::Span StorageReader::GetDataView(uint64_t offset) const +{ + return m_impl->GetDataView(offset); +} + +template <> +Util::Span StorageReader::GetDataView(uint64_t offset) const +{ + return m_impl->GetDataView(offset); +} + template <> Util::Span StorageReader::GetDataView(uint64_t offset) const { @@ -188,6 +261,18 @@ Util::Span StorageReader::GetDataView(uint64_t offset) return m_impl->GetDataView(offset); } +template <> +Util::Span StorageReader::GetDataView(uint64_t offset) const +{ + return m_impl->GetDataView(offset); +} + +template <> +Util::Span StorageReader::GetDataView(uint64_t offset) const +{ + return m_impl->GetDataView(offset); +} + Util::Span StorageReader::GetRawDataView(uint64_t offset) const { return m_impl->GetRawDataView(offset); @@ -217,3 +302,8 @@ std::vector StorageReader::GetAllOffsets() const { return m_impl->GetAllOffsets(); } + +uint64_t StorageReader::GetDataPaddingInBits(uint64_t metadataOffset) const +{ + return m_impl->GetDataPaddingInBits(metadataOffset); +} diff --git a/mlmodel/src/MILBlob/Blob/StorageReader.hpp b/mlmodel/src/MILBlob/Blob/StorageReader.hpp index 3bb073a00..bc8c7b687 100644 --- a/mlmodel/src/MILBlob/Blob/StorageReader.hpp +++ b/mlmodel/src/MILBlob/Blob/StorageReader.hpp @@ -6,8 +6,10 @@ #pragma once #include "MILBlob/Bf16.hpp" -#include "MILBlob/Fp16.hpp" #include "MILBlob/Blob/BlobDataType.hpp" +#include "MILBlob/Fp16.hpp" +#include "MILBlob/Fp8.hpp" +#include "MILBlob/SubByteTypes.hpp" #include "MILBlob/Util/Span.hpp" #include #include @@ -22,12 +24,16 @@ namespace Blob { * Memory-mapping is performed lazily on first access to the underlying data. * * This file format supports the following types: + * - uint1,2,4 + * - int4 * - uint8_t * - Bf16 * - Fp16 * - float * - int16_t * - uint16_t + * - int32_t + * - uint32_t */ class StorageReader final { public: @@ -85,25 +91,47 @@ class StorageReader final { /** Returns a vector containing the metadata offsets for all blobs in the file, in order. */ std::vector GetAllOffsets() const; + uint64_t GetDataPaddingInBits(uint64_t metadataOffset) const; + private: class Impl; const std::unique_ptr m_impl; }; +template <> +Util::Span StorageReader::GetDataView(uint64_t) const; template <> Util::Span StorageReader::GetDataView(uint64_t) const; template <> Util::Span StorageReader::GetDataView(uint64_t) const; template <> +Util::Span StorageReader::GetDataView(uint64_t) const; +template <> +Util::Span StorageReader::GetDataView(uint64_t) const; +template <> Util::Span StorageReader::GetDataView(uint64_t) const; template <> Util::Span StorageReader::GetDataView(uint64_t) const; template <> Util::Span StorageReader::GetDataView(uint64_t) const; template <> +Util::Span StorageReader::GetDataView(uint64_t) const; +template <> +Util::Span StorageReader::GetDataView(uint64_t) const; +template <> +Util::Span StorageReader::GetDataView(uint64_t) const; +template <> +Util::Span StorageReader::GetDataView(uint64_t) const; +template <> +Util::Span StorageReader::GetDataView(uint64_t) const; +template <> Util::Span StorageReader::GetDataView(uint64_t) const; template <> Util::Span StorageReader::GetDataView(uint64_t) const; +template <> +Util::Span StorageReader::GetDataView(uint64_t) const; +template <> +Util::Span StorageReader::GetDataView(uint64_t) const; } // namespace Blob } // namespace MILBlob diff --git a/mlmodel/src/MILBlob/Blob/StorageWriter.cpp b/mlmodel/src/MILBlob/Blob/StorageWriter.cpp index b57774d1c..2cc077e9c 100644 --- a/mlmodel/src/MILBlob/Blob/StorageWriter.cpp +++ b/mlmodel/src/MILBlob/Blob/StorageWriter.cpp @@ -8,6 +8,7 @@ #include "MILBlob/Blob/StorageFormat.hpp" #include "MILBlob/Blob/StorageWriter.hpp" #include "MILBlob/Fp16.hpp" +#include "MILBlob/Fp8.hpp" #include "MILBlob/Util/Span.hpp" #include "MILBlob/Util/SpanCast.hpp" @@ -74,13 +75,41 @@ class StorageWriter::Impl final { storage_header m_header; }; +template +uint64_t SpanSizeInBytes(Util::Span data) +{ + if constexpr (MILBlob::IsSubByteSized::value) { + auto uint8Span = MILBlob::Util::CastFromBitSpan(data); + return SpanSizeInBytes(uint8Span); + } else { + return data.Size() * sizeof(T); + } +} + +template +void WritePaddingBits(blob_metadata& metadata, size_t numElements) +{ + // types aligned to byte boundaries don't need this padding + if constexpr (MILBlob::IsSubByteSized::value) { + metadata.padding_size_in_bits = 0; + std::size_t numBitsRemaining = (numElements * T::SizeInBits) % 8; + if (numBitsRemaining != 0) { + metadata.padding_size_in_bits = 8 - numBitsRemaining; + } + } +} + template uint64_t StorageWriter::Impl::WriteData(Util::Span data) { // 1. Write data blob_metadata metadata; metadata.mil_dtype = BlobDataTypeTraits::type>::DataType; - metadata.sizeInBytes = data.Size() * sizeof(T); + metadata.sizeInBytes = SpanSizeInBytes(data); + + // populate padding_size_in_bits, if we're writing a sub-byte-sized type + WritePaddingBits>(metadata, data.Size()); + // Get offset for data auto metadataOffset = m_fileWriter->GetNextAlignedOffset(); // metadata is 64 bit aligned. @@ -94,7 +123,13 @@ uint64_t StorageWriter::Impl::WriteData(Util::Span data) MILVerifyIsTrue(metadataOffset == actualMetadataOffset, std::runtime_error, "[MIL StorageWriter]: Metadata written to different offset than expected."); - auto actualDataOffset = m_fileWriter->AppendData(Util::SpanCast(data)); + Util::Span byteSpan; + if constexpr (MILBlob::IsSubByteSized::value) { + byteSpan = Util::CastFromBitSpan(data); + } else { + byteSpan = Util::SpanCast(data); + } + auto actualDataOffset = m_fileWriter->AppendData(byteSpan); MILVerifyIsTrue(dataOffset == actualDataOffset, std::runtime_error, "[MIL StorageWriter]: Metadata written to different offset than expected."); @@ -127,12 +162,30 @@ uint64_t StorageWriter::WriteData(Util::Span data) return m_impl->WriteData(data); } +template <> +uint64_t StorageWriter::WriteData(Util::Span data) +{ + return m_impl->WriteData(data); +} + template <> uint64_t StorageWriter::WriteData(Util::Span data) { return m_impl->WriteData(data); } +template <> +uint64_t StorageWriter::WriteData(Util::Span data) +{ + return m_impl->WriteData(data); +} + +template <> +uint64_t StorageWriter::WriteData(Util::Span data) +{ + return m_impl->WriteData(data); +} + template <> uint64_t StorageWriter::WriteData(Util::Span data) { @@ -151,6 +204,24 @@ uint64_t StorageWriter::WriteData(Util::Span data) return m_impl->WriteData(data); } +template <> +uint64_t StorageWriter::WriteData(Util::Span data) +{ + return m_impl->WriteData(data); +} + +// Implement WriteData forwarding stubs for all sub byte types +#define DECLARE_SUB_BYTE_TYPE(TYPE_NAME) \ + template <> \ + uint64_t StorageWriter::WriteData(Util::Span data) \ + { \ + return m_impl->WriteData(data); \ + } + +#include "MILBlob/SubByteTypeList.hpp" + +#undef DECLARE_SUB_BYTE_TYPE + template <> uint64_t StorageWriter::WriteData(Util::Span data) { diff --git a/mlmodel/src/MILBlob/Blob/StorageWriter.hpp b/mlmodel/src/MILBlob/Blob/StorageWriter.hpp index 00a1423a5..58e3c95ca 100644 --- a/mlmodel/src/MILBlob/Blob/StorageWriter.hpp +++ b/mlmodel/src/MILBlob/Blob/StorageWriter.hpp @@ -7,6 +7,8 @@ #include "MILBlob/Bf16.hpp" #include "MILBlob/Fp16.hpp" +#include "MILBlob/Fp8.hpp" +#include "MILBlob/SubByteTypes.hpp" #include "MILBlob/Util/Span.hpp" #include #include @@ -47,6 +49,8 @@ class StorageWriter final { const std::unique_ptr m_impl; }; +template <> +uint64_t StorageWriter::WriteData(Util::Span); template <> uint64_t StorageWriter::WriteData(Util::Span); template <> @@ -56,11 +60,29 @@ uint64_t StorageWriter::WriteData(Util::Span); template <> uint64_t StorageWriter::WriteData(Util::Span); template <> +uint64_t StorageWriter::WriteData(Util::Span); +template <> +uint64_t StorageWriter::WriteData(Util::Span); +template <> uint64_t StorageWriter::WriteData(Util::Span); template <> uint64_t StorageWriter::WriteData(Util::Span); template <> +uint64_t StorageWriter::WriteData(Util::Span); +template <> +uint64_t StorageWriter::WriteData(Util::Span); +template <> +uint64_t StorageWriter::WriteData(Util::Span); +template <> +uint64_t StorageWriter::WriteData(Util::Span); +template <> +uint64_t StorageWriter::WriteData(Util::Span); +template <> +uint64_t StorageWriter::WriteData(Util::Span); +template <> uint64_t StorageWriter::WriteData(Util::Span); +template <> +uint64_t StorageWriter::WriteData(Util::Span); } // namespace Blob } // namespace MILBlob diff --git a/mlmodel/src/MILBlob/Fp8.cpp b/mlmodel/src/MILBlob/Fp8.cpp new file mode 100644 index 000000000..d21f17729 --- /dev/null +++ b/mlmodel/src/MILBlob/Fp8.cpp @@ -0,0 +1,188 @@ +// Copyright (c) 2021, Apple Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-3-clause license that can be +// found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +#include "MILBlob/Fp8.hpp" + +#include + +using namespace MILBlob; + +// Some global constants. +constexpr uint8_t fp32MantissaBits = 23; +constexpr int8_t fp32ExponentBias = 127; + +// Helper function to handle Fp32 -> Fp8 exponent and mantissa. +template +void HandleFp32ToFp8ExponentMantissa(FP8_CAST& fp8, FloatCast& fp32) +{ + int32_t unbiasedExponent = fp32.components.exponent - fp32ExponentBias; + if (unbiasedExponent + FP8_TYPE::fp8ExponentBias > 0) { + // Normal. + fp8.components.exponent = uint8_t(fp32.components.exponent - fp32ExponentBias + FP8_TYPE::fp8ExponentBias); + } else { + // Denormal. + FloatCast fp32_bias; + fp32_bias.components.sign = fp32.components.sign; + fp32_bias.components.exponent = -1 * FP8_TYPE::fp8ExponentBias + fp32ExponentBias + 1; + fp32_bias.components.mantissa = 0; + fp32.f += fp32_bias.f; + fp8.components.exponent = 0; + } + if ((fp32.components.mantissa & ((0x1 << (fp32MantissaBits - FP8_TYPE::fp8MantissaBits)) - 1)) != 0) { + throw std::range_error("FP8 SetFloat requires rounding for the given value."); + } + fp8.components.mantissa = fp32.components.mantissa >> (fp32MantissaBits - FP8_TYPE::fp8MantissaBits); +} + +// Helper function to handle normalizing the denormalized case for fp8. +// For denormalized fp8's, we need to normalize by subtracting a bias of 2^(1 - fp8ExponentBias) +template +void HandleFp8ToFp32Denormalize(FP8_CAST& fp8, FloatCast& fp32) +{ + if (fp8.components.exponent == 0 && fp8.components.mantissa != 0) { + fp32.components.exponent++; + FloatCast fp32_bias; + fp32_bias.components.sign = fp8.components.sign; + fp32_bias.components.exponent = fp32.components.exponent; + fp32_bias.components.mantissa = 0; + fp32.f -= fp32_bias.f; + } +} + +// Helper function to handle exponent and mantissa for Fp8 -> Fp32 conversion. +template +void HandleFp8ToFp32ExponentMantissa(const FP8_CAST& fp8, FloatCast& fp32) +{ + if (fp8.components.exponent == 0 && fp8.components.mantissa == 0) { + fp32.components.exponent = 0; + fp32.components.mantissa = 0; + return; + } + int32_t unbiasedExponent = fp8.components.exponent - FP8_TYPE::fp8ExponentBias; + fp32.components.exponent = uint32_t(unbiasedExponent + fp32ExponentBias); + fp32.components.mantissa = + uint32_t(int32_t(fp8.components.mantissa << (fp32MantissaBits - FP8_TYPE::fp8MantissaBits))); +} + +float Fp8E5M2::GetFloat() const +{ + FloatCast fp32 = {.f = 0}; + // Set the sign bit. + fp32.components.sign = data.components.sign; + + // Standard NaN/Inf case. We just use the fp8 mantissa as there's + // no strong requirements for mantissa in the NaN case. + if (data.components.exponent == (0x1 << fp8ExponentBits) - 1) { + fp32.components.exponent = 0xFF; + fp32.components.mantissa = data.components.mantissa; + return fp32.f; + } + HandleFp8ToFp32ExponentMantissa(data, fp32); + HandleFp8ToFp32Denormalize(data, fp32); + return fp32.f; +} + +float Fp8E4M3FN::GetFloat() const +{ + FloatCast fp32 = {.f = 0}; + // Set the sign bit. + fp32.components.sign = data.components.sign; + // NaN case, infinity is not supported. We just use the mantissa from the fp8. + if (data.components.exponent == (0x1 << fp8ExponentBits) - 1 && data.components.mantissa == 0x7) { + fp32.components.exponent = 0xFF; + fp32.components.mantissa = data.components.mantissa; + return fp32.f; + } + HandleFp8ToFp32ExponentMantissa(data, fp32); + HandleFp8ToFp32Denormalize(data, fp32); + return fp32.f; +} + +void Fp8E5M2::SetFloat(float f) +{ + FloatCast fp32 = {.f = f}; + data = {.byte = 0}; + // Set sign bit. + data.components.sign = fp32.components.sign; + + // If f is nan or inf, set exponent to all 1's. + if (std::isnan(f)) { + data.components.exponent = (0x1 << fp8ExponentBits) - 1; + data.components.mantissa = 1; + } else if (std::isinf(f)) { + data.components.exponent = (0x1 << fp8ExponentBits) - 1; + data.components.mantissa = 0; + } else if (f == 0) { + data.components.exponent = 0; + data.components.mantissa = 0; + } else { + int32_t unbiasedExponent = fp32.components.exponent - fp32ExponentBias; + // Float is normal or denormal, check the exponent and set it. + // For now, we throw on over/underflows. There are alternative ways to handle + // this (round to zero). + if (unbiasedExponent > fp8ExponentBias) { + throw std::range_error("Fp8E5M2 SetFloat exponent overflow."); + } else if (unbiasedExponent < (-1 * fp8ExponentBias - int32_t(fp8MantissaBits) + 1)) { + throw std::range_error("Fp8E5M2 SetFloat exponent underflow."); + } + HandleFp32ToFp8ExponentMantissa(data, fp32); + } +} + +void Fp8E4M3FN::SetFloat(float f) +{ + FloatCast fp32 = {.f = f}; + data = {.byte = 0}; + // Set sign bit. + data.components.sign = fp32.components.sign; + + // If f is nan or inf, set exponent to all 1's. + if (std::isnan(f)) { + data.components.exponent = (0x1 << fp8ExponentBits) - 1; + data.components.mantissa = 7; + } else if (std::isinf(f)) { + throw std::range_error("Fp8E4M3FN SetFloat infinity not supported."); + } else if (f == 0) { + data.components.exponent = 0; + data.components.mantissa = 0; + } else { + int32_t unbiasedExponent = fp32.components.exponent - fp32ExponentBias; + // Float is normal or denormal, check the exponent and set it. + // For now, we throw on over/underflows. There are alternative ways to handle + // this (round to zero). + if (unbiasedExponent > fp8ExponentBias + 1) { + throw std::range_error("Fp8E4M3FN SetFloat exponent overflow."); + } else if (unbiasedExponent < (-1 * fp8ExponentBias - int32_t(fp8MantissaBits) + 1)) { + // Underflow occurs when the exponent is below the minimum denormal value. + // This means unbiased exponent is less than -fp8ExponentBias - fp8MantissaBits + 1 + throw std::range_error("Fp8E4M3FN SetFloat exponent underflow."); + } + HandleFp32ToFp8ExponentMantissa(data, fp32); + } +} + +Fp8E5M2 Fp8E5M2::FromFloat(float f) +{ + Fp8E5M2 result; + result.SetFloat(f); + return result; +} + +Fp8E4M3FN Fp8E4M3FN::FromFloat(float f) +{ + Fp8E4M3FN result; + result.SetFloat(f); + return result; +} + +bool Fp8E5M2::IsNaN() const +{ + return (data.components.exponent == 0x1F && data.components.mantissa != 0); +} + +bool Fp8E4M3FN::IsNaN() const +{ + return (data.components.exponent == 0xF && data.components.mantissa == 7); +} diff --git a/mlmodel/src/MILBlob/Fp8.hpp b/mlmodel/src/MILBlob/Fp8.hpp new file mode 100644 index 000000000..1a99e9e69 --- /dev/null +++ b/mlmodel/src/MILBlob/Fp8.hpp @@ -0,0 +1,107 @@ +// Copyright (c) 2021, Apple Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-3-clause license that can be +// found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +#pragma once + +#include +#include + +namespace MILBlob { + +// General helper typedef to help process an FP32 in different forms/its +// constituent components. +typedef union { + float f; + uint32_t bytes; + struct { + uint32_t mantissa : 23; + uint32_t exponent : 8; + uint32_t sign : 1; + } components; +} FloatCast; + +// Macro for FP8 types. +#define DECLARE_FP8_TYPE(NAME, EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS) \ + struct NAME { \ + typedef union { \ + uint8_t byte; \ + struct { \ + uint8_t mantissa : MANTISSA_BITS; \ + uint8_t exponent : EXPONENT_BITS; \ + uint8_t sign : 1; \ + } components; \ + } Cast; \ + explicit NAME(uint8_t d) \ + { \ + data.byte = d; \ + }; \ + NAME() \ + { \ + data.byte = 0; \ + } \ + static NAME FromFloat(float f); \ + float GetFloat() const; \ + void SetFloat(float f); \ + uint8_t GetByte() const \ + { \ + return data.byte; \ + } \ + void SetByte(uint8_t byte) \ + { \ + data.byte = byte; \ + } \ + bool IsNaN() const; \ + Cast data; \ + static constexpr int8_t fp8ExponentBias = EXPONENT_BIAS; \ + static constexpr uint8_t fp8ExponentBits = EXPONENT_BITS; \ + static constexpr uint8_t fp8MantissaBits = MANTISSA_BITS; \ + static_assert(fp8ExponentBits + fp8MantissaBits == 7, "Number of exponent and mantissa bits should be 7"); \ + }; \ + inline bool operator==(const NAME& first, const NAME& second) noexcept \ + { \ + if ((first.data.byte & 0x7F) == 0 && (second.data.byte & 0x7F) == 0) { \ + return true; \ + } \ + if (first.IsNaN() && second.IsNaN()) { \ + return false; \ + } \ + return first.data.byte == second.data.byte; \ + } \ + inline bool operator!=(const NAME& first, const NAME& second) noexcept \ + { \ + if ((first.data.byte & 0x7F) == 0 && (second.data.byte & 0x7F) == 0) { \ + return false; \ + } \ + if (first.IsNaN() && second.IsNaN()) { \ + return true; \ + } \ + return first.data.byte != second.data.byte; \ + } + +// Define the types. +DECLARE_FP8_TYPE(Fp8E5M2, 5, 2, 15) +DECLARE_FP8_TYPE(Fp8E4M3FN, 4, 3, 7) + +} // namespace MILBlob + +namespace std { + +template <> +struct hash { + size_t operator()(const MILBlob::Fp8E5M2& fp) const + { + return fp.data.byte; + } +}; + +template <> +struct hash { + size_t operator()(const MILBlob::Fp8E4M3FN& fp) const + { + return fp.data.byte; + } +}; + +} // namespace std diff --git a/mlmodel/src/MILBlob/SubByteTypeList.hpp b/mlmodel/src/MILBlob/SubByteTypeList.hpp new file mode 100644 index 000000000..295313c33 --- /dev/null +++ b/mlmodel/src/MILBlob/SubByteTypeList.hpp @@ -0,0 +1,13 @@ +// Copyright (c) 2021, Apple Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-3-clause license that can be +// found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +// Listing of sub-byte-sized types in MIL +// Template file used for generating stub functionality +DECLARE_SUB_BYTE_TYPE(Int4) +DECLARE_SUB_BYTE_TYPE(UInt6) +DECLARE_SUB_BYTE_TYPE(UInt4) +DECLARE_SUB_BYTE_TYPE(UInt3) +DECLARE_SUB_BYTE_TYPE(UInt2) +DECLARE_SUB_BYTE_TYPE(UInt1) diff --git a/mlmodel/src/MILBlob/SubByteTypes.cpp b/mlmodel/src/MILBlob/SubByteTypes.cpp new file mode 100644 index 000000000..e2611bd6e --- /dev/null +++ b/mlmodel/src/MILBlob/SubByteTypes.cpp @@ -0,0 +1,209 @@ +// Copyright (c) 2021, Apple Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-3-clause license that can be +// found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +#include "MILBlob/Util/Verify.hpp" + +#include "MILBlob/SubByteTypes.hpp" +#include "MILBlob/Util/SubByteConversionUtils.hpp" +#include +#include + +namespace MILBlob { + +struct IndexAndOffset { + uint64_t index; + uint8_t offset; +}; + +static IndexAndOffset GetIndexAndOffsetForSubByteValue(uint64_t i, uint8_t numBits) +{ + IndexAndOffset ret; + + uint64_t startBit = numBits * i; + + ret.index = startBit / 8; + ret.offset = startBit % 8; + + return ret; +} + +template +std::vector PackSubByteVecForNonByteAligned(Util::Span span) +{ + std::vector ret(MILBlob::SizeInBytes(span.Size()), 0); + + for (uint64_t i = 0; i < span.Size(); i++) { + MILVerifyIsTrue(span[i] <= T::MAX && span[i] >= T::MIN, + std::range_error, + "Value " + std::to_string(span[i]) + " is outside allowed subbyte datatype range [" + + std::to_string(T::MIN) + ", " + std::to_string(T::MAX) + "]."); + + auto indexAndOffset = GetIndexAndOffsetForSubByteValue(i, T::SizeInBits); + auto idx = indexAndOffset.index; + auto offset = indexAndOffset.offset; + + ret[idx] |= ((uint8_t)(span[i] << offset)); + if (offset > 8 - T::SizeInBits) { + // part of the i'th element of span spills over to idx+1 + // uint8_t rshift = T::SizeInBits - (8 - offset); + uint8_t rshift = 8 - offset; + ret[idx + 1] |= ((uint8_t)span[i] >> rshift); + } + } + + return ret; +} + +template +std::vector PackSubByteVecImpl(Util::Span vec) +{ + if constexpr (!MILBlob::SubByteIsByteAligned()) { + return PackSubByteVecForNonByteAligned(vec); + } + const auto ElementsPerByte = 8 / T::SizeInBits; + std::vector ret(MILBlob::SizeInBytes(vec.Size())); + for (size_t i = 0; i < vec.Size(); i++) { + size_t shiftAmmount = T::SizeInBits * (i % ElementsPerByte); + MILVerifyIsTrue(vec[i] <= T::MAX && vec[i] >= T::MIN, + std::range_error, + "Value " + std::to_string(vec[i]) + " is outside allowed subbyte datatype range [" + + std::to_string(T::MIN) + ", " + std::to_string(T::MAX) + "]."); + ret[i / ElementsPerByte] |= (static_cast((vec[i] & T::BitMask) << shiftAmmount)); + } + return ret; +} + +#define DEFINE_PACK_SUB_BYTE_VEC(TYPE) \ + std::vector PackSubByteVec(const std::vector& vec) \ + { \ + using impl_t = decltype(TYPE::data); \ + Util::Span int8Span(reinterpret_cast(vec.data()), vec.size()); \ + return PackSubByteVecImpl(int8Span); \ + } + +#define DECLARE_SUB_BYTE_TYPE(TYPE_NAME) DEFINE_PACK_SUB_BYTE_VEC(TYPE_NAME) +#include "MILBlob/SubByteTypeList.hpp" +#undef DECLARE_SUB_BYTE_TYPE + +#define DEFINE_UNPACK_SUB_BYTE_VEC(TYPE) \ + template <> \ + std::vector UnPackSubByteVec(const std::vector& vec, size_t numElements) \ + { \ + return UnPackSubByteVecImpl(vec, numElements); \ + } + +template +std::vector UnPackSubByteVecImpl(const std::vector& vec, size_t numElements) +{ + std::vector ret(numElements); + MILVerifyIsTrue( + vec.size() == MILBlob::SizeInBytes(numElements), + std::invalid_argument, + "Unpacking to sub-byte type vector has invalid number of elements. Sub-byte vector with NumElements " + "requires exactly vec.size() bytes."); + Util::Span subByteSpan((typename MILBlob::Util::voidType::type)(vec.data()), numElements); + for (size_t i = 0; i < numElements; i++) { + ret[i] = subByteSpan.ValueAt(i); + } + return ret; +} + +#define DECLARE_SUB_BYTE_TYPE(TYPE_NAME) DEFINE_UNPACK_SUB_BYTE_VEC(TYPE_NAME) +#include "MILBlob/SubByteTypeList.hpp" +#undef DECLARE_SUB_BYTE_TYPE + +template <> +std::vector PackInt8Span(Util::Span unpackedValues) +{ + return PackSubByteVecImpl(unpackedValues); +} + +template <> +std::vector PackUInt8Span(Util::Span unpackedValues) +{ + return PackSubByteVecImpl(unpackedValues); +} + +template <> +std::vector PackUInt8Span(Util::Span unpackedValues) +{ + return PackSubByteVecImpl(unpackedValues); +} + +template <> +std::vector PackUInt8Span(Util::Span unpackedValues) +{ + return PackSubByteVecImpl(unpackedValues); +} + +template <> +std::vector PackUInt8Span(Util::Span unpackedValues) +{ + return PackSubByteVecImpl(unpackedValues); +} + +template <> +std::vector PackUInt8Span(Util::Span unpackedValues) +{ + return PackSubByteVecImpl(unpackedValues); +} + +// Class methods for Int4, UInt4, etc. +#define IMPLEMENT_METHODS_FOR_SUB_BYTE_TYPE(TYPE_NAME) \ + TYPE_NAME::TYPE_NAME(decltype(TYPE_NAME::data) d) \ + { \ + MILVerifyIsTrue(d <= TYPE_NAME::MAX && d >= TYPE_NAME::MIN, \ + std::range_error, \ + #TYPE_NAME " value is out of range."); \ + data = d; \ + } \ + /* static */ TYPE_NAME TYPE_NAME::FromInt(int i) \ + { \ + TYPE_NAME result; \ + result.SetInt(i); \ + return result; \ + } \ + int TYPE_NAME::GetInt() const \ + { \ + return static_cast(data); \ + } \ + void TYPE_NAME::SetInt(int i) \ + { \ + MILVerifyIsTrue(i <= TYPE_NAME::MAX && i >= TYPE_NAME::MIN, \ + std::range_error, \ + #TYPE_NAME " value is out of range."); \ + data = static_cast(i); \ + return; \ + } \ + bool operator==(const TYPE_NAME& first, const TYPE_NAME& second) noexcept \ + { \ + return first.data == second.data; \ + } \ + bool operator!=(const TYPE_NAME& first, const TYPE_NAME& second) noexcept \ + { \ + return first.data != second.data; \ + } \ + static_assert(sizeof(TYPE_NAME) == 1, #TYPE_NAME " struct must be of size 1 byte"); + +#define DECLARE_SUB_BYTE_TYPE(TYPE_NAME) IMPLEMENT_METHODS_FOR_SUB_BYTE_TYPE(TYPE_NAME) +#include "MILBlob/SubByteTypeList.hpp" +#undef DECLARE_SUB_BYTE_TYPE + +}; // namespace MILBlob + +namespace std { + +// +128 here so that casting i.data to size_t, for T==Int4, is safe +#define DEFINE_HASH_FOR_SUB_BYTE_TYPE(TYPE) \ + size_t hash::operator()(const MILBlob::TYPE& i) const \ + { \ + return static_cast(i.data + 128); \ + } + +#define DECLARE_SUB_BYTE_TYPE(TYPE_NAME) DEFINE_HASH_FOR_SUB_BYTE_TYPE(TYPE_NAME) +#include "MILBlob/SubByteTypeList.hpp" +#undef DECLARE_SUB_BYTE_TYPE + +} // namespace std diff --git a/mlmodel/src/MILBlob/SubByteTypes.hpp b/mlmodel/src/MILBlob/SubByteTypes.hpp new file mode 100644 index 000000000..96be5e7a5 --- /dev/null +++ b/mlmodel/src/MILBlob/SubByteTypes.hpp @@ -0,0 +1,134 @@ +// Copyright (c) 2021, Apple Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-3-clause license that can be +// found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +#pragma once + +#include +#include +#include +#include + +// A sub-byte type of is represented in MIL by a byte-sized struct which wraps +// an value of type IMPL_TYPE +#define DEFINE_SUB_BYTE_TYPE(NAME, IMPL_TYPE, BIT_SIZE, MASK, MAX_VAL, MIN_VAL) \ + struct NAME { \ + explicit NAME(IMPL_TYPE d); \ + NAME() : data(0) {} \ + static NAME FromInt(int i); \ + int GetInt() const; \ + void SetInt(int i); \ + IMPL_TYPE data; \ + static constexpr uint8_t SizeInBits = BIT_SIZE; \ + static constexpr uint8_t BitMask = MASK; \ + static constexpr IMPL_TYPE MAX = MAX_VAL; \ + static constexpr IMPL_TYPE MIN = MIN_VAL; \ + static_assert(MAX >= MIN, "Incompatible values for MIN and MAX"); \ + }; + +// Declares the following exports for sub-byte-type NAME +// operator == +// operator != +// +// Packs a sub byte vector into uint8_t representation since a vector of sub byte type +// cannot be packed. +// std::vector PackSubByteVec(const std::vector& vec); +// +// Unpacks a sub byte vector in uint8_t representation to a vector of the sub byte type. +// template <> +// std::vector UnPackSubByteVec(const std::vector& vec, size_t numElements); +#define DECLARE_SUB_BYTE_TYPE_METHODS(NAME) \ + bool operator==(const NAME& first, const NAME& second) noexcept; \ + bool operator!=(const NAME& first, const NAME& second) noexcept; \ + std::vector PackSubByteVec(const std::vector& vec); \ + template <> \ + std::vector UnPackSubByteVec(const std::vector& vec, size_t numElements); + +namespace MILBlob { + +template +class IsSubByteSized { + struct S { + char a; + char b; + }; + template + static char Tester(decltype(&U::SizeInBits)); + template + static S Tester(...); + +public: + enum { + value = sizeof(Tester(0)) == sizeof(char) + }; +}; + +template +constexpr bool SubByteIsByteAligned() +{ + return (8 / T::SizeInBits) * T::SizeInBits == 8; +} + +template +constexpr std::size_t SizeInBytes(std::size_t numElements) +{ + return (std::size_t)std::ceil((numElements * T::SizeInBits) / 8.0); +} + +template +std::vector UnPackSubByteVec(const std::vector& vec, std::size_t numElements); + +DEFINE_SUB_BYTE_TYPE(Int4, int8_t, 4, 0xF, 7, -8) +DECLARE_SUB_BYTE_TYPE_METHODS(Int4) + +DEFINE_SUB_BYTE_TYPE(UInt6, uint8_t, 6, 0b111111, 63, 0) +DECLARE_SUB_BYTE_TYPE_METHODS(UInt6) + +DEFINE_SUB_BYTE_TYPE(UInt4, uint8_t, 4, 0xF, 15, 0) +DECLARE_SUB_BYTE_TYPE_METHODS(UInt4) + +DEFINE_SUB_BYTE_TYPE(UInt3, uint8_t, 3, 0b111, 7, 0) +DECLARE_SUB_BYTE_TYPE_METHODS(UInt3) + +DEFINE_SUB_BYTE_TYPE(UInt2, uint8_t, 2, 0b11, 3, 0) +DECLARE_SUB_BYTE_TYPE_METHODS(UInt2) + +DEFINE_SUB_BYTE_TYPE(UInt1, uint8_t, 1, 0b1, 1, 0) +DECLARE_SUB_BYTE_TYPE_METHODS(UInt1) + +} // namespace MILBlob + +namespace std { + +template <> +struct hash { + size_t operator()(const MILBlob::Int4& i) const; +}; + +template <> +struct hash { + size_t operator()(const MILBlob::UInt6& i) const; +}; + +template <> +struct hash { + size_t operator()(const MILBlob::UInt4& i) const; +}; + +template <> +struct hash { + size_t operator()(const MILBlob::UInt3& i) const; +}; + +template <> +struct hash { + size_t operator()(const MILBlob::UInt2& i) const; +}; + +template <> +struct hash { + size_t operator()(const MILBlob::UInt1& i) const; +}; + +} // namespace std diff --git a/mlmodel/src/MILBlob/Util/Span.hpp b/mlmodel/src/MILBlob/Util/Span.hpp index 6ed298b1a..9ce9a8596 100644 --- a/mlmodel/src/MILBlob/Util/Span.hpp +++ b/mlmodel/src/MILBlob/Util/Span.hpp @@ -5,6 +5,7 @@ #pragma once +#include "MILBlob/SubByteTypes.hpp" #include "MILBlob/Util/Verify.hpp" #include #include @@ -97,10 +98,15 @@ class SpanSize final { // If Extent is specified, Span supports compile-time bounds checking // when the Get<> method is used. // -// This version of Span also supports iterating slices and dimensions -// of multi-dimensional contiguous memory blocks. +// For underlying types of at least byte-size, this version of Span also +// supports iterating slices and dimensions of multi-dimensional +// contiguous memory blocks. +// +// For sub-byte types, only basic access to the data pointer and size +// are supported. //---------------------------------------------------------------------- +// Span types of at least byte-size. template class Span final { public: @@ -122,6 +128,8 @@ class Span final { template using IsIndexValid = span_helpers::IsIndexValid; + static_assert(!MILBlob::IsSubByteSized::value, "Sub byte-sized types must use the reduced Span implementation"); + class SliceIterator final { public: SliceIterator(pointer p, size_t stride) : m_ptr(p), m_stride(stride) {} @@ -417,6 +425,186 @@ class Span final { SpanSize m_size; }; +template +struct voidType { + using type = void*; +}; +template +struct voidType::value>::type> { + using type = const void*; +}; +// Specializations for sub-byte types. +// This should ideally be implemented with std::enable_if but that involves an ABI breaking change. +// The pointer referenced by m_ptr and returned by Data() is byte aligned and packed, with possible +// padding in the last byte. +#define DEFINE_SPAN_CLASS_FOR_SUBBYTE(subByteType) \ +public: \ + template \ + using SpanSize = span_helpers::SpanSize; \ + \ + template \ + using IsDynamicExtent = span_helpers::IsDynamicExtent; \ + \ + ~Span() = default; \ + \ + Span(const Span&) = default; \ + Span(Span&&) noexcept = default; \ + \ + Span& operator=(const Span&) = default; \ + Span& operator=(Span&&) noexcept = default; \ + \ + /** Implicit copy constructor for converting a mutable span to a const span. Extent and type must be the same. */ \ + template ::value && \ + std::is_same::type>::value, \ + int>::type = 0> \ + Span(const Span& other) : m_ptr(other.Data()) \ + , m_size(other.Size()) \ + {} \ + \ + /** Implicit move constructor for converting a mutable span to a const span. Extent and type must be the same. */ \ + template ::value && \ + std::is_same::type>::value, \ + int>::type = 0> \ + Span(Span&& other) : m_ptr(other.Data()) \ + , m_size(other.Size()) \ + {} \ + \ + template ::value, int>::type = 0> \ + Span() : m_ptr(nullptr) \ + , m_size(0) \ + {} \ + \ + template ::value, int>::type = 0> \ + explicit Span(voidType::type p) : m_ptr(p) \ + {} \ + \ + template ::value, int>::type = 0> \ + Span(voidType::type p, size_t size) : m_ptr(size == 0 ? nullptr : p) \ + , m_size(size) \ + {} \ + \ + voidType::type Data() const \ + { \ + return m_ptr; \ + } \ + \ + size_t Size() const \ + { \ + return m_size.Size(); \ + } \ + \ + constexpr bool IsEmpty() const \ + { \ + return Size() == 0; \ + } \ + template \ + Span StaticResize() const \ + { \ + MILVerifyIsTrue(NewExtent <= Size(), std::range_error, "index out of bounds"); \ + return Span(Data()); \ + } \ + \ + std::remove_const::type ValueAt(std::size_t index) \ + { \ + if (index >= Size()) { \ + throw std::out_of_range("index out of bounds."); \ + } \ + using nonConstSubByteType = std::remove_const::type; \ + using impl_t = decltype(nonConstSubByteType::data); \ + \ + uint8_t bitSize = nonConstSubByteType::SizeInBits; \ + size_t elementIndex = index % Size(); \ + size_t packedBitsIndex = elementIndex * bitSize / 8; \ + size_t startBitIndex = elementIndex * bitSize % 8; \ + uint8_t bitMask = static_cast(nonConstSubByteType::BitMask << startBitIndex); \ + uint8_t restoredElement_uint8 = (*((const uint8_t*)Data() + packedBitsIndex) & bitMask) >> startBitIndex; \ + \ + /* For non-byte-aligned dtypes like UInt3, the required bits can be spread across 2 bytes. \ + Create mask and retrieve bits from the second byte if needed. \ + Look at SpanTests::testSubByteUIntValueAt*/ \ + size_t retrievedBits = 8 - startBitIndex; \ + if (retrievedBits < bitSize) { \ + bitMask = 0; \ + for (size_t i = 0; i < (bitSize - retrievedBits); ++i) { \ + bitMask |= 1 << i; \ + } \ + restoredElement_uint8 |= (*((const uint8_t*)Data() + packedBitsIndex + 1) & bitMask) << retrievedBits; \ + } \ + \ + /* If sign=1, fill all 1s in the prefix. \ + e.g., say the Int4 value is 1011 which is -5 in 2s complement. At this point, restoredElement_uint8 is \ + 00001011. To represent -5 correctly in 1 byte, we fill prefix 1s, resulting in 111110111. */ \ + if (nonConstSubByteType::MIN < 0) { \ + uint8_t sign_bit = (restoredElement_uint8 >> (bitSize - 1)) & 1; \ + if (sign_bit == 1) { \ + for (size_t i = 0; i < 8 - bitSize; ++i) { \ + restoredElement_uint8 |= 1 << (i + bitSize); \ + } \ + } \ + } \ + return nonConstSubByteType(*reinterpret_cast(&restoredElement_uint8)); \ + } \ + \ +private: \ + voidType::type m_ptr; \ + SpanSize m_size; + +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(Int4) +}; +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(const Int4) +}; + +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(UInt6) +}; +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(const UInt6) +}; + +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(UInt4) +}; +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(const UInt4) +}; + +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(UInt3) +}; +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(const UInt3) +}; + +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(UInt2) +}; +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(const UInt2) +}; + +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(UInt1) +}; +template +class Span final { + DEFINE_SPAN_CLASS_FOR_SUBBYTE(const UInt1) +}; + // MakeSpan for std::vector yields Span // Examples: // (1) create a mutable span diff --git a/mlmodel/src/MILBlob/Util/SpanCast.hpp b/mlmodel/src/MILBlob/Util/SpanCast.hpp index b3219d293..d6337eef6 100644 --- a/mlmodel/src/MILBlob/Util/SpanCast.hpp +++ b/mlmodel/src/MILBlob/Util/SpanCast.hpp @@ -5,7 +5,9 @@ #pragma once +#include "MILBlob/SubByteTypes.hpp" #include "MILBlob/Util/Span.hpp" +#include namespace MILBlob { namespace Util { @@ -19,10 +21,45 @@ namespace Util { template Span SpanCast(Span span) { + static_assert(!MILBlob::IsSubByteSized::value && !MILBlob::IsSubByteSized::value, + "SpanCast for sub-byte sized types is not supported"); auto ptr = reinterpret_cast(span.Data()); auto size = (span.Size() * sizeof(SourceT)) / sizeof(TargetT); return Span(ptr, size); } +/** + Reinterpret casts the underlying Span to a sub-byte type span. numElements indicates the number of + sub-byte elements in the case where the last byte contains some padding due to round to nearest byte. +*/ + +template ::value, bool> = true> +Span CastToBitSpan(Span span, size_t numElements) +{ + static_assert(std::is_same::value || std::is_same::value, + "CastToBitSpan is only possible when casting from a uint8_t span"); + if (span.Size() != MILBlob::SizeInBytes(numElements)) { + throw std::invalid_argument( + "BitSpanCast to sub-byte type span has invalid number of elements. Sub-byte span with NumElements " + "requires exactly Span.Size() bytes."); + } + return Span((typename MILBlob::Util::voidType::type)(span.Data()), numElements); +} + +/** + Reinterpret casts the underlying sub-byte-sized Span to a Span +*/ +template ::value, bool> = true> +Span CastFromBitSpan(Span span) +{ + size_t numBits = span.Size() * SourceT::SizeInBits; + size_t numElements = numBits / 8; + // need 1 more byte-sized element to hold remainder, if it exists + if (numBits % 8 != 0) { + numElements++; + } + return Span((const uint8_t*)span.Data(), numElements); +} + } // namespace Util } // namespace MILBlob diff --git a/mlmodel/src/MILBlob/Util/SubByteConversionUtils.hpp b/mlmodel/src/MILBlob/Util/SubByteConversionUtils.hpp new file mode 100644 index 000000000..1a5bb8c82 --- /dev/null +++ b/mlmodel/src/MILBlob/Util/SubByteConversionUtils.hpp @@ -0,0 +1,41 @@ +// Copyright (c) 2021, Apple Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-3-clause license that can be +// found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +#pragma once + +#include "MILBlob/Util/Span.hpp" +#include + +namespace MILBlob { + +// This header contains the utils used by coremltools to pack subbyte datatype values. + +// Packs a span of int8_t containing unpacked values into a packed uint8_t vector +template +std::vector PackInt8Span(Util::Span unpackedValues); + +template <> +std::vector PackInt8Span(Util::Span unpackedValues); + +// Packs a span of uint8_t containing unpacked values into a packed uint8_t vector +template +std::vector PackUInt8Span(Util::Span unpackedValues); + +template <> +std::vector PackUInt8Span(Util::Span unpackedValues); + +template <> +std::vector PackUInt8Span(Util::Span unpackedValues); + +template <> +std::vector PackUInt8Span(Util::Span unpackedValues); + +template <> +std::vector PackUInt8Span(Util::Span unpackedValues); + +template <> +std::vector PackUInt8Span(Util::Span unpackedValues); + +} // namespace MILBlob diff --git a/mlmodel/src/Model.cpp b/mlmodel/src/Model.cpp index cbc2cb9c8..3760bc958 100644 --- a/mlmodel/src/Model.cpp +++ b/mlmodel/src/Model.cpp @@ -7,27 +7,61 @@ #include namespace CoreML { - + Model::Model() { m_spec = std::make_shared(); m_spec->set_specificationversion(MLMODEL_SPECIFICATION_VERSION); } - + Model::Model(const Specification::Model& proto) { m_spec = std::make_shared(proto); // We need to check this here because the proto could be overly strict downgradeSpecificationVersion(); } - + Model::Model(const std::string& description) : Model::Model() { Specification::Metadata* metadata = m_spec->mutable_description()->mutable_metadata(); metadata->set_shortdescription(description); } - + Model::Model(const Model& other) = default; Model::~Model() = default; + // rdar://111583564 (Allow Models with no Input Parameters) + bool Model::modelTypeAllowsEmptyInput(MLModelType modelType) { + switch (modelType) { + case MLModelType_neuralNetwork: + case MLModelType_neuralNetworkRegressor: + case MLModelType_neuralNetworkClassifier: + case MLModelType_pipeline: + case MLModelType_pipelineRegressor: + case MLModelType_pipelineClassifier: + case MLModelType_mlProgram: + return true; + default: + return false; + } + } + + bool Model::modelTypeAllowsMultipleFunctions(MLModelType modelType) { + switch (modelType) { + case MLModelType_mlProgram: + return true; + default: + return false; + } + } + + bool Model::modelTypeAllowsStatefulPrediction(MLModelType modelType) { + switch (modelType) { + case MLModelType_mlProgram: + return true; + default: + return false; + } + } + Result Model::validateGeneric(const Specification::Model& model) { // make sure compat version fields are filled in if (model.specificationversion() == 0) { @@ -49,11 +83,12 @@ namespace CoreML { } // validate model interface - Result r = validateModelDescription(model.description(), model.specificationversion()); + auto validationPolicy = ValidationPolicy(static_cast(model.Type_case())); + Result r = validateModelDescription(model.description(), model.specificationversion(), validationPolicy); if (!r.good()) { return r; } - + if (model.isupdatable()){ if (model.specificationversion() < MLMODEL_SPECIFICATION_VERSION_IOS13) { std::string err = "Model specification version for an updatable model must be '" + std::to_string(MLMODEL_SPECIFICATION_VERSION_IOS13) + "' or above."; @@ -127,33 +162,33 @@ namespace CoreML { Result Model::validate() const { return Model::validate(*m_spec); } - + Result Model::load(std::istream& in, Model& out) { if (!in.good()) { return Result(ResultType::UNABLE_TO_OPEN_FILE, "unable to open file for read"); } - + Result r = loadSpecification(*(out.m_spec), in); if (!r.good()) { return r; } // validate on load - + r = out.validate(); return r; } - + Result Model::load(const std::string& path, Model& out) { std::ifstream in(path, std::ios::binary); return load(in, out); } - - // We will only reduce the given specification version if possible. We never increase it here. + + // We will only reduce the given specification version if possible. We never increase it here. void Model::downgradeSpecificationVersion() { CoreML::downgradeSpecificationVersion(m_spec.get()); } - + Result Model::save(std::ostream& out) { if (!out.good()) { return Result(ResultType::UNABLE_TO_OPEN_FILE, @@ -169,10 +204,10 @@ namespace CoreML { if (!r.good()) { return r; } - + return saveSpecification(*m_spec, out); } - + Result Model::save(const std::string& path) { std::ofstream out(path, std::ios::binary); return save(out); @@ -194,7 +229,7 @@ namespace CoreML { } return inputs; } - + SchemaType Model::outputSchema() const { SchemaType outputs; const Specification::ModelDescription& interface = m_spec->description(); @@ -207,7 +242,7 @@ namespace CoreML { } return outputs; } - + Result Model::addInput(const std::string& featureName, FeatureType featureType) { Specification::ModelDescription* interface = m_spec->mutable_description(); @@ -216,7 +251,7 @@ namespace CoreML { arg->set_allocated_type(featureType.allocateCopy()); return Result(); } - + Result Model::addOutput(const std::string& targetName, FeatureType targetType) { Specification::ModelDescription* interface = m_spec->mutable_description(); @@ -225,7 +260,7 @@ namespace CoreML { arg->set_allocated_type(targetType.allocateCopy()); return Result(); } - + MLModelType Model::modelType() const { return static_cast(m_spec->Type_case()); } @@ -233,36 +268,36 @@ namespace CoreML { std::string Model::modelTypeName() const { return MLModelType_Name(modelType()); } - + const Specification::Model& Model::getProto() const { return *m_spec; } - + Specification::Model& Model::getProto() { return *m_spec; } - + Result Model::enforceTypeInvariant(const std::vector& allowedFeatureTypes, FeatureType featureType) { - + for (const FeatureType& t : allowedFeatureTypes) { if (featureType == t) { // no invariant broken -- type matches one of the allowed types return Result(); } } - + return Result::featureTypeInvariantError(allowedFeatureTypes, featureType); } - + bool Model::operator==(const Model& other) const { return *m_spec == *(other.m_spec); } - + bool Model::operator!=(const Model& other) const { return !(*this == other); } - + static void writeFeatureDescription(std::stringstream& ss, const Specification::FeatureDescription& feature) { ss << "\t\t" @@ -275,7 +310,7 @@ namespace CoreML { } ss << "\n"; } - + void Model::toStringStream(std::stringstream& ss) const { ss << "Spec version: " << m_spec->specificationversion() << "\n"; ss << "Model type: " << MLModelType_Name(static_cast(m_spec->Type_case())) << "\n"; @@ -295,7 +330,7 @@ namespace CoreML { ss << "\t" << "Predicted probability name: " << m_spec->description().predictedprobabilitiesname() << "\n"; } } - + std::string Model::toString() const { std::stringstream ss; toStringStream(ss); @@ -320,11 +355,11 @@ _MLModelSpecification::_MLModelSpecification() _MLModelSpecification::_MLModelSpecification(const CoreML::Model& te) { cppFormat.reset(new CoreML::Specification::Model(te.getProto())); } - + _MLModelMetadataSpecification::_MLModelMetadataSpecification() : cppMetadata(std::make_shared()) { } - + _MLModelMetadataSpecification::_MLModelMetadataSpecification(const CoreML::Specification::Metadata& meta) : cppMetadata(std::make_shared(meta)) { diff --git a/mlmodel/src/Model.hpp b/mlmodel/src/Model.hpp index 0aedda508..223a54c63 100644 --- a/mlmodel/src/Model.hpp +++ b/mlmodel/src/Model.hpp @@ -31,7 +31,7 @@ class Model { protected: std::shared_ptr m_spec; - + public: Model(); Model(const std::string& description); @@ -39,6 +39,10 @@ class Model { Model(const Model& other); virtual ~Model(); + static bool modelTypeAllowsEmptyInput(MLModelType modelType); + static bool modelTypeAllowsMultipleFunctions(MLModelType modelType); + static bool modelTypeAllowsStatefulPrediction(MLModelType modelType); + // operator overloads bool operator== (const Model& other) const; bool operator!= (const Model& other) const; @@ -53,7 +57,7 @@ class Model { */ static Result load(std::istream& stream, Model& out); - + /** * Deserializes a MLModel from a file given by a path. * @@ -64,7 +68,7 @@ class Model { */ static Result load(const std::string& path, Model& out); - + /** * Serializes a MLModel to an std::ostream. * @@ -72,7 +76,7 @@ class Model { * @return the result of the save operation, with ResultType::NO_ERROR on success. */ Result save(std::ostream& stream); - + /** * Serializes a MLModel to a file given by a path. * @@ -80,26 +84,26 @@ class Model { * @return result of save operation, with ResultType::NO_ERROR on success. */ Result save(const std::string& path); - + const std::string& shortDescription() const; MLModelType modelType() const; std::string modelTypeName() const; - + /** * Get the schema (name, type) for the inputs of this transform. * * @return input schema for outputs in this transform. */ SchemaType inputSchema() const; - + /** * Get the output schema (name, type) for the inputs of this transform. * * @return output schema for outputs in this transform. */ SchemaType outputSchema() const; - + /** * Enforces type invariant conditions. * @@ -120,8 +124,8 @@ class Model { */ static Result enforceTypeInvariant(const std::vector& allowedFeatureTypes, FeatureType featureType); - - + + /** * Ensures the spec is valid. This gets called every time before the * spec gets added to the MLModel. @@ -130,7 +134,7 @@ class Model { */ static Result validate(const Specification::Model& model); Result validate() const; - + /** * Add an input to the transform-spec. * @@ -139,7 +143,7 @@ class Model { * @return Result type of this operation. */ virtual Result addInput(const std::string& featureName, FeatureType featureType); - + /** * Add an output to the transform-spec. * @@ -148,18 +152,18 @@ class Model { * @return Result type of this operation. */ virtual Result addOutput(const std::string& outputName, FeatureType outputType); - + /** * If a model does not use features from later specification versions, this will * set the spec version so that the model can be executed on older versions of * Core ML. */ void downgradeSpecificationVersion(); - + // TODO -- This seems like a giant hack. This is leaking abstractions. const Specification::Model& getProto() const; Specification::Model& getProto(); - + // string representation (text description) std::string toString() const; void toStringStream(std::stringstream& ss) const; @@ -177,7 +181,7 @@ typedef struct _MLModelSpecification { _MLModelSpecification(const CoreML::Specification::Model&); _MLModelSpecification(const CoreML::Model&); } MLModelSpecification; - + typedef struct _MLModelMetadataSpecification { std::shared_ptr cppMetadata; _MLModelMetadataSpecification(); @@ -189,6 +193,6 @@ typedef struct _MLModelDescriptionSpecification { _MLModelDescriptionSpecification(); _MLModelDescriptionSpecification(const CoreML::Specification::ModelDescription&); } MLModelDescriptionSpecification; - + } #endif diff --git a/mlmodel/src/ResultType.hpp b/mlmodel/src/ResultType.hpp index 3c4d8f2be..2b97ae1d9 100644 --- a/mlmodel/src/ResultType.hpp +++ b/mlmodel/src/ResultType.hpp @@ -43,7 +43,16 @@ enum class ResultType { INVALID_UPDATABLE_MODEL_CONFIGURATION, // NN shaper failure, not necessarily an error - POTENTIALLY_INVALID_NEURAL_NETWORK_SHAPES + POTENTIALLY_INVALID_NEURAL_NETWORK_SHAPES, + + // The model type doesn't support the multi-function. + MODEL_TYPE_DOES_NOT_SUPPORT_MULTI_FUNCTION, + + // The default function name is invalid. + INVALID_DEFAULT_FUNCTION_NAME, + + // The model type requires at least one input feature. + MODEL_TYPE_DOES_NOT_SUPPORT_EMPTY_INPUT, }; } diff --git a/mlmodel/src/TreeEnsembleCommon.cpp b/mlmodel/src/TreeEnsembleCommon.cpp index 6c851c3f1..eb654fe59 100644 --- a/mlmodel/src/TreeEnsembleCommon.cpp +++ b/mlmodel/src/TreeEnsembleCommon.cpp @@ -247,7 +247,7 @@ namespace CoreML { namespace TreeEnsembles { if (nullptr == false_child_node) { continue; // Press on for further validation. Will trigger fatality in null check below at "This indicates there are logic errors above fooling us up; abort." } - + if(false_child_node == n) { std::ostringstream ss; ss << "False child and parent have same ID (TreeID=" << n->tree_id @@ -275,7 +275,7 @@ namespace CoreML { namespace TreeEnsembles { if (nullptr == true_child_node) { continue; // Press on for further validation. Will trigger fatality in null check below at "This indicates there are logic errors above fooling us up; abort." } - + if(true_child_node == n) { std::ostringstream ss; ss << "True child and parent have same ID (TreeID=" << n->tree_id @@ -413,7 +413,7 @@ namespace CoreML { namespace TreeEnsembles { // Now, are there nodes in not connected to any root nodes? // Because no node can have more than one parent, and that root - // nodes are defined by a node having no parent, then there + // nodes are definied by a node having no parent, then there // exists nodes not connected to any root node if and only if // there is a cycle. Thus we can easily test for this. @@ -443,7 +443,7 @@ namespace CoreML { namespace TreeEnsembles { * correctness of the dimension and multiclass options. */ auto tree_ensemble = std::make_shared<_TreeEnsemble>(); - + /** Add in the default values. */ @@ -469,22 +469,22 @@ namespace CoreML { namespace TreeEnsembles { << ") does not match specified output dimension (" << output_dimension << ")."; add_error_message(ss.str()); } - + /** Stage 5: pull out and verify the class-type specific parameters. */ if(m_spec.has_treeensembleregressor()) { tree_ensemble->operation_mode = _TreeEnsemble::OperationMode::REGRESSION_MODE; tree_ensemble->post_processing_transform = static_cast( m_spec.treeensembleregressor().postevaluationtransform()); - + } else if(m_spec.has_treeensembleclassifier()) { //auto tes_cl = m_spec.description().classifiertargets(); - + const auto& classifier = m_spec.treeensembleclassifier(); Specification::Int64Vector int64ClassLabels; Specification::StringVector stringClassLabels; - + switch (classifier.ClassLabels_case()) { case Specification::TreeEnsembleClassifier::kInt64ClassLabels: int64ClassLabels = classifier.int64classlabels(); @@ -498,58 +498,58 @@ namespace CoreML { namespace TreeEnsembles { // not sure if that's the desired outcome. break; } - + tree_ensemble->post_processing_transform = static_cast( m_spec.treeensembleclassifier().postevaluationtransform()); - + size_t n_classes = static_cast(std::max(int64ClassLabels.vector_size(), stringClassLabels.vector_size())); if(n_classes == 0) { - + /** Handle the binary classification mode. */ if(output_dimension == 1) { tree_ensemble->output_classes_string.clear(); tree_ensemble->output_classes_integer = {0, 1}; - + tree_ensemble->operation_mode = _TreeEnsemble::OperationMode::BINARY_CLASSIFICATION_MODE; - + } else { - + tree_ensemble->output_classes_string.clear(); tree_ensemble->output_classes_integer.resize(static_cast(output_dimension)); - + for(size_t i = 0; i < output_dimension; ++i) { tree_ensemble->output_classes_integer[i] = int64_t(i); } - + tree_ensemble->operation_mode = _TreeEnsemble::OperationMode::MULTICLASS_CLASSIFICATION_MODE; } - + } else if(/* Binary classification. */ (output_dimension == 1 && n_classes == 2) - + /* Multiclass classification. */ || (output_dimension >= 2 && n_classes == output_dimension) ) { - + bool binary_classification = (output_dimension == 1); - + if(binary_classification) { tree_ensemble->operation_mode = _TreeEnsemble::OperationMode::BINARY_CLASSIFICATION_MODE; } else { tree_ensemble->operation_mode = _TreeEnsemble::OperationMode::MULTICLASS_CLASSIFICATION_MODE; } - - + + bool integer_classes = int64ClassLabels.vector_size() > stringClassLabels.vector_size(); - + if(integer_classes) { tree_ensemble->output_classes_integer.resize(n_classes); } else { tree_ensemble->output_classes_string.resize(n_classes); } - + assert(n_classes >= 0); assert(n_classes < std::numeric_limits::max()); for(size_t i = 0; i < n_classes; ++i) { @@ -568,14 +568,14 @@ namespace CoreML { namespace TreeEnsembles { add_error_message(ss.str()); } } - + /** Stage 6: If there have been any errors, raise them. * */ if(error_count != 0) { throw std::logic_error("Error(s) in tree structure: \n" + current_error_msg.str()); } - + // And we're done. return tree_ensemble; } diff --git a/mlmodel/src/Utils.cpp b/mlmodel/src/Utils.cpp index 745eb9d1d..2c8595a4a 100644 --- a/mlmodel/src/Utils.cpp +++ b/mlmodel/src/Utils.cpp @@ -121,7 +121,11 @@ void CoreML::downgradeSpecificationVersion(Specification::Model *pModel) { // lets start at the newest specification version and downgrade from there pModel->set_specificationversion(MLMODEL_SPECIFICATION_VERSION_NEWEST); } - + + if (pModel->specificationversion() == MLMODEL_SPECIFICATION_VERSION_IOS18 && !hasIOS18Features(*pModel)) { + pModel->set_specificationversion(MLMODEL_SPECIFICATION_VERSION_IOS17); + } + if (pModel->specificationversion() == MLMODEL_SPECIFICATION_VERSION_IOS17 && !hasIOS17Features(*pModel)) { pModel->set_specificationversion(MLMODEL_SPECIFICATION_VERSION_IOS16); } @@ -332,6 +336,19 @@ bool CoreML::hasFloat16MultiArray(const Specification::Model& model) { return false; } +bool CoreML::hasCoreML8Opsets(const Specification::Model& model) { + if (model.Type_case() == Specification::Model::kMlProgram) { + auto main_iter = model.mlprogram().functions().find("main"); + if (main_iter != model.mlprogram().functions().end()) { + const auto& main = main_iter->second; + if (main.opset() == "CoreML8") { + return true; + } + } + } + return false; +} + bool CoreML::hasCoreML7Opsets(const Specification::Model& model) { if (model.Type_case() == Specification::Model::kMlProgram) { auto main_iter = model.mlprogram().functions().find("main"); @@ -708,6 +725,51 @@ bool CoreML::hasIOS17Features(const Specification::Model& model) { return result; } +bool CoreML::hasIOS18Features(const Specification::Model& model) { + // New in IOS18 features: + // - Language expansion for multilingual BERT used in text classifier and word tagger (revision == 5) + + bool result = false; + + switch (model.Type_case()) { + case Specification::Model::kPipeline: + for (auto &m : model.pipeline().models()) { + result = result || hasIOS18Features(m); + if (result) { + return true; + } + } + break; + case Specification::Model::kPipelineRegressor: + for (auto &m : model.pipelineregressor().pipeline().models()) { + result = result || hasIOS18Features(m); + if (result) { + return true; + } + } + break; + case Specification::Model::kPipelineClassifier: + for (auto &m : model.pipelineclassifier().pipeline().models()) { + result = result || hasIOS18Features(m); + if (result) { + return true; + } + } + break; + case Specification::Model::kWordTagger: + return model.wordtagger().revision() == 5; + case Specification::Model::kTextClassifier: + return model.textclassifier().revision() == 5; + default: + break; + } + + result = result || hasCoreML8Opsets(model); + result = result || hasMultiFunctions(model); + result = result || hasEmptyInput(model); + return result; +} + bool CoreML::hasCustomModel(const Specification::Model& model) { return (model.Type_case() == Specification::Model::kCustomModel); } @@ -1017,3 +1079,13 @@ bool CoreML::hasIOS14NeuralNetworkFeatures(const Specification::Model& model) { } return false; } + +bool CoreML::hasMultiFunctions(const Specification::Model& model) { + const auto& description = model.description(); + return description.functions_size() != 0 || !description.defaultfunctionname().empty(); +} + +bool CoreML::hasEmptyInput(const Specification::Model& model) { + const auto& description = model.description(); + return description.input_size() == 0; +} diff --git a/mlmodel/src/Utils.hpp b/mlmodel/src/Utils.hpp index 94792ed97..9daec541a 100644 --- a/mlmodel/src/Utils.hpp +++ b/mlmodel/src/Utils.hpp @@ -112,6 +112,7 @@ namespace CoreML { bool hasIOS15Features(const Specification::Model& model); bool hasIOS16Features(const Specification::Model& model); bool hasIOS17Features(const Specification::Model& model); + bool hasIOS18Features(const Specification::Model& model); typedef std::pair StringPair; // Returns a vector of pairs of strings, one pair per custom layer instance @@ -149,8 +150,10 @@ namespace CoreML { bool hasGrayscaleFloat16Image(const Specification::Model& model); bool hasCoreML6Opsets(const Specification::Model& model); bool hasCoreML7Opsets(const Specification::Model& model); - + bool hasCoreML8Opsets(const Specification::Model& model); + bool hasMultiFunctions(const Specification::Model& model); bool hasModelOrSubModelProperty(const Specification::Model& model, const std::function &boolFunc); + bool hasEmptyInput(const Specification::Model& model); // We also need a hasNonmaxSupression and hasBayesianProbitRegressor static inline std::vector readFloat16Weights(const Specification::WeightParams& weights) { diff --git a/mlmodel/src/Validation/CategoricalMappingValidator.cpp b/mlmodel/src/Validation/CategoricalMappingValidator.cpp index 817140ff3..5f4f9e6c9 100644 --- a/mlmodel/src/Validation/CategoricalMappingValidator.cpp +++ b/mlmodel/src/Validation/CategoricalMappingValidator.cpp @@ -21,7 +21,7 @@ namespace CoreML { if (!result.good()) { return result; } - + auto mapping_type = format.categoricalmapping().MappingType_case(); auto defval_type = format.categoricalmapping().ValueOnUnknown_case(); Specification::FeatureType::TypeCase requiredInputType; @@ -31,7 +31,7 @@ namespace CoreML { switch(mapping_type) { case Specification::CategoricalMapping::MappingTypeCase::kStringToInt64Map: - + if(defval_type == Specification::CategoricalMapping::ValueOnUnknownCase::kStrValue) { return Result(ResultType::INVALID_MODEL_PARAMETERS, "ValueOnUnknown set to string value while mapping produces int64."); @@ -40,9 +40,9 @@ namespace CoreML { requiredOutputType = Specification::FeatureType::kInt64Type; requiredInputSeqType = Specification::SequenceFeatureType::kStringType; requiredOutputSeqType = Specification::SequenceFeatureType::kInt64Type; - + break; - + case Specification::CategoricalMapping::MappingTypeCase::kInt64ToStringMap: if(defval_type == Specification::CategoricalMapping::ValueOnUnknownCase::kInt64Value) { return Result(ResultType::INVALID_MODEL_PARAMETERS, @@ -52,9 +52,9 @@ namespace CoreML { requiredInputType = Specification::FeatureType::kInt64Type; requiredOutputSeqType = Specification::SequenceFeatureType::kStringType; requiredInputSeqType = Specification::SequenceFeatureType::kInt64Type; - + break; - + case Specification::CategoricalMapping::MappingTypeCase::MAPPINGTYPE_NOT_SET: return Result(ResultType::INVALID_MODEL_PARAMETERS, "Mapping not set."); @@ -65,7 +65,7 @@ namespace CoreML { if (!result.good()) { return result; } - + // Validate the outputs result = validateDescriptionsContainFeatureWithTypes(interface.output(), 1, {requiredOutputType, Specification::FeatureType::kSequenceType}); if (!result.good()) { @@ -83,7 +83,7 @@ namespace CoreML { "of categorical mapping."); } - // Make sure the output is a sequence as well + // Make sure the outupt is a sequence as well if (interface.output(0).type().Type_case() != Specification::FeatureType::kSequenceType) { return Result(ResultType::UNSUPPORTED_FEATURE_TYPE_FOR_MODEL_TYPE, "Output of a sequence categorical mapping must be a sequence"); diff --git a/mlmodel/src/Validation/FeatureVectorizerValidator.cpp b/mlmodel/src/Validation/FeatureVectorizerValidator.cpp index 18cf746eb..26f40d115 100644 --- a/mlmodel/src/Validation/FeatureVectorizerValidator.cpp +++ b/mlmodel/src/Validation/FeatureVectorizerValidator.cpp @@ -9,19 +9,19 @@ #include "../build/format/Model.pb.h" namespace CoreML { - + template <> Result validate(const Specification::Model& format) { const auto& interface = format.description(); - + Result result; - + // Validate its a MLModel type. result = validateModelDescription(interface, format.specificationversion()); if (!result.good()) { return result; } - + // Validate the inputs result = validateDescriptionsContainFeatureWithTypes(interface.input(), 0, @@ -32,23 +32,23 @@ namespace CoreML { if (!result.good()) { return result; } - + // Validate the outputs result = validateDescriptionsContainFeatureWithTypes(interface.output(), 1, {Specification::FeatureType::kMultiArrayType}); if (!result.good()) { return result; } - + // Validate the parameters for (int i = 0; i < format.featurevectorizer().inputlist_size(); i++) { auto& element = format.featurevectorizer().inputlist(i); auto size = element.inputdimensions(); if (size <= 0) { return Result(ResultType::INVALID_MODEL_PARAMETERS, - "Dimension size must be greater than zero."); + "Dimension size must be greater tha zero."); } } - + return result; } } diff --git a/mlmodel/src/Validation/InterfaceValidators.cpp b/mlmodel/src/Validation/InterfaceValidators.cpp index 462404dfb..6e6ec77bb 100644 --- a/mlmodel/src/Validation/InterfaceValidators.cpp +++ b/mlmodel/src/Validation/InterfaceValidators.cpp @@ -11,10 +11,17 @@ #include "ValidatorUtils-inl.hpp" #include "../build/format/Model.pb.h" #include "Globals.hpp" +#include "Model.hpp" +#include "Utils.hpp" namespace CoreML { - + ValidationPolicy::ValidationPolicy(MLModelType modelType) + :allowsEmptyInput(Model::modelTypeAllowsEmptyInput(modelType)), + allowsEmptyOutput(false), + allowsMultipleFunctions(Model::modelTypeAllowsMultipleFunctions(modelType)), + allowsStatefulPrediction(Model::modelTypeAllowsStatefulPrediction(modelType)) + {} Result validateSizeRange(const Specification::SizeRange &range) { if (range.upperbound() > 0 && range.lowerbound() > static_cast(range.upperbound())) { @@ -24,7 +31,7 @@ namespace CoreML { return Result(); } - Result validateFeatureDescription(const Specification::FeatureDescription& desc, int modelVersion, bool isInput) { + Result validateFeatureDescription(const Specification::FeatureDescription& desc, int modelVersion, FeatureIOType featureIOType) { if (desc.name() == "") { return Result(ResultType::INVALID_MODEL_INTERFACE, "Feature description must have a non-empty name."); @@ -36,6 +43,21 @@ namespace CoreML { } const auto& type = desc.type(); + if (type.Type_case() == Specification::FeatureType::kStateType) { + // State features must be declared in the state feature descriptions. (For backward compatibility reason, + // it's also allowed to be in the input feature description.) + if (featureIOType != FeatureIOType::STATE && featureIOType != FeatureIOType::INPUT) { + return Result(ResultType::INVALID_MODEL_INTERFACE, + "State feature '" + desc.name() + "' should only be declared in the state feature description."); + } + } else { + // State feature description shall not have anything but state features. + if (featureIOType == FeatureIOType::STATE) { + return Result(ResultType::INVALID_MODEL_INTERFACE, + "State feature description can declare only state features, but '" + desc.name() + "' is not."); + } + } + switch (type.Type_case()) { case Specification::FeatureType::kDoubleType: case Specification::FeatureType::kInt64Type: @@ -137,7 +159,7 @@ namespace CoreML { } - if (isInput && !hasExplicitDefault && !hasImplictDefault) { + if (featureIOType == FeatureIOType::INPUT && !hasExplicitDefault && !hasImplictDefault) { return Result(ResultType::INVALID_MODEL_INTERFACE, "Description of multiarray feature '" + desc.name() + "' has missing shape constraints."); } @@ -177,7 +199,7 @@ namespace CoreML { case CoreML::Specification::ArrayFeatureType::kDoubleDefaultValue: if (type.multiarraytype().datatype() != Specification::ArrayFeatureType_ArrayDataType_DOUBLE){ return Result(ResultType::INVALID_MODEL_INTERFACE, - "Description of multiarray feature '" + desc.name() + "' has mismatch" + "Description of multiarray feature '" + desc.name() + "' has mistmatch" " between dataType and the type of default optional value."); } break; @@ -185,21 +207,21 @@ namespace CoreML { if (type.multiarraytype().datatype() != Specification::ArrayFeatureType_ArrayDataType_FLOAT32 && type.multiarraytype().datatype() != Specification::ArrayFeatureType_ArrayDataType_FLOAT16){ return Result(ResultType::INVALID_MODEL_INTERFACE, - "Description of multiarray feature '" + desc.name() + "' has mismatch" + "Description of multiarray feature '" + desc.name() + "' has mistmatch" " between dataType and the type of default optional value."); } break; case CoreML::Specification::ArrayFeatureType::kIntDefaultValue: if (type.multiarraytype().datatype() != Specification::ArrayFeatureType_ArrayDataType_INT32){ return Result(ResultType::INVALID_MODEL_INTERFACE, - "Description of multiarray feature '" + desc.name() + "' has mismatch" + "Description of multiarray feature '" + desc.name() + "' has mistmatch" " between dataType and the type of default optional value."); } break; default: break; } - + break; } @@ -337,7 +359,7 @@ namespace CoreML { if (modelVersion < MLMODEL_SPECIFICATION_VERSION_IOS12) { return Result(ResultType::INVALID_MODEL_INTERFACE, - "Sequence types are only valid in specification version >= " + std::to_string(MLMODEL_SPECIFICATION_VERSION_IOS12)+ + "Sequence types are only valid in specification verison >= " + std::to_string(MLMODEL_SPECIFICATION_VERSION_IOS12)+ ". This model has version " + std::to_string(modelVersion)); } @@ -364,6 +386,56 @@ namespace CoreML { } break; } + case Specification::FeatureType::kStateType: + { + if (modelVersion < MLMODEL_SPECIFICATION_VERSION_IOS18) { + return Result(ResultType::INVALID_MODEL_INTERFACE, + "State types are only valid in specification verison >= " + std::to_string(MLMODEL_SPECIFICATION_VERSION_IOS18)+ + ". This model has version " + std::to_string(modelVersion)); + } + + if (type.isoptional()) { + return Result(ResultType::INVALID_MODEL_INTERFACE, + "State feature '" + desc.name() + "' cannot be optional."); + } + + const auto &defaultShape = type.statetype().arraytype().shape(); + bool hasExplicitDefault = (type.statetype().arraytype().shape_size() != 0); + + if (!hasExplicitDefault) { + return Result(ResultType::INVALID_MODEL_INTERFACE, + "Description of State feature '" + desc.name() + "' has missing shape constraints."); + } + + for (int i=0; i < type.statetype().arraytype().shape_size(); i++) { + const auto &value = type.statetype().arraytype().shape(i); + if (value < 0) { + return Result(ResultType::INVALID_MODEL_INTERFACE, + "Description of State feature '" + desc.name() + "' has an invalid shape. " + "Element " + std::to_string(i) + " has non-positive value " + std::to_string(value) + "."); + } + } + + switch (type.statetype().arraytype().datatype()) { + case Specification::ArrayFeatureType_ArrayDataType_FLOAT16: + break; + default: + return Result(ResultType::INVALID_MODEL_INTERFACE, + "Description of State feature '" + desc.name() + "' has an invalid or unspecified dataType. " + "It must be specified as FLOAT16"); + } + + if (type.statetype().arraytype().ShapeFlexibility_case() != Specification::ArrayFeatureType::SHAPEFLEXIBILITY_NOT_SET) { + return Result(ResultType::INVALID_MODEL_INTERFACE, + "State feature '" + desc.name() + "' cannot have flexible shape."); + } + + if (type.statetype().arraytype().defaultOptionalValue_case() != Specification::ArrayFeatureType::DEFAULTOPTIONALVALUE_NOT_SET) { + return Result(ResultType::INVALID_MODEL_INTERFACE, + "State feature '" + desc.name() + "' cannot have default optional value."); + } + break; + } case Specification::FeatureType::TYPE_NOT_SET: // these aren't equal to anything, even themselves return Result(ResultType::INVALID_MODEL_INTERFACE, @@ -374,50 +446,102 @@ namespace CoreML { return Result(); } - Result validateFeatureDescriptions(const Specification::ModelDescription& interface, int modelVersion) { - // a model must have at least one input and one output - if (interface.input_size() < 1) { - return Result(ResultType::INVALID_MODEL_INTERFACE, "Models must have one or more inputs."); + inline Result validateModelLevelFeatureDescriptionsAreEmpty(const Specification::ModelDescription& interface) { + if (interface.input_size() != 0) { + return Result(ResultType::INVALID_MODEL_INTERFACE, "Multi-function model must not use top level input feature description."); } - if (interface.output_size() < 1) { - return Result(ResultType::INVALID_MODEL_INTERFACE, "Models must have one or more outputs."); + + if (interface.output_size() != 0) { + return Result(ResultType::INVALID_MODEL_INTERFACE, "Multi-function model must not use top level output feature description."); } - for (const auto& input : interface.input()) { - Result r = validateFeatureDescription(input, modelVersion, true); - if (!r.good()) { return r; } + if (interface.state_size() != 0) { + return Result(ResultType::INVALID_MODEL_INTERFACE, "Multi-function model must not use top level state feature description."); } - for (const auto& output : interface.output()) { - Result r = validateFeatureDescription(output, modelVersion, false); - if (!r.good()) { return r; } + if (!interface.predictedfeaturename().empty()) { + return Result(ResultType::INVALID_MODEL_INTERFACE, "Multi-function model must not use top level predictedFeatureName field."); + } + + if (!interface.predictedprobabilitiesname().empty()) { + return Result(ResultType::INVALID_MODEL_INTERFACE, "Multi-function model must not use top level predictedProbabilitiesName field."); + } + + if (!interface.traininginput().empty()) { + return Result(ResultType::INVALID_MODEL_INTERFACE, "Multi-function model must not use top level training input feature description."); } - // If we got here, all inputs/outputs seem good independently of each other. return Result(); } - Result validateModelDescription(const Specification::ModelDescription& interface, int modelVersion) { + Result validateMultiFunctionDescription(const Specification::ModelDescription& interface, int modelVersion, const ValidationPolicy& validationPolicy) { + if (!validationPolicy.allowsMultipleFunctions) { + return Result(ResultType::MODEL_TYPE_DOES_NOT_SUPPORT_MULTI_FUNCTION, + "This model type doesn't support multi-function syntax."); + } - Result result = validateFeatureDescriptions(interface, modelVersion); - if (!result.good()) { - return result; + if (modelVersion < MLMODEL_SPECIFICATION_VERSION_IOS18) { + return Result(ResultType::INVALID_COMPATIBILITY_VERSION, + "Multi-function syntax is only valid in specification verison >= " + std::to_string(MLMODEL_SPECIFICATION_VERSION_IOS18)+ + ". This model has version " + std::to_string(modelVersion)); } - return result; + const auto& functions = interface.functions(); + auto functionNames = std::vector(); + for (const auto& function: functions) { + Result r = validateFeatureDescriptions(function, modelVersion, validationPolicy); + if (!r.good()) { + return r; + } + functionNames.push_back(function.name()); + } + + // The default function name must be in function name list. + const auto& defaultFunctionName = interface.defaultfunctionname(); + if (find(functionNames.begin(), functionNames.end(), defaultFunctionName) == functionNames.end()) { + return Result(ResultType::INVALID_DEFAULT_FUNCTION_NAME, + "The default function name '" + defaultFunctionName + "' is not found in the function name list: " + componentsJoinedBy(functionNames, ",")); + } + + return Result(); + } + + Result validateModelDescription(const Specification::ModelDescription& interface, int modelVersion, const ValidationPolicy& validationPolicy) { + Result result; + if (interface.functions_size() > 0 || !interface.defaultfunctionname().empty()) { + // The model uses multi-function configuration. + + // Validate it doesn't use top level feature descriptions + result = validateModelLevelFeatureDescriptionsAreEmpty(interface); + if (!result.good()) { + return result; + } + + result = validateMultiFunctionDescription(interface, modelVersion, validationPolicy); + if (!result.good()) { + return result; + } + } else { + // The model doesn't use multi-function configuration. + Result result = validateFeatureDescriptions(interface, modelVersion, validationPolicy); + if (!result.good()) { + return result; + } + } + return result; } - Result validateRegressorInterface(const Specification::ModelDescription& description, int modelVersion) { - + Result validateRegressorInterface(const Specification::ModelDescription& description, int modelVersion, const ValidationPolicy& validationPolicy) { + if (description.predictedfeaturename() == "") { return Result(ResultType::INVALID_MODEL_INTERFACE, "Specification is missing regressor predictedFeatureName."); } - + // Validate feature descriptions - Result result = validateFeatureDescriptions(description, modelVersion); + Result result = validateFeatureDescriptions(description, modelVersion, validationPolicy); if (!result.good()) { return result; } @@ -431,8 +555,9 @@ namespace CoreML { return result; } - Result validateClassifierFeatureDescriptions(const Specification::ModelDescription& interface, - bool expected_class_is_int64) { + template + inline Result validateClassifierFeatureDescriptions_(const Description& interface, + bool expected_class_is_int64) { const auto& predictedFeatureName = interface.predictedfeaturename(); const auto& probOutputName = interface.predictedprobabilitiesname(); @@ -469,6 +594,16 @@ namespace CoreML { return Result(); } + Result validateClassifierFeatureDescriptions(const Specification::ModelDescription& interface, + bool expected_class_is_int64) { + return validateClassifierFeatureDescriptions_(interface, expected_class_is_int64); + } + + Result validateClassifierFeatureDescriptions(const Specification::FunctionDescription& interface, + bool expected_class_is_int64) { + return validateClassifierFeatureDescriptions_(interface, expected_class_is_int64); + } + /* * Validate optional inputs/outputs. * For most models, optional is not allowed (all inputs/outputs required). @@ -491,7 +626,7 @@ namespace CoreML { } return validateOptionalOutputs(interface); } - + inline Result validateOptionalTree(const Specification::ModelDescription& interface) { return validateOptionalOutputs(interface); } @@ -535,12 +670,16 @@ namespace CoreML { // just need to check that not all inputs are optional bool hasNotOptional = false; for (const auto& input : description.input()) { + if (input.type().Type_case() == Specification::FeatureType::kStateType) { // ignore optionality for State type input (which is always non-optional) + hasNotOptional = true; + continue; + } if (!input.type().isoptional()) { hasNotOptional = true; break; } } - if (!hasNotOptional) { + if (description.input().size() > 0 && !hasNotOptional) { return Result(ResultType::INVALID_MODEL_PARAMETERS, "At least one feature for a neural network must NOT be optional."); } return Result(); @@ -591,7 +730,7 @@ namespace CoreML { return validateOptionalOutputs(format.description()); } - + Result validateCanModelBeUpdatable(const Specification::Model& format) { Result r; switch (format.Type_case()) { @@ -611,4 +750,3 @@ namespace CoreML { } } } - diff --git a/mlmodel/src/Validation/KNearestNeighborsClassifierValidator.cpp b/mlmodel/src/Validation/KNearestNeighborsClassifierValidator.cpp index 23a3c3bf8..fcc6e68b2 100644 --- a/mlmodel/src/Validation/KNearestNeighborsClassifierValidator.cpp +++ b/mlmodel/src/Validation/KNearestNeighborsClassifierValidator.cpp @@ -95,15 +95,15 @@ namespace CoreML { out << "KNearestNeighborsClassifier requires a weighting scheme to be set." << std::endl; return Result(ResultType::INVALID_MODEL_PARAMETERS, out.str()); } - + int intLabelCount = knnClassifier.has_int64classlabels() ? knnClassifier.int64classlabels().vector_size() : 0; int stringLabelCount = knnClassifier.has_stringclasslabels() ? knnClassifier.stringclasslabels().vector_size() : 0; - + int labelCount = MAX(intLabelCount, stringLabelCount); - + auto classLabelCase = knnClassifier.ClassLabels_case(); auto defaultClassLabelIsInt64 = false; - + switch (knnClassifier.DefaultClassLabel_case()) { case Specification::KNearestNeighborsClassifier::kDefaultStringLabel: if (classLabelCase != Specification::KNearestNeighborsClassifier::CLASSLABELS_NOT_SET && @@ -114,7 +114,7 @@ namespace CoreML { } defaultClassLabelIsInt64 = false; break; - + case Specification::KNearestNeighborsClassifier::kDefaultInt64Label: if (classLabelCase != Specification::KNearestNeighborsClassifier::CLASSLABELS_NOT_SET && classLabelCase != Specification::KNearestNeighborsClassifier::kInt64ClassLabels) { @@ -124,7 +124,7 @@ namespace CoreML { } defaultClassLabelIsInt64 = true; break; - + case Specification::KNearestNeighborsClassifier::DEFAULTCLASSLABEL_NOT_SET: if (labelCount == 0) { std::stringstream out; @@ -132,13 +132,13 @@ namespace CoreML { return Result(ResultType::INVALID_MODEL_PARAMETERS, out.str()); } } - + res = validateClassifierInterface(format, format.knearestneighborsclassifier(), true, defaultClassLabelIsInt64); if (!res.good()) { return res; } - + return validateNearestNeighborsIndex(format, labelCount); } diff --git a/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkLayerValidator.cpp b/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkLayerValidator.cpp index 87760d73d..c09864b58 100644 --- a/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkLayerValidator.cpp +++ b/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkLayerValidator.cpp @@ -507,17 +507,17 @@ Result NeuralNetworkSpecValidator::validatePoolingLayer(const Specification::Neu Result NeuralNetworkSpecValidator::validatePooling3dLayer(const Specification::NeuralNetworkLayer& layer) { HANDLE_RESULT_AND_RETURN_ON_ERROR(validateInputCount(layer, 1, 1)); HANDLE_RESULT_AND_RETURN_ON_ERROR(validateOutputCount(layer, 1, 1)) - + if (ndArrayInterpretation) { HANDLE_RESULT_AND_RETURN_ON_ERROR(validateInputOutputRankEquality(layer, "Pooling3d", blobNameToRank)); - // Rank 5 for 2 spatial dimensions, 1 temporal dimension, batch dimension, and 1+ channels. + // Rank 5 for 2 spacial dimensions, 1 temporal dimension, batch dimension, and 1+ channels. HANDLE_RESULT_AND_RETURN_ON_ERROR(validateRankCount(layer, "Pooling3d", 5, -1, blobNameToRank)); } - + // Kernel const auto pooling3d = layer.pooling3d(); - + HANDLE_RESULT_AND_RETURN_ON_ERROR(validatePositive(pooling3d.kerneldepth(), "Kernel Depth")); HANDLE_RESULT_AND_RETURN_ON_ERROR(validatePositive(pooling3d.kernelheight(), "Kernel Height")); HANDLE_RESULT_AND_RETURN_ON_ERROR(validatePositive(pooling3d.kernelwidth(), "Kernel Width")); @@ -526,7 +526,7 @@ Result NeuralNetworkSpecValidator::validatePooling3dLayer(const Specification::N HANDLE_RESULT_AND_RETURN_ON_ERROR(validatePositive(pooling3d.stridedepth(), "Stride Depth")); HANDLE_RESULT_AND_RETURN_ON_ERROR(validatePositive(pooling3d.strideheight(), "Stride Height")); HANDLE_RESULT_AND_RETURN_ON_ERROR(validatePositive(pooling3d.stridewidth(), "Stride Width")); - + // Custom Padding auto paddingType = pooling3d.paddingtype(); HANDLE_RESULT_AND_RETURN_ON_ERROR(validatePooling3dPadding(paddingType, pooling3d.custompaddingfront(), "Front")); @@ -535,7 +535,7 @@ Result NeuralNetworkSpecValidator::validatePooling3dLayer(const Specification::N HANDLE_RESULT_AND_RETURN_ON_ERROR(validatePooling3dPadding(paddingType, pooling3d.custompaddingbottom(), "Bottom")); HANDLE_RESULT_AND_RETURN_ON_ERROR(validatePooling3dPadding(paddingType, pooling3d.custompaddingleft(), "Left")); HANDLE_RESULT_AND_RETURN_ON_ERROR(validatePooling3dPadding(paddingType, pooling3d.custompaddingright(), "Right")); - + return Result(); } @@ -547,7 +547,7 @@ Result NeuralNetworkSpecValidator::validateGlobalPooling3dLayer(const Specificat if (ndArrayInterpretation) { HANDLE_RESULT_AND_RETURN_ON_ERROR(validateInputOutputRankEquality(layer, "Pooling3d", blobNameToRank)); - // Rank 5 for 2 spatial dimensions, 1 temporal dimension, batch dimension, and 1+ channels. + // Rank 5 for 2 spacial dimensions, 1 temporal dimension, batch dimension, and 1+ channels. HANDLE_RESULT_AND_RETURN_ON_ERROR(validateRankCount(layer, "Pooling3d", 5, -1, blobNameToRank)); } @@ -1773,7 +1773,7 @@ Result NeuralNetworkSpecValidator::validateBatchedMatmulLayer(const Specificatio "However, bias is only supported when the layer has 1 input."; return Result(ResultType::INVALID_MODEL_PARAMETERS, err); } - + if (layer.input_size() > 1 && layer.batchedmatmul().int8dynamicquantize()) { std::string err = "BatchedMatMul layer '" + layer.name() + "': cannot use dynamic quantization with 2 inputs."; return Result(ResultType::INVALID_MODEL_PARAMETERS, err); @@ -2341,7 +2341,7 @@ Result NeuralNetworkSpecValidator::validateLoopLayer(const Specification::Neural const auto &bodyNNSpec = params.bodynetwork(); bool isConditionNet = (conditionNNSpec.layers_size() > 0) ? true : false; - // validate some generic requirements for the existence of fields + // validate some generic requirements for the existense of fields if (bodyNNSpec.layers_size() == 0) { std::string err = "Loop Layer '" + std::string(layer.name()) + "' has an empty body network"; return Result(ResultType::INVALID_MODEL_PARAMETERS, err); diff --git a/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkValidator.cpp b/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkValidator.cpp index 612e9e1cf..024cd0646 100644 --- a/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkValidator.cpp +++ b/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkValidator.cpp @@ -36,7 +36,7 @@ NeuralNetworkSpecValidator::NeuralNetworkSpecValidator(const std::map Result NeuralNetworkSpecValidator::validateNeuralNetwork(const T& nn) { - + Result r; - + // Loop over the layers // For each layer, validate the following: - // 1. inputtensor/outputtensor message, rank compatibility with Model input/output ranks + // 1. inputtensor/outputtensor message, rank compatibilty with Model input/output ranks // 2. Check rank consistency across the network for all blobs: ranks are not allowed to change for the same blob // 3. Call layer specific validation function // 4. check that layer's inputs are already present in "availableBlobs" set // 5. check that layer's outputs are NOT already present in "availableBlobs" set // 6. Add the layer's outputs to the "availableBlobs" set for (const auto& layer : nn.layers()) { - + if (!r.good()) { return r; } - + // check for inputtensor message validity if (ndArrayInterpretation) { if (layer.inputtensor_size() != 0) { @@ -422,15 +422,15 @@ Result NeuralNetworkSpecValidator::validateNeuralNetwork(const T& nn) { } } } // inputtensor, outputtensor validity end - + // First, we check the layer for internal correctness // this calls layer wise function r = validateLayer(layer); - + if (!r.good()) { return r; } - + // Check for topological defects: the layer's input must have been produced by a blob we have // already seen. for (const auto& input: layer.input()) { @@ -440,7 +440,7 @@ Result NeuralNetworkSpecValidator::validateNeuralNetwork(const T& nn) { return Result(ResultType::INVALID_MODEL_PARAMETERS, err); } } - + // Check for topological defects: check that the same output isn't being produced in two different places. // unless its the "copy" layer for (const auto& output: layer.output()) { @@ -456,7 +456,7 @@ Result NeuralNetworkSpecValidator::validateNeuralNetwork(const T& nn) { blobs[output].insert(layer.name()); } } // loop over layers - + return Result(); } @@ -465,13 +465,14 @@ Result NeuralNetworkSpecValidator::validateNeuralNetwork(const T& nn) { template Result validateNeuralNetworkTopLevel(const Specification::ModelDescription& interface, const T& nn, std::set& outputBlobNames, - bool isUpdatable) { + bool isUpdatable, + const ValidationPolicy& validationPolicy) { Result r; - + // First calculate the value of the flag "ndArrayInterpretation" // ndArrayInterpretation == False ==> iOS 11/12 (old) execution path can be used, i.e. all tensors are static rank 5. // ndArrayInterpretation == True ==> Tensors can have any rank (including 5). - + bool ndArrayInterpretation = false; bool hasNonIOS12Layer = false; @@ -485,56 +486,59 @@ Result validateNeuralNetworkTopLevel(const Specification::ModelDescription& inte break; } } - + if (nn.arrayinputshapemapping() != Specification::NeuralNetworkMultiArrayShapeMapping::RANK5_ARRAY_MAPPING) { hasNewArrayShapeMapping = true; } - + if (nn.imageinputshapemapping() != Specification::NeuralNetworkImageShapeMapping::RANK5_IMAGE_MAPPING) { hasNewImageShapeMapping = true; } - + for (const auto &layer: nn.layers()) { if (!isIOS12NeuralNetworkLayer(layer)) { hasNonIOS12Layer = true; break; } } - + if (hasNonIOS12Layer || hasNewArrayShapeMapping || hasNewImageShapeMapping) { ndArrayInterpretation = true; } - + if (hasNonIOS12Layer && !hasNewArrayShapeMapping && hasMultiArrayInput) { return Result(ResultType::INVALID_MODEL_INTERFACE, "Neural Network Multi-Array input shape mapping cannot be 'RANK5_ARRAY_MAPPING' if the network contains a layer added in version 4 (iOS 13) or later. Use 'EXACT_ARRAY_MAPPING' instead."); } - + if (!hasNewArrayShapeMapping && hasNewImageShapeMapping && hasMultiArrayInput) { return Result(ResultType::INVALID_MODEL_INTERFACE, "Neural Network Multi-Array input shape mapping cannot be 'RANK5_ARRAY_MAPPING' if the image input Shape mapping is not 'RANK5_IMAGE_MAPPING'"); } - + //==================== End of logic to determine the value of "ndArrayInterpretation" ====================== - - if (interface.input_size() == 0) { + + if (!validationPolicy.allowsEmptyInput && interface.input_size() == 0) { return Result(ResultType::INVALID_MODEL_INTERFACE, "Neural networks require at least one input."); } - + if (interface.output_size() == 0) { return Result(ResultType::INVALID_MODEL_INTERFACE, "Neural networks produce at least one output."); } - + if (nn.layers().size() == 0) { return Result(ResultType::INVALID_MODEL_PARAMETERS, "Neural networks require at least one layer."); } - - if (std::all_of(interface.input().begin(), interface.input().end(), + + if (interface.input_size() > 0 && std::all_of(interface.input().begin(), interface.input().end(), [](const Specification::FeatureDescription& input) { - return input.type().isoptional(); + if (input.type().Type_case() == Specification::FeatureType::kStateType) + return true; // ignores optionality for State type (which is always non-optional). + else + return input.type().isoptional(); })) { return Result(ResultType::INVALID_MODEL_INTERFACE, "Neural networks require at least one non-optional input."); @@ -544,7 +548,7 @@ Result validateNeuralNetworkTopLevel(const Specification::ModelDescription& inte HANDLE_RESULT_AND_RETURN_ON_ERROR(validateInputOutputTypes(interface.input(), ResultReason::MODEL_INPUT_TYPE_INVALID, "inputs")); std::map ioBlobNameToRank; // to collect ranks of input/output blobs from the shapes present in the description - + // populate "ioBlobNameToRank" if (ndArrayInterpretation) { for (const auto& input: interface.input()) { @@ -570,18 +574,18 @@ Result validateNeuralNetworkTopLevel(const Specification::ModelDescription& inte } } } - + // Collect Model input names and do some checking - + // inputBlobs: For each named data blob, the name of the node which produced it (there can be multiple in if-else branches) std::map> inputBlobs; for (const auto& input: interface.input()) { // For input blobs, we'll give them a dummy producing layer name inputBlobs[input.name()] = {"__input"}; if (input.type().Type_case() == Specification::FeatureType::kMultiArrayType) { - + if (!ndArrayInterpretation) { - + // only vector-like (rank 1) or image-like (rank 3) inputs are allowed bool validShapeFound = false; if (input.type().multiarraytype().shape().size() > 0) { @@ -619,7 +623,7 @@ Result validateNeuralNetworkTopLevel(const Specification::ModelDescription& inte if (!validShapeFound) { return Result(ResultType::INVALID_MODEL_INTERFACE, "Input MLMultiArray to neural networks must have dimension 1 (vector) or 3 (image-like arrays)."); } - + } else { // validate input shape when "ndArrayInterpretation" is True if (!(r = validateNdMultiArrayInputType(input.type().multiarraytype())).good()) { @@ -628,18 +632,18 @@ Result validateNeuralNetworkTopLevel(const Specification::ModelDescription& inte } // if else block on spec version to check validity of input shape } } - + // validate the Neural Network message - + // create an object to validate neural network message NeuralNetworkSpecValidator validator(inputBlobs, ioBlobNameToRank, ndArrayInterpretation, 0, ioBlobNameToRank); - + r = validator.validateNeuralNetwork(nn); - + if (!r.good()) { return r; } - + // gather all output blobs of the graph for (auto& blob: validator.blobs){ if (inputBlobs.find(blob.first) == inputBlobs.end()) { @@ -652,7 +656,7 @@ Result validateNeuralNetworkTopLevel(const Specification::ModelDescription& inte } } } - + // Call the shaper: compatibility with iOS 12 if (!ndArrayInterpretation) { // Compute the shapes @@ -664,11 +668,11 @@ Result validateNeuralNetworkTopLevel(const Specification::ModelDescription& inte return Result(ResultType::POTENTIALLY_INVALID_NEURAL_NETWORK_SHAPES, err); } } - + if (!r.good()) { return r; } - + if (isUpdatable) { r = validateUpdatableNeuralNetwork(nn); if (!r.good()) { return r; } @@ -676,28 +680,30 @@ Result validateNeuralNetworkTopLevel(const Specification::ModelDescription& inte r = validateTrainingInputs(interface, nn); if (!r.good()) { return r; } } - + return r; - + } namespace CoreML { - + template <> Result validate(const Specification::Model& format) { + const auto validationPolicy = ValidationPolicy(MLModelType_neuralNetworkClassifier); + // must have classifier parameters - Result r = validateClassifierInterface(format, format.neuralnetworkclassifier()); + Result r = validateClassifierInterface(format, format.neuralnetworkclassifier(), /* allowEmptyLabels */ false, /* defaultClassLabelIsInt64 */ false, validationPolicy); if (!r.good()) { return r; } - + std::set outputBlobNames; - r = validateNeuralNetworkTopLevel(format.description(), format.neuralnetworkclassifier(), outputBlobNames, format.isupdatable()); - + r = validateNeuralNetworkTopLevel(format.description(), format.neuralnetworkclassifier(), outputBlobNames, format.isupdatable(), validationPolicy); + if (!r.good()) { return r; } - + std::string probBlob = format.neuralnetworkclassifier().labelprobabilitylayername(); // Check if the probability blob name was provided in the proto if (!probBlob.empty()) { @@ -707,7 +713,7 @@ namespace CoreML { return Result(ResultType::INVALID_MODEL_PARAMETERS, err); } } - + // Now, we need to check that all the model's output names are either blob names or the extra outputs // for a classifier for (const auto& output : format.description().output()) { @@ -720,59 +726,57 @@ namespace CoreML { } } } - + return r; - + } - + template <> Result validate(const Specification::Model& format) { + auto validationPolicy = ValidationPolicy(MLModelType_neuralNetworkRegressor); + // must have regressor parameters - Result r = validateRegressorInterface(format.description(), format.specificationversion()); + Result r = validateRegressorInterface(format.description(), format.specificationversion(), validationPolicy); if (!r.good()) { return r; } - + std::set outputBlobNames; - return validateNeuralNetworkTopLevel(format.description(), format.neuralnetworkregressor(), outputBlobNames, format.isupdatable()); + return validateNeuralNetworkTopLevel(format.description(), format.neuralnetworkregressor(), outputBlobNames, format.isupdatable(), validationPolicy); } - + template <> Result validate(const Specification::Model& format) { - - - + auto validationPolicy = ValidationPolicy(MLModelType_neuralNetwork); + const auto& interface = format.description(); - + // This isn't true for classifiers and regressors -- need to template specialize it to make these work HANDLE_RESULT_AND_RETURN_ON_ERROR(validateInputOutputTypes(interface.output(), ResultReason::MODEL_OUTPUT_TYPE_INVALID, "outputs")); std::set outputBlobNames; - - Result r = validateNeuralNetworkTopLevel(format.description(), format.neuralnetwork(), outputBlobNames, format.isupdatable()); - + + Result r = validateNeuralNetworkTopLevel(format.description(), format.neuralnetwork(), outputBlobNames, format.isupdatable(), validationPolicy); + if (r.good()) { // Make sure that all of the model interface's outputs are actually produced by some blob for (const auto& output : format.description().output()) { - + const std::string& name = output.name(); - + std::string err; if (outputBlobNames.count(name) == 0) { err = "Interface specifies output '" + name + "' which is not produced by any layer in the neural network."; return Result(ResultType::INVALID_MODEL_INTERFACE, err); } outputBlobNames.erase(name); - + } } - + return r; - + } } - - - diff --git a/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkValidatorUtils.hpp b/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkValidatorUtils.hpp index f3fe2ee79..23fdb683c 100644 --- a/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkValidatorUtils.hpp +++ b/mlmodel/src/Validation/NeuralNetwork/NeuralNetworkValidatorUtils.hpp @@ -33,12 +33,12 @@ inline Result validateTensorMessage(const Specification::Tensor& tensor, const S inline Result checkRank(const Specification::NeuralNetworkLayer& layer, const std::string &layerType, int min, int max, const std::string &blobType, int rank) { - + // blobType: "input" or "output" - + assert( min <= max || max < 0 ); std::string err; - + if (max > 0 && max == min && rank != max) { err = "Layer '" + layer.name() + "' of type '" + layerType + + "' has " + blobType + " rank " + std::to_string(rank) + " but expects rank exactly " + std::to_string(min) + "."; return Result(ResultType::INVALID_MODEL_PARAMETERS, err); @@ -58,17 +58,17 @@ inline Result checkRank(const Specification::NeuralNetworkLayer& layer, inline Result validateRankCount(const Specification::NeuralNetworkLayer& layer, const std::string &layerType, int min, int max, std::map& blobNameToRank) { - + Result r; - + // check that 1st input's rank is within permissible limits if (blobNameToRank.find(layer.input(0)) != blobNameToRank.end()) { int rank = blobNameToRank.at(layer.input(0)); r = checkRank(layer, layerType, min, max, "input", rank); } - + if (!r.good()) {return r;} - + // check that 2nd input's rank is within permissible limits if (blobNameToRank.find(layer.output(0)) != blobNameToRank.end()) { int rank = blobNameToRank.at(layer.output(0)); @@ -79,7 +79,7 @@ inline Result validateRankCount(const Specification::NeuralNetworkLayer& layer, inline Result validateInputOutputRankEquality(const Specification::NeuralNetworkLayer& layer, std::string layerType, std::map& blobNameToRank) { - + if (blobNameToRank.find(layer.input(0)) != blobNameToRank.end() && blobNameToRank.find(layer.output(0)) != blobNameToRank.end()) { if (blobNameToRank.at(layer.input(0)) != blobNameToRank.at(layer.output(0))) { @@ -94,10 +94,10 @@ inline Result validateInputOutputRankEquality(const Specification::NeuralNetwork // Min and max are the minimum and maximum number of possible inputs. // negative values are interpreted as no bound inline Result validateInputCount(const Specification::NeuralNetworkLayer& layer, int min, int max) { - + assert( min <= max || max < 0 ); std::string err; - + if (max > 0 && max == min && layer.input_size() != max) { err = "Layer '" + std::string(layer.name()) + "' of type " + std::to_string(layer.layer_case()) + " has " + std::to_string(layer.input_size()) + " inputs but expects exactly " + std::to_string(min) + "."; return Result(ResultType::INVALID_MODEL_PARAMETERS, err); @@ -122,6 +122,7 @@ inline Result validateInputOutputTypes(const ::google::protobuf::RepeatedPtrFiel switch (feature.type().Type_case()) { case Specification::FeatureType::kImageType: case Specification::FeatureType::kMultiArrayType: + case Specification::FeatureType::kStateType: return true; default: return false; @@ -130,7 +131,7 @@ inline Result validateInputOutputTypes(const ::google::protobuf::RepeatedPtrFiel if (!std::all_of(features.cbegin(), features.cend(), checkFeatures)) { return Result(ResultType::INVALID_MODEL_INTERFACE, reason, - "Neural Networks require " + featureTypesDesc + " to be images or MLMultiArray."); + "Neural Networks require " + featureTypesDesc + " to be images, MLMultiArray, or State."); } return Result(); @@ -158,10 +159,10 @@ inline Result validateNdMultiArrayInputType(const Specification::ArrayFeatureTyp } inline Result validateOutputCount(const Specification::NeuralNetworkLayer& layer, int min, int max) { - + assert( min <= max || max < 0 ); std::string err; - + if (max > 0 && max == min && layer.output_size() != max) { err = "Layer '" + layer.name() + "' of type " + std::to_string(layer.layer_case()) + + " has " + std::to_string(layer.output_size()) + " outputs but expects exactly " + std::to_string(min) + "."; return Result(ResultType::INVALID_MODEL_PARAMETERS, err); @@ -180,7 +181,7 @@ inline Result validateOutputCount(const Specification::NeuralNetworkLayer& layer } inline Result validateRankExists(const Specification::NeuralNetworkLayer& layer) { - + if (layer.inputtensor_size() == 0 || layer.outputtensor_size() == 0) { std::string err = "Layer '" + std::string(layer.name()) + "' must have rank specified for its input and output."; return Result(ResultType::INVALID_MODEL_PARAMETERS, err); @@ -250,7 +251,7 @@ inline bool isWeightParamTypeCompatible(const std::vector& weig inline Result validateLSTMWeightParams(const Specification::LSTMWeightParams& lstmWeightParams, const Specification::LSTMParams lstmParams) { bool has_peephole_vector = lstmParams.haspeepholevectors(); bool has_bias_vector = lstmParams.hasbiasvectors(); - + // Validate all weightParam types match std::vector weightTypes; weightTypes.push_back(valueType(lstmWeightParams.inputgateweightmatrix())); diff --git a/mlmodel/src/Validation/ScalarValidator.cpp b/mlmodel/src/Validation/ScalarValidator.cpp index 0f6bc1f32..21b88d9bc 100644 --- a/mlmodel/src/Validation/ScalarValidator.cpp +++ b/mlmodel/src/Validation/ScalarValidator.cpp @@ -13,11 +13,11 @@ namespace CoreML { template <> Result validate(const Specification::Model& format) { const auto& description = format.description(); - + // Convenience typedefs typedef Specification::FeatureType FT; typedef Specification::Imputer::ReplaceValueCase RVC; - + Result result; // Validate its a MLModel type. @@ -25,32 +25,32 @@ namespace CoreML { if (!result.good()) { return result; } - + // Validate the inputs result = validateDescriptionsContainFeatureWithTypes(description.input(), 1, {FT::kInt64Type, FT::kDoubleType, FT::kMultiArrayType}); - + if (!result.good()) { return result; } - + // Validate the outputs result = validateDescriptionsContainFeatureWithTypes(description.output(), 1, {FT::kInt64Type, FT::kDoubleType, FT::kMultiArrayType}); - + if (!result.good()) { return result; } - - // Make sure the input and output type match. + + // Make sure the input and output type match. // From the above, we know that we have exactly one input and one output type. const auto& input = description.input()[0]; const auto& output = description.output()[0]; - + if(input.type().Type_case() == FT::kInt64Type) { if((output.type().Type_case() != FT::kInt64Type) && (output.type().Type_case() != FT::kDoubleType)) { - + return Result(ResultType::INVALID_MODEL_PARAMETERS, "Input type Int64 must output to Int64 or Double."); } @@ -58,59 +58,59 @@ namespace CoreML { return Result(ResultType::INVALID_MODEL_PARAMETERS, "Type of input feature does not match the output type feature."); } - + // If it's an array, we need to test sizes. if(input.type().Type_case() == FT::kMultiArrayType) { if(input.type().multiarraytype().shape_size() != 1) { return Result(ResultType::INVALID_MODEL_PARAMETERS, "Only 1 dimensional arrays input features are supported by the scaler."); } - + if(output.type().multiarraytype().shape_size() != 1 || (input.type().multiarraytype().shape(0) != output.type().multiarraytype().shape(0))) { return Result(ResultType::INVALID_MODEL_PARAMETERS, "Shape of output array does not match shape of input array."); } - + // Now, make sure that the repeated values make sense. int64_t shift_size = static_cast(format.scaler().shiftvalue_size()); - + if(!(shift_size == 0 || shift_size == 1 || shift_size == input.type().multiarraytype().shape(0))) { - + return Result(ResultType::INVALID_MODEL_PARAMETERS, "For input type array, specified shift values must be empty, a scalar, or a vector of the matching length."); } - + // Now, make sure that the repeated values make sense. int64_t scale_size = static_cast(format.scaler().scalevalue_size()); - + if(!(scale_size == 0 || scale_size == 1 || scale_size == input.type().multiarraytype().shape(0))) { - + return Result(ResultType::INVALID_MODEL_PARAMETERS, "For input type array, specified scale values must be empty, a scalar, or a vector of the matching length."); } } else { // Now, make sure that the repeated values make sense. size_t shift_size = static_cast(format.scaler().shiftvalue_size()); - + if(!(shift_size == 0 || shift_size == 1)) { - + return Result(ResultType::INVALID_MODEL_PARAMETERS, - "For a scalar input type, specified shift value must be empty or a scalar."); + "For a scalar imput type, specified shift value must be empty or a scalar."); } - + // Now, make sure that the repeated values make sense. size_t scale_size = static_cast(format.scaler().scalevalue_size()); - + if(!(scale_size == 0 || scale_size == 1)) { - + return Result(ResultType::INVALID_MODEL_PARAMETERS, "For input type array, specified scale values must be empty or a scalar."); } } - + return result; } } diff --git a/mlmodel/src/Validation/ValidatorUtils-inl.hpp b/mlmodel/src/Validation/ValidatorUtils-inl.hpp index a1b9b5fee..5d601cea2 100644 --- a/mlmodel/src/Validation/ValidatorUtils-inl.hpp +++ b/mlmodel/src/Validation/ValidatorUtils-inl.hpp @@ -17,7 +17,7 @@ #include namespace CoreML { - + enum WeightParamType { FLOAT32, // float32 weights FLOAT16, // float16 weights @@ -26,7 +26,7 @@ namespace CoreML { UNSPECIFIED, // More then one type specified EMPTY // No populated fields }; - + // Returns true if the weight params object has only a single type encoded in it inline bool checkSingleWeightType(const Specification::WeightParams ¶m) { int numFilledIn = 0; @@ -40,7 +40,7 @@ namespace CoreML { numFilledIn++; return (numFilledIn == 1); } - + inline int numberOfWeightType(const Specification::WeightParams ¶m) { int numFilledIn = 0; if (param.floatvalue_size() > 0) @@ -74,7 +74,7 @@ namespace CoreML { } return EMPTY; } - + /* * Utility that make sures the feature types are valid. * @@ -84,7 +84,7 @@ namespace CoreML { */ inline Result validateSchemaTypes(const std::vector& allowedFeatureTypes, const Specification::FeatureDescription& featureDesc) { - + // Check the types auto type = featureDesc.type().Type_case(); for (const auto& t : allowedFeatureTypes) { @@ -93,7 +93,7 @@ namespace CoreML { return Result(); } } - + // Invalid type std::stringstream out; out << "Unsupported type \"" << MLFeatureTypeType_Name(static_cast(featureDesc.type().Type_case())) @@ -139,24 +139,24 @@ namespace CoreML { int maxFeatureCount, const std::vector& allowedFeatureTypes) { Result result; - + // 0 means no maximum fixed feature count. if (maxFeatureCount != 0 && features.size() > maxFeatureCount) { return Result(ResultType::TOO_MANY_FEATURES_FOR_MODEL_TYPE, "Feature descriptions exceeded " + std::to_string(maxFeatureCount)); } - + for (int i = 0; i < features.size(); i++) { result = validateSchemaTypes(allowedFeatureTypes, features[i]); if (!result.good()) { return result; } } - + return result; } /* - * Utility that checks a set of descriptions to validate + * Utility that checks a set of descriptions to validate * there is a feature with a specific name and type in an allowed set */ template @@ -173,8 +173,8 @@ namespace CoreML { return Result(ResultType::INTERFACE_FEATURE_NAME_MISMATCH, "Expected feature '" + name + "' to the model is not present in the model description."); } - - + + static inline int getWeightParamSize(const Specification::WeightParams& weights) { WeightParamType paramValueType = valueType(weights); switch (paramValueType) { @@ -190,7 +190,7 @@ namespace CoreML { return 0; }; - + static inline int getWeightParamSizeInBytes(const Specification::WeightParams& weights) { WeightParamType paramValueType = valueType(weights); switch (paramValueType) { @@ -242,10 +242,33 @@ namespace CoreML { } else { return Result(); } - + }; Result validateSizeRange(const Specification::SizeRange & range); + /** + * Joins each component in the container to a string with a separator. + * + * ``` + * auto v = std::vector({1, 2, 3}) + * componentsJoinedBy(v, ",") // "1, 2, 3" + * ``` + */ + template + std::string componentsJoinedBy(const Container& container, const std::string& sep) { + auto ss = std::ostringstream(); + auto beg = std::begin(container); + auto end = std::end(container); + if (std::distance(beg, end) > 0) { + auto it = beg; + for (; it != end - 1; ++it) { + ss << *it; + ss << sep; + } + ss << *it; + } + return ss.str(); + } } #endif /* ValidatorUtils_h */ diff --git a/mlmodel/src/Validation/Validators.hpp b/mlmodel/src/Validation/Validators.hpp index d78c50337..9d73b08c7 100644 --- a/mlmodel/src/Validation/Validators.hpp +++ b/mlmodel/src/Validation/Validators.hpp @@ -11,6 +11,7 @@ #include "Format.hpp" #include "Result.hpp" +#include "Globals.hpp" #include "../../build/format/Model_enums.h" #include "ValidatorUtils-inl.hpp" @@ -32,49 +33,115 @@ namespace CoreML { */ template Result validate(const Specification::Model& format); - - /* - * Validate feature descriptions in interface have supported names and type info - * - * @param interface Model interface - # @param modelVersion The version of the model for backwards compatibility - * @return Result type of this operation. - */ - Result validateFeatureDescriptions(const Specification::ModelDescription& interface, int modelVersion); + struct ValidationPolicy { + ValidationPolicy() + :allowsEmptyInput(false), + allowsEmptyOutput(false), + allowsMultipleFunctions(false), + allowsStatefulPrediction(false) {} + + /* + * Initializes the policy based on the model type. + */ + ValidationPolicy(MLModelType modelType); + + bool allowsEmptyInput; + bool allowsEmptyOutput; + bool allowsMultipleFunctions; + bool allowsStatefulPrediction; + }; + + enum class FeatureIOType { + INPUT, + OUTPUT, + STATE, + }; /* * Validate an individual feature description * - * @param feature description + * @param feature description # @param modelVersion The version of the model for backwards compatibility * @return Result type of this operation. */ - Result validateFeatureDescription(const Specification::FeatureDescription& desc, int modelVersion, bool isInput = true); + Result validateFeatureDescription(const Specification::FeatureDescription& desc, int modelVersion, FeatureIOType featureIOType = FeatureIOType::INPUT); /* * Validate model interface describes a valid transform * * @param interface Model interface # @param modelVersion The version of the model for backwards compatibility + * @param validationPolicy The validation policy. * @return Result type of this operation. */ - Result validateModelDescription(const Specification::ModelDescription& interface, int modelVersion); + Result validateModelDescription(const Specification::ModelDescription& interface, int modelVersion, const ValidationPolicy& validationPolicy = ValidationPolicy()); /* * Validate model interface describes a valid regressor * * @param interface Model interface + * @param validationPolicy The validation policy. * @return Result type of this operation. */ - Result validateRegressorInterface(const Specification::ModelDescription& interface, int modelVersion); + Result validateRegressorInterface(const Specification::ModelDescription& interface, int modelVersion, const ValidationPolicy& validationPolicy = ValidationPolicy()); /* - * Validate classifier output feature descriptions. + * Validate classifier output feature descriptions of the model. */ Result validateClassifierFeatureDescriptions(const Specification::ModelDescription& interface, bool expected_class_is_int64); + /* + * Validate classifier output feature descriptions of a function model. + */ + Result validateClassifierFeatureDescriptions(const Specification::FunctionDescription& interface, + bool expected_class_is_int64); + + /* + * Validate feature descriptions in interface or function descriptions have supported names and type info. + * + * @param interface Model or Function interface. + * @param modelVersion The version of the model for backwards compatibility. + * @param validationPolicy The validation policy. + * @return Result type of this operation. + */ + template + inline Result validateFeatureDescriptions(const Description& interface, int modelVersion, const ValidationPolicy& validationPolicy) { + if (interface.input_size() < 1) { + if (!validationPolicy.allowsEmptyInput) { + return Result(ResultType::MODEL_TYPE_DOES_NOT_SUPPORT_EMPTY_INPUT, "Models must have one or more inputs."); + } + + if (modelVersion < MLMODEL_SPECIFICATION_VERSION_IOS18) { + return Result(ResultType::INVALID_COMPATIBILITY_VERSION, + "Empty input is only valid in specification verison >= " + std::to_string(MLMODEL_SPECIFICATION_VERSION_IOS18)+ + ". This model has version " + std::to_string(modelVersion)); + } + } + + if (!validationPolicy.allowsEmptyOutput && interface.output_size() < 1) { + return Result(ResultType::INVALID_MODEL_INTERFACE, "Models must have one or more outputs."); + } + + for (const auto& input : interface.input()) { + Result r = validateFeatureDescription(input, modelVersion, FeatureIOType::INPUT); + if (!r.good()) { return r; } + } + + for (const auto& output : interface.output()) { + Result r = validateFeatureDescription(output, modelVersion, FeatureIOType::OUTPUT); + if (!r.good()) { return r; } + } + + for (const auto& state : interface.state()) { + Result r = validateFeatureDescription(state, modelVersion, FeatureIOType::STATE); + if (!r.good()) { return r; } + } + + // If we got here, all inputs/outputs seem good independently of each other. + return Result(); + } /* * Validate model interface describes a valid classifier @@ -86,10 +153,11 @@ namespace CoreML { Result validateClassifierInterface(const T& model, const U& modelParameters, const bool allowEmptyLabels = false, - const bool defaultClassLabelIsInt64 = false) { - + const bool defaultClassLabelIsInt64 = false, + const ValidationPolicy& validationPolicy = ValidationPolicy()) { + bool expected_class_is_int64; - + // validate class labels switch (modelParameters.ClassLabels_case()) { case U::kInt64ClassLabels: @@ -97,31 +165,31 @@ namespace CoreML { return Result(ResultType::INVALID_MODEL_PARAMETERS, "Classifier declared to have Int64 class labels must provide labels."); } - + if(modelParameters.stringclasslabels().vector_size() != 0) { return Result(ResultType::INVALID_MODEL_PARAMETERS, "Classifier declared with Int64 class labels must provide exclusively Int64 class labels."); } - + expected_class_is_int64 = true; - + break; - + case U::kStringClassLabels: if (!allowEmptyLabels && modelParameters.stringclasslabels().vector_size() == 0) { return Result(ResultType::INVALID_MODEL_PARAMETERS, "Classifier declared to have String class labels must provide labels."); } - + if(modelParameters.int64classlabels().vector_size() != 0) { return Result(ResultType::INVALID_MODEL_PARAMETERS, "Classifier declared with String class labels must provide exclusively String class labels."); } - + expected_class_is_int64 = false; - + break; - + case U::CLASSLABELS_NOT_SET: if (!allowEmptyLabels) { return Result(ResultType::INVALID_MODEL_PARAMETERS, "Classifier models must provide class labels."); @@ -130,23 +198,23 @@ namespace CoreML { break; } const Specification::ModelDescription& interface = model.description(); - + // Validate feature descriptions - Result result = validateFeatureDescriptions(interface, model.specificationversion()); + Result result = validateFeatureDescriptions(interface, model.specificationversion(), validationPolicy); if (!result.good()) { return result; } return validateClassifierFeatureDescriptions(interface, expected_class_is_int64); } - + /* * Validate optional inputs/outputs. * For most models, optional is not allowed (all inputs/outputs required). * Some models have different behavior. */ Result validateOptional(const Specification::Model& format); - + /* * Validate if the model type can be set to updatable. */ diff --git a/mlmodel/src/transforms/LinearModel.hpp b/mlmodel/src/transforms/LinearModel.hpp index d27a15a26..0bac14378 100644 --- a/mlmodel/src/transforms/LinearModel.hpp +++ b/mlmodel/src/transforms/LinearModel.hpp @@ -20,7 +20,7 @@ class LinearModel : public Model { const std::string& description); LinearModel(const Model &model); - + /** * Set the weights. * @@ -36,7 +36,7 @@ class LinearModel : public Model { * @return Result type with errors. */ Result setOffsets(std::vector offsets); - + /** * Get offsets/intercepts. * @@ -50,7 +50,7 @@ class LinearModel : public Model { * @return weights. */ std::vector< std::vector> getWeights(); - + }; } diff --git a/mlmodel/src/transforms/LogisticModel.hpp b/mlmodel/src/transforms/LogisticModel.hpp index 6d9aff26d..3170ca4bb 100644 --- a/mlmodel/src/transforms/LogisticModel.hpp +++ b/mlmodel/src/transforms/LogisticModel.hpp @@ -16,7 +16,7 @@ namespace CoreML { * Reader/Writer interface for a GLM. * * A construction class that, in the end, outputs a properly constructed - * specification that is guaranteed to load in an LinearModelSpec class. + * specification that is gauranteed to load in an LinearModelSpec class. * */ class EXPORT LogisticModel : public Model { diff --git a/mlmodel/src/transforms/TreeEnsemble.hpp b/mlmodel/src/transforms/TreeEnsemble.hpp index 02065202e..21c0e474d 100644 --- a/mlmodel/src/transforms/TreeEnsemble.hpp +++ b/mlmodel/src/transforms/TreeEnsemble.hpp @@ -94,12 +94,12 @@ namespace CoreML { void setRelativeNodeHitRate(size_t treeId, size_t nodeId, double v); /** - * Missing values can either track the path of the "true" child or the "false" child. By - * default, they always travel down the false path. Set this to alter this behavior for a + * Missing values can either track the path of the "true" child or the "false" child. By + * default, they always travel down the false path. Set this to alter this behavior for a * given node. */ void setMissingValueBehavior(size_t treeId, size_t nodeId, bool missing_value_tracks_true_child); - + /** * If this is called, a node is created that is marked as a leaf evaluation node, * which means that when this node is triggered, a value is added to the @@ -182,15 +182,15 @@ namespace CoreML { */ TreeEnsembleRegressor(const std::string& predictedValueOutput, const std::string& description = ""); - + void setPostEvaluationTransform(PostEvaluationTransform transform); - + virtual ~TreeEnsembleRegressor(); - + private: Specification::TreeEnsembleRegressor* tree_regressor_parameters; }; - + typedef TreeEnsembleBase::BranchMode BranchMode; typedef TreeEnsembleBase::PostEvaluationTransform PostEvaluationTransform; } diff --git a/mlmodel/tests/InterfaceTests.cpp b/mlmodel/tests/InterfaceTests.cpp index 71e0f3e75..b863a3ea8 100644 --- a/mlmodel/tests/InterfaceTests.cpp +++ b/mlmodel/tests/InterfaceTests.cpp @@ -21,6 +21,27 @@ static Specification::Model& addOptionalField(Specification::Model& model, const return model; } +static void setupMultiArrayFeature(Specification::FeatureDescription *feature, const std::string& name) { + feature->set_name(name); + feature->mutable_type()->mutable_multiarraytype()->set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType_FLOAT32); + feature->mutable_type()->mutable_multiarraytype()->add_shape(1); +} + +static void setupStateFeature(Specification::FeatureDescription *feature, const std::string& name) { + feature->set_name(name); + feature->mutable_type()->mutable_statetype()->mutable_arraytype()->set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType_FLOAT16); + feature->mutable_type()->mutable_statetype()->mutable_arraytype()->add_shape(1); +} + +static ValidationPolicy validationPolicyForStateTests() { + auto validationPolicy = ValidationPolicy(); + validationPolicy.allowsEmptyInput = true; + validationPolicy.allowsEmptyOutput = true; + validationPolicy.allowsMultipleFunctions = true; + validationPolicy.allowsStatefulPrediction = true; + return validationPolicy; +} + int testOptionalInputs() { // Test that all fields are required on a random model (normalizer) Specification::Model m1; @@ -62,92 +83,92 @@ int testFeatureDescriptions() { Specification::Model m; auto *feature = m.mutable_description()->add_input(); - ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); // Just with a name as invalid feature->set_name("test_input"); - ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); // Empty type, still invalid feature->mutable_type(); - ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); // Int64 typ, now its valid feature->mutable_type()->mutable_int64type(); - ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); // String type, valid feature->mutable_type()->mutable_stringtype(); - ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); // Double type, valid feature->mutable_type()->mutable_doubletype(); - ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); // Multiarray type, with no params, invalid feature->mutable_type()->mutable_multiarraytype(); - ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); // Multiarray type, double with no shape, invalid as input, valid as output feature->mutable_type()->mutable_multiarraytype()->set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType_DOUBLE); - ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); feature->mutable_type()->mutable_multiarraytype()->set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType_DOUBLE); - ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, false)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::OUTPUT)); feature->mutable_type()->mutable_multiarraytype()->set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType_FLOAT32); - ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); feature->mutable_type()->mutable_multiarraytype()->set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType_FLOAT32); - ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, false)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::OUTPUT)); feature->mutable_type()->mutable_multiarraytype()->set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType_INT32); - ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); feature->mutable_type()->mutable_multiarraytype()->set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType_INT32); - ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, false)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::OUTPUT)); // Zero length shape is invalid for inputs, but valid for outputs feature->mutable_type()->mutable_multiarraytype()->mutable_shape(); - ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); - ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, false)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::OUTPUT)); // Non-zero length shape, valid feature->mutable_type()->mutable_multiarraytype()->mutable_shape()->Add(128); - ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, true)); - ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, false)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::INPUT)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS11_2, FeatureIOType::OUTPUT)); // Dictionary, with no params, invalid feature->mutable_type()->mutable_dictionarytype(); - ML_ASSERT_BAD(validateFeatureDescription(*feature,true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_NEWEST, FeatureIOType::INPUT)); // With key type, valid feature->mutable_type()->mutable_dictionarytype()->mutable_stringkeytype(); - ML_ASSERT_GOOD(validateFeatureDescription(*feature,true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_NEWEST, FeatureIOType::INPUT)); feature->mutable_type()->mutable_dictionarytype()->mutable_int64keytype(); - ML_ASSERT_GOOD(validateFeatureDescription(*feature,true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_NEWEST, FeatureIOType::INPUT)); // Image, with no params, invalid feature->mutable_type()->mutable_imagetype(); - ML_ASSERT_BAD(validateFeatureDescription(*feature,true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_NEWEST, FeatureIOType::INPUT)); // With just width, invalid feature->mutable_type()->mutable_imagetype()->set_width(10); - ML_ASSERT_BAD(validateFeatureDescription(*feature,true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_NEWEST, FeatureIOType::INPUT)); // With both width and height, still invalid because no colorspace feature->mutable_type()->mutable_imagetype()->set_height(20); - ML_ASSERT_BAD(validateFeatureDescription(*feature,true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_NEWEST, FeatureIOType::INPUT)); // Now with colorspace, valid feature->mutable_type()->mutable_imagetype()->set_colorspace(::CoreML::Specification::ImageFeatureType_ColorSpace_BGR); - ML_ASSERT_GOOD(validateFeatureDescription(*feature,true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_NEWEST, FeatureIOType::INPUT)); feature->mutable_type()->mutable_imagetype()->set_colorspace(::CoreML::Specification::ImageFeatureType_ColorSpace_RGB); - ML_ASSERT_GOOD(validateFeatureDescription(*feature,true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_NEWEST, FeatureIOType::INPUT)); feature->mutable_type()->mutable_imagetype()->set_colorspace(::CoreML::Specification::ImageFeatureType_ColorSpace_GRAYSCALE); - ML_ASSERT_GOOD(validateFeatureDescription(*feature,true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_NEWEST, FeatureIOType::INPUT)); feature->mutable_type()->mutable_imagetype()->set_colorspace(::CoreML::Specification::ImageFeatureType_ColorSpace_GRAYSCALE_FLOAT16); - ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS16, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_IOS16, FeatureIOType::INPUT)); feature->mutable_type()->mutable_imagetype()->set_colorspace(::CoreML::Specification::ImageFeatureType_ColorSpace_INVALID_COLOR_SPACE); - ML_ASSERT_BAD(validateFeatureDescription(*feature,true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature, MLMODEL_SPECIFICATION_VERSION_NEWEST, FeatureIOType::INPUT)); ////////////////////////////////// // Test more recent shape constraints @@ -161,47 +182,47 @@ int testFeatureDescriptions() { // Make fixed size 6 x 5 feature2->mutable_type()->mutable_imagetype()->set_width(6); feature2->mutable_type()->mutable_imagetype()->set_height(5); - ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); /// Enumerated // Add flexibility of a single enumerated size 6 x 5 auto *shape = feature2->mutable_type()->mutable_imagetype()->mutable_enumeratedsizes()->add_sizes(); shape->set_width(6); shape->set_height(5); - ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); // Reset that to a single 10 x 5 which would make the 6 x 5 invalid! shape->set_width(10); shape->set_height(5); - ML_ASSERT_BAD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); // Add 6 x 5 to the list so its now [10x5, 6 x 5] which should make it valid again shape = feature2->mutable_type()->mutable_imagetype()->mutable_enumeratedsizes()->add_sizes(); shape->set_width(6); shape->set_height(5); - ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); /// Range - // Now make it a range that includes 6 x 5 + // Now make it a range that inclues 6 x 5 auto* size_range = feature2->mutable_type()->mutable_imagetype()->mutable_imagesizerange(); size_range->mutable_widthrange()->set_lowerbound(1); size_range->mutable_widthrange()->set_upperbound(-1); // unbounded size_range->mutable_heightrange()->set_lowerbound(2); size_range->mutable_heightrange()->set_upperbound(5); - ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); // Now make the range not include 6 x 5 size_range->mutable_widthrange()->set_lowerbound(7); - ML_ASSERT_BAD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); // Fix it to include it again size_range->mutable_widthrange()->set_lowerbound(2); - ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); // Fail due to upper bound can't be larger than lower size_range->mutable_widthrange()->set_upperbound(1); - ML_ASSERT_BAD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); ///////////// @@ -212,7 +233,7 @@ int testFeatureDescriptions() { // 10 x 5 default size array_type->add_shape(10); array_type->add_shape(5); - ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); // Range // Now specify ranges (>1 x [5...20]) @@ -223,11 +244,11 @@ int testFeatureDescriptions() { auto rangeForDim1 = array_type->mutable_shaperange()->add_sizeranges(); rangeForDim1->set_lowerbound(5); rangeForDim1->set_upperbound(20); - ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); // Change to (>1 x [6..20]) which is not consistent with 10 x 5 rangeForDim1->set_lowerbound(6); - ML_ASSERT_BAD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); // Enumerated auto eshape1 = array_type->mutable_enumeratedshapes()->add_shapes(); @@ -235,14 +256,260 @@ int testFeatureDescriptions() { eshape1->add_shape(2); // Now allow [ 6x2 ] which is inconsistent with default 10 x 5 - ML_ASSERT_BAD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_BAD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); // Add another one to make the set [6x2 , 10x5] which is consistent auto eshape2 = array_type->mutable_enumeratedshapes()->add_shapes(); eshape2->add_shape(10); eshape2->add_shape(5); - ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, true)); + ML_ASSERT_GOOD(validateFeatureDescription(*feature2, MLMODEL_SPECIFICATION_VERSION, FeatureIOType::INPUT)); + + return 0; +} + +int testMultiFunctionSpecificationVersion() { + Specification::Model m; + + auto validationPolicy = ValidationPolicy(); + validationPolicy.allowsEmptyInput = true; + validationPolicy.allowsEmptyOutput = false; + validationPolicy.allowsMultipleFunctions = true; + + auto *description = m.mutable_description(); + description->set_defaultfunctionname("foo"); + auto *function = description->add_functions(); + function->set_name("foo"); + + setupMultiArrayFeature(function->add_input(), "x"); + setupMultiArrayFeature(function->add_output(), "y"); + + // Check model specification version requirements. + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS17, validationPolicy), ResultType::INVALID_COMPATIBILITY_VERSION); + ML_ASSERT_GOOD(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy)); + + return 0; +} + +int testMultiFunctionDefaultFunctionName() { + Specification::Model m; + + auto validationPolicy = ValidationPolicy(); + validationPolicy.allowsEmptyInput = true; + validationPolicy.allowsEmptyOutput = true; + validationPolicy.allowsMultipleFunctions = true; + + auto *description = m.mutable_description(); + description->set_defaultfunctionname("foo"); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy), ResultType::INVALID_DEFAULT_FUNCTION_NAME); + + auto *function = description->add_functions(); + function->set_name("bar"); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy), ResultType::INVALID_DEFAULT_FUNCTION_NAME); + + function->set_name("foo"); + ML_ASSERT_GOOD(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy)); + + return 0; +} + +int testMultiFunctionTopLevelFeatureDescriptionsMustBeEmpty() { + Specification::Model m; + + auto validationPolicy = ValidationPolicy(); + validationPolicy.allowsEmptyInput = true; + validationPolicy.allowsEmptyOutput = true; + validationPolicy.allowsMultipleFunctions = true; + + auto *description = m.mutable_description(); + description->set_defaultfunctionname("foo"); + auto *function = description->add_functions(); + function->set_name("foo"); + + ML_ASSERT_GOOD(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy)); + + description->add_input()->set_name("x"); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + description->clear_input(); + + description->add_output()->set_name("y"); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + description->clear_output(); + + description->add_traininginput()->set_name("z"); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + description->clear_traininginput(); + + description->add_state()->set_name("s"); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + description->clear_state(); + + description->set_predictedfeaturename("f"); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + description->clear_predictedfeaturename(); + + description->set_predictedprobabilitiesname("f"); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + description->clear_predictedprobabilitiesname(); + + return 0; +} + +int testMultiFunctionEmptyInput() { + Specification::Model m; + + auto validationPolicy = ValidationPolicy(); + validationPolicy.allowsEmptyInput = true; + validationPolicy.allowsEmptyOutput = true; + validationPolicy.allowsMultipleFunctions = true; + + auto *description = m.mutable_description(); + description->set_defaultfunctionname("foo"); + auto *function = description->add_functions(); + function->set_name("foo"); + + ML_ASSERT_GOOD(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy)); + + validationPolicy.allowsEmptyInput = false; + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy), ResultType::MODEL_TYPE_DOES_NOT_SUPPORT_EMPTY_INPUT); + + return 0; +} + +int testMultiFunctionAllowed() { + Specification::Model m; + + auto validationPolicy = ValidationPolicy(); + validationPolicy.allowsEmptyInput = true; + validationPolicy.allowsEmptyOutput = true; + validationPolicy.allowsMultipleFunctions = true; + + auto *description = m.mutable_description(); + description->set_defaultfunctionname("foo"); + auto *function = description->add_functions(); + function->set_name("foo"); + + ML_ASSERT_GOOD(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy)); + + validationPolicy.allowsMultipleFunctions = false; + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy), ResultType::MODEL_TYPE_DOES_NOT_SUPPORT_MULTI_FUNCTION); + + return 0; +} + +int testStateSpecificationVersion() { + Specification::Model m; + auto *description = m.mutable_description(); + setupStateFeature(description->add_state(), "x"); + + // Check model specification version requirements. + auto validationPolicy = validationPolicyForStateTests(); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS17, validationPolicy), ResultType::INVALID_COMPATIBILITY_VERSION); + ML_ASSERT_GOOD(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_IOS18, validationPolicy)); + + return 0; +} + +/// For backward compabitility reason, it is OK to declare state features in the input descriptions. +int testStateFeatureDescriptionInInputs() { + Specification::Model m; + auto *description = m.mutable_description(); + + setupStateFeature(description->add_input(), "x"); + + // Check model specification version requirements. + auto validationPolicy = validationPolicyForStateTests(); + ML_ASSERT_GOOD(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_NEWEST, validationPolicy)); + + return 0; +} + +int testStateFeatureIsNotFP16_shouldFail() { + Specification::Model m; + auto *description = m.mutable_description(); + auto *state = description->add_state(); + + setupStateFeature(state, "x"); + state->mutable_type()->mutable_statetype()->mutable_arraytype()->set_datatype(::CoreML::Specification::ArrayFeatureType_ArrayDataType_FLOAT32); + + // Check model specification version requirements. + auto validationPolicy = validationPolicyForStateTests(); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_NEWEST, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + + return 0; +} + +int testStateFeatureIsOptional_shouldFail() { + Specification::Model m; + auto *description = m.mutable_description(); + auto *state = description->add_state(); + + setupStateFeature(state, "x"); + state->mutable_type()->set_isoptional(true); + + // Check model specification version requirements. + auto validationPolicy = validationPolicyForStateTests(); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_NEWEST, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + + return 0; +} + +int testStateFeatureHasNoDefaultShape_shouldFail() { + Specification::Model m; + auto *description = m.mutable_description(); + auto *state = description->add_state(); + + setupStateFeature(state, "x"); + state->mutable_type()->mutable_statetype()->mutable_arraytype()->clear_shape(); + + // Check model specification version requirements. + auto validationPolicy = validationPolicyForStateTests(); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_NEWEST, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + + return 0; +} + +int testStateFeatureHasNoArrayType_shouldFail() { + Specification::Model m; + auto *description = m.mutable_description(); + auto *state = description->add_state(); + + setupStateFeature(state, "x"); + state->mutable_type()->mutable_statetype()->clear_arraytype(); + + // Check model specification version requirements. + auto validationPolicy = validationPolicyForStateTests(); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_NEWEST, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + + return 0; +} + +int testStateFeature_ArrayUsesRangeFlexibleShape_shouldFail() { + Specification::Model m; + auto *description = m.mutable_description(); + auto *state = description->add_state(); + + setupStateFeature(state, "x"); + state->mutable_type()->mutable_statetype()->mutable_arraytype()->mutable_shaperange(); + + // Check model specification version requirements. + auto validationPolicy = validationPolicyForStateTests(); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_NEWEST, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); + + return 0; +} + +int testStateFeature_ArrayUsesEnumeratedFlexibleShape_shouldFail() { + Specification::Model m; + auto *description = m.mutable_description(); + auto *state = description->add_state(); + + setupStateFeature(state, "x"); + state->mutable_type()->mutable_statetype()->mutable_arraytype()->mutable_enumeratedshapes(); + + // Check model specification version requirements. + auto validationPolicy = validationPolicyForStateTests(); + ML_ASSERT_BAD_WITH_TYPE(validateModelDescription(*description, MLMODEL_SPECIFICATION_VERSION_NEWEST, validationPolicy), ResultType::INVALID_MODEL_INTERFACE); return 0; } diff --git a/mlmodel/tests/MILBlob/AutoDeleteTempFile.cpp b/mlmodel/tests/MILBlob/AutoDeleteTempFile.cpp index e9a398438..d5347104f 100644 --- a/mlmodel/tests/MILBlob/AutoDeleteTempFile.cpp +++ b/mlmodel/tests/MILBlob/AutoDeleteTempFile.cpp @@ -9,10 +9,10 @@ #include #include #include -#include -#include #include #include +#include +#include using namespace MILBlob::TestUtil; diff --git a/mlmodel/tests/MILBlob/AutoDeleteTempFile.hpp b/mlmodel/tests/MILBlob/AutoDeleteTempFile.hpp index 4ba6bc557..7dd974936 100644 --- a/mlmodel/tests/MILBlob/AutoDeleteTempFile.hpp +++ b/mlmodel/tests/MILBlob/AutoDeleteTempFile.hpp @@ -12,8 +12,7 @@ namespace TestUtil { class AutoDeleteTempFile { public: - enum FileType - { + enum FileType { FILE = 0, DIR = 1 }; diff --git a/mlmodel/tests/MILBlob/BlobUtils.cpp b/mlmodel/tests/MILBlob/BlobUtils.cpp index 0f746deff..e6d17c310 100644 --- a/mlmodel/tests/MILBlob/BlobUtils.cpp +++ b/mlmodel/tests/MILBlob/BlobUtils.cpp @@ -162,7 +162,199 @@ AutoDeleteTempFile MakeStorageTempFileWith3Records() // 896 BYTES // DATA 7 0xe8d0, 0x007e, 0x0000, 0x0000, + // Padding + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + // 960 BYTES + // METADATA 8 + 0xBEEF, 0xDEAD, 0x0008, 0x0000, // sentinel=0xDEADBEEF, mil_dtype=int4 + 0x0004, 0x0000, 0x0000, 0x0000, // sizeInBytes=4 bytes + 0x0400, 0x0000, 0x0000, 0x0000, // offset + 0x0000, 0x0000, 0x0000, 0x0000, // padding_size_in_bits + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_1 + 0x0001, 0x0000, 0x0000, 0x0000, // reserved_2 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_3 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_4 + // 1024 BYTES + // DATA 8 + 0xe8d0, 0x007e, 0x0000, 0x0000, + // Padding + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + // 1088 BYTES + // METADATA 9 + 0xBEEF, 0xDEAD, 0x000B, 0x0000, // sentinel=0xDEADBEEF, mil_dtype=uint4 + 0x0003, 0x0000, 0x0000, 0x0000, // sizeInBytes=3 bytes + 0x0480, 0x0000, 0x0000, 0x0000, // offset + 0x0004, 0x0000, 0x0000, 0x0000, // padding_size_in_bits = 4 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_1 + 0x0001, 0x0000, 0x0000, 0x0000, // reserved_2 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_3 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_4 + // 1152 BYTES + // DATA 9 + 0xe8d1, 0x107c, 0x0000, 0x0000, + + // Padding + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + // 1216 BYTES + // METADATA 10 + 0xBEEF, 0xDEAD, 0x0009, 0x0000, // sentinel=0xDEADBEEF, mil_dtype=uint1 + 0x0003, 0x0000, 0x0000, 0x0000, // sizeInBytes=3 bytes + 0x0500, 0x0000, 0x0000, 0x0000, // offset + 0x0007, 0x0000, 0x0000, 0x0000, // padding_size_in_bits = 7 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_1 + 0x0001, 0x0000, 0x0000, 0x0000, // reserved_2 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_3 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_4 + // 1280 BYTES + // DATA 10 + 0xEC24, 0xFFF7, 0x0000, 0x0000, + + // Padding + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + // 1344 BYTES + // METADATA 11 + 0xBEEF, 0xDEAD, 0x000A, 0x0000, // sentinel=0xDEADBEEF, mil_dtype=uint2 + 0x0002, 0x0000, 0x0000, 0x0000, // sizeInBytes=2 bytes + 0x0580, 0x0000, 0x0000, 0x0000, // offset + 0x0002, 0x0000, 0x0000, 0x0000, // padding_size_in_bits = 2 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_1 + 0x0001, 0x0000, 0x0000, 0x0000, // reserved_2 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_3 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_4 + // 1408 BYTES + // DATA 11 + 0xEC24, 0x0000, 0x0000, 0x0000, + + // Padding + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + // 1472 BYTES + // METADATA 12 + 0xBEEF, 0xDEAD, 0x000C, 0x0000, // sentinel=0xDEADBEEF, mil_dtype=uint3 + 0x0004, 0x0000, 0x0000, 0x0000, // sizeInBytes=4 bytes + 0x0600, 0x0000, 0x0000, 0x0000, // offset + 0x0005, 0x0000, 0x0000, 0x0000, // padding_size_in_bits = 5 (9 elements) + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_1 + 0x0001, 0x0000, 0x0000, 0x0000, // reserved_2 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_3 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_4 + // 1536 BYTES + // DATA 12 + 0xEC24, 0x1D45, 0x0000, 0x0000, + // Padding + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + // 1600 BYTES + // METADATA 13 + 0xBEEF, 0xDEAD, 0x000E, 0x0000, // sentinel=0xDEADBEEF, mil_dtype=int32 + 0x0004, 0x0000, 0x0000, 0x0000, // sizeInBytes=4 bytes + 0x0680, 0x0000, 0x0000, 0x0000, // offset + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_0 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_1 + 0x0001, 0x0000, 0x0000, 0x0000, // reserved_2 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_3 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_4 + // 1664 BYTES + // DATA 13 + 0x0C24, 0x0000, 0x0000, 0x0000, + // Padding + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + // 1728 BYTES + // METADATA 14 + 0xBEEF, 0xDEAD, 0x000F, 0x0000, // sentinel=0xDEADBEEF, mil_dtype=uint32 + 0x0008, 0x0000, 0x0000, 0x0000, // sizeInBytes=8 bytes + 0x0700, 0x0000, 0x0000, 0x0000, // offset + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_0 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_1 + 0x0001, 0x0000, 0x0000, 0x0000, // reserved_2 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_3 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_4 + // 1792 BYTES + // DATA 14 + 0x0C24, 0x0000, 0xBEEF, 0x0000, + // Padding + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + // 1856 BYTES + // METADATA 15 + 0xBEEF, 0xDEAD, 0x0011, 0x0000, // sentinel=0xDEADBEEF, mil_dtype=Fp8E5M2 + 0x0004, 0x0000, 0x0000, 0x0000, // sizeInBytes=4 bytes + 0x0780, 0x0000, 0x0000, 0x0000, // offset + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_0 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_1 + 0x0001, 0x0000, 0x0000, 0x0000, // reserved_2 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_3 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_4 + // 1920 BYTES + // DATA 15 + 0xEFBE, 0xCA00, 0x0000, 0x0000, + // Padding + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, + // 1984 BYTES + // METADATA 16 + 0xBEEF, 0xDEAD, 0x0010, 0x0000, // sentinel=0xDEADBEEF, mil_dtype=Fp8E4M3 + 0x0004, 0x0000, 0x0000, 0x0000, // sizeInBytes=4 bytes + 0x0780, 0x0000, 0x0000, 0x0000, // offset + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_0 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_1 + 0x0001, 0x0000, 0x0000, 0x0000, // reserved_2 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_3 + 0x0000, 0x0000, 0x0000, 0x0000, // reserved_4 + // 2048 BYTES + // DATA 15 + 0xEFBE, 0xCA00, 0x0000, 0x0000, }; // clang-format on diff --git a/mlmodel/tests/MILBlob/FileWriterTests.cpp b/mlmodel/tests/MILBlob/FileWriterTests.cpp index 4db7f5cd7..59c9bd820 100644 --- a/mlmodel/tests/MILBlob/FileWriterTests.cpp +++ b/mlmodel/tests/MILBlob/FileWriterTests.cpp @@ -198,4 +198,3 @@ int testFileWriterTestsReadData() return 0; } - diff --git a/mlmodel/tests/MILBlob/MMapFileReaderTests.cpp b/mlmodel/tests/MILBlob/MMapFileReaderTests.cpp index 7e3c4646a..01fa8c485 100644 --- a/mlmodel/tests/MILBlob/MMapFileReaderTests.cpp +++ b/mlmodel/tests/MILBlob/MMapFileReaderTests.cpp @@ -109,4 +109,3 @@ int testMMapFileReaderTestsReadStruct() return 0; } - diff --git a/mlmodel/tests/MILBlob/SpanCastTests.cpp b/mlmodel/tests/MILBlob/SpanCastTests.cpp index 2de82c8d0..c326ba6ec 100644 --- a/mlmodel/tests/MILBlob/SpanCastTests.cpp +++ b/mlmodel/tests/MILBlob/SpanCastTests.cpp @@ -3,6 +3,7 @@ // Use of this source code is governed by a BSD-3-clause license that can be // found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +#include "MILBlob/SubByteTypes.hpp" #include "MILBlob/Util/SpanCast.hpp" #include "framework/TestUtils.hpp" #include "MLModelTests.hpp" @@ -36,3 +37,36 @@ int testSpanCastTestsBasics() return 0; } +int testSpanCastTestsToInt4() +{ + std::vector v; + v.emplace_back(0x20); + v.emplace_back(0x64); + Span span = MakeSpan(v); + + // Valid casts. + CastToBitSpan(span, 3); + CastToBitSpan(span, 4); + + // Invalid due to size being too short or too long. + ML_ASSERT_THROWS(CastToBitSpan(span, 2), std::invalid_argument); + ML_ASSERT_THROWS(CastToBitSpan(span, 5), std::invalid_argument); + + return 0; +} + +int testSpanCastTestsFromInt4() +{ + std::vector v; + v.emplace_back(0x20); + v.emplace_back(0x64); + Span span = MakeSpan(v); + { + Span int4Span = CastToBitSpan(span, 3); + ML_ASSERT_EQ(int4Span.Size(), 3); + Span uint8Span = CastFromBitSpan(int4Span); + ML_ASSERT_EQ(uint8Span.Size(), 2); + } + + return 0; +} diff --git a/mlmodel/tests/MILBlob/SpanTests.cpp b/mlmodel/tests/MILBlob/SpanTests.cpp index ca3aa2d93..946024df0 100644 --- a/mlmodel/tests/MILBlob/SpanTests.cpp +++ b/mlmodel/tests/MILBlob/SpanTests.cpp @@ -3,7 +3,10 @@ // Use of this source code is governed by a BSD-3-clause license that can be // found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +#include "MILBlob/SubByteTypes.hpp" #include "MILBlob/Util/Span.hpp" +#include "MILBlob/Util/SpanCast.hpp" +#include "MILBlob/Util/SubByteConversionUtils.hpp" #include "framework/TestUtils.hpp" #include "MLModelTests.hpp" @@ -656,3 +659,216 @@ int testSpanTestsIterationMultipleDims() return 0; } +int testSpanTestsInt4() +{ + std::vector v; + v.emplace_back(0x20); + v.emplace_back(0x64); + Span span = MakeSpan(v); + + Span int4Span = CastToBitSpan(span, 4); + + for (size_t i = 0; i < int4Span.Size(); ++i) { + uint8_t byteBlock = ((uint8_t*)int4Span.Data())[i / 2]; + Int4 value; + if (i % 2) { + value.SetInt((byteBlock & 0xf0) >> 4); + } else { + value.SetInt(byteBlock & 0x0f); + } + ML_ASSERT_EQ(value, Int4(static_cast(i * 2))); + } + + return 0; +} + +int testSpanTestsSubbyteIntValueAt() +{ + std::vector v; + v.emplace_back(0x20); + v.emplace_back(0x9B); + v.emplace_back(0x08); + Span span = MakeSpan(v); + + Span int4Span = CastToBitSpan(span, 5); + + ML_ASSERT_EQ(int4Span.ValueAt(0), Int4(0)); + ML_ASSERT_EQ(int4Span.ValueAt(1), Int4(2)); + ML_ASSERT_EQ(int4Span.ValueAt(2), Int4(-5)); + ML_ASSERT_EQ(int4Span.ValueAt(3), Int4(-7)); + ML_ASSERT_EQ(int4Span.ValueAt(4), Int4(-8)); + ML_ASSERT_THROWS(int4Span.ValueAt(7), std::out_of_range); + + return 0; +} + +int testSpanTestsSubByteUIntValueAt() +{ + { + const std::vector v = {0x20, 0x9B, 0x0F}; + Span span = MakeSpan(v); + + Span uint4Span = CastToBitSpan(span, 5); + + ML_ASSERT_EQ(uint4Span.ValueAt(0), UInt4(0)); + ML_ASSERT_EQ(uint4Span.ValueAt(1), UInt4(2)); + ML_ASSERT_EQ(uint4Span.ValueAt(2), UInt4(11)); + ML_ASSERT_EQ(uint4Span.ValueAt(3), UInt4(9)); + ML_ASSERT_EQ(uint4Span.ValueAt(4), UInt4(15)); + ML_ASSERT_THROWS(uint4Span.ValueAt(5), std::out_of_range); + } + + { + std::vector v; + v.emplace_back(0x20); + v.emplace_back(0x9B); + v.emplace_back(0x0F); + Span span = MakeSpan(v); + + Span uint2Span = CastToBitSpan(span, 9); + + ML_ASSERT_EQ(uint2Span.ValueAt(0), UInt2(0)); + ML_ASSERT_EQ(uint2Span.ValueAt(2), UInt2(2)); + ML_ASSERT_EQ(uint2Span.ValueAt(4), UInt2(3)); + ML_ASSERT_EQ(uint2Span.ValueAt(6), UInt2(1)); + ML_ASSERT_EQ(uint2Span.ValueAt(8), UInt2(3)); + ML_ASSERT_THROWS(uint2Span.ValueAt(9), std::out_of_range); + } + + { + std::vector v = {0x9B}; + Span span = MakeSpan(v); + + Span uint1Span = CastToBitSpan(span, 8); + + ML_ASSERT_EQ(uint1Span.ValueAt(0), UInt1(1)); + ML_ASSERT_EQ(uint1Span.ValueAt(2), UInt1(0)); + ML_ASSERT_EQ(uint1Span.ValueAt(4), UInt1(1)); + ML_ASSERT_EQ(uint1Span.ValueAt(6), UInt1(0)); + ML_ASSERT_THROWS(uint1Span.ValueAt(9), std::out_of_range); + } + + { + // Bytes to Uint3 decimals decomposed from LSB to MSB + // 0xFA - 1111 1010 => 11,111,010 => 7,2 + // 0x5B - 0101 1011 => 0,101,101,1 => 5,5,7 + // 0x0E - 0000 1110 => 000,011,10 => 0,3,4 + std::vector v = {0xFA, 0x5B, 0x0E}; + Span span = MakeSpan(v); + + Span uint3Span = CastToBitSpan(span, 8); + + ML_ASSERT_EQ(uint3Span.ValueAt(0), UInt3(2)); + ML_ASSERT_EQ(uint3Span.ValueAt(2), UInt3(7)); + ML_ASSERT_EQ(uint3Span.ValueAt(4), UInt3(5)); + ML_ASSERT_EQ(uint3Span.ValueAt(5), UInt3(4)); + ML_ASSERT_THROWS(uint3Span.ValueAt(9), std::out_of_range); + } + + return 0; +} + +int testSpanTestsConstInt4() +{ + std::vector v; + v.emplace_back(0x20); + v.emplace_back(0x64); + Span span = MakeSpan(v); + + Span int4Span = CastToBitSpan(span, 4); + + for (size_t i = 0; i < int4Span.Size(); ++i) { + uint8_t byteBlock = ((uint8_t*)int4Span.Data())[i / 2]; + Int4 value; + if (i % 2) { + value.SetInt((byteBlock & 0xf0) >> 4); + } else { + value.SetInt(byteBlock & 0x0f); + } + ML_ASSERT_EQ(value, Int4(static_cast(i * 2))); + } + + { + // Test that supplied span is too small to hold the requested number of elements + Span int4SpanTooSmall; + ML_ASSERT_THROWS(int4SpanTooSmall = CastToBitSpan(span, 5), std::invalid_argument); + } + { + // Test that supplied span is too large to hold the requested number of elements + Span int4SpanTooSmall; + ML_ASSERT_THROWS(int4SpanTooSmall = CastToBitSpan(span, 2), std::invalid_argument); + } + + return 0; +} + +int testSpanTestsConstUInt4() +{ + std::vector v; + v.emplace_back(0x20); + v.emplace_back(0x64); + Span span = MakeSpan(v); + + Span uint4Span = CastToBitSpan(span, 4); + + for (size_t i = 0; i < uint4Span.Size(); ++i) { + uint8_t byteBlock = ((uint8_t*)uint4Span.Data())[i / 2]; + UInt4 value; + if (i % 2) { + value.SetInt((byteBlock & 0xf0) >> 4); + } else { + value.SetInt(byteBlock & 0x0f); + } + ML_ASSERT_EQ(value, UInt4(static_cast(i * 2))); + } + + { + // Test that supplied span is too small to hold the requested number of elements + Span uint4SpanTooSmall; + ML_ASSERT_THROWS(uint4SpanTooSmall = CastToBitSpan(span, 5), std::invalid_argument); + } + { + // Test that supplied span is too large to hold the requested number of elements + Span uint4SpanTooLarge; + ML_ASSERT_THROWS(uint4SpanTooLarge = CastToBitSpan(span, 2), std::invalid_argument); + } + + return 0; +} + +template +class SpanHasAtMethod { + struct S { + char a; + char b; + }; + template + static char Tester(decltype(&U::At)); + template + static S Tester(...); + +public: + enum { + value = sizeof(Tester(0)) == sizeof(char) + }; +}; + +int testSpanTestsSpanOverload() +{ + ML_ASSERT_EQ(true, SpanHasAtMethod>::value); + ML_ASSERT_EQ(false, SpanHasAtMethod>::value); + + ML_ASSERT_EQ(true, MILBlob::IsSubByteSized::value); + ML_ASSERT_EQ(true, MILBlob::IsSubByteSized::value); + ML_ASSERT_EQ(true, MILBlob::IsSubByteSized::value); + ML_ASSERT_EQ(true, MILBlob::IsSubByteSized::value); + ML_ASSERT_EQ(true, MILBlob::IsSubByteSized::value); + ML_ASSERT_EQ(true, MILBlob::IsSubByteSized::value); + + ML_ASSERT_EQ(true, MILBlob::SubByteIsByteAligned()); + ML_ASSERT_EQ(true, MILBlob::SubByteIsByteAligned()); + ML_ASSERT_EQ(false, MILBlob::SubByteIsByteAligned()); + ML_ASSERT_EQ(false, MILBlob::SubByteIsByteAligned()); + + return 0; +} diff --git a/mlmodel/tests/MILBlob/StorageReaderTests.cpp b/mlmodel/tests/MILBlob/StorageReaderTests.cpp index b7dbc2813..e6b873ac7 100644 --- a/mlmodel/tests/MILBlob/StorageReaderTests.cpp +++ b/mlmodel/tests/MILBlob/StorageReaderTests.cpp @@ -96,7 +96,6 @@ int testStorageReaderTestsTruncatedMetadata() StorageReader reader(tempfile.GetFilename()); ML_ASSERT_THROWS_WITH_MESSAGE(reader.GetDataView(64), std::range_error, "index out of bounds"); - return 0; } @@ -136,7 +135,6 @@ int testStorageReaderTestsTruncatedData() StorageReader reader(tempfile.GetFilename()); ML_ASSERT_THROWS_WITH_MESSAGE(reader.GetDataView(64), std::range_error, "index out of bounds"); - return 0; } @@ -276,6 +274,49 @@ int testStorageReaderTestsThreeRecords() ML_ASSERT_SPAN_EQ(data, Util::MakeSpan(expectedValues)); } + { // read Int4 weights from metadata t + auto int4Data = reader.GetDataView(960); + ML_ASSERT_EQ(int4Data.Size(), 8); + auto uint8Data = MILBlob::Util::CastFromBitSpan(int4Data); + auto data = MILBlob::Util::SpanCast(uint8Data); + + std::vector expectedValues = {uint16_t(0xe8d0), uint16_t(0x007e)}; + ML_ASSERT_SPAN_EQ(data, Util::MakeSpan(expectedValues)); + } + { + auto uint3Data = reader.GetDataView(1472); + ML_ASSERT_EQ(uint3Data.Size(), 9); + auto uint8Data = MILBlob::Util::CastFromBitSpan(uint3Data); + auto data = MILBlob::Util::SpanCast(uint8Data); + + std::vector expectedValues = {uint16_t(0xEC24), uint16_t(0x1D45)}; + ML_ASSERT_SPAN_EQ(data, Util::MakeSpan(expectedValues)); + } + { + auto int32Data = reader.GetDataView(1600); + ML_ASSERT_EQ(int32Data.Size(), 1); + std::vector expectedValues = {int32_t(0x0C24)}; + ML_ASSERT_SPAN_EQ(int32Data, Util::MakeSpan(expectedValues)); + } + { + auto uint32Data = reader.GetDataView(1728); + ML_ASSERT_EQ(uint32Data.Size(), 2); + std::vector expectedValues = {uint32_t(0x0C24), uint32_t(0xBEEF)}; + ML_ASSERT_SPAN_EQ(uint32Data, Util::MakeSpan(expectedValues)); + } + { + auto fp8E5M2Data = reader.GetDataView(1856); + ML_ASSERT_EQ(fp8E5M2Data.Size(), 4); + std::vector expectedValues = {Fp8E5M2(0xBE), Fp8E5M2(0xEF), Fp8E5M2(0x00), Fp8E5M2(0xCA)}; + ML_ASSERT_SPAN_EQ(fp8E5M2Data, Util::MakeSpan(expectedValues)); + } + { + auto fp8E4M3FNData = reader.GetDataView(1984); + ML_ASSERT_EQ(fp8E4M3FNData.Size(), 4); + std::vector expectedValues = {Fp8E4M3FN(0xBE), Fp8E4M3FN(0xEF), Fp8E4M3FN(0x00), Fp8E4M3FN(0xCA)}; + ML_ASSERT_SPAN_EQ(fp8E4M3FNData, Util::MakeSpan(expectedValues)); + } + return 0; } @@ -309,7 +350,7 @@ int testStorageReaderTestsRawData() ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValues))); } - { // read Bf16 weights from metadata 4 + { // read Bf16 weights from metadata 5 auto data = reader.GetRawDataView(576); ML_ASSERT_EQ(data.Size(), size_t(8)); @@ -317,7 +358,7 @@ int testStorageReaderTestsRawData() ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValues))); } - { // read int16_t weights from metadata 5 + { // read int16_t weights from metadata 6 auto data = reader.GetRawDataView(704); ML_ASSERT_EQ(data.Size(), size_t(4)); @@ -325,7 +366,7 @@ int testStorageReaderTestsRawData() ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); } - { // read uint16_t weights from metadata 5 + { // read uint16_t weights from metadata 7 auto data = reader.GetRawDataView(832); ML_ASSERT_EQ(data.Size(), size_t(4)); @@ -333,6 +374,81 @@ int testStorageReaderTestsRawData() ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); } + { // read Int4 weights from metadata 8 + auto data = reader.GetRawDataView(960); + ML_ASSERT_EQ(data.Size(), size_t(4)); + + // remember int4's are actually stored here, so this vector type is immaterial + // (cant materialize an int4 span from an int4 vector) + std::vector expectedValue = {uint16_t(0xe8d0), uint16_t(0x7e)}; + ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); + } + + { // read UInt4 weights from metadata 9 + auto data = reader.GetRawDataView(1088); + ML_ASSERT_EQ(data.Size(), size_t(3)); + + std::vector expectedValue = {uint8_t(0xd1), uint8_t(0xe8), uint8_t(0x7c)}; + ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); + } + { // read UInt1 weights from metadata 10 + auto data = reader.GetRawDataView(1216); + ML_ASSERT_EQ(data.Size(), size_t(3)); + + std::vector expectedValue = {uint8_t(0x24), uint8_t(0xec), uint8_t(0xf7)}; + ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); + } + { // read UInt2 weights from metadata 11 + auto data = reader.GetRawDataView(1344); + ML_ASSERT_EQ(data.Size(), size_t(2)); + + std::vector expectedValue = {uint8_t(0x24), uint8_t(0xec)}; + ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); + } + + { // read UInt3 weights from metadata 12 + auto data = reader.GetRawDataView(1472); + ML_ASSERT_EQ(data.Size(), size_t(4)); + + std::vector expectedValue = {uint8_t(0x24), uint8_t(0xEC), uint8_t(0x45), uint8_t(0x1D)}; + ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); + } + { // read Int32 weights from metadata 13 + auto data = reader.GetRawDataView(1600); + ML_ASSERT_EQ(data.Size(), size_t(4)); + + std::vector expectedValue = {uint8_t(0x24), uint8_t(0x0C), uint8_t(0x00), uint8_t(0x00)}; + ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); + } + { // read Uint32 weights from metadata 14 + auto data = reader.GetRawDataView(1728); + ML_ASSERT_EQ(data.Size(), size_t(8)); + + std::vector expectedValue = {uint8_t(0x24), + uint8_t(0x0C), + uint8_t(0x00), + uint8_t(0x00), + uint8_t(0xEF), + uint8_t(0xBE), + uint8_t(0x00), + uint8_t(0x00)}; + ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); + } + { // read Fp8E5M2 weights from metadata 15 + auto data = reader.GetRawDataView(1856); + ML_ASSERT_EQ(data.Size(), size_t(4)); + + std::vector expectedValue = {uint8_t(0xBE), uint8_t(0xEF), uint8_t(0x00), uint8_t(0xCA)}; + ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); + } + { // read Fp8E4M3FN weights from metadata 16 + auto data = reader.GetRawDataView(1984); + ML_ASSERT_EQ(data.Size(), size_t(4)); + + std::vector expectedValue = {uint8_t(0xBE), uint8_t(0xEF), uint8_t(0x00), uint8_t(0xCA)}; + ML_ASSERT_SPAN_EQ(data, Util::SpanCast(Util::MakeSpan(expectedValue))); + } + return 0; } @@ -366,6 +482,48 @@ int testStorageReaderTestsDataOffset() ML_ASSERT_EQ(BlobDataType::BFloat16, reader.GetDataType(576)); } + { // read data offset for UInt1 weights from metadata 9 + ML_ASSERT_EQ(uint64_t(1280), reader.GetDataOffset(1216)); + ML_ASSERT_EQ(uint64_t(3), reader.GetDataSize(1216)); + ML_ASSERT_EQ(BlobDataType::UInt1, reader.GetDataType(1216)); + ML_ASSERT_EQ(7, reader.GetDataPaddingInBits(1216)); + } + + { // read data offset for UInt3 weights from metadata 12 + ML_ASSERT_EQ(uint64_t(0x600), reader.GetDataOffset(1472)); + ML_ASSERT_EQ(uint64_t(4), reader.GetDataSize(1472)); + ML_ASSERT_EQ(BlobDataType::UInt3, reader.GetDataType(1472)); + ML_ASSERT_EQ(5, reader.GetDataPaddingInBits(1472)); + } + + { // read data offset for Int32 weights from metadata 13 + ML_ASSERT_EQ(uint64_t(1664), reader.GetDataOffset(1600)); + ML_ASSERT_EQ(uint64_t(4), reader.GetDataSize(1600)); + ML_ASSERT_EQ(BlobDataType::Int32, reader.GetDataType(1600)); + ML_ASSERT_EQ(0, reader.GetDataPaddingInBits(1600)); + } + + { // read data offset for UInt32 weights from metadata 14 + ML_ASSERT_EQ(uint64_t(1792), reader.GetDataOffset(1728)); + ML_ASSERT_EQ(uint64_t(8), reader.GetDataSize(1728)); + ML_ASSERT_EQ(BlobDataType::UInt32, reader.GetDataType(1728)); + ML_ASSERT_EQ(0, reader.GetDataPaddingInBits(1728)); + } + + { // read data offset for Fp8E5M2 weights from metadata 15 + ML_ASSERT_EQ(uint64_t(1920), reader.GetDataOffset(1856)); + ML_ASSERT_EQ(uint64_t(4), reader.GetDataSize(1856)); + ML_ASSERT_EQ(BlobDataType::Float8E5M2, reader.GetDataType(1856)); + ML_ASSERT_EQ(0, reader.GetDataPaddingInBits(1856)); + } + + { // read data offset for Fp8E4M3FN weights from metadata 15 + ML_ASSERT_EQ(uint64_t(1920), reader.GetDataOffset(1984)); + ML_ASSERT_EQ(uint64_t(4), reader.GetDataSize(1984)); + ML_ASSERT_EQ(BlobDataType::Float8E4M3FN, reader.GetDataType(1984)); + ML_ASSERT_EQ(0, reader.GetDataPaddingInBits(1984)); + } + return 0; } diff --git a/mlmodel/tests/MILBlob/StorageWriterTests.cpp b/mlmodel/tests/MILBlob/StorageWriterTests.cpp index 04f561d57..880cca6c9 100644 --- a/mlmodel/tests/MILBlob/StorageWriterTests.cpp +++ b/mlmodel/tests/MILBlob/StorageWriterTests.cpp @@ -7,6 +7,8 @@ #include "MILBlob/Blob/StorageFormat.hpp" #include "MILBlob/Blob/StorageWriter.hpp" #include "MILBlob/Fp16.hpp" +#include "MILBlob/Fp8.hpp" +#include "MILBlob/Util/SpanCast.hpp" #include "AutoDeleteTempFile.hpp" #include "BlobUtils.hpp" #include "framework/TestUtils.hpp" @@ -29,21 +31,38 @@ namespace { return header.count == count && header.version == uint32_t(2); } +template +[[nodiscard]] bool IsCorrectMetadataForSubByte(const std::string& filePath, + uint64_t offset, + uint64_t entryCount, + BlobDataType dataType) +{ + blob_metadata metadata; + TestUtil::ReadData(filePath, metadata, offset); + uint64_t occupiedBits = (metadata.sizeInBytes * 8 - metadata.padding_size_in_bits); + bool sizesMatch = (occupiedBits / T::SizeInBits) == entryCount; + return metadata.sentinel == BlobMetadataSentinel && metadata.mil_dtype == dataType && sizesMatch; +} + template [[nodiscard]] bool IsCorrectMetadata(const std::string& filePath, uint64_t offset, uint64_t entryCount, BlobDataType dataType) { - blob_metadata metadata; - TestUtil::ReadData(filePath, metadata, offset); - - return metadata.sentinel == BlobMetadataSentinel && metadata.mil_dtype == dataType && - metadata.sizeInBytes == entryCount * sizeof(T) && metadata.offset % DefaultStorageAlignment == 0; + if constexpr (MILBlob::IsSubByteSized::value) { + return IsCorrectMetadataForSubByte(filePath, offset, entryCount, dataType); + } else { + blob_metadata metadata; + TestUtil::ReadData(filePath, metadata, offset); + + return metadata.sentinel == BlobMetadataSentinel && metadata.mil_dtype == dataType && + metadata.sizeInBytes == entryCount * sizeof(T) && metadata.offset % DefaultStorageAlignment == 0; + } } template -[[nodiscard]] bool IsCorrectData(const std::string& filePath, uint64_t offset, Util::Span expectedSpan) +[[nodiscard]] bool IsCorrectDataImpl(const std::string& filePath, uint64_t offset, Util::Span expectedSpan) { blob_metadata metadata; TestUtil::ReadData(filePath, metadata, offset); @@ -56,6 +75,55 @@ template std::equal(outputSpan.begin(), outputSpan.end(), expectedSpan.begin()); } +template +[[nodiscard]] bool IsCorrectSubByteData(const std::string& filePath, uint64_t offset, Util::Span expectedSpan) +{ + blob_metadata metadata; + TestUtil::ReadData(filePath, metadata, offset); + // sizing this int8 vector with the int4 span size is an overestimation - should be + // fine for this test since we need the buffer only anyway + std::vector v(expectedSpan.Size()); + Util::Span outputSpan((void*)v.data(), expectedSpan.Size()); + TestUtil::ReadBlobFile(filePath, metadata.offset, outputSpan); + + auto ourBytesPtr = static_cast(outputSpan.Data()); + auto otherBytesPtr = static_cast(expectedSpan.Data()); + + // scan bytes up to but not including padding + std::size_t numBits = outputSpan.Size() * T::SizeInBits; + std::size_t remainderBits = numBits % 8; + + std::size_t numFullBytes = numBits / 8; + std::size_t numRemainingElements = remainderBits / T::SizeInBits; + + for (size_t i = 0; i < numFullBytes; i++) { + if (ourBytesPtr[i] != otherBytesPtr[i]) { + return false; + } + } + + // scan remainder, ignore garbage bits + for (size_t i = 0; i < numRemainingElements; i++) { + auto mask = T::BitMask << (i * T::SizeInBits); + auto ourVal = ourBytesPtr[numFullBytes].data & mask; + auto otherVal = otherBytesPtr[numFullBytes].data & mask; + if (ourVal != otherVal) { + return false; + } + } + return true; +} + +template +[[nodiscard]] bool IsCorrectData(const std::string& filePath, uint64_t offset, Util::Span expectedSpan) +{ + if constexpr (MILBlob::IsSubByteSized::value) { + return IsCorrectSubByteData(filePath, offset, expectedSpan); + } else { + return IsCorrectDataImpl(filePath, offset, expectedSpan); + } +} + } // anonymous namespace int testStorageWriterTestsSupportedTypes() @@ -112,6 +180,38 @@ int testStorageWriterTestsSupportedTypes() ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); } + // Writing Fp8E4M3FN values + { + const std::vector val = {Fp8E4M3FN(0xCA), Fp8E4M3FN(0xBE), Fp8E4M3FN(0x80), Fp8E4M3FN(0x00)}; + auto expectedSpan = Util::MakeSpan(val); + uint64_t offset = 0; + { + StorageWriter writer(tempfile.GetFilename(), /* truncateFile */ false); + offset = writer.WriteData(expectedSpan); + } + + ML_ASSERT_EQ(offset % DefaultStorageAlignment, uint64_t(0)); + ML_ASSERT(IsCorrectHeader(filePath, ++headerCount /*count*/)); + ML_ASSERT(IsCorrectMetadata(filePath, offset, 4, BlobDataType::Float8E4M3FN)); + ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); + } + + // Writing Fp8E5M2 values + { + const std::vector val = {Fp8E5M2(0xCA), Fp8E5M2(0xBE), Fp8E5M2(0x80), Fp8E5M2(0x00)}; + auto expectedSpan = Util::MakeSpan(val); + uint64_t offset = 0; + { + StorageWriter writer(tempfile.GetFilename(), /* truncateFile */ false); + offset = writer.WriteData(expectedSpan); + } + + ML_ASSERT_EQ(offset % DefaultStorageAlignment, uint64_t(0)); + ML_ASSERT(IsCorrectHeader(filePath, ++headerCount /*count*/)); + ML_ASSERT(IsCorrectMetadata(filePath, offset, 4, BlobDataType::Float8E5M2)); + ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); + } + // Writing bf16 values { const std::vector val = {Bf16(0x12), Bf16(0x00), Bf16(0x124), Bf16(0xabcd)}; @@ -176,12 +276,151 @@ int testStorageWriterTestsSupportedTypes() ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); } + // Writing int4 values + { + std::vector val = {0xFA, 0x17, 0xD2}; + auto expectedSpanInt8 = Util::MakeSpan(val); + const auto expectedSpan = MILBlob::Util::CastToBitSpan(expectedSpanInt8, 6); + + uint64_t offset = 0; + { + StorageWriter writer(tempfile.GetFilename(), /* truncateFile */ false); + offset = writer.WriteData(expectedSpan); + } + + ML_ASSERT_EQ(offset % DefaultStorageAlignment, uint64_t(0)); + ML_ASSERT(IsCorrectHeader(filePath, ++headerCount /*count*/)); + ML_ASSERT(IsCorrectMetadata(filePath, offset, 6, BlobDataType::Int4)); + ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); + } + + // Writing uint4 values + { + std::vector val = {0xFA, 0x17, 0xD2}; + auto expectedSpanUInt8 = Util::MakeSpan(val); + const auto expectedSpan = MILBlob::Util::CastToBitSpan(expectedSpanUInt8, 5); + + uint64_t offset = 0; + { + StorageWriter writer(tempfile.GetFilename(), /* truncateFile */ false); + offset = writer.WriteData(expectedSpan); + } + + ML_ASSERT_EQ(offset % DefaultStorageAlignment, uint64_t(0)); + ML_ASSERT(IsCorrectHeader(filePath, ++headerCount /*count*/)); + ML_ASSERT(IsCorrectMetadata(filePath, offset, 5, BlobDataType::UInt4)); + ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); + } + + // Writing uint2 values + { + std::vector val = {0b11001010, 0b11110101}; + auto expectedSpanUInt8 = Util::MakeSpan(val); + const auto expectedSpan = MILBlob::Util::CastToBitSpan(expectedSpanUInt8, 5); + + uint64_t offset = 0; + { + StorageWriter writer(tempfile.GetFilename(), /* truncateFile */ false); + offset = writer.WriteData(expectedSpan); + } + + ML_ASSERT_EQ(offset % DefaultStorageAlignment, uint64_t(0)); + ML_ASSERT(IsCorrectHeader(filePath, ++headerCount /*count*/)); + ML_ASSERT(IsCorrectMetadata(filePath, offset, 5, BlobDataType::UInt2)); + ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); + } + + // Writing uint1 values + { + std::vector val = {0b11001010, 0b11110101}; + auto expectedSpanUInt8 = Util::MakeSpan(val); + const auto expectedSpan = MILBlob::Util::CastToBitSpan(expectedSpanUInt8, 13); + + uint64_t offset = 0; + { + StorageWriter writer(tempfile.GetFilename(), /* truncateFile */ false); + offset = writer.WriteData(expectedSpan); + } + + ML_ASSERT_EQ(offset % DefaultStorageAlignment, uint64_t(0)); + ML_ASSERT(IsCorrectHeader(filePath, ++headerCount /*count*/)); + ML_ASSERT(IsCorrectMetadata(filePath, offset, 13, BlobDataType::UInt1)); + ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); + } + + // Writing uint3 values + { + std::vector val = {0b11001010, 0b11110101}; + auto expectedSpanUInt8 = Util::MakeSpan(val); + const auto expectedSpan = MILBlob::Util::CastToBitSpan(expectedSpanUInt8, 4); + + uint64_t offset = 0; + { + StorageWriter writer(tempfile.GetFilename(), /* truncateFile */ false); + offset = writer.WriteData(expectedSpan); + } + + ML_ASSERT_EQ(offset % DefaultStorageAlignment, uint64_t(0)); + ML_ASSERT(IsCorrectHeader(filePath, ++headerCount /*count*/)); + ML_ASSERT(IsCorrectMetadata(filePath, offset, 4, BlobDataType::UInt3)); + ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); + } + + // Writing uint6 values + { + std::vector val = {0b11001010, 0b11110101, 0b00000100}; + auto expectedSpanUInt8 = Util::MakeSpan(val); + const auto expectedSpan = MILBlob::Util::CastToBitSpan(expectedSpanUInt8, 3); + + uint64_t offset = 0; + { + StorageWriter writer(tempfile.GetFilename(), /* truncateFile */ false); + offset = writer.WriteData(expectedSpan); + } + + ML_ASSERT_EQ(offset % DefaultStorageAlignment, uint64_t(0)); + ML_ASSERT(IsCorrectHeader(filePath, ++headerCount /*count*/)); + ML_ASSERT(IsCorrectMetadata(filePath, offset, 3, BlobDataType::UInt6)); + ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); + } + + // Writing int32 values + { + const std::vector val = {0xFFC2, 0x7FFF}; + auto expectedSpan = Util::MakeSpan(val); + uint64_t offset = 0; + { + StorageWriter writer(tempfile.GetFilename(), /* truncateFile */ false); + offset = writer.WriteData(expectedSpan); + } + + ML_ASSERT_EQ(offset % DefaultStorageAlignment, uint64_t(0)); + ML_ASSERT(IsCorrectHeader(filePath, ++headerCount)); + ML_ASSERT(IsCorrectMetadata(filePath, offset, 2, BlobDataType::Int32)); + ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); + } + // Writing uint32 values + { + const std::vector val = {0xFFC2, 0x7FFF, 0xDEAD, 0XCAFE}; + auto expectedSpan = Util::MakeSpan(val); + uint64_t offset = 0; + { + StorageWriter writer(tempfile.GetFilename(), /* truncateFile */ false); + offset = writer.WriteData(expectedSpan); + } + + ML_ASSERT_EQ(offset % DefaultStorageAlignment, uint64_t(0)); + ML_ASSERT(IsCorrectHeader(filePath, ++headerCount)); + ML_ASSERT(IsCorrectMetadata(filePath, offset, 4, BlobDataType::UInt32)); + ML_ASSERT(IsCorrectData(filePath, offset, expectedSpan)); + } + return 0; } int testStorageWriterTestsAppendToExistingFile() { - // File does not exists, creates one + // File does not exist, creates one { AutoDeleteTempFile tempfile; StorageWriter(tempfile.GetFilename(), /* truncateFile */ false); diff --git a/mlmodel/tests/MLModelTests.hpp b/mlmodel/tests/MLModelTests.hpp index 1a935772c..fd6f12261 100644 --- a/mlmodel/tests/MLModelTests.hpp +++ b/mlmodel/tests/MLModelTests.hpp @@ -19,6 +19,7 @@ MLMODEL_TEST(testLargeModel) MLMODEL_TEST(testVeryLargeModel) MLMODEL_TEST(testOptionalInputs) MLMODEL_TEST(testFeatureDescriptions) + MLMODEL_TEST(testNNValidatorLoop) MLMODEL_TEST(testNNValidatorMissingInput) MLMODEL_TEST(testNNValidatorSimple) @@ -106,6 +107,8 @@ MLMODEL_TEST(testSpecDowngradeFlexibleShapes) MLMODEL_TEST(testSpecDowngradeFlexibleShapes2) MLMODEL_TEST(testSpecDowngradePipeline) MLMODEL_TEST(testWordTaggerTransferLearningSpecIOS14) +MLMODEL_TEST(testEmptyInputModel_downgradeToIOS18) +MLMODEL_TEST(testMultiFunctionModel_downgradeToIOS18) MLMODEL_TEST(testBayesianProbitRegressionValidationBasic) MLMODEL_TEST(testRangeVal) MLMODEL_TEST(testRangeValDivide) @@ -195,6 +198,23 @@ MLMODEL_TEST(testInvalidLayerNormalizationWrongGammaOrBeta) MLMODEL_TEST(testInvalidConstantPad) MLMODEL_TEST(testInvalidArgsortWrongAxis) +// multi-function tests +MLMODEL_TEST(testMultiFunctionSpecificationVersion) +MLMODEL_TEST(testMultiFunctionDefaultFunctionName) +MLMODEL_TEST(testMultiFunctionTopLevelFeatureDescriptionsMustBeEmpty) +MLMODEL_TEST(testMultiFunctionEmptyInput) +MLMODEL_TEST(testMultiFunctionAllowed) + +// stateful prediction tests +MLMODEL_TEST(testStateSpecificationVersion) +MLMODEL_TEST(testStateFeatureDescriptionInInputs) +MLMODEL_TEST(testStateFeatureIsNotFP16_shouldFail) +MLMODEL_TEST(testStateFeatureIsOptional_shouldFail) +MLMODEL_TEST(testStateFeatureHasNoDefaultShape_shouldFail) +MLMODEL_TEST(testStateFeatureHasNoArrayType_shouldFail) +MLMODEL_TEST(testStateFeature_ArrayUsesRangeFlexibleShape_shouldFail) +MLMODEL_TEST(testStateFeature_ArrayUsesEnumeratedFlexibleShape_shouldFail) + // Updatable model tests MLMODEL_TEST(testUpdatableModelSpecVersion) MLMODEL_TEST(testInvalidUpdatableModelQuantizedWeights) @@ -258,12 +278,17 @@ MLMODEL_TEST(testMMapFileReaderTestsFileErrorNotFound) MLMODEL_TEST(testMMapFileReaderTestsReadData) MLMODEL_TEST(testMMapFileReaderTestsReadStruct) MLMODEL_TEST(testSpanCastTestsBasics) +MLMODEL_TEST(testSpanCastTestsFromInt4) +MLMODEL_TEST(testSpanCastTestsToInt4) MLMODEL_TEST(testSpanTestsAccessImmutable) MLMODEL_TEST(testSpanTestsAccessMutable) +MLMODEL_TEST(testSpanTestsConstInt4) +MLMODEL_TEST(testSpanTestsConstUInt4) MLMODEL_TEST(testSpanTestsCopyAndAssignment) MLMODEL_TEST(testSpanTestsDefaultConstructor) MLMODEL_TEST(testSpanTestsEmpty) MLMODEL_TEST(testSpanTestsImplicitConstCopyCtor) +MLMODEL_TEST(testSpanTestsInt4) MLMODEL_TEST(testSpanTestsIterationDynamicSlices) MLMODEL_TEST(testSpanTestsIterationIllegal) MLMODEL_TEST(testSpanTestsIterationMultipleDims) @@ -293,8 +318,11 @@ MLMODEL_TEST(testSpanTestsSlicingIllegalBounds) MLMODEL_TEST(testSpanTestsSlicingUnbounded) MLMODEL_TEST(testSpanTestsSlicingUnboundedEdge) MLMODEL_TEST(testSpanTestsSlicingZeroLength) +MLMODEL_TEST(testSpanTestsSpanOverload) MLMODEL_TEST(testSpanTestsStaticSizedAccessImmutable) MLMODEL_TEST(testSpanTestsStaticSizedAccessMutable) +MLMODEL_TEST(testSpanTestsSubByteUIntValueAt) +MLMODEL_TEST(testSpanTestsSubbyteIntValueAt) MLMODEL_TEST(testStorageIntegrationTestsReadDataWithIncorrectOffset) MLMODEL_TEST(testStorageIntegrationTestsReadDataWithIncorrectType) MLMODEL_TEST(testStorageIntegrationTestsWriteAndReadValues) diff --git a/mlmodel/tests/NNValidatorTests.cpp b/mlmodel/tests/NNValidatorTests.cpp index 42069c7dd..5216a0282 100644 --- a/mlmodel/tests/NNValidatorTests.cpp +++ b/mlmodel/tests/NNValidatorTests.cpp @@ -240,7 +240,7 @@ int testInvalidDefaultOptionalValue() { // axis should be in range [-(rank + 1), rank + 1) Result res = Model::validate(m); ML_ASSERT_BAD(res); - ML_ASSERT(res.message().find("mismatch between dataType and the type") != std::string::npos); + ML_ASSERT(res.message().find("mistmatch between dataType and the type") != std::string::npos); return 0; } @@ -1163,7 +1163,7 @@ int testValidPooling3d() { pooling3dLayer->add_input("input"); pooling3dLayer->add_output("probs"); auto *mutablePooling3d = pooling3dLayer->mutable_pooling3d(); - + // Add Kernel sizes mutablePooling3d->set_kerneldepth(2); mutablePooling3d->set_kernelheight(2); @@ -1182,10 +1182,10 @@ int testValidPooling3d() { mutablePooling3d->set_custompaddingbottom(7); mutablePooling3d->set_custompaddingleft(7); mutablePooling3d->set_custompaddingright(7); - + Result res = validate(m1); ML_ASSERT_GOOD(res); - + return 0; } @@ -1213,7 +1213,7 @@ int testInvalidPooling3dNegativeKernelSize() { pooling3dLayer->add_input("input"); pooling3dLayer->add_output("probs"); auto *mutablePooling3d = pooling3dLayer->mutable_pooling3d(); - + // Add Kernel sizes mutablePooling3d->set_kerneldepth(2); mutablePooling3d->set_kernelheight(2); @@ -1232,10 +1232,10 @@ int testInvalidPooling3dNegativeKernelSize() { mutablePooling3d->set_custompaddingbottom(7); mutablePooling3d->set_custompaddingleft(7); mutablePooling3d->set_custompaddingright(7); - + Result res = validate(m1); ML_ASSERT_BAD(res); - + return 0; } @@ -1264,7 +1264,7 @@ int testInvalidPooling3dCostumPaddingSetForNonCustomPaddingType() { pooling3dLayer->add_input("input"); pooling3dLayer->add_output("probs"); auto *mutablePooling3d = pooling3dLayer->mutable_pooling3d(); - + // Add Kernel sizes mutablePooling3d->set_kerneldepth(2); mutablePooling3d->set_kernelheight(2); @@ -1283,10 +1283,10 @@ int testInvalidPooling3dCostumPaddingSetForNonCustomPaddingType() { mutablePooling3d->set_custompaddingbottom(7); mutablePooling3d->set_custompaddingleft(7); mutablePooling3d->set_custompaddingright(7); - + Result res = validate(m1); ML_ASSERT_BAD(res); - + return 0; } @@ -3039,13 +3039,13 @@ int testInvalidUpsampleNearestNeighborsModeWithAlignCorners() { params->set_mode(Specification::UpsampleLayerParams_InterpolationMode::UpsampleLayerParams_InterpolationMode_NN); params->set_linearupsamplemode(Specification::UpsampleLayerParams_LinearUpsampleMode_ALIGN_CORNERS_FALSE); - + Result res = validate(m1); ML_ASSERT_BAD(res); params->set_mode(Specification::UpsampleLayerParams_InterpolationMode::UpsampleLayerParams_InterpolationMode_NN); params->set_linearupsamplemode(Specification::UpsampleLayerParams_LinearUpsampleMode_ALIGN_CORNERS_TRUE); - + res = validate(m1); ML_ASSERT_BAD(res); return 0; @@ -3111,7 +3111,7 @@ int testFractionalUpsample() { // Fractional scaling factor valid params->add_fractionalscalingfactor(2.5); params->add_fractionalscalingfactor(3.5); - + // Requires "align corners" bilinear mode params->set_mode(Specification::UpsampleLayerParams_InterpolationMode_NN); res = validate(m1); @@ -3167,12 +3167,12 @@ int testValidUpsampleAlignCorners() { params->set_mode(Specification::UpsampleLayerParams_InterpolationMode_BILINEAR); params->set_linearupsamplemode(Specification::UpsampleLayerParams_LinearUpsampleMode_ALIGN_CORNERS_FALSE); - + Result res = validate(m1); ML_ASSERT_GOOD(res); params->set_linearupsamplemode(Specification::UpsampleLayerParams_LinearUpsampleMode_ALIGN_CORNERS_TRUE); - + res = validate(m1); ML_ASSERT_GOOD(res); @@ -3217,7 +3217,7 @@ int testUpsampleArgsortSpec() { upsampleParams->add_scalingfactor(1.0); upsampleParams->set_mode(Specification::UpsampleLayerParams_InterpolationMode_BILINEAR); - + auto *argsortLayer = nn->add_layers(); argsortLayer->set_name("argsort"); argsortLayer->add_input("A"); @@ -3278,7 +3278,7 @@ int testInvalidSoftmax() { const auto nn = m1.mutable_neuralnetwork(); nn->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); - + Specification::NeuralNetworkLayer *layer = nn->add_layers(); layer->set_name("softmax"); layer->add_input("input"); @@ -3292,29 +3292,29 @@ int testInvalidSoftmax() { } int testInvalidSoftmax2() { - + Specification::Model m1; - + auto *topIn = m1.mutable_description()->add_input(); topIn->set_name("input"); auto *shape = topIn->mutable_type()->mutable_multiarraytype(); // rank must be at least length 3 shape->add_shape(5); shape->add_shape(5); - + auto *out3 = m1.mutable_description()->add_output(); out3->set_name("probs"); out3->mutable_type()->mutable_multiarraytype(); - + const auto nn = m1.mutable_neuralnetwork(); nn->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); - + Specification::NeuralNetworkLayer *layer = nn->add_layers(); layer->set_name("softmax"); layer->add_input("input"); layer->add_output("probs"); (void) layer->mutable_softmax(); - + Result res = validate(m1); ML_ASSERT_BAD(res); return 0; @@ -4285,11 +4285,11 @@ int testValidBranch() { topIn->mutable_type()->mutable_multiarraytype(); auto *shape = topIn->mutable_type()->mutable_multiarraytype(); shape->add_shape(1); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + // "If" net Specification::NeuralNetwork nnIf; auto *l1 = nnIf.add_layers(); @@ -4297,7 +4297,7 @@ int testValidBranch() { l1->set_name("if_relu"); l1->add_input("A"); l1->add_output("B"); - + // "else" net Specification::NeuralNetwork nnElse; auto *l2 = nnElse.add_layers(); @@ -4305,7 +4305,7 @@ int testValidBranch() { l2->set_name("else_relu"); l2->add_input("A"); l2->add_output("B"); - + // Main network auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); @@ -4314,14 +4314,14 @@ int testValidBranch() { l3->set_name("condition_producing_layer"); l3->add_input("A"); l3->add_output("cond"); - + auto *l4 = nnMain->add_layers(); auto *branch_layer = l4->mutable_branch(); l4->set_name("branch_layer"); l4->add_input("cond"); branch_layer->mutable_ifbranch()->CopyFrom(nnIf); branch_layer->mutable_elsebranch()->CopyFrom(nnElse); - + Result res = validate(m); ML_ASSERT_GOOD(res); return 0; @@ -4334,11 +4334,11 @@ int testInvalidBranchOutputNotProduced1() { topIn->mutable_type()->mutable_multiarraytype(); auto *shape = topIn->mutable_type()->mutable_multiarraytype(); shape->add_shape(1); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + // "If" net Specification::NeuralNetwork nnIf; auto *l1 = nnIf.add_layers(); @@ -4346,7 +4346,7 @@ int testInvalidBranchOutputNotProduced1() { l1->set_name("if_relu"); l1->add_input("A"); l1->add_output("B"); - + // "else" net Specification::NeuralNetwork nnElse; auto *l2 = nnElse.add_layers(); @@ -4354,7 +4354,7 @@ int testInvalidBranchOutputNotProduced1() { l2->set_name("else_relu"); l2->add_input("A"); l2->add_output("B2"); - + // Main network auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); @@ -4363,14 +4363,14 @@ int testInvalidBranchOutputNotProduced1() { l3->set_name("condition_producing_layer"); l3->add_input("A"); l3->add_output("cond"); - + auto *l4 = nnMain->add_layers(); auto *branch_layer = l4->mutable_branch(); l4->set_name("branch_layer"); l4->add_input("cond"); branch_layer->mutable_ifbranch()->CopyFrom(nnIf); branch_layer->mutable_elsebranch()->CopyFrom(nnElse); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4383,11 +4383,11 @@ int testInvalidBranchOutputNotProduced2() { topIn->mutable_type()->mutable_multiarraytype(); auto *shape = topIn->mutable_type()->mutable_multiarraytype(); shape->add_shape(1); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + // "If" net Specification::NeuralNetwork nnIf; auto *l1 = nnIf.add_layers(); @@ -4395,7 +4395,7 @@ int testInvalidBranchOutputNotProduced2() { l1->set_name("if_relu"); l1->add_input("A"); l1->add_output("B"); - + // Main network auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); @@ -4404,13 +4404,13 @@ int testInvalidBranchOutputNotProduced2() { l3->set_name("condition_producing_layer"); l3->add_input("A"); l3->add_output("cond"); - + auto *l4 = nnMain->add_layers(); auto *branch_layer = l4->mutable_branch(); l4->set_name("branch_layer"); l4->add_input("cond"); branch_layer->mutable_ifbranch()->CopyFrom(nnIf); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4423,11 +4423,11 @@ int testInvalidBranchBlobOverwrite() { topIn->mutable_type()->mutable_multiarraytype(); auto *shape = topIn->mutable_type()->mutable_multiarraytype(); shape->add_shape(1); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + // "If" net Specification::NeuralNetwork nnIf; auto *l1 = nnIf.add_layers(); @@ -4435,7 +4435,7 @@ int testInvalidBranchBlobOverwrite() { l1->set_name("if_relu"); l1->add_input("A"); l1->add_output("cond"); - + // "else" net Specification::NeuralNetwork nnElse; auto *l2 = nnElse.add_layers(); @@ -4443,7 +4443,7 @@ int testInvalidBranchBlobOverwrite() { l2->set_name("else_relu"); l2->add_input("A"); l2->add_output("B"); - + // Main network auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); @@ -4452,14 +4452,14 @@ int testInvalidBranchBlobOverwrite() { l3->set_name("condition_producing_layer"); l3->add_input("A"); l3->add_output("cond"); - + auto *l4 = nnMain->add_layers(); auto *branch_layer = l4->mutable_branch(); l4->set_name("branch_layer"); l4->add_input("cond"); branch_layer->mutable_ifbranch()->CopyFrom(nnIf); branch_layer->mutable_elsebranch()->CopyFrom(nnElse); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4472,11 +4472,11 @@ int testInvalidCopy() { topIn->mutable_type()->mutable_multiarraytype(); auto *shape = topIn->mutable_type()->mutable_multiarraytype(); shape->add_shape(1); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + auto *nn = m.mutable_neuralnetwork(); nn->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); auto *l = nn->add_layers(); @@ -4484,7 +4484,7 @@ int testInvalidCopy() { l->set_name("copy"); l->add_input("A"); l->add_output("A"); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4503,28 +4503,28 @@ int testInvalidLoop1() { auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - - + + Specification::NeuralNetwork nnBody; auto *l1 = nnBody.add_layers(); (void)l1->mutable_activation()->mutable_relu(); l1->set_name("relu"); l1->add_input("A"); l1->add_output("B"); - + auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); auto *l2 = nnMain->add_layers(); l2->set_name("for_loop"); auto *loop_params = l2->mutable_loop(); loop_params->mutable_bodynetwork()->CopyFrom(nnBody); - + auto *l3 = nnMain->add_layers(); l3->set_name("copy"); l3->add_input("A"); l3->add_output("B"); (void) l3->mutable_copy(); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4543,21 +4543,21 @@ int testInvalidLoop2() { auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + Specification::NeuralNetwork nnCondition; auto *l1 = nnCondition.add_layers(); l1->mutable_greaterthan()->set_alpha(1.0); l1->set_name("cond"); l1->add_input("A"); l1->add_output("cond"); - + Specification::NeuralNetwork nnBody; auto *l2 = nnBody.add_layers(); (void)l2->mutable_activation()->mutable_relu(); l2->set_name("relu"); l2->add_input("A"); l2->add_output("B"); - + auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); auto *l3 = nnMain->add_layers(); @@ -4565,13 +4565,13 @@ int testInvalidLoop2() { auto *loop_params = l3->mutable_loop(); loop_params->mutable_bodynetwork()->CopyFrom(nnBody); loop_params->mutable_conditionnetwork()->CopyFrom(nnCondition); - + auto *l4 = nnMain->add_layers(); l4->set_name("copy"); l4->add_input("A"); l4->add_output("B"); (void) l4->mutable_copy(); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4590,15 +4590,15 @@ int testInvalidLoop3() { auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + Specification::NeuralNetwork nnBody; - + auto *l2 = nnBody.add_layers(); (void)l2->mutable_activation()->mutable_relu(); l2->set_name("relu"); l2->add_input("A"); l2->add_output("B"); - + auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); auto *l3 = nnMain->add_layers(); @@ -4606,13 +4606,13 @@ int testInvalidLoop3() { auto *loop_params = l3->mutable_loop(); loop_params->mutable_bodynetwork()->CopyFrom(nnBody); loop_params->set_conditionvar("cond"); - + auto *l4 = nnMain->add_layers(); l4->set_name("copy"); l4->add_input("A"); l4->add_output("B"); (void) l4->mutable_copy(); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4631,37 +4631,37 @@ int testInvalidLoop4() { auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + Specification::NeuralNetwork nnCondition; auto *l1 = nnCondition.add_layers(); l1->mutable_greaterthan()->set_alpha(1.0); l1->set_name("cond2"); l1->add_input("A"); l1->add_output("cond2"); - + Specification::NeuralNetwork nnBody; auto *l2 = nnBody.add_layers(); (void)l2->mutable_activation()->mutable_relu(); l2->set_name("relu"); l2->add_input("A"); l2->add_output("B"); - + auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); - + auto *l3 = nnMain->add_layers(); l3->set_name("for_loop"); auto *loop_params = l3->mutable_loop(); loop_params->mutable_bodynetwork()->CopyFrom(nnBody); loop_params->mutable_conditionnetwork()->CopyFrom(nnCondition); loop_params->set_conditionvar("cond"); - + auto *l4 = nnMain->add_layers(); l4->set_name("copy"); l4->add_input("A"); l4->add_output("B"); (void) l4->mutable_copy(); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4680,31 +4680,31 @@ int testInvalidLoop5() { auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + Specification::NeuralNetwork nnCondition; auto *l1 = nnCondition.add_layers(); l1->mutable_greaterthan()->set_alpha(1.0); l1->set_name("cond"); l1->add_input("A"); l1->add_output("cond"); - + Specification::NeuralNetwork nnBody; auto *l2 = nnBody.add_layers(); (void)l2->mutable_activation()->mutable_relu(); l2->set_name("relu"); l2->add_input("A"); l2->add_output("B"); - + auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); - + auto *l3 = nnMain->add_layers(); l3->set_name("for_loop"); auto *loop_params = l3->mutable_loop(); loop_params->mutable_bodynetwork()->CopyFrom(nnBody); loop_params->mutable_conditionnetwork()->CopyFrom(nnCondition); loop_params->set_conditionvar("cond"); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4723,20 +4723,20 @@ int testInvalidLoopBreak() { auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); - + auto *l1 = nnMain->add_layers(); l1->set_name("copy"); l1->add_input("A"); l1->add_output("B"); (void) l1->mutable_copy(); - + auto *l2 = nnMain->add_layers(); l2->set_name("break"); (void) l2->mutable_loopbreak(); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4755,20 +4755,20 @@ int testInvalidLoopContinue() { auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + auto *nnMain = m.mutable_neuralnetwork(); nnMain->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); - + auto *l1 = nnMain->add_layers(); l1->set_name("copy"); l1->add_input("A"); l1->add_output("B"); (void) l1->mutable_copy(); - + auto *l2 = nnMain->add_layers(); l2->set_name("continue"); (void) l2->mutable_loopcontinue(); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4780,36 +4780,36 @@ int testInvalidRankInconsistency() { rank of B when output of relu1 : 1 rank of B when input of relu2: 2 (makes the model invalid) */ - + Specification::Model m; auto *topIn = m.mutable_description()->add_input(); topIn->set_name("A"); topIn->mutable_type()->mutable_multiarraytype(); auto *shape = topIn->mutable_type()->mutable_multiarraytype(); shape->add_shape(1); - + auto *out = m.mutable_description()->add_output(); out->set_name("C"); out->mutable_type()->mutable_multiarraytype(); - + auto *nn = m.mutable_neuralnetwork(); nn->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); - + auto *l = nn->add_layers(); (void)l->mutable_activation()->mutable_relu(); l->set_name("relu1"); l->add_input("A"); l->add_output("B"); l->add_outputtensor()->set_rank(1); - + auto *l2 = nn->add_layers(); (void)l2->mutable_activation()->mutable_relu(); l2->set_name("relu2"); l2->add_input("B"); l2->add_output("C"); l2->add_inputtensor()->set_rank(2); - - + + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4828,17 +4828,17 @@ int testInvalidExpandDims1() { topIn->mutable_type()->mutable_multiarraytype(); auto *shape = topIn->mutable_type()->mutable_multiarraytype(); shape->add_shape(2); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); auto *shape_out = out->mutable_type()->mutable_multiarraytype(); shape_out->add_shape(2); shape_out->add_shape(1); shape_out->add_shape(1); - + auto *nn = m.mutable_neuralnetwork(); nn->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); - + auto *l = nn->add_layers(); l->set_name("ED"); l->add_input("A"); @@ -4847,7 +4847,7 @@ int testInvalidExpandDims1() { l->add_outputtensor()->set_rank(3); auto *params = l->mutable_expanddims(); params->add_axes(-1); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4866,14 +4866,14 @@ int testInvalidExpandDims2() { topIn->mutable_type()->mutable_multiarraytype(); auto *shape = topIn->mutable_type()->mutable_multiarraytype(); shape->add_shape(2); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); auto *shape_out = out->mutable_type()->mutable_multiarraytype(); shape_out->add_shape(2); shape_out->add_shape(1); shape_out->add_shape(1); - + auto *nn = m.mutable_neuralnetwork(); nn->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); auto *l = nn->add_layers(); @@ -4885,7 +4885,7 @@ int testInvalidExpandDims2() { auto *params = l->mutable_expanddims(); params->add_axes(2); params->add_axes(-4); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4906,12 +4906,12 @@ int testInvalidSqueeze1() { shape->add_shape(2); shape->add_shape(1); shape->add_shape(1); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); auto *shape_out = out->mutable_type()->mutable_multiarraytype(); shape_out->add_shape(2); - + auto *nn = m.mutable_neuralnetwork(); nn->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); auto *l = nn->add_layers(); @@ -4923,7 +4923,7 @@ int testInvalidSqueeze1() { auto *params = l->mutable_squeeze(); params->add_axes(1); params->add_axes(1); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4942,11 +4942,11 @@ int testInvalidPoolingRank1() { shape->add_shape(2); shape->add_shape(1); shape->add_shape(1); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + auto *nn = m.mutable_neuralnetwork(); nn->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); auto *l = nn->add_layers(); @@ -4958,7 +4958,7 @@ int testInvalidPoolingRank1() { params->set_type(::Specification::PoolingLayerParams::AVERAGE); params->set_globalpooling(true); params->mutable_valid(); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -4978,11 +4978,11 @@ int testInvalidPoolingRank2() { shape->add_shape(2); shape->add_shape(1); shape->add_shape(1); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + auto *nn = m.mutable_neuralnetwork(); nn->set_arrayinputshapemapping(Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); auto *l = nn->add_layers(); @@ -4995,7 +4995,7 @@ int testInvalidPoolingRank2() { params->set_type(::Specification::PoolingLayerParams::AVERAGE); params->set_globalpooling(true); params->mutable_valid(); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; @@ -5014,19 +5014,19 @@ int testInvalidIOS13LayerOldRank() { shape->add_shape(2); shape->add_shape(1); shape->add_shape(1); - + auto *out = m.mutable_description()->add_output(); out->set_name("B"); out->mutable_type()->mutable_multiarraytype(); - + auto *nn = m.mutable_neuralnetwork(); - + auto *l = nn->add_layers(); l->set_name("erf"); l->add_input("A"); l->add_output("B"); l->mutable_erf(); - + Result res = validate(m); ML_ASSERT_BAD(res); return 0; diff --git a/mlmodel/tests/UtilsTests.cpp b/mlmodel/tests/UtilsTests.cpp index 557549b86..bdc3908e9 100644 --- a/mlmodel/tests/UtilsTests.cpp +++ b/mlmodel/tests/UtilsTests.cpp @@ -138,3 +138,87 @@ int testWordTaggerTransferLearningSpecIOS14() { return 0; } + +int testEmptyInputModel_downgradeToIOS18() { + // This model doesn't have input. Such empty input models are supported in iOS18 and later. + auto spec = Specification::Model(); + spec.set_specificationversion(MLMODEL_SPECIFICATION_VERSION_NEWEST); + auto* modelDescription = spec.mutable_description(); + + auto output = modelDescription->add_output(); + output->set_name("output"); + output->mutable_type()->mutable_multiarraytype()->set_datatype(Specification::ArrayFeatureType_ArrayDataType::ArrayFeatureType_ArrayDataType_FLOAT32); + output->mutable_type()->mutable_multiarraytype()->add_shape(1); + + auto* net = spec.mutable_neuralnetwork(); + auto* layer = net->add_layers(); + layer->set_name("load_constantND_layer"); + layer->add_output("output"); + layer->mutable_loadconstantnd()->add_shape(1); + layer->mutable_loadconstantnd()->mutable_data()->add_floatvalue(0.1f); + + // The model uses empty input, which was introduced in iOS18. + Model emptyInputModel(spec); + ML_ASSERT_EQ(emptyInputModel.getProto().specificationversion(), MLMODEL_SPECIFICATION_VERSION_IOS18); + + // Now, add an input. + auto input = modelDescription->add_input(); + input->set_name("input"); + input->mutable_type()->mutable_multiarraytype()->set_datatype(Specification::ArrayFeatureType_ArrayDataType::ArrayFeatureType_ArrayDataType_FLOAT32); + input->mutable_type()->mutable_multiarraytype()->add_shape(1); + + // The model uses EXACT_ARRAY_MAPPING, which was introduced in iOS13. Other than that, + // there is nothing special. We expect the downgrade utility sets it to iOS13. + Model modelWithInput(spec); + ML_ASSERT_EQ(modelWithInput.getProto().specificationversion(), MLMODEL_SPECIFICATION_VERSION_IOS13); + + return 0; +} + +int testMultiFunctionModel_downgradeToIOS18() { + // This model doesn't have input. Such empty input models are supported in iOS18 and later. + auto spec = Specification::Model(); + spec.set_specificationversion(MLMODEL_SPECIFICATION_VERSION_NEWEST); + auto* modelDescription = spec.mutable_description(); + + auto *function = modelDescription->add_functions(); + function->set_name("f"); + + modelDescription->set_defaultfunctionname("f"); + + auto input = function->add_input(); + input->set_name("input"); + input->mutable_type()->mutable_multiarraytype()->set_datatype(Specification::ArrayFeatureType_ArrayDataType::ArrayFeatureType_ArrayDataType_FLOAT32); + input->mutable_type()->mutable_multiarraytype()->add_shape(1); + + auto output = function->add_output(); + output->set_name("output"); + output->mutable_type()->mutable_multiarraytype()->set_datatype(Specification::ArrayFeatureType_ArrayDataType::ArrayFeatureType_ArrayDataType_FLOAT32); + output->mutable_type()->mutable_multiarraytype()->add_shape(1); + + spec.mutable_mlprogram(); + + // The model uses multi-function description syntax, which was introduced in iOS18. + Model multiFunctionModel(spec); + ML_ASSERT_EQ(multiFunctionModel.getProto().specificationversion(), MLMODEL_SPECIFICATION_VERSION_IOS18); + + // Let's remove multi-function syntax and use the good old syntax. + modelDescription->clear_functions(); + modelDescription->clear_defaultfunctionname(); + + input = modelDescription->add_input(); + input->set_name("input"); + input->mutable_type()->mutable_multiarraytype()->set_datatype(Specification::ArrayFeatureType_ArrayDataType::ArrayFeatureType_ArrayDataType_FLOAT32); + input->mutable_type()->mutable_multiarraytype()->add_shape(1); + + output = modelDescription->add_output(); + output->set_name("output"); + output->mutable_type()->mutable_multiarraytype()->set_datatype(Specification::ArrayFeatureType_ArrayDataType::ArrayFeatureType_ArrayDataType_FLOAT32); + output->mutable_type()->mutable_multiarraytype()->add_shape(1); + + // Now, the model is nothing special ML Program, which was introduced in iOS15. + Model notMultiFunctionModel(spec); + ML_ASSERT_EQ(notMultiFunctionModel.getProto().specificationversion(), MLMODEL_SPECIFICATION_VERSION_IOS15); + + return 0; +} diff --git a/mlmodel/tests/framework/TestUtils.hpp b/mlmodel/tests/framework/TestUtils.hpp index 539e533a6..fb70022e4 100644 --- a/mlmodel/tests/framework/TestUtils.hpp +++ b/mlmodel/tests/framework/TestUtils.hpp @@ -24,7 +24,7 @@ bool caughtCorrectException = false; \ try { expr; } \ catch (const exType&) { caughtCorrectException = true; } \ - catch (...) { std::clog << __FILE__ << ":" << __LINE__ << ": error: caught unexpected exception type.\n"; return 1;} \ + catch (...) { std::clog << __FILE__ << ":" << __LINE__ << ": error: caught unexpected exeception type.\n"; return 1;} \ if (!caughtCorrectException) { std::clog << __FILE__ << ":" << __LINE__ << ": expected exception, but none thrown.\n"; return 1; } } #define ML_ASSERT_THROWS_WITH_MESSAGE(expr, exType, message) \ @@ -40,7 +40,7 @@ return 1; \ } \ } catch (...) { \ - std::clog << __FILE__ << ":" << __LINE__ << ": error: caught unexpected exception type.\n"; \ + std::clog << __FILE__ << ":" << __LINE__ << ": error: caught unexpected exeception type.\n"; \ return 1; \ } \ if (!caughtCorrectException) { \ diff --git a/reqs/test.pip b/reqs/test.pip index 86cb7a346..a6b789ce3 100644 --- a/reqs/test.pip +++ b/reqs/test.pip @@ -24,16 +24,23 @@ scipy==1.9.2; python_version == '3.11' six sympy > 1.6 gast==0.4.0 -torch==2.2.0 -torchaudio==2.2.0 -torchvision==0.17.0 +torch==2.2.0; platform_machine != "arm64" +torch==2.3.0; platform_machine == "arm64" +executorch==0.2.0; platform_machine == "arm64" and python_version >= '3.10' and python_version <= '3.11' +torchaudio==2.2.0; platform_machine != "arm64" +torchaudio==2.3.0; platform_machine == "arm64" +torchvision==0.17.0; platform_machine != "arm64" +torchvision==0.18.0; platform_machine == "arm64" +torchsr==1.0.4; platform_machine == "arm64" and python_version >= '3.10' and python_version <= '3.11' +timm==0.6.13; platform_machine == "arm64" and python_version >= '3.10' and python_version <= '3.11' xgboost==1.4.2; platform_machine != "arm64" mock wrapt tqdm pytest-timeout -transformers==4.26.0 +transformers==4.26.0; platform_machine != "arm64" +transformers==4.38.2; platform_machine == "arm64" # coremltools.optimize.torch filelock==3.6.0 diff --git a/scripts/test.sh b/scripts/test.sh index 7265b05c3..a94f723fa 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -127,6 +127,12 @@ if [[ $COV != "" ]]; then fi echo ${TEST_CMD} -eval ${TEST_CMD} +eval ${TEST_CMD}" &" +init_pid="$!" +init_exit_code=0 +wait ${init_pid} || init_exit_code=$? +if [[ "${init_exit_code}" != "0" ]]; then + eval ${TEST_CMD}" --last-failed" +fi pip uninstall -y coremltools