diff --git a/coremltools/converters/_converters_entry.py b/coremltools/converters/_converters_entry.py index 05c090092..ce7d60cac 100644 --- a/coremltools/converters/_converters_entry.py +++ b/coremltools/converters/_converters_entry.py @@ -42,6 +42,7 @@ def convert( minimum_deployment_target=None, convert_to=None, compute_precision=None, + skip_model_load=False, **kwargs ): """ @@ -203,6 +204,18 @@ def convert( - Before coremltools 5.0 release, change the default to coremltools.precision.FLOAT16 when convert_to="mlprogram" + skip_model_load : bool + Set to True to prevent coremltools from calling into the Core ML framework + to compile and load the model, post-conversion. In that case, the returned + model object cannot be used to make a prediction, but can be used to save + via "model.save()". This flag may be used to convert to a newer model type + on an older Mac, which if done without turning this flag on, may raise a + runtime warning. + Example: Use this flag to suppress runtime warning when converting to + ML program model type on a macOS 11, since ML program + can only be compiled and loaded from macOS12+. + Defaults to False. + Returns ------- model : ``coremltools.models.MLModel`` or ``coremltools.converters.mil.Program`` @@ -283,6 +296,7 @@ def convert( outputs=outputs, classifier_config=classifier_config, transforms=transforms, + skip_model_load=skip_model_load, **kwargs ) diff --git a/coremltools/converters/mil/backend/mil/helper.py b/coremltools/converters/mil/backend/mil/helper.py index ae67dbf6c..7336f38c5 100644 --- a/coremltools/converters/mil/backend/mil/helper.py +++ b/coremltools/converters/mil/backend/mil/helper.py @@ -7,6 +7,7 @@ from coremltools.converters.mil.mil import types from coremltools.converters.mil.mil.types import builtin_to_proto_types from coremltools.models.model import _WEIGHTS_DIR_NAME, _WEIGHTS_FILE_NAME +import coremltools.proto.FeatureTypes_pb2 as ft from coremltools.converters.mil.mil.types import ( type_to_builtin_type, @@ -353,3 +354,14 @@ def create_immediate_value(var): raise NotImplementedError("List element type, {}, not supported yet.".format(var.sym_type.__type_info__())) else: return create_scalar_value(var.val) + +def cast_to_framework_io_dtype(var, is_output): + if var.dtype == types.fp32: + return ft.ArrayFeatureType.ArrayDataType.FLOAT32 + elif var.dtype == types.int32: + return ft.ArrayFeatureType.ArrayDataType.INT32 + else: + ioname = "Output " if is_output else "Input " + ioname2 = "outputs" if is_output else "inputs" + raise NotImplementedError(ioname + var.name + " has data type " + builtin_to_string(var.dtype) + \ + ". ML Program models only support fp32 and int32 " + ioname2 + ".") diff --git a/coremltools/converters/mil/backend/mil/load.py b/coremltools/converters/mil/backend/mil/load.py index d9d256376..315ff6071 100644 --- a/coremltools/converters/mil/backend/mil/load.py +++ b/coremltools/converters/mil/backend/mil/load.py @@ -248,7 +248,7 @@ def _add_classify_op(prog, classifier_config): def load(prog, weights_dir, resume_on_errors=False, **kwargs): if "main" not in prog.functions: raise ValueError("main function not found in program") - + mil_backend_passes(prog) # if user has specified "ClassifierConfig", then add the "classify" op to the prog @@ -314,9 +314,7 @@ def load(prog, weights_dir, resume_on_errors=False, **kwargs): if name not in image_input_names: # make a feature type of Type "multiArrayType" - array_type = ft.ArrayFeatureType( - shape=shape, dataType=ft.ArrayFeatureType.ArrayDataType.FLOAT32 - ) + array_type = ft.ArrayFeatureType(shape=shape, dataType=cast_to_framework_io_dtype(var, False)) input_feature_type.multiArrayType.CopyFrom(array_type) else: if len(shape) < 3: @@ -344,7 +342,7 @@ def load(prog, weights_dir, resume_on_errors=False, **kwargs): ml.FeatureDescription(name=name, type=input_feature_type) ) elif types.is_scalar(var.sym_type): - array_type = ft.ArrayFeatureType(shape=[1], dataType=ft.ArrayFeatureType.ArrayDataType.FLOAT32) + array_type = ft.ArrayFeatureType(shape=[1], dataType=cast_to_framework_io_dtype(var, False)) input_feature_type.multiArrayType.CopyFrom(array_type) input_features.append(ml.FeatureDescription(name=var.name, type=input_feature_type)) else: @@ -353,8 +351,15 @@ def load(prog, weights_dir, resume_on_errors=False, **kwargs): for var in prog.functions["main"].outputs: output_feature_type = ft.FeatureType() if types.is_tensor(var.sym_type) or types.is_primitive(var.sym_type): - # Ignore output type; always set to ArrayFeatureType(shape=None, dataType=FLOAT32) - array_type = ft.ArrayFeatureType(shape=None, dataType=ft.ArrayFeatureType.ArrayDataType.FLOAT32) + dataType = None + if classifier_config is None or var.name != predicted_feature_name: + # Not a classifier output, make sure model output type matches with ML Program type. + dataType = cast_to_framework_io_dtype(var, True) + else: + # Classifier outputs are set up separately, so default to fp32 for now. + dataType = ft.ArrayFeatureType.ArrayDataType.FLOAT32 + + array_type = ft.ArrayFeatureType(shape=None, dataType=dataType) output_feature_type.multiArrayType.CopyFrom(array_type) output_features.append(ml.FeatureDescription(name=var.name, type=output_feature_type)) elif (types.is_dict(var.sym_type)): diff --git a/coremltools/converters/mil/backend/mil/passes/adjust_io_to_supported_types.py b/coremltools/converters/mil/backend/mil/passes/adjust_io_to_supported_types.py new file mode 100644 index 000000000..887cd4378 --- /dev/null +++ b/coremltools/converters/mil/backend/mil/passes/adjust_io_to_supported_types.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2021, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from coremltools.converters.mil.mil import Builder as _mb +from coremltools.converters.mil.mil import types as _types +from coremltools.converters.mil.mil.ops import defs as _ops +from coremltools.converters.mil.mil.passes.pass_registry import register_pass as _register_pass + +import warnings as _warnings + +@_register_pass(namespace="mil_backend") +def adjust_io_to_supported_types(prog): + """ + Converts all dTypes to types that are supported by the CoreML runtime. + The runtime supports only fp16, fp32, int32, str, and bool variables. + + General rules: + * Integer vars that are not 32 bit are replaced with int32 types. + * All other types not in the list of runtime supported types are replaced with the fp32 dtype. + No casts are inserted; the previous type is replaced. The assumption is that all remaining + types are numerical and can be reasonably replaced with 32 bit float types. + + The "main" function has additional rules since its I/O is mapped to CoreML model I/O: + * Fp16 I/O is replaced with fp32 I/O. + Casts (fp32 input -> fp16) are inserted at the beginning of the program to preserve 16 bit inputs. + Casts (fp16 -> fp32 output) are inserted at the end of the program to preserve 16 bit computations. + + * All non-integer I/O that is not fp32 is replaced with fp32 I/O. + A cast (prev input type -> fp32) is inserted at the beginning of the program to preserve non-fp32 inputs. + A cast (prev type -> fp32 out) is inserted at the end of the program to preserve non-fp32 computations. + The assumption is that all remaining types are numerical and it is valid to cast them to/from fp32. + + * The only exception: Int64 outputs are allowed for the classifier op. This is to keep consistency with + the CoreML API, which uses 64 bit integers to represent classifier labels. + + ------ + + func main(bool x, int32 y, fp32 z) { + bool out = logical_not(x) + } -> (out, y, z) + + becomes + + func main(fp32 x, int32 y, fp32 z) { + bool x_casted = cast(x) + bool out__pre__output__fp32__cast = logical_not(x_casted) + fp32 out = cast(out__pre__output__fp32__cast) + } -> (out, y, z) + + ------ + + func not_main(bool x, int32 y, fp32 z) { + bool out = logical_not(x) + } -> (out, y, z) + + is unchanged. + """ + for name, func in prog.functions.items(): + _adjust_io_to_supported_types(func, name == "main") + + +__RUNTIME_SUPPORTED_TYPES = [_types.fp16, _types.fp32, _types.int32, _types.str, _types.bool] + +##### +# Main Function +##### +def _adjust_main_inputs(func): + first_op = func.operations[0] if len(func.operations) > 0 else None + for input_name, input_var in func.inputs.items(): + if (_types.is_tensor(input_var.sym_type) or _types.is_scalar(input_var.sym_type)) \ + and input_var.dtype != _types.fp32 \ + and input_var.dtype != _types.int32: + input_dtype_str = _types.builtin_to_string(input_var.dtype) + if _types.is_int(input_var.dtype): + # Replace non-int32 input type with int32. + _warnings.warn("Input" + input_var.name + " is of dType " + input_dtype_str +\ + ". Only integer variables of bit width 32 are supported by the CoreML runtime. " +\ + "This input will be assigned a dType of int32. " +\ + "No cast will be inserted; the previous dtype will be replaced.") + input_var._sym_type = _types.tensor(_types.int32, input_var.sym_type.get_shape()) + elif input_var.dtype == _types.fp64: + # Replace float64 input type with fp32. + _warnings.warn("Input" + input_var.name + " is of dtype fp64. 64 bit float inputs are " +\ + "not supported by ML program models. This input will be assigned a dType " +\ + "of fp32. No cast will be inserted; the previous dtype will be replaced.") + input_var._sym_type = _types.tensor(_types.fp32, input_var.sym_type.get_shape()) + else: + # This is some other dType. Change the type to fp32 and add a cast. + # This is only a limitation of main--other functions do not represent CoreML model inputs + # and do not have the same limitation on input types. + _warnings.warn("Input" + input_var.name + " is of dType " + input_dtype_str + ". The " +\ + "CoreML runtime does not support inputs with this dType (only fp32 and " +\ + "int32 inputs are supported). This input will be assigned a dType of " +\ + "fp32. A cast will be inserted at the beginning of the program to " +\ + "convert the input to the originally defined dType.") + with func: + casted_input_var = _mb.cast(x=input_var, dtype=input_dtype_str, before_op=first_op) + func.replace_uses_of_var_after_op(anchor_op=casted_input_var.op, old_var=input_var, new_var=casted_input_var) + input_var._sym_type = _types.tensor(_types.fp32, input_var.sym_type.get_shape()) + + +def _adjust_main_outputs(func): + new_outputs = [] + for output_var in func.outputs: + output_type = output_var.sym_type + if (_types.is_tensor(output_type) or _types.is_scalar(output_type)) \ + and output_var.dtype != _types.fp32 \ + and output_var.dtype != _types.int32: + output_dtype_str = _types.builtin_to_string(output_var.dtype) + _warnings.warn("Output" + output_var.name + " is of dType " + output_dtype_str + ". The " +\ + "CoreML runtime does not support outputs with this dType (only int32 and " +\ + "fp32 are supported for outputs). This output will be assigned a dType " +\ + "of fp32. A cast will be inserted at the end of the program to convert" +\ + "the original output dType to the dType supported by the CoreML runtime.") + + output_var_name = output_var.name + output_var.set_name(output_var_name + "__pre__output__fp32__cast") + # Convert the output to fp32, and add a cast. + with func: + output_var = _mb.cast(x=output_var, dtype="fp32") + output_var.set_name(output_var_name) + new_outputs.append(output_var) + func.set_outputs(new_outputs) + + +##### +# General Functions and Blocks +##### +def _adjust_var(var): + """ + Changes the dtype of the provided variable according + to the rules outlined in the top level pass comment + (see adjust_io_to_supported_types). + """ + if (_types.is_tensor(var.sym_type) or _types.is_scalar(var.sym_type)) \ + and var.dtype not in __RUNTIME_SUPPORTED_TYPES: + dtype_str = _types.builtin_to_string(var.dtype) + if _types.is_int(var.dtype): + # Replace non-int32 input type with int32. + _warnings.warn("Input" + var.name + " is of dType " + dtype_str +\ + ". Only integer variables of bit width 32 are supported by the CoreML runtime. " +\ + "This input will be assigned a dType of int32. " +\ + "No cast will be inserted; the previous dtype will be replaced.") + var._sym_type = _types.tensor(_types.int32, var.sym_type.get_shape()) + else: + # This is some other unsupported dType. Change the input type to fp32. + _warnings.warn("Var " + var.name + " is of dType " + dtype_str + ". The CoreML runtime " +\ + "does not support this dType (only fp16, fp32, bool, and int32 are supported). " +\ + "This input will be assigned a dType of fp32. No cast will be inserted; " +\ + "the previous dtype will be replaced.") + var._sym_type = _types.tensor(_types.fp32, var.sym_type.get_shape()) + + +def _adjust_func_inputs(func): + for input_name, input_var in func.inputs.items(): + _adjust_var(input_var) + + +def _adjust_block_inputs(block): + for input_var in block.inputs: + _adjust_var(input_var) + + +def _adjust_ops(block): + len_block = len(block.operations) + i = 0 + while i < len_block: + op = block.operations[i] + + # Classifier is a special exception to this rule. It can output 64 bit integer labels. + # Classifier should be inserted after running this pass. + if op.op_type == "classify": + raise ValueError("ML Program backend pass adjust_to_supported_types does not support programs" +\ + " that have already added a classify op.") + + for subblock in op.blocks: + _adjust_block_inputs(subblock) + _adjust_ops(subblock) + + for var in op.outputs: + _adjust_var(var) + + # Cast ops have a param (dtype) that should match the output dtype. + # If the output dtype or input dtype was previously adjusted, + # the cast op must change or be removed in kind. + if op.op_type == "cast": + output_type_str = _types.builtin_to_string(op.outputs[0].dtype) + if op.outputs[0].dtype == op.x.dtype: + # The type of the input or output of this cast op was changed per the rules + # defined in the top level comment for adjust_io_to_supported_types. + # + # That changed output type is the same type as the input to the cast + # op. Therefore, regardless of whether the user created this cast or + # not, it is now redundant (noop), and should be removed. + # + # The removal isn't covered by the main cast + # optimization pass since that pass runs before this pass. + block.replace_uses_of_var_after_op( + anchor_op=op, old_var=op.outputs[0], new_var=op.x + ) + block.remove_ops([op]) + len_block = len(block.operations) + i -= 1 + elif output_type_str != op.dtype.val: + # The type of the output of this cast op was changed per the rules + # defined in the top level comment for adjust_io_to_supported_types. + # + # This cast is meaningful, and the "dtype" param now differs from the output + # type. Replace the dtype cast with a new cast op with a matching dtype param. + with block: + new_cast_out = _mb.cast(x=op.x, dtype=output_type_str, before_op=op) + block.replace_uses_of_var_after_op( + anchor_op=op, old_var=op.outputs[0], new_var=new_cast_out + ) + block.remove_ops([op]) + len_block = len(block.operations) + i = i + 1 + return block + +##### +# The Pass +##### +def _adjust_io_to_supported_types(func, is_main): + if is_main: + _adjust_main_inputs(func) + _adjust_ops(func) + _adjust_main_outputs(func) + else: + _adjust_func_inputs(func) + _adjust_ops(func) diff --git a/coremltools/converters/mil/backend/mil/passes/fuse_activation_silu.py b/coremltools/converters/mil/backend/mil/passes/fuse_activation_silu.py new file mode 100644 index 000000000..b22518bb4 --- /dev/null +++ b/coremltools/converters/mil/backend/mil/passes/fuse_activation_silu.py @@ -0,0 +1,78 @@ +# Copyright (c) 2020, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from coremltools.converters.mil.mil.passes.pass_registry import register_pass +from coremltools.converters.mil.mil import Builder as mb + +def match_pattern(op): + if op.op_type == "sigmoid": + # abort fusion if op output is also a block output + if op.outputs[0] in op.enclosing_block.outputs: + return None + # find following op + child_ops = op.outputs[0].child_ops + if len(child_ops) == 1: + mul_op_candidate = list(child_ops)[0] + if mul_op_candidate.op_type != "mul": + return None + mul_inputs_actual = {mul_op_candidate.x.name, mul_op_candidate.y.name} + mul_inputs_expect = {op.x.name, op.outputs[0].name} + if mul_inputs_actual != mul_inputs_expect: + return None + return mul_op_candidate + + return None + + +def try_to_transform(sigmoid_op, mul_op, block): + out_name = mul_op.outputs[0].name + # create a new silu op + x = mb.silu(x=sigmoid_op.x, name=out_name, before_op=sigmoid_op) + mul_op.enclosing_block.replace_uses_of_var_after_op( + anchor_op=mul_op, old_var=mul_op.outputs[0], new_var=x + ) + # Remove all the ops at once + block.remove_ops([sigmoid_op, mul_op]) + return True + + +def fuse_activation_silu_block(block): + fusion_status = False + for op in list(block.operations): + for b in op.blocks: + block_changed = True + while block_changed: + block_changed = fuse_activation_silu_block(b) + if len(op.blocks) > 0: + continue + + mul_op = match_pattern(op) + if mul_op is not None: + with block: + fusion_status = try_to_transform(op, mul_op, block) + # has to break as the downstream iterator is affected. + if fusion_status: + return fusion_status + return fusion_status + + +@register_pass(namespace="mil_backend") +def fuse_activation_silu(prog): + """ + Fold x * sigmoid(x) into silu(x) + + Given: + %1 = sigmoid(x=%0) + %2 = mul(x=%0, y=%1) or mul(x=%1, y=%0) + ... + + Result: + %3 = silu(%0) + ... + """ + for f_name, f in prog.functions.items(): + block_changed = True + while block_changed: + block_changed = fuse_activation_silu_block(f) diff --git a/coremltools/converters/mil/backend/mil/passes/mil_passes.py b/coremltools/converters/mil/backend/mil/passes/mil_passes.py index 35153fd70..c18486500 100644 --- a/coremltools/converters/mil/backend/mil/passes/mil_passes.py +++ b/coremltools/converters/mil/backend/mil/passes/mil_passes.py @@ -11,7 +11,9 @@ def mil_backend_passes(prog): passes = [ "common::const_elimination", + "mil_backend::adjust_io_to_supported_types", "mil_backend::insert_image_preprocessing_ops", + "mil_backend::fuse_activation_silu", # TODO: Right now, "const elimination" pass CANNOT be done after the "homogenize_input_dtypes" pass. # Remove this requirement in rdar://76032946. # Right now due to a bug in the PYMIL const op, which is that it can only produce FP32 and INT32 types tensors (e.g. it can't produce int64), diff --git a/coremltools/converters/mil/backend/mil/passes/test_passes.py b/coremltools/converters/mil/backend/mil/passes/test_passes.py new file mode 100644 index 000000000..a1691c7a1 --- /dev/null +++ b/coremltools/converters/mil/backend/mil/passes/test_passes.py @@ -0,0 +1,794 @@ +# Copyright (c) 2020, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +import pytest +import itertools +import numpy as np +import copy + +# import mil internal ops to add it to the builder +import coremltools as ct +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil.passes.pass_registry import PASS_REGISTRY +from coremltools.converters.mil import types +from coremltools.converters.mil.mil.types import string_to_builtin, builtin_to_string, promote_types + +# Set the testing backend +import coremltools.converters.mil.testing_reqs as testing_reqs + +from coremltools.converters.mil.testing_utils import ( + get_op_types_in_program, + apply_pass_and_basic_check, + assert_model_is_valid, +) + + +class TestAdjustToSupportedTypes: + + def test_basic(self): + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 1, 1, 1), dtype=types.bool), + mb.TensorSpec(shape=(1, 1, 1, 1), dtype=types.int32), + mb.TensorSpec(shape=(1, 1, 1, 1), dtype=types.fp32)]) + def prog(x, y, z): + out = mb.logical_not(x=x) + return (out, y, z) + prog.functions['not_main'] = copy.deepcopy(prog.functions['main']) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::adjust_io_to_supported_types" + ) + + """ + Input graph: + + func main(bool x, int32 y, fp32 z) { + bool out = logical_not(x) + } -> (out, y, z) + + becomes + + func main(fp32 x, int32 y, fp32 z) { + bool x_casted = cast(x) + bool out__pre__output__fp32__cast = logical_not(x_casted) + fp32 out = cast(out__pre__output__fp32__cast) + } -> (out, y, z) + """ + assert get_op_types_in_program(prev_prog) == ['logical_not'] + assert get_op_types_in_program(prog) == ['cast', 'logical_not', 'cast'] + + prev_inputs = list(prev_prog.functions['main'].inputs.items()) + inputs = list(prog.functions['main'].inputs.items()) + assert prev_inputs[0][1].name == inputs[0][1].name + assert inputs[0][1].dtype == types.fp32 + for i in range(1, len(inputs)): + assert prev_inputs[i][1].name == inputs[i][1].name + assert prev_inputs[i][1].dtype == inputs[i][1].dtype + + prev_outputs = prev_prog.functions['main'].outputs + outputs = prog.functions['main'].outputs + assert prev_outputs[0].name == outputs[0].name + assert outputs[0].dtype == types.fp32 + for i in range(1, len(outputs)): + assert prev_outputs[i].name == outputs[i].name + assert prev_outputs[i].dtype == outputs[i].dtype + + """ + Input graph: + + func not_main(bool x, int32 y, fp32 z) { + bool out = logical_not(x) + } -> (out, y, z) + + is identical after the pass. + """ + assert get_op_types_in_program(prev_prog, 'not_main') == ['logical_not'] + assert get_op_types_in_program(prog, 'not_main') == ['logical_not'] + + prev_inputs = list(prev_prog.functions['not_main'].inputs.items()) + inputs = list(prog.functions['not_main'].inputs.items()) + for i in range(0, len(inputs)): + assert prev_inputs[i][1].name == inputs[i][1].name + assert prev_inputs[i][1].dtype == inputs[i][1].dtype + + prev_outputs = prev_prog.functions['not_main'].outputs + outputs = prog.functions['not_main'].outputs + for i in range(0, len(outputs)): + assert prev_outputs[i].name == outputs[i].name + assert prev_outputs[i].dtype == outputs[i].dtype + + def test_int64_input(self): + """ + Input graph: + + func main(int64 x) { + } -> (x) + + becomes + + func main(int32 x) { + } -> (x) + """ + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 1, 1, 1), dtype=types.int64)]) + def prog(x): + return x + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::adjust_io_to_supported_types" + ) + + prev_inputs = list(prev_prog.functions['main'].inputs.items()) + inputs = list(prog.functions['main'].inputs.items()) + assert prev_inputs[0][1].name == inputs[0][1].name + assert inputs[0][1].dtype == types.int32 + + def test_float64_input(self): + """ + Input graph: + + func main(float64 x) { + } -> (x) + + becomes + + func main(float32 x) { + } -> (x) + """ + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 1, 1, 1), dtype=types.fp64)]) + def prog(x): + return x + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::adjust_io_to_supported_types" + ) + + prev_inputs = list(prev_prog.functions['main'].inputs.items()) + inputs = list(prog.functions['main'].inputs.items()) + assert prev_inputs[0][1].name == inputs[0][1].name + assert inputs[0][1].dtype == types.fp32 + + def test_int8_input(self): + """ + Input graph: + + func main(int8 x) { + } -> (x) + + becomes + + func main(int32 x) { + } -> (x) + """ + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 1, 1, 1), dtype=types.int8)]) + def prog(x): + return x + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::adjust_io_to_supported_types" + ) + + prev_inputs = list(prev_prog.functions['main'].inputs.items()) + inputs = list(prog.functions['main'].inputs.items()) + assert prev_inputs[0][1].name == inputs[0][1].name + assert inputs[0][1].dtype == types.int32 + + def test_subblock(self): + """ + Input graph: + + func main(float64 a, float32 b) { + float64 out_0, float32 out_1 = while_loop(a, b, + (float64 a, float32 b) { + bool cond = less(a, b) + } -> (cond) + (float64 a, float32 b) { + float64 temp = const(1) + float64 out = add(a, b) + } -> (out, b) + ); + } -> (out_0, out_1) + + becomes + + func main(float32 a, float32 b) { + float32 out_0, float32 out_1 = while_loop(a, b, + (float32 a, float32 b) { + bool cond = less(a, b) + } -> (cond) + (float32 a, float32 b) { + float32 temp = const(1) + float32 out = add(a, b) + } -> (out, b) + ); + } -> (out_0, out_1) + """ + def body(a, b): + return mb.add(x=a, y=np.float64(1)), b + + def cond(a, b): + return mb.less(x=a, y=b) + + @mb.program(input_specs=[mb.TensorSpec(shape=(1,), dtype=types.fp64), + mb.TensorSpec(shape=(1,), dtype=types.fp32)]) + def prog(a, b): + return mb.while_loop(_cond=cond, _body=body, loop_vars=(a, b)) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::adjust_io_to_supported_types" + ) + + prev_inputs = list(prev_prog.functions['main'].inputs.items()) + inputs = list(prog.functions['main'].inputs.items()) + for i in range(0, len(prev_inputs)): + assert prev_inputs[i][1].name == inputs[i][1].name + assert inputs[i][1].dtype == types.fp32 + + assert get_op_types_in_program(prev_prog) == ['while_loop'] + assert get_op_types_in_program(prog) == ['while_loop'] + + def assert_block_inputs(prev_inputs, inputs): + for i in range(0, len(prev_inputs)): + assert prev_inputs[i].name == inputs[i].name + assert inputs[i].dtype == types.fp32 + + subblocks = prog.functions['main'].operations[0].blocks + prev_subblocks = prev_prog.functions['main'].operations[0].blocks + for i in range(0, len(subblocks)): + assert_block_inputs(prev_subblocks[i].inputs, subblocks[i].inputs) + + def test_adjust_cast(self): + """ + Input graph: + + func main(int32 x) { + fp64 y = cast(x=x, dtype="fp64") + } -> (y) + + becomes + + func main(int32 x) { + fp32 y = cast(x=x, dtype="fp32") + } -> (y) + """ + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 1, 1, 1), dtype=types.int32)]) + def prog(x): + y = mb.cast(x=x, dtype="fp64") + return y + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::adjust_io_to_supported_types" + ) + + assert get_op_types_in_program(prev_prog) == ['cast'] + assert get_op_types_in_program(prog) == ['cast'] + + prev_cast = prev_prog.functions['main'].operations[1] + cast = prog.functions['main'].operations[2] + + assert prev_cast.dtype.val == "fp64" + assert prev_cast.outputs[0].dtype == types.fp64 + + assert cast.dtype.val == "fp32" + assert cast.outputs[0].dtype == types.fp32 + + def test_adjust_redundant_cast(self): + """ + Input graph: + + func main(int32 x) { + int64 y = cast(x=x, dtype="int64") + } -> (y) + + becomes + + func main(int32 x) { + } -> (x) + """ + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 1, 1, 1), dtype=types.int32)]) + def prog(x): + y = mb.cast(x=x, dtype="int64") + return y + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::adjust_io_to_supported_types" + ) + + assert get_op_types_in_program(prev_prog) == ['cast'] + assert get_op_types_in_program(prog) == [] + +class TestImagePreprocessingPass: + + def test_program_grayscale(self): + """ + Input graph: + + main(x: ImageType(color_layout="G", channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + + Output graph: + + main(x: ImageType(channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 1, 20, 20))]) + def prog(x): + y1 = mb.relu(x=x) + y2 = mb.relu(x=x) + z = mb.add(x=y1, y=y2) + return z + + prog.main_input_types = (ct.ImageType(name='x', + shape=[1, 1, 20, 20], + color_layout="G", + channel_first=True),) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::insert_image_preprocessing_ops" + ) + assert get_op_types_in_program(prev_prog) == ["relu", "relu", "add"] + assert get_op_types_in_program(prog) == ["relu", "relu", "add"] + + def test_program_grayscale_with_scale(self): + """ + Input graph: + + main(x: ImageType(scale=2.0, color_layout="G", channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + + Output graph: + + main(x: ImageType(channel_first=True)) { + y = mul(x, 2) + y1 = relu(y) + y2 = relu(y) + output = add(y1, y2) + } [output] + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 1, 20, 20))]) + def prog(x): + y1 = mb.relu(x=x) + y2 = mb.relu(x=x) + z = mb.add(x=y1, y=y2) + return z + + prog.main_input_types = (ct.ImageType(name='x', + shape=[1, 1, 20, 20], + scale=2.0, + color_layout="G", + channel_first=True),) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::insert_image_preprocessing_ops" + ) + assert get_op_types_in_program(prev_prog) == ["relu", "relu", "add"] + assert get_op_types_in_program(prog) == ["mul", "relu", "relu", "add"] + scale_op = prog.find_ops(op_type="mul", exactly_one=True)[0] + assert scale_op.y.val == 2.0 + + def test_program_grayscale_with_bias(self): + """ + Input graph: + + main(x: ImageType(bias=2.0, color_layout="G", channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + + Output graph: + + main(x: ImageType(channel_first=True)) { + y = add(x, 2) + y1 = relu(y) + y2 = relu(y) + output = add(y1, y2) + } [output] + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 1, 20, 20))]) + def prog(x): + y1 = mb.relu(x=x) + y2 = mb.relu(x=x) + z = mb.add(x=y1, y=y2) + return z + + prog.main_input_types = (ct.ImageType(name='x', + shape=[1, 1, 20, 20], + bias=2.0, + color_layout="G", + channel_first=True),) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::insert_image_preprocessing_ops" + ) + assert get_op_types_in_program(prev_prog) == ["relu", "relu", "add"] + assert get_op_types_in_program(prog) == ["add", "relu", "relu", "add"] + add_op = prog.find_ops(op_type="add", exactly_one=False)[0] + assert add_op.y.val == 2.0 + + def test_program_grayscale_with_scale_bias(self): + """ + Input graph: + + main(x: ImageType(scale=2.0, bias=2.0, color_layout="G", channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + + Output graph: + + main(x: ImageType(channel_first=True)) { + y_scaled = mul(x, 2) + y = add(y_scaled, 2) + y1 = relu(y) + y2 = relu(y) + output = add(y1, y2) + } [output] + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 1, 20, 20))]) + def prog(x): + y1 = mb.relu(x=x) + y2 = mb.relu(x=x) + z = mb.add(x=y1, y=y2) + return z + + prog.main_input_types = (ct.ImageType(name='x', + shape=[1, 1, 20, 20], + scale=2.0, + bias=2.0, + color_layout="G", + channel_first=True),) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::insert_image_preprocessing_ops" + ) + assert get_op_types_in_program(prev_prog) == ["relu", "relu", "add"] + assert get_op_types_in_program(prog) == ["mul", "add", "relu", "relu", "add"] + scale_op = prog.find_ops(op_type="mul", exactly_one=True)[0] + assert scale_op.y.val == 2.0 + add_op = prog.find_ops(op_type="add", exactly_one=False)[0] + assert add_op.y.val == 2.0 + + def test_program_rgb(self): + """ + Input graph: + + main(x: ImageType(color_layout="RGB", channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + + Output graph: + + main(x: ImageType(channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 3, 20, 20))]) + def prog(x): + y1 = mb.relu(x=x) + y2 = mb.relu(x=x) + z = mb.add(x=y1, y=y2) + return z + + prog.main_input_types = (ct.ImageType(name='x', + shape=[1, 3, 20, 20], + color_layout="RGB", + channel_first=True),) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::insert_image_preprocessing_ops" + ) + assert get_op_types_in_program(prev_prog) == ["relu", "relu", "add"] + assert get_op_types_in_program(prog) == ["relu", "relu", "add"] + + def test_program_rgb_scale_bias(self): + """ + Input graph: + + main(x: ImageType(color_layout="RGB", scale=2.0, bias=[1.0, 2.0, 3.0], channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + + Output graph: + + main(x: ImageType(channel_first=True)) { + y = mul(x, scale) + y_bias = add(y, bias) + y1 = relu(y_bias) + y2 = relu(y_bias) + output = add(y1, y2) + } [output] + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 3, 20, 20))]) + def prog(x): + y1 = mb.relu(x=x) + y2 = mb.relu(x=x) + z = mb.add(x=y1, y=y2) + return z + + prog.main_input_types = (ct.ImageType(name='x', + shape=[1, 3, 20, 20], + scale=2.0, + bias=[1.0, 2.0, 3.0], + color_layout="RGB", + channel_first=True),) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::insert_image_preprocessing_ops" + ) + assert get_op_types_in_program(prev_prog) == ["relu", "relu", "add"] + assert get_op_types_in_program(prog) == ["mul", "add", "relu", "relu", "add"] + scale_op = prog.find_ops(op_type="mul", exactly_one=True)[0] + assert scale_op.y.val == 2.0 + add_op = prog.find_ops(op_type="add", exactly_one=False)[0] + assert np.all(add_op.y.val == np.array([1.0, 2.0, 3.0]).reshape([1, 3, 1, 1])) + + def test_program_bgr(self): + """ + Input graph: + + main(x: ImageType(color_layout="BGR", channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + + Output graph: + + main(x: ImageType(channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 3, 20, 20))]) + def prog(x): + y1 = mb.relu(x=x) + y2 = mb.relu(x=x) + z = mb.add(x=y1, y=y2) + return z + + prog.main_input_types = (ct.ImageType(name='x', + shape=[1, 3, 20, 20], + color_layout="BGR", + channel_first=True),) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::insert_image_preprocessing_ops" + ) + assert get_op_types_in_program(prev_prog) == ["relu", "relu", "add"] + assert get_op_types_in_program(prog) == ["relu", "relu", "add"] + + def test_program_bgr_scale_bias(self): + """ + Input graph: + + main(x: ImageType(color_layout="BGR", scale=2.0, bias=[1.0, 2.0, 3.0], channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + + Output graph: + + main(x: ImageType(channel_first=True)) { + y = mul(x, scale) + y_bias = add(y, bias) + y1 = relu(y_bias) + y2 = relu(y_bias) + output = add(y1, y2) + } [output] + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 3, 20, 20))]) + def prog(x): + y1 = mb.relu(x=x) + y2 = mb.relu(x=x) + z = mb.add(x=y1, y=y2) + return z + + prog.main_input_types = (ct.ImageType(name='x', + shape=[1, 3, 20, 20], + scale=2.0, + bias=[1.0, 2.0, 3.0], + color_layout="BGR", + channel_first=True),) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::insert_image_preprocessing_ops" + ) + assert get_op_types_in_program(prev_prog) == ["relu", "relu", "add"] + assert get_op_types_in_program(prog) == ["mul", "add", "relu", "relu", "add"] + scale_op = prog.find_ops(op_type="mul", exactly_one=True)[0] + assert scale_op.y.val == 2.0 + add_op = prog.find_ops(op_type="add", exactly_one=False)[0] + assert np.all(add_op.y.val == np.array([1.0, 2.0, 3.0]).reshape([1, 3, 1, 1])) + + @pytest.mark.parametrize( + "scale_type, bias_type", itertools.product([np.float, np.int32], [np.float, np.int32]) + ) + def test_scale_bias_types(self, scale_type, bias_type): + """ + Input graph: + + main(x: ImageType(color_layout="RGB", scale=2.0, bias=[1.0, 2.0, 3.0], channel_first=True)) { + y1 = relu(x) + y2 = relu(x) + output = add(y1, y2) + } [output] + + Output graph: + + main(x: ImageType(channel_first=True)) { + y = mul(x, scale) + y_bias = add(y, bias) + y1 = relu(y_bias) + y2 = relu(y_bias) + output = add(y1, y2) + } [output] + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 3, 20, 20))]) + def prog(x): + y1 = mb.relu(x=x) + y2 = mb.relu(x=x) + z = mb.add(x=y1, y=y2) + return z + + prog.main_input_types = (ct.ImageType(name='x', + shape=[1, 3, 20, 20], + scale=scale_type(2.0), + bias=np.array([1, 2, 3]).astype(bias_type), + color_layout="RGB", + channel_first=True),) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "mil_backend::insert_image_preprocessing_ops" + ) + assert get_op_types_in_program(prev_prog) == ["relu", "relu", "add"] + assert get_op_types_in_program(prog) == ["mul", "add", "relu", "relu", "add"] + scale_op = prog.find_ops(op_type="mul", exactly_one=True)[0] + assert scale_op.y.dtype() == prog.functions["main"].inputs["x"].dtype() + add_op = prog.find_ops(op_type="add", exactly_one=False)[0] + assert add_op.y.dtype() == prog.functions["main"].inputs["x"].dtype() + +class TestSanitizerPass: + + def test_sanitize_numeric_var_names(self): + """ + Input: + main(%x: (1, 3, 20, fp32)(Tensor)) { + block0() { + %var_1: (1, 3, 20, fp32)(Tensor) = relu(x=%x, name="var_1") + %1: (1, 3, 20, fp32)(Tensor) = relu(x=%x, name="1") + %3: (1, 3, 20, fp32)(Tensor) = add(x=%Var_1, y=%1, name="3") + } -> (%3) + } + + Output: + main(%x: (1, 3, 20, fp32)(Tensor)) { + block0() { + %var_1: (1, 3, 20, fp32)(Tensor) = relu(x=%x, name="var_1") + %var_1_0: (1, 3, 20, fp32)(Tensor) = relu(x=%x, name="op_1") + %var_3: (1, 3, 20, fp32)(Tensor) = add(x=%var_1, y=%var_1_0, name="op_3") + } -> (%var_3) + } + + """ + + @mb.program(input_specs=[mb.TensorSpec(shape=(1, 3, 20))]) + def prog(x): + y1 = mb.relu(x=x, name = "var_1") + y2 = mb.relu(x=x, name = "1") + z = mb.add(x=y1, y=y2, name = "3") + return z + + PASS_REGISTRY["mil_backend::sanitize_name_strings"](prog) + block = prog.functions["main"] + assert block.find_ops(op_type="relu")[0].outputs[0].name == "var_1" + assert block.find_ops(op_type="relu")[1].outputs[0].name == "var_1_0" + assert prog["main"].outputs[0].name == "var_3" + assert block.find_ops(op_type="relu")[0].name == "var_1" + assert block.find_ops(op_type="relu")[1].name == "op_1" + assert block.find_ops(op_type="add")[0].name == "op_3" + + +class TestPassFuseActivationSiLU: + """ + Input graph: + input --> sigmoid --> mul --> output + Output graph: + input --> silu --> output + """ + + @pytest.mark.parametrize( + "reverse_order", itertools.product([True, False]), + ) + def test_0(self, reverse_order): + x_shape = tuple(np.random.randint(low=1, high=4, size=5)) + + @mb.program(input_specs=[mb.TensorSpec(shape=x_shape)]) + def program(x): + sigmoid_x = mb.sigmoid(x=x) + if not reverse_order: + x = mb.mul(x=x, y=sigmoid_x) + else: + x = mb.mul(x=sigmoid_x, y=x) + return x + + prev_prog, prev_block, block = apply_pass_and_basic_check( + program, "mil_backend::fuse_activation_silu" + ) + + assert get_op_types_in_program(prev_prog) == ["sigmoid", "mul"] + assert get_op_types_in_program(program) == ["silu"] + + assert_model_is_valid( + program=program, + inputs={"x": x_shape}, + backend="mlprogram", + expected_output_shapes={block.outputs[0].name: tuple(x_shape)}, + ) + +class TestHomogenizeInputDtypes: + + @pytest.mark.parametrize( + ["op", "x_dtype", "y_dtype"], + [ + ["add", "int32", "fp32"], + ["mul", "fp32", "int32"], + ["minimum", "int64", "int32"], + ["add", "int32", "fp16"], + ["add", "fp16", "int32"], + ["equal", "bool", "int32"], + ["mod", "int64", "fp16"], + ["not_equal", "fp32", "bool"], + ["pow", "fp16", "fp32"], + ["greater", "fp16", "fp32"], + ["matmul", "fp16", "int32"], + ] + ) + def test_mixed_input_dtypes(self, op, x_dtype, y_dtype): + @mb.program(input_specs=[mb.TensorSpec(shape=(10, 10), dtype=string_to_builtin(x_dtype)), + mb.TensorSpec(shape=(10, 10), dtype=string_to_builtin(y_dtype))]) + def prog(x, y): + x = getattr(mb, op)(x=x, y=y) + return x + + assert get_op_types_in_program(prog) == [op] + + _, _, block = apply_pass_and_basic_check(prog, "mil_backend::homogenize_input_dtypes") + + assert get_op_types_in_program(prog) == ["cast", op] + + promoted_dtype = promote_types(string_to_builtin(x_dtype), string_to_builtin(y_dtype)) + + # Asserting cast configuration + cast = block.find_ops(op_type="cast")[0] + assert cast.dtype.val == builtin_to_string(promoted_dtype) + assert len(cast.outputs) == 1 + assert len(cast.outputs[0].child_ops) == 1 + assert cast.outputs[0].child_ops[0].op_type == op diff --git a/coremltools/converters/mil/backend/nn/load.py b/coremltools/converters/mil/backend/nn/load.py index a95cff6e5..a83473276 100644 --- a/coremltools/converters/mil/backend/nn/load.py +++ b/coremltools/converters/mil/backend/nn/load.py @@ -37,8 +37,8 @@ from coremltools.converters._profile_utils import _profile -def _convert_to_image_input(proto, inputs): - tmp_model = MLModel(proto) +def _convert_to_image_input(proto, inputs, skip_model_load=False): + tmp_model = MLModel(proto, skip_model_load=skip_model_load) for input_type in inputs: if isinstance(input_type, ImageType): if input_type.color_layout == "G": @@ -64,8 +64,8 @@ def _convert_to_image_input(proto, inputs): return tmp_model.get_spec() -def _convert_to_classifier(proto, classifier_config): - tmp_model = MLModel(proto) +def _convert_to_classifier(proto, classifier_config, skip_model_load=False): + tmp_model = MLModel(proto, skip_model_load=skip_model_load) tmp_model = neural_network.utils.make_nn_classifier( tmp_model, classifier_config.class_labels, @@ -284,12 +284,14 @@ def load(prog, **kwargs): # image input has_image_input = any([isinstance(s, ImageType) for s in input_types]) if has_image_input: - proto = _convert_to_image_input(proto, input_types) + proto = _convert_to_image_input(proto, input_types, + skip_model_load=kwargs.get("skip_model_load", False)) # classifier flag classifier_config = kwargs.get("classifier_config", None) if classifier_config is not None: - proto = _convert_to_classifier(proto, classifier_config) + proto = _convert_to_classifier(proto, classifier_config, + skip_model_load=kwargs.get("skip_model_load", False)) _set_user_inputs(proto, input_types) _set_symbolic_inputs(proto, symbolic_inputs) diff --git a/coremltools/converters/mil/backend/nn/op_mapping.py b/coremltools/converters/mil/backend/nn/op_mapping.py index 56f1c8ca9..14c382c86 100644 --- a/coremltools/converters/mil/backend/nn/op_mapping.py +++ b/coremltools/converters/mil/backend/nn/op_mapping.py @@ -38,8 +38,13 @@ def convert_ops(const_context, builder, ops, outputs): elif op.op_type in MIL_TO_NN_MAPPING_REGISTRY: mapper = MIL_TO_NN_MAPPING_REGISTRY[op.op_type] else: - msg = "{} is not implemented for nn backend. block: {}" - raise ValueError(msg.format(op.op_type, op.enclosing_block)) + msg = ("Op {} is used in the source model. This op is not supported " + "by the NeuralNetwork (compatibility with MacOS < 12, iOS < 15) model " + "type. To successfully convert this model, convert to the ML Program " + "model type (minimum target MacOS 12, iOS 15 and later).\n" + "Use coremltools.convert(..., convert_to=\"mlprogram\") to convert to ML Program.\n" + "block: {}") + raise NotImplementedError(msg.format(op.op_type, op.enclosing_block)) # const is globally shared in nn. mapper(const_context, builder, op) @@ -423,8 +428,6 @@ def conv_helper(const_context, builder, op): pad["padding_left"] = op.pad.val[4] pad["padding_right"] = op.pad.val[5] - # This doesn't work till builder fills in all optional values - # (rdar://59280101) has_bias = op.bias is not None groups = op.groups.val @@ -440,7 +443,7 @@ def conv_helper(const_context, builder, op): dilations = dilations[:-1] + [1] + dilations[-1:] strides = strides[:-1] + [1] + strides[-1:] - if weights is not None and weights.dtype == 'uint8': + if weights is not None and op.op_type == "conv_quantized": nbits = op.nbits.val weights = _convert_array_to_nbit_quantized_bytes(weights.flatten(), nbits).tobytes() quantization_type = op.quantization_type.val @@ -2422,46 +2425,63 @@ def layer_norm(const_context, builder, op): axes = [axis+rank if axis < 0 else axis for axis in op.axes.val] epsilon = op.epsilon.val - if rank in [2, 3] and len(axes) == 1 and axes[0] == rank - 1 and input_shape.count(-1) < 2 and input_shape[-1] != -1: + # if input shape = (X1, X2) or (X0, X1, X2), axes = [-1], X1 and X2 are known + # then the following operations are performed + # - reshape to (X1, 1, X2) / (X0, X1, 1, X2) + # - apply MVN layer, which normalizes across last 2 dims + # - apply scale layer + # - reshape back to (X1, X2) / (X0, X1, X2) + # Otherwise, we express the layer_norm as primitive operations + if rank in [2, 3] and len(axes) == 1 and axes[0] == rank - 1 and input_shape.count(-1) < 2 \ + and input_shape[-1] != -1 and input_shape[-2] != -1: + + reshaped_shape = input_shape[:] + # Insert a singleton dimension in the 'height' position + reshaped_shape.insert(-1, 1) + + if len(reshaped_shape) == 4: + gamma_shape = reshaped_shape[1:] + else: + gamma_shape = reshaped_shape - normalized_shape = input_shape[-len(axes) :] - gamma = _np.ones(normalized_shape) if op.gamma is None else op.gamma.val - beta = _np.zeros(normalized_shape) if op.beta is None else op.beta.val + gamma = _np.ones(gamma_shape) if op.gamma is None else _np.tile(op.gamma.val, (gamma_shape[0], 1, 1)) + beta = _np.zeros(gamma_shape) if op.beta is None else _np.tile(op.beta.val, (gamma_shape[0], 1, 1)) builder.add_reshape_static( name=op.name + "_reshape", input_name=make_input(const_context, builder, op.x), - output_name=op.x.name + "_reshape", - output_shape=input_shape + [1, 1], + output_name=op.name + "_reshape", + output_shape=reshaped_shape, ) - + builder.add_mvn( - name=op.x.name + "_mvn", - input_name=op.x.name + "_reshape", - output_name=op.x.name + "_mvn", - across_channels=True, + name=op.name + "_mvn", + input_name=op.name + "_reshape", + output_name=op.name + "_mvn", + across_channels=False, normalize_variance=True, epsilon=epsilon, ) builder.add_scale( - name=op.x.name + "_5d", - input_name=op.x.name + "_mvn", - output_name=op.x.name + "_5d", + name=op.name + "_scale", + input_name=op.name + "_mvn", + output_name=op.name + "_scale", W=gamma, b=beta, has_bias=True, - shape_scale=[len(gamma)], - shape_bias=[len(beta)], + shape_scale=_np.shape(gamma), + shape_bias=_np.shape(beta), ) builder.add_reshape_static( name=op.name, - input_name=op.x.name + "_5d", + input_name=op.name + "_scale", output_name=op.outputs[0].name, output_shape=input_shape, ) - else: + + else: # We don't meet the conditions for an MVN layer, so we use primitives mean_name = op.name + "_mean" builder.add_reduce_mean( name=mean_name, @@ -2780,11 +2800,7 @@ def shape(const_context, builder, op): ) -@register_mil_to_nn_mapping -def upsample_nearest_neighbor(const_context, builder, op): - scale_factor_h = op.scale_factor_height.val - scale_factor_w = op.scale_factor_width.val - +def add_upsample_nn(const_context, builder, op, scale_factor_h, scale_factor_w): if _np.abs(_np.round(scale_factor_h) - scale_factor_h) < 1e-4 and scale_factor_h >= 1 - 1e-4: scale_factor_h = int(scale_factor_h) else: @@ -2808,6 +2824,26 @@ def upsample_nearest_neighbor(const_context, builder, op): ) +@register_mil_to_nn_mapping +def resize_nearest_neighbor(const_context, builder, op): + Hout, Wout = op.target_size_height.val, op.target_size_width.val + x_shape = op.x.shape + Hin, Win = x_shape[-2], x_shape[-1] + + scale_factor_h = Hout / Hin if Hout % Hin == 0 else (Hout + 1e-4) / Hin + scale_factor_w = Wout / Win if Wout % Win == 0 else (Wout + 1e-4) / Win + + add_upsample_nn(const_context, builder, op, scale_factor_h, scale_factor_w) + + +@register_mil_to_nn_mapping +def upsample_nearest_neighbor(const_context, builder, op): + scale_factor_h = op.scale_factor_height.val + scale_factor_w = op.scale_factor_width.val + + add_upsample_nn(const_context, builder, op, scale_factor_h, scale_factor_w) + + @register_mil_to_nn_mapping def upsample_bilinear(const_context, builder, op): builder.add_upsample( diff --git a/coremltools/converters/mil/converter.py b/coremltools/converters/mil/converter.py index a007d7fe4..311c80fa9 100644 --- a/coremltools/converters/mil/converter.py +++ b/coremltools/converters/mil/converter.py @@ -235,10 +235,14 @@ def _mil_convert( return modelClass(package_path, useCPUOnly=kwargs.get("useCPUOnly", False), # important: keep the default "useCPUOnly" flag to False is_temp_package=not kwargs.get('package_dir'), - mil_program=mil_program) + mil_program=mil_program, + skip_model_load=kwargs.get('skip_model_load', False)) # important: keep the default "useCPUOnly" flag to False - return modelClass(proto, useCPUOnly=kwargs.get("useCPUOnly", False), mil_program=mil_program) + return modelClass(proto, + useCPUOnly=kwargs.get("useCPUOnly", False), + mil_program=mil_program, + skip_model_load=kwargs.get('skip_model_load', False)) def mil_convert_to_proto( model, diff --git a/coremltools/converters/mil/frontend/tensorflow/ops.py b/coremltools/converters/mil/frontend/tensorflow/ops.py index 19504b719..b4c354fe4 100644 --- a/coremltools/converters/mil/frontend/tensorflow/ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/ops.py @@ -1090,6 +1090,51 @@ def Fill(context, node): context.add(node.name, x) +@register_tf_op +def ImageProjectiveTransformV2(context, node): + # Data shape format: [batch, height, width, channels] + x = context[node.inputs[0]] + # Transforms shape format: [batch, 8] or [1, 8] matrix, [a0, a1, a2, b0, b1, b2, c0, c1] + transforms = context[node.inputs[1]] + # 1-D Tensor [new_height, new_width] + output_shape = context[node.inputs[2]] + + if len(node.inputs) > 3: + raise NotImplementedError("'interpolation', 'fill_mode' not supported") + + # Don't allow non-zero c0 or c1, check for each batch + n_batch = transforms.val.shape[0] + transform_matrix = _np.empty((n_batch, 6)) + for b in range(n_batch): + c0 = transforms.val[b][6] + c1 = transforms.val[b][7] + if not (c0 == c1 == 0.0): + raise NotImplementedError( + "'affine' op with 'transforms' contains non-zero " + + "c0 or c1 is not supported, Got: {}".format( + transforms + ) + ) + # drop c0 and c1 values from the transform matrix + transform_matrix[b] = _np.delete(transforms.val[b], [6, 7]) + + x = _transpose_NHWC_to_NCHW(x) + x = mb.affine( + x=x, + transform_matrix=transform_matrix, + output_height=output_shape.val[0], + output_width=output_shape.val[1], + sampling_mode="bilinear", + padding_mode="constant", + padding_value=0.0, + coordinates_mode="unnormalized", + align_corners=True, + name=node.name + "_affine", + ) + x = _transpose_NCHW_to_NHWC(x, node.name) + context.add(node.name, x) + + @register_tf_op def RealDiv(context, node): x = context[node.inputs[0]] @@ -1098,6 +1143,38 @@ def RealDiv(context, node): context.add(node.name, x) +@register_tf_op(tf_alias=["Addons>Resampler"]) +def Resampler(context, node): + # Data shape format: (Batch, Hin, Win, C) + x = context[node.inputs[0]] + # Warp shape format: (Batch, Hout, Wout, 2) + warp = context[node.inputs[1]] + + # Handle rank-3 warp tensor + is_rank3_warp = warp.rank == 3 + if is_rank3_warp: # expand spatial dimension + warp = mb.expand_dims(x=warp, axes=[1], name=warp.name + "_expand_dims") + + x = _transpose_NHWC_to_NCHW(x) + x = mb.resample( + x=x, + coordinates=warp, + sampling_mode="bilinear", + padding_mode="constant", + padding_value=0.0, + coordinates_mode="unnormalized", + align_corners=True, + name=node.name + "_resample", + ) + x = _transpose_NCHW_to_NHWC( + x, node.name + "_transpose" if is_rank3_warp else node.name + ) + if is_rank3_warp: # squeeze spatial dimension + x = mb.squeeze(x=x, axes=[1], name=node.name) + + context.add(node.name, x) + + @register_tf_op def Rsqrt(context, node): x = context[node.inputs[0]] @@ -1797,7 +1874,11 @@ def Select(context, node): if rank_cond == 1 and rank_a > 1: axes = [-i - 1 for i in range(rank_a - rank_cond)] cond = mb.expand_dims(x=cond, axes=axes) - + + if not types.is_bool(cond.dtype): + # cond must be bool type + cond = mb.cast(x=cond, dtype="bool") + x = mb.select(cond=cond, a=a, b=b, name=node.name) context.add(node.name, x) @@ -2168,7 +2249,7 @@ def ResizeNearestNeighbor(context, node): raise ValueError( '"ResizeNearestNeighbor" op: the second input, which is the output size, must have 2 elements' ) - + Hout, Wout = None, None if context[node.inputs[1]].val is None: # for the dynamic input shape case, # context[node.inputs[1]] is a mul(x=input_shape, y=scaling_factor) op. @@ -2186,16 +2267,48 @@ def ResizeNearestNeighbor(context, node): # first transpose to from channel last to channel first format for coreml x = _transpose_NHWC_to_NCHW(x) - # add the upsample layer - x = mb.upsample_nearest_neighbor( - x=x, - scale_factor_height=scaling_factor_h, - scale_factor_width=scaling_factor_w, - name=node.name + "_channel_first_upsample", - ) + + align_corners = node.attr.get("align_corners", False) + half_pixel_centers = node.attr.get("half_pixel_centers", False) + + # add either the resize or the upsample layer + if align_corners is False and half_pixel_centers is False: + x = mb.upsample_nearest_neighbor( + x=x, + scale_factor_height=scaling_factor_h, + scale_factor_width=scaling_factor_w, + name=node.name + "_channel_first_upsample", + ) + elif align_corners is False and half_pixel_centers is True: + # if output size can be determined at compile time, + # we call the core op resize_nearest_neighbor, + # otherwise we use upsample_nearest_neighbor for approximation. + # rdar://75204549 (resize_nearest_neighbor need to support dynamic input shape) + if Hout is not None and Wout is not None: + x = mb.resize_nearest_neighbor( + x=x, + target_size_height=Hout, + target_size_width=Wout, + name=node.name + "_channel_first_resize", + ) + else: + _logging.warning('Using upsample_nearest_neighbor to approximate resize_nearest_neighbor.') + x = mb.upsample_nearest_neighbor( + x=x, + scale_factor_height=scaling_factor_h, + scale_factor_width=scaling_factor_w, + name=node.name + "_channel_first_upsample", + ) + + else: + raise NotImplementedError( + "ResizeNearestNeighbor op with align_corners={}and half_pixel_centers={} not supported".format( + align_corners, half_pixel_centers + ) + ) + # transpose again x = _transpose_NCHW_to_NHWC(x, node.name) - context.add(node.name, x) @@ -2485,7 +2598,7 @@ def Split(context, node): else: x = mb.split(x=x, num_splits=num_splits, axis=axis, name=node.name) context.add(node.name, x) - # TODO (rdar://60358242) If tf.split output is returned, there's no + # TODO : If tf.split output is returned, there's no # get_tuple nodes. Some graph pass is needed. Example: # # x = tf.placeholder(tf.float32, shape=input_shape1) @@ -2522,7 +2635,7 @@ def ScatterNd(context, node): indices = context[node.inputs[0]] updates = context[node.inputs[1]] shape = context[node.inputs[2]] - x = mb.fill(shape=shape, value=0) + x = mb.fill(shape=shape, value=types.nptype_from_builtin(updates.dtype)(0)) x = mb.scatter_nd(data=x, indices=indices, updates=updates, name=node.name) context.add(node.name, x) @@ -2585,6 +2698,8 @@ def CropAndResize(context, node): box_indices = context[node.inputs[2]] boxes = context[node.inputs[1]] box_indices = mb.expand_dims(x=box_indices, axes=[1]) + if box_indices.dtype != boxes.dtype: + box_indices = mb.cast(x=box_indices, dtype=types.builtin_to_string(boxes.dtype)) boxes = mb.concat(values=(box_indices, boxes), axis=1) # TODO: Dynamic rank: Use GetShape and select indices dynamically boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1]) diff --git a/coremltools/converters/mil/frontend/tensorflow/ssa_passes/tf_lstm_to_core_lstm.py b/coremltools/converters/mil/frontend/tensorflow/ssa_passes/tf_lstm_to_core_lstm.py index 99cba0616..1228763f7 100644 --- a/coremltools/converters/mil/frontend/tensorflow/ssa_passes/tf_lstm_to_core_lstm.py +++ b/coremltools/converters/mil/frontend/tensorflow/ssa_passes/tf_lstm_to_core_lstm.py @@ -30,7 +30,7 @@ def tf_lstm_to_core_lstm(prog): - If tf_lstm_block_cell: only cs, h output (outputs[1], outputs[6]) are consumed. Similar to above. - - batch size == 1 (due to bugs in core lstm backend impl rdar://62475041) + - batch size == 1 Inputs: @@ -68,15 +68,9 @@ def try_replace_with_core_lstm(op): else: # tf_lstm_block batch = op.x.shape[1] - # Check for unsupported configuration - # 1. Peephole is present - # TODO: rdar://62913058 ([LSTM] Incorrect output when pass peephole values to LSTM/rnn_arch) + # Check for unsupported configuration : When peephole is present if op.use_peephole.val: return False - # 2. Clip is provided - # TODO: rdar://62913148 ([LSTM] Incorrect output when clip is used for LSTM/rnn_arch) - if op.cell_clip is not None: - return False # Check if tf_lstm_block_cell can be replaced with lstm op i, cs, f, o, ci, co, h = op.outputs diff --git a/coremltools/converters/mil/frontend/tensorflow/test/test_custom_ops.py b/coremltools/converters/mil/frontend/tensorflow/test/test_custom_ops.py index 37b836586..6955f1e64 100644 --- a/coremltools/converters/mil/frontend/tensorflow/test/test_custom_ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/test/test_custom_ops.py @@ -187,7 +187,6 @@ def type_inference(self): ret_shape[axis] = k return types.tensor(x_type, ret_shape), types.tensor(types.int32, ret_shape) - # TODO: rdar://61241807 ([MIL] [Polish] Custom layer operator documentation) # Following logging is to ensure testing of TopK implemented in tf converter # default path is testing with appropriate conversion function # Log default tf topk diff --git a/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py b/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py index ad4a8d7dc..e5939641b 100644 --- a/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py +++ b/coremltools/converters/mil/frontend/tensorflow/test/test_ops.py @@ -23,6 +23,46 @@ tf = pytest.importorskip("tensorflow") +class TestContribResampler(TensorFlowBaseTest): + @pytest.mark.parametrize( + "use_cpu_only, backend, data_warp_shapes", + itertools.product( + [True, False], + backends, + [ + # Data shape format: (Batch, Hin, Win, C) + # Warp shape format: (Batch, Hout, Wout, 2) + [(1, 3, 3, 1), (1, 3, 3, 2)], # no size change + [(2, 5, 5, 3), (2, 3, 3, 2)], # down-sampling + [(3, 6, 6, 1), (3, 8, 8, 2)], # up-sampling + [(1, 3, 9, 1), (1, 19, 2)], # rank-3 warp tensor + ], + ), + ) + def test( + self, use_cpu_only, backend, data_warp_shapes, + ): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + data_shape, warp_shape = data_warp_shapes + + @make_tf_graph([data_shape, warp_shape]) + def build_model(x, warp): + return tf.contrib.resampler.resampler(data=x, warp=warp) + + model, inputs, outputs = build_model + # warp exceeding input sizes in order to test more padding modes + input_values = [ + random_gen(data_shape, -100, 100), + random_gen(warp_shape, -15, 15), + ] + input_dict = dict(zip(inputs, input_values)) + self.run_compare_tf( + model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend, + ) + + class TestDebugging(TensorFlowBaseTest): """ TF converter does not handling debugging nodes, they are @@ -571,10 +611,6 @@ def test_where(self, use_cpu_only, backend, rank): ) -@pytest.mark.xfail( - condition=backends[0] == "mlprogram", - reason="Investigate failure rdar://78630549" -) class TestCast(TensorFlowBaseTest): @pytest.mark.parametrize('use_cpu_only, backend, rank, dtype', itertools.product( @@ -586,6 +622,9 @@ class TestCast(TensorFlowBaseTest): def test(self, use_cpu_only, backend, rank, dtype): shape = np.random.randint(low=1, high=3, size=rank) + if backend == "mlprogram" and use_cpu_only and dtype == "int32": + pytest.xfail("rdar://78630549") + @make_tf_graph([shape]) def build_model(x): y = tf.cast(x, dtype=dtype) @@ -1250,7 +1289,7 @@ class TestDepthwiseConv(TensorFlowBaseTest): [ (1, 1), (2, 2), - ], # rdar://60668562 (MIL: Conversion for TF op 'SpaceToBatchND' not implemented.) + ], [True, False], [1, 3], ), @@ -1759,7 +1798,7 @@ def test_binary_compare(self, use_cpu_for_conversion, backend, rank, tf_op, use_cpu_only = use_cpu_for_conversion if backend == "mlprogram" and not use_cpu_for_conversion: - pytest.xfail("Error in building plan : MIL GPU backend failure. rdar://77442362") + pytest.xfail("Error in building plan : MIL GPU backend failure. rdar://78218824") x_shape = y_shape = list(np.random.randint(low=2, high=4, size=rank)) @@ -1818,7 +1857,7 @@ def test_binary_logical(self, use_cpu_for_conversion, backend, rank, tf_op, use_cpu_only = use_cpu_for_conversion if backend == "mlprogram" and not use_cpu_for_conversion: - pytest.xfail("Error in building plan : MIL GPU backend failure. rdar://77442362") + pytest.xfail("Error in building plan : MIL GPU backend failure. rdar://78218824") x_shape = y_shape = list(np.random.randint(low=2, high=4, size=rank)) @@ -1872,7 +1911,7 @@ class TestEinsum(TensorFlowBaseTest): ) def test(self, use_cpu_for_conversion, backend, equation, reverse_input_order): if backend == "mlprogram" and equation == "abcd,adce->abce" and not use_cpu_for_conversion: - pytest.xfail("Seg fault on loading MIL model on GPU context. rdar://77442588") + pytest.xfail("Seg fault on loading MIL model on GPU context. rdar://77443711") if equation == "abcd,adce->abce": input_shapes = [[3, 4, 2, 6], [3, 6, 2, 2]] @@ -3042,14 +3081,14 @@ def test_keras_random_uniform( ) -@pytest.mark.skipif(_macos_version() <= (10, 16), +@pytest.mark.skipif(_macos_version() < (10, 16), reason="This only works for 'neuralnetwork' on macOS 11") class TestReduction(TensorFlowBaseTest): @pytest.mark.parametrize( "use_cpu_for_conversion, backend, rank_and_axes, keep_dims, tf_op", itertools.product( - [False], - ["mlprogram"], + [True, False], + backends, [ (1, (-1,)), (2, (0,)), @@ -3088,8 +3127,9 @@ def test_reduction(self, use_cpu_for_conversion, backend, rank_and_axes, keep_di rank, axes = rank_and_axes shape = np.random.randint(low=1, high=4, size=rank) - if backend == 'mlprogram' and not use_cpu_for_conversion and rank_and_axes == (5, None): - pytest.xfail("Seg fault. rdar://77443572") + if backend == 'mlprogram' and not use_cpu_for_conversion: + if rank_and_axes == (5, None) or tf_op in {tf.reduce_logsumexp}: + pytest.xfail("Seg fault. rdar://77443572") def parse_axes(axes): if axes is None: @@ -3134,9 +3174,6 @@ def build_model(x): def test_tf_reduction(): if isinstance(axes, list) and axes and len(axes) == rank and not keep_dims: - return # TODO MIL: Add rank 0 and dim size 0 related tests for every op - - if tf_op in {tf.reduce_any, tf.reduce_all, tf.reduce_logsumexp}: # Remove constraint, rdar://66610973 return input_type = list(shape) @@ -3144,7 +3181,7 @@ def test_tf_reduction(): if tf_op in {tf.reduce_all, tf.reduce_any}: input_type += [tf.bool] x_val = np.random.randint(low=0, high=2, size=shape).astype( - np.float32 + np.bool ) elif tf_op in {tf.math.reduce_euclidean_norm}: x_val = random_gen(shape=shape, rand_min=0.0, rand_max=10.0) @@ -3180,7 +3217,6 @@ def build_model(x): test_tf_reduction() class TestGather(TensorFlowBaseTest): - # TODO: [MIL] Gather layer with 0-d indices leads to input shape mismatch @pytest.mark.parametrize( "use_cpu_only, backend, rankX_rankIndices_axis, mode", itertools.product( @@ -3304,7 +3340,7 @@ def test_scatter_nd_with_zeros( indices_shape[-1] = np.random.randint(low=1, high=data_rank + 1) updates_shape = list(indices_shape[:-1]) + list(shape[indices_shape[-1] :]) - updates = np.random.rand(*updates_shape).astype(np.float32) + updates = np.random.rand(*updates_shape).astype(np.int32) indices_list = [] for i in range(indices_shape[-1]): indices_list.append(np.random.randint(0, shape[i], size=indices_shape[:-1])) @@ -3312,7 +3348,7 @@ def test_scatter_nd_with_zeros( indices = np.stack(indices_list, axis=-1).astype(np.int32) @make_tf_graph( - [list(indices.shape) + [tf.int32], updates_shape, [data_rank, tf.int32]] + [list(indices.shape) + [tf.int32], updates_shape + [tf.int32], [data_rank, tf.int32]] ) def build_model(indices, updates, shape): return tf.raw_ops.ScatterNd(indices=indices, updates=updates, shape=shape) @@ -3658,8 +3694,6 @@ class TestSliceBySize(TensorFlowBaseTest): def test_slice_by_size( self, use_cpu_only, backend, rank, single_size, dynamic_size ): - if dynamic_size == False and backend != "neuralnetwork": - pytest.xfail("TODO: activate after rdar://75290346 is fixed and is in the build. Tracked by rdar://75823380") input_shape = np.random.randint(low=2, high=4, size=rank) begin_val = np.array( [np.random.randint(input_shape[i]) for i in range(rank)] @@ -3967,10 +4001,6 @@ def build_model(x): test_tf_static() test_tf_dynamic() -@pytest.mark.xfail( - condition=backends[0] == "mlprogram", - reason="Investigate failure rdar://78080118" -) class TestNonMaximumSuppression(TensorFlowBaseTest): @pytest.mark.parametrize( ",".join( @@ -4379,11 +4409,10 @@ class TestSplit(TensorFlowBaseTest): ) def test_split(self, use_cpu_for_conversion, backend, rank, dynamic): if backend == "mlprogram" and not use_cpu_for_conversion: - pytest.xfail("Seg fault. rdar://77444472") + pytest.xfail("Seg fault. rdar://77443711") input_shape1 = np.random.randint(low=1, high=3, size=rank) for axis in range(-rank, rank, 2): - # FIXME: skip split_num==1 due to: rdar://63030405. Rank 0 tensor for MIL for split_num in range(2, input_shape1[axis] + 1, 2): if input_shape1[axis] % split_num != 0: continue @@ -4395,7 +4424,7 @@ def test_split(self, use_cpu_for_conversion, backend, rank, dynamic): @make_tf_graph([tf_input_shape]) def build_model(x): res = tf.split(x, split_num, axis=axis) - # TODO (rdar://60358242) If tf.split output is returned, there's no + # Comment: If tf.split output is returned, there's no # get_tuple nodes. Some graph pass is needed. Example: # # x = tf.placeholder(tf.float32, shape=input_shape1) @@ -4816,7 +4845,6 @@ class TestMatrixDiag(TensorFlowBaseTest): def test(self, use_cpu_only, backend, length, dynamic): if dynamic: - return # FIXME: "rdar://65198011 (Re-enable Conv3dTranspose and DynamicTile unit tests)" input_shape = np.random.randint(low=1, high=4, size=length) a, b = np.prod(input_shape[:2]), np.prod(input_shape[2:]) size = np.array([a,b]).astype(np.int32) @@ -5069,10 +5097,6 @@ def build_model(x, tf_perm): with pytest.raises(ValueError, match=r".*must be const at compile time.*"): dynamic_perm() - @pytest.mark.xfail( - reason="The reduce_transpose graph pass fails on a model with sequence of transpose: ", - run=False, - ) @pytest.mark.parametrize( "use_cpu_only, backend, rank", itertools.product( @@ -5082,7 +5106,11 @@ def build_model(x, tf_perm): ), ) def test_redundant_transpose(self, use_cpu_only, backend, rank): + if rank == 2: + pytest.xfail("shape mismatch on CPU, numerical mismatch on GPU, when rank is 2. rdar://79741438") + import random + random.seed(10) input_shape = np.random.randint(low=1, high=4, size=rank) num_layers = 30 perms = [] @@ -5331,9 +5359,6 @@ def build_model(x): "use_cpu_only, backend", itertools.product([True, False], backends) ) def test_tf_dynamic_elem_shape(self, use_cpu_only, backend): - # Support dynamic elem_shape - if backend != "neuralnetwork": - return # TF1: TensorArrayV3, TensorArrayWriteV3, TensorArrayScatterV3, # TensorArraySizeV3, TensorArrayGatherV3 @@ -5539,7 +5564,7 @@ def test_tf_lstm_block_cell(self, use_cpu_only, backend, batch): class TestVariable(TensorFlowBaseTest): - @pytest.mark.xfail(reason="Investigate get_global ", run=False) + @pytest.mark.xfail(reason="Investigate get_global ", run=False) @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True], backends,) ) @@ -5609,15 +5634,18 @@ def build_model(x): class TestIsFinite(TensorFlowBaseTest): @pytest.mark.parametrize( - "use_cpu_only, backend, rank, dynamic", + "use_cpu_for_conversion, backend, rank, dynamic", itertools.product( [True, False], backends, [rank for rank in range(5)], [True, False] ), ) - def test(self, use_cpu_only, backend, rank, dynamic): + def test(self, use_cpu_for_conversion, backend, rank, dynamic): if rank == 0: pytest.xfail('Rank 0 not supported by CoreML runtime') + if backend == "mlprogram" and not use_cpu_for_conversion: + pytest.xfail("rdar://78343191") + def _generate_num_with_inf(input_shape): res = random_gen(input_shape, rand_min=-1, rand_max=1) random_map = np.random.choice([np.inf, -np.inf, 0], size=input_shape) @@ -5663,9 +5691,10 @@ def build_model(x): model, input_dict, outputs, - use_cpu_only=use_cpu_only, frontend_only=False, backend=backend, + use_cpu_for_conversion=use_cpu_for_conversion, + use_cpu_only=use_cpu_for_conversion, ) class TestLogSoftMax(TensorFlowBaseTest): diff --git a/coremltools/converters/mil/frontend/tensorflow/tfssa.py b/coremltools/converters/mil/frontend/tensorflow/tfssa.py index 393c81063..004f42c02 100644 --- a/coremltools/converters/mil/frontend/tensorflow/tfssa.py +++ b/coremltools/converters/mil/frontend/tensorflow/tfssa.py @@ -109,8 +109,6 @@ def find_inputs_and_outputs(self): # we use function entry and exit points if available # otherwise we find graph entry and exit points - # TODO: op name should be fixed here. - # Remove wrappers that are used for old tfssa enters = [ n.name for n in self.graph.values() if ("entry" in n.op or "Entry" in n.op) ] diff --git a/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops.py b/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops.py index 8699833c4..61d711668 100644 --- a/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops.py +++ b/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops.py @@ -5,6 +5,7 @@ import itertools import numpy as np +from coremltools._deps import version_lt from coremltools.converters.mil import testing_reqs from coremltools.converters.mil.frontend.tensorflow.test import ( testing_utils as tf_testing_utils, @@ -114,6 +115,243 @@ del TestWhileLoop.test_nested_while_body # tf.function() error in TF2 +class TestImageResample(TensorFlow2BaseTest): + @pytest.mark.skipif(condition=version_lt(tf, "2.4"), + reason="tfa.image.resample requires TF 2.4+") + @pytest.mark.parametrize( + "use_cpu_only, backend, data_warp_shapes", + itertools.product( + [True, False], + backends, + [ + # Data shape format: (Batch, Hin, Win, C) + # Warp shape format: (Batch, Hout, Wout, 2) + [(1, 3, 3, 1), (1, 3, 3, 2)], # no size change + [(2, 5, 5, 3), (2, 3, 3, 2)], # down-sampling + [(3, 6, 6, 1), (3, 8, 8, 2)], # up-sampling + ], + ), + ) + def test_resample( + self, use_cpu_only, backend, data_warp_shapes, + ): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + tfa = pytest.importorskip("tensorflow_addons") + + data_shape, warp_shape = data_warp_shapes + + @make_tf_graph([data_shape, warp_shape]) + def build_model(x, warp): + return tfa.image.resampler(data=x, warp=warp) + + model, inputs, outputs = build_model + # warp exceeding input sizes in order to test more padding modes + input_values = [ + random_gen(data_shape, -100, 100), + random_gen(warp_shape, -15, 15), + ] + input_dict = dict(zip(inputs, input_values)) + self.run_compare_tf2( + model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend, + ) + + +class TestImageTransform(TensorFlow2BaseTest): + @pytest.mark.skip( + "TODO: rdar://73165549 (Add other mode in 'affine' to coremltools when backend is ready)" + ) + @pytest.mark.parametrize( + "use_cpu_only, backend, transforms, interpolation, shapes", + itertools.product( + [True], + backends, + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, -250, 0.0, 1.0, 0.0, 0.0, 0.0], + [1.25, -1.75, 25.0, -25.0, 1.5, -1.5, 0.0, 0.0], + ], + ["BILINEAR"], + [ + ((1, 2, 2, 1), None), + ((2, 2, 2, 1), (2, 3)), + ((3, 5, 5, 2), (4, 4)), + ((1, 3, 3, 2), (6, 6)), + ((3, 50, 50, 2), (20, 20)), + ], + ), + ) + def test(self, use_cpu_only, backend, transforms, interpolation, shapes): + x_shape, output_shape = shapes + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + tfa = pytest.importorskip("tensorflow_addons") + + @make_tf_graph([x_shape]) + def build_model(x): + return tfa.image.transform( + x, + transforms=transforms, + interpolation=interpolation, + output_shape=output_shape, + ) + + model, inputs, outputs = build_model + input_values = [ + random_gen(x_shape, -100, 100), + ] + input_dict = dict(zip(inputs, input_values)) + self.run_compare_tf2( + model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend, + ) + + +class TestActivationSiLU(TensorFlow2BaseTest): + @pytest.mark.parametrize( + "use_cpu_only, backend, rank, tf_op", + itertools.product( + [True, False], + backends, + list(range(1, 6)), + [ + tf.nn.swish, # TODO(yuduo): in TF 2.4.0+, it's renamed to tf.nn.silu, + tf.keras.activations.swish, + ], + ), + ) + def test(self, use_cpu_only, backend, rank, tf_op): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + x_shape = tuple(np.random.randint(low=1, high=4, size=rank)) + + @make_tf_graph([x_shape]) + def build_model(x): + return tf_op(x) + + model, inputs, outputs = build_model + input_values = [ + random_gen(x_shape, -100, 100), + ] + input_dict = dict(zip(inputs, input_values)) + self.run_compare_tf2( + model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend, + ) + + +class TestResizeNearestNeighbor(TensorFlow2BaseTest): + @pytest.mark.parametrize( + "use_cpu_only, backend, input_shape, target_shape, align_corners, half_pixel_centers", + itertools.product( + [True, False], + backends, + [(1, 10, 20, 1), (2, 5, 1, 3)], + [(25, 30), (2, 20)], + [False], + [True, False], + ), + ) + def test_raw_ops( + self, + use_cpu_only, + backend, + input_shape, + target_shape, + align_corners, + half_pixel_centers, + ): + if align_corners is True and half_pixel_centers is True: + return + + if backend == "neuralnetwork": + # neural network backend does not support fractional scale factors for nearest neighbor upsample op + if target_shape[-1] % input_shape[-1] != 0: + return + if target_shape[-2] % input_shape[-2] != 0: + return + + if not use_cpu_only and not half_pixel_centers and backend == "mlprogram": + # use_cpu_only == False & half_pixel_centers == False, & backend == mlprogram + # then there are numerical errors + pytest.xfail("rdar://78321005") + + + @make_tf_graph([input_shape]) + def build_model(x): + return tf.raw_ops.ResizeNearestNeighbor( + images=x, + size=target_shape, + align_corners=align_corners, + half_pixel_centers=half_pixel_centers, + ) + + model, inputs, outputs = build_model + input_values = [random_gen(input_shape, -100, 100)] + input_dict = dict(zip(inputs, input_values)) + self.run_compare_tf2( + model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend, + ) + + @pytest.mark.parametrize( + "use_cpu_only, backend, size", + itertools.product([True, False], backends, [(1, 1), (2, 3), (4, 1)]), + ) + def test_keras_layer(self, use_cpu_only, backend, size): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + x_shape = tuple(np.random.randint(low=1, high=4, size=4)) + + @make_tf_graph([x_shape]) + def build_model(x): + return tf.keras.layers.UpSampling2D( + size=size, interpolation="nearest", + )(x) + + model, inputs, outputs = build_model + input_values = [random_gen(x_shape, -100, 100)] + input_dict = dict(zip(inputs, input_values)) + self.run_compare_tf2( + model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend, + ) + + @pytest.mark.parametrize( + "use_cpu_only, backend, size, method", + itertools.product( + [True, False], + backends, + [(1, 1), (2, 3)], + [tf.image.ResizeMethod.NEAREST_NEIGHBOR], + ), + ) + def test_tf_image_resize(self, use_cpu_only, backend, size, method): + if backend == "mlprogram" and not use_cpu_only: + pytest.xfail("rdar://78343225 ((MIL GPU) Core ML Tools Unit Test failures [numerical error])") + + if backend == "mlprogram" and size == (1, 1): + pytest.xfail("rdar://79699954 (Nearest neighbor resize numerical mismatch when output size is (1,1))") + + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + x_shape = tuple(np.random.randint(low=1, high=3, size=4)) + + @make_tf_graph([x_shape]) + def build_model(x): + return tf.image.resize(x, size=size, method=method) + + model, inputs, outputs = build_model + input_values = [ + random_gen(x_shape, -100, 100), + ] + input_dict = dict(zip(inputs, input_values)) + self.run_compare_tf2( + model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend, + ) + + class TestNormalizationTF2(TensorFlowBaseTest): @pytest.mark.parametrize( "use_cpu_only, func, backend, epsilon", diff --git a/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops_tf_keras.py b/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops_tf_keras.py index 9c7603d98..453d817ff 100644 --- a/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops_tf_keras.py +++ b/coremltools/converters/mil/frontend/tensorflow2/test/test_v2_ops_tf_keras.py @@ -198,8 +198,8 @@ class TestConvolution(TensorFlowBaseTest): [(2, 4, 4, 2, 2, 2), (3, 7, 5, 1, 3, 2)], [(1, 1, 1), (1, 2, 3), (1, 3, 2)], [ - (1, 1, 1) - ], # rdar://62951360 (Enhance SpaceToBatchND op to support more dialation rate of Conv) + (1, 1, 1), (2, 2, 2), + ], [1, 3], [1, 2], ), @@ -221,6 +221,16 @@ def test_conv( if _get_version(_tf.__version__) < _StrictVersion("2.5.0") and groups != 1: return + # TF does not support strides > 1 in conjunction with dilation_rate > 1 + for i, stride in enumerate(strides): + if stride > 1 and dilations[i] > 1: + return + + # Dilations with Conv3D not supported yet, since SpaceToBatchND is only supported for ranks 3 or 4 + for d in dilations: + if d > 1 and op == tf.keras.layers.Conv3D: + return + s1, s2, s3, k1, k2, k3 = spatial_dim_and_ks c_in, c_out = 2, 4 input_shape = None @@ -277,7 +287,7 @@ def test_conv( ), itertools.product( [True, False], - ["neuralnetwork"], # rdar://66998312 ([MIL] concat layer with variable length input support) + backends, [ tf.keras.layers.LocallyConnected1D, tf.keras.layers.LocallyConnected2D, @@ -287,8 +297,8 @@ def test_conv( [(2, 4, 4, 2, 2, 2), (3, 7, 5, 1, 3, 2)], [(1, 1, 1), (1, 2, 3), (1, 3, 2)], [ - (1, 1, 1) - ], # rdar://62951360 (Enhance SpaceToBatchND op to support more dialation rate of Conv) + (1, 1, 1), (2, 2, 2), + ], [1, 3], ), ) @@ -1225,9 +1235,6 @@ def test_lstm_time_distributed_dense(self, use_cpu_only, backend): "use_cpu_only, backend", itertools.product([True, False], backends) ) def test_lstm_dynamic_batch(self, use_cpu_only, backend): - # Support dynamic elem_shape - if backend != "neuralnetwork": - return input_shape = (1, 1280) inp = tf.keras.layers.Input(shape=input_shape) h0 = tf.keras.layers.Input(shape=(512,)) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 3c018eb0d..bb7b7c9d2 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -216,6 +216,111 @@ def _construct_constant(val, name): return mb.const(val=val, name=name) +@register_torch_op +def affine_grid_generator(context, node): + # rdar://73165386 (Improve error handling of coremltools "affine" op PyTorch conversion.) + + affine_op_name = node.name + theta, size, align_corners = _get_inputs(context, node, expected=3) + + # note: only add consts here as PyTorch uses affine_grid + grid_sampler together + is_theta_const = theta.val is not None + if is_theta_const: + context.add(mb.const(val=theta.val, name="{}_theta".format(affine_op_name))) + else: # theta is dynamic input, keep track of it's name + context.add(mb.const(val=theta.name, name="{}_theta".format(affine_op_name))) + + context.add(mb.const(val=size.val, name="{}_size".format(affine_op_name))) + context.add(mb.const(val=align_corners.val, name="{}_align_corners".format(affine_op_name))) + + +@register_torch_op +def grid_sampler(context, node): + affine_op_name = node.inputs[1] + # https://github.com/pytorch/pytorch/blob/00d432a1ed179eff52a9d86a0630f623bf20a37a/aten/src/ATen/native/GridSampler.h#L10-L11 + m_mode = {0: "bilinear", 1: "nearest"} + m_padding_mode = {0: "constant", 1: "border", 2: "reflection"} + + # add `resample` if grid/coordinates is in input, otherwise, + # add `affine` to generate grid from `affine_grid_generator`. + if affine_op_name in context: # add `resample` op + inputs = _get_inputs(context, node, expected=5) + sampling_mode = m_mode[inputs[2].val] + padding_mode = m_padding_mode[inputs[3].val] + align_corners = inputs[4].val + + # When align_corners=False, padding_mode is corresponding to Core ML's symmetric + if padding_mode == "reflection" and align_corners is False: + padding_mode = "symmetric" + + x = mb.resample( + x=inputs[0], + coordinates=inputs[1], + sampling_mode=sampling_mode, + padding_mode=padding_mode, + padding_value=0.0, + coordinates_mode="normalized_minus_one_to_one", + align_corners=align_corners, + name=node.name, + ) + context.add(x) + else: # add `affine` op instead + x = context[node.inputs[0]] + # inputs from `affine_grid_generator` + affine_theta = context["{}_theta".format(affine_op_name)] + affine_size = context["{}_size".format(affine_op_name)] + affine_align_corners = context["{}_align_corners".format(affine_op_name)] + + # affine_theta.val is either name string (dynamic input) or np.ndarray (static values) + # see `affine_grid_generator` for details. + is_theta_const = not isinstance(affine_theta.val, str) + if is_theta_const: + transform_matrix = _np.reshape(affine_theta.val, (affine_theta.shape[0], 6)) + else: # theta is dynamic input, add `reshape` op to PyMIL + transform_matrix = mb.reshape( + x=context[affine_theta.val], + shape=(-1, 6), + name=node.name + "_theta_reshape", + ) + + # inputs from `grid_sampler` + sampling_mode = m_mode[context[node.inputs[2]].val] + padding_mode = m_padding_mode[context[node.inputs[3]].val] + align_corners = context[node.inputs[4]].val + + if sampling_mode != "bilinear": + raise NotImplementedError("'sampling_mode' not supported.") + + if padding_mode != "constant": + raise NotImplementedError("'padding_mode' not supported.") + + if affine_align_corners.val != align_corners: + raise ValueError( + "Op 'affine_grid_generator' and 'grid_sampler' must agree on 'align_corners'." + ) + + x = mb.affine( + x=x, + transform_matrix=transform_matrix, + output_height=affine_size.val[2], + output_width=affine_size.val[3], + sampling_mode=sampling_mode, + padding_mode=padding_mode, + padding_value=0.0, + coordinates_mode="normalized_minus_one_to_one", + align_corners=align_corners, + name=node.name, + ) + context.add(x) + + +@register_torch_op +def silu(context, node): + inputs = _get_inputs(context, node, expected=1) + x = mb.silu(x=inputs[0], name=node.name) + context.add(x) + + @register_torch_op def constant(context, node): assert len(node.inputs) == 0 @@ -380,10 +485,8 @@ def addmm(context, node): # addmm(Tensor input, Tensor mat1, Tensor mat2, Scalar beta=1, Scalar alpha=1) # output = beta * input + alpha * mat1 * mat2 - assert len(node.inputs) == 5 assert len(node.outputs) == 1 - - inputs = [context[name] for name in node.inputs] + inputs = _get_inputs(context, node, expected=5) bias = inputs[0] mat1 = inputs[1] mat2 = inputs[2] @@ -408,6 +511,14 @@ def addmm(context, node): addmm_node = mb.linear(x=mat1, weight=mat2, bias=bias, name=node.name) context.add(addmm_node) +@register_torch_op +def linear(context, node): + inputs = _get_inputs(context, node, expected=[2, 3]) + x = inputs[0] + W = inputs[1] + bias = inputs[2] if len(node.inputs) == 3 else None + res = mb.linear(x=x, weight=W, bias=bias, name=node.name) + context.add(res) @register_torch_op(torch_alias=["conv2d"]) def _convolution(context, node): @@ -486,8 +597,6 @@ def _convolution(context, node): # # For ConvTranspose2d: [bottom, right] -> [0, b, 0, r] output_padding = [0 if i % 2 == 0 else out_pad[i//2] for i in range(len(pad))] - # TODO: rdar://65588783 ([PyTorch] Define and error out on unsupported configuration for output_padding) - # error out here with unsupported configuration along with output padding if sum(pad) == 0 and any(output_padding): raise ValueError("ConvTranspose configuration of padding=0 and output_padding > 0 not supported!") post_crop = pad.copy() @@ -764,9 +873,30 @@ def maximum(context, node): @register_torch_op def div(context, node): - inputs = _get_inputs(context, node, expected=2) - - res = mb.real_div(x=inputs[0], y=inputs[1], name=node.name) + inputs = _get_inputs(context, node, expected=[2,3]) + + if len(inputs) > 2 and inputs[2] is not None: + rounding_mode = inputs[2].val + if rounding_mode == "floor": + # round towards negative infinity + # e.g.: + # values before floor: [2.6, -3.4, -3.6] + # values after floor: [2, -4, -4] + res = mb.floor_div(x=inputs[0], y=inputs[1], name=node.name) + elif rounding_mode == "trunc": + # round towards 0 + # e.g.: + # values before trunc: [2.6, -3.4, -3.6] + # values after trunc: [2, -3, -3] + z = mb.real_div(x=inputs[0], y=inputs[1]) + s = mb.sign(x=z) + all_positive = mb.mul(x=z, y=s) + all_positive_floor = mb.floor(x=all_positive) + res = mb.mul(x=all_positive_floor, y=s, name=node.name) + else: + raise NotImplementedError("rounding mode \"{}\" not supported in the \"div\" op".format(rounding_mode)) + else: + res = mb.real_div(x=inputs[0], y=inputs[1], name=node.name) context.add(res) @@ -1194,7 +1324,7 @@ def _add_batch_norm_3d(): is_batch_norm_2d = (input_rank == 3 or input_rank == 4) is_batch_norm_3d = input_rank == 5 - if training or running_mean.val is None or running_var.val is None: + if training or running_mean.val is None or running_var.val is None or weight is None or bias is None: _add_batch_norm_dynamic() elif is_batch_norm_1d: _add_batch_norm_1d() @@ -2037,7 +2167,7 @@ def lstm(context, node): out1, out2 = mb.split(x=out_state_tensors_list[i], num_splits=2, axis=1) # each output of shape [B, H] after the split out = mb.stack(values=[out1, out2], axis=0) # [2, B, H] list_of_tensors_to_stack.append(out) - final_out = mb.concat(values=list_of_tensors_to_stack, axis=0) # output of shape (num_layers * 2, B, H) + final_out = mb.concat(values=list_of_tensors_to_stack, axis=0, name=name) # output of shape (num_layers * 2, B, H) context.add(final_out, name) else: if num_layers == 1: @@ -2672,7 +2802,7 @@ def log_softmax(context, node): res = mb.log(x=res, name=node.name) context.add(res) -@register_torch_op +@register_torch_op(torch_alias=["nll_loss_nd"]) def nll_loss(context, node): inputs = _get_inputs(context, node, expected=5) @@ -2999,12 +3129,16 @@ def masked_fill(context, node): value = inputs[2] # @mb.select does not properly broadcast scalar input, so as a workaround # we create a full sized tensor. - # rdar://61463562 if types.is_int(value.dtype): # @mb.fill cannot handle value with dtype integer # so we cast the value. value = mb.cast(x=value, dtype="fp32") + + if not types.is_bool(mask.dtype): + # cond must be bool type + mask = mb.cast(x=mask, dtype="bool") + shape = mb.shape(x=x, name=node.name + "_shape") value = mb.fill(shape=shape, value=value, name=node.name + "_value") res = mb.select(cond=mask, a=value, b=x, name=node.name) @@ -3384,7 +3518,12 @@ def is_floating_point(context, node): @register_torch_op def where(context, node): inputs = _get_inputs(context, node, expected=3) - context.add(mb.select(cond=inputs[0], a=inputs[1], b=inputs[2], name=node.name)) + cond = inputs[0] + if not types.is_bool(cond.dtype): + # cond must be bool type + cond = mb.cast(x=cond, dtype="bool") + + context.add(mb.select(cond=cond, a=inputs[1], b=inputs[2], name=node.name)) @register_torch_op def neg(context, node): diff --git a/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py b/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py index a8280e2e0..caf69af07 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py +++ b/coremltools/converters/mil/frontend/torch/test/test_internal_graph.py @@ -1271,6 +1271,11 @@ def test_avg_pool1d( ): if pad > kernel_size / 2: return + + if ceil_mode: + if kernel_size == 1 and stride == 2 and pad == 0 and input_shape[-1] == 10: + pytest.xfail("Torch ceil_mode does not match exactly with CoreML's ceil_mode. rdar://80050546") + test_input = torch.rand(input_shape) expected_result = F.avg_pool1d( test_input, @@ -1305,6 +1310,11 @@ def test_avg_pool2d( ): if pad > kernel_size / 2: return + + if ceil_mode: + if kernel_size == 1 and stride == 2 and pad == 0 and input_shape[-1] == 10: + pytest.xfail("Torch ceil_mode does not match exactly with CoreML's ceil_mode. rdar://80050546") + test_input = torch.rand(input_shape) expected_result = F.avg_pool2d( test_input, @@ -1341,6 +1351,11 @@ def test_max_pool1d( ): if pad > kernel_size / 2: return + + if ceil_mode: + if kernel_size == 1 and stride == 2 and pad == 0 and input_shape[-1] == 10: + pytest.xfail("Torch ceil_mode does not match exactly with CoreML's ceil_mode. rdar://80050546") + test_input = torch.rand(input_shape) expected_result = F.max_pool1d( test_input, @@ -1373,6 +1388,11 @@ def test_max_pool2d( ): if pad > kernel_size / 2: return + + if ceil_mode: + if kernel_size == 1 and stride == 2 and pad == 0 and input_shape[-1] == 10: + pytest.xfail("Torch ceil_mode does not match exactly with CoreML's ceil_mode. rdar://80050546") + test_input = torch.rand(input_shape) expected_result = F.max_pool2d( test_input, diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index c56325626..73c75b2b6 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -13,7 +13,7 @@ from coremltools.converters.mil.testing_reqs import * from .testing_utils import * from coremltools import TensorType - +from coremltools._deps import version_lt pytestmark = pytest.mark.skipif( sys.version_info >= (3, 8), reason="Segfault with Python 3.8+" @@ -31,6 +31,100 @@ COMMON_SHAPES = [(1, 10), (1, 5, 6), (1, 3, 5, 6), (1, 3, 4, 5, 6)] COMMON_SHAPES_ALL = [(1, )] + COMMON_SHAPES + +class TestAffineGrid(TorchBaseTest): + @pytest.mark.parametrize( + "backend, x_shape_and_target_size, " + "sampling_mode, padding_mode, align_corners", + itertools.product( + backends, + [ + # shape format: (Batch, Channel, Height, Width) + [(1, 1, 3, 3), (1, 1, 3, 3)], # no size change + [(2, 3, 5, 5), (2, 3, 3, 2)], # down-sampling + [(3, 1, 6, 6), (3, 1, 8, 8)], # up-sampling + ], + ["bilinear"], + ["zeros"], + [True], + ), + ) + def test( + self, + backend, + x_shape_and_target_size, + sampling_mode, + padding_mode, + align_corners, + ): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + x_shape, target_size = x_shape_and_target_size + theta = torch.rand((x_shape[0], 2, 3)) + + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.affine_grid = torch.nn.functional.affine_grid + self.grid_sample = torch.nn.functional.grid_sample + + def forward(self, x): + grid = self.affine_grid( + theta=theta, size=target_size, align_corners=align_corners, + ) + x = self.grid_sample( + x, + grid=grid, + mode=sampling_mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + return x + + model = TestModule() + self.run_compare_torch(x_shape, model, backend=backend) + + +class TestGridSample(TorchBaseTest): + @pytest.mark.parametrize( + "backend, data_grid_shapes, mode, padding_mode, align_corners", + itertools.product( + backends, + [ + # Input shape format: (Batch, C, Hin, Win) + # Grid shape format: (Batch, Hout, Wout, 2) + [(1, 1, 3, 3), (1, 3, 3, 2)], # no size change + [(2, 3, 5, 5), (2, 3, 3, 2)], # down-sampling + [(3, 1, 6, 6), (3, 8, 8, 2)], # up-sampling + ], + ["bilinear", "nearest"], + ["zeros", "border", "reflection"], + [True, False], + ), + ) + def test( + self, + backend, + data_grid_shapes, + mode, + padding_mode, + align_corners, + ): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + params = { + "mode": mode, + "padding_mode": padding_mode, + "align_corners": align_corners, + } + model = ModuleWrapper( + function=torch.nn.functional.grid_sample, kwargs=params + ) + self.run_compare_torch(data_grid_shapes, model, backend=backend) + + class TestNLLLoss(TorchBaseTest): @pytest.mark.parametrize( "reduction, backend", @@ -84,24 +178,42 @@ def test_argsort(self, shape, axis, descending, backend): class TestBatchNorm(TorchBaseTest): @pytest.mark.parametrize( - "num_features, eps, backend", - itertools.product([5, 3, 1], [0.1, 1e-05], backends), + "num_features, eps, affine, backend", + itertools.product([5, 3, 1], [0.1, 1e-05], [True, False], backends), ) - def test_batchnorm(self, num_features, eps, backend): - model = nn.BatchNorm2d(num_features, eps) + def test_batchnorm(self, num_features, eps, affine, backend): + model = nn.BatchNorm2d(num_features, eps, affine=affine) self.run_compare_torch((6, num_features, 5, 5), model, backend=backend) @pytest.mark.parametrize( - "num_features, eps, dynamic_input, backend", - itertools.product([5, 1], [0.1, 1e-05], ["None", "Batch", "Height", "Width", "Depth", "All"], backends), + "affine, backend", + itertools.product([True, False], backends), ) - def test_batchnorm_3d(self, num_features, eps, dynamic_input, backend): - if backend != "neuralnetwork" and dynamic_input == "All" and num_features == 5: + def test_batchnorm_2d_with_conv(self, affine, backend): + class CRNNBase(nn.Module): + def __init__(self, ch_in, ch_out, kernel_size=3): + super(CRNNBase, self).__init__() + self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size) + self.norm = nn.BatchNorm2d(ch_out, affine=affine) + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + model = CRNNBase(ch_in=6, ch_out=16) + self.run_compare_torch((1, 6, 15, 30), model, backend=backend) + + @pytest.mark.parametrize( + "num_features, eps, affine, dynamic_input, backend", + itertools.product([5, 1], [0.1, 1e-05], [True, False], ["None", "Batch", "Height", "Width", "Depth", "All"], backends), + ) + def test_batchnorm_3d(self, num_features, eps, affine, dynamic_input, backend): + if backend != "neuralnetwork" and num_features == 5 and (dynamic_input == "All" or affine == False): pytest.xfail("rdar://75770475 ([ActivateMIL] Failure in " "test_ops_public_torch.py::TestBatchNorm::test_batchnorm_3d " "[elementwise_kernel_cpu: Cannot broadcast])") - model = nn.BatchNorm3d(num_features, eps) + model = nn.BatchNorm3d(num_features, eps, affine=affine) input_shape = (6, num_features, 2, 3, 4) if dynamic_input == "None": self.run_compare_torch( @@ -160,13 +272,16 @@ def test_batchnorm_dynamic(self, rank, num_features, eps, training, backend): inputs, model, expected_results, input_as_shape=False, backend=backend, ) - @pytest.mark.parametrize("backend", backends) - def test_batchnorm_1d(self, backend): + @pytest.mark.parametrize( + "affine, backend", + itertools.product([True, False], backends), + ) + def test_batchnorm_1d_with_conv(self, affine, backend): class CRNNBase(nn.Module): - def __init__(self, ch_in, ch_out, kernel_size=3, use_bn=True): + def __init__(self, ch_in, ch_out, kernel_size=3): super(CRNNBase, self).__init__() self.conv = nn.Conv1d(ch_in, ch_out, kernel_size=kernel_size) - self.norm = nn.BatchNorm1d(ch_out) + self.norm = nn.BatchNorm1d(ch_out, affine=affine) def forward(self, x): x = self.conv(x) @@ -176,23 +291,23 @@ def forward(self, x): self.run_compare_torch((1, 6, 15), model, backend=backend) @pytest.mark.parametrize( - "shape, eps, backend", - itertools.product([(1, 10), (4, 6), (10, 1)], [0.1, 1e-05], backends), + "shape, eps, affine, backend", + itertools.product([(1, 10), (4, 6), (10, 1)], [0.1, 1e-05], [True, False], backends), ) - def test_batchnorm1d_rank2(self, shape, eps, backend): + def test_batchnorm1d_rank2(self, shape, eps, affine, backend): N,C = shape - batchnorm = nn.BatchNorm1d(C, eps=eps).eval() + batchnorm = nn.BatchNorm1d(C, eps=eps, affine=affine).eval() self.run_compare_torch( (N, C), batchnorm, backend=backend, ) @pytest.mark.parametrize( - "shape, eps, backend", - itertools.product([(4, 8, 2), (1, 5, 3), (5, 10, 1), (6, 1, 4)], [0.1, 1e-05], backends), + "shape, eps, affine, backend", + itertools.product([(4, 8, 2), (1, 5, 3), (5, 10, 1), (6, 1, 4)], [0.1, 1e-05], [True, False], backends), ) - def test_batchnorm1d_rank3(self, shape, eps, backend): + def test_batchnorm1d_rank3(self, shape, eps, affine, backend): N,C,L = shape - batchnorm = nn.BatchNorm1d(C, eps=eps).eval() + batchnorm = nn.BatchNorm1d(C, eps=eps, affine=affine).eval() self.run_compare_torch( (N, C, L), batchnorm, backend=backend, ) @@ -224,20 +339,37 @@ def test_groupnorm(self, group_features, eps, affine, backend): class TestLinear(TorchBaseTest): @pytest.mark.parametrize( - "in_features, out_features, backend", - itertools.product([10, 25], [3, 6], backends), + "in_features, out_features, bias, backend", + itertools.product([5], [10], [True, False], backends), ) - def test_addmm(self, in_features, out_features, backend): - model = nn.Linear(in_features, out_features) + def test_linear_rank1_input(self, in_features, out_features, bias, backend): + model = nn.Linear(in_features, out_features, bias=bias) + self.run_compare_torch((in_features,), model, backend=backend) + + @pytest.mark.parametrize( + "in_features, out_features, bias, backend", + itertools.product([10, 25], [3, 6], [True, False], backends), + ) + def test_linear_rank2_input(self, in_features, out_features, bias, backend): + model = nn.Linear(in_features, out_features, bias=bias) self.run_compare_torch((1, in_features), model, backend=backend) @pytest.mark.parametrize( - "in_features, out_features, backend", - itertools.product([5], [10], backends), + "in_features, out_features, bias, backend", + itertools.product([10], [6], [True, False], backends), ) - def test_linear_rank1_input(self, in_features, out_features, backend): - model = nn.Linear(in_features, out_features) - self.run_compare_torch((in_features,), model, backend=backend) + def test_linear_rank3_input(self, in_features, out_features, bias, backend): + model = nn.Linear(in_features, out_features, bias=bias) + self.run_compare_torch((1, 3, in_features), model, backend=backend) + + @pytest.mark.parametrize( + "in_features, out_features, bias, backend", + itertools.product([10], [6], [True, False], backends), + ) + def test_linear_rank4_input(self, in_features, out_features, bias, backend): + model = nn.Linear(in_features, out_features, bias=bias) + self.run_compare_torch((1, 5, 3, in_features), model, backend=backend) + class TestConv(TorchBaseTest): @pytest.mark.parametrize( @@ -984,7 +1116,7 @@ def forward(self, x, y): elif mode == "maximum": return torch.maximum(x, y) else: - raise ValueError(f"Unsupported mode: {mode}") + raise ValueError("Unsupported mode: {mode}".format(mode=mode)) model = TestModel() self.run_compare_torch([input_shape] * 2, model, backend=backend) @@ -1547,7 +1679,6 @@ def test_cumsum(self, backend, axis): class TestReshape(TorchBaseTest): - # TODO: Add dynamic & rank preserving reshape tests for pytorch @pytest.mark.parametrize( "backend, output_shape", itertools.product(backends, [(3, 2), (2, -1), (2, 1, 1, 3),],), @@ -1742,7 +1873,7 @@ def test_sigmoid_hard(self, backend, shape): ) @pytest.mark.skipif( _macos_version() <= (10, 15), - reason="Parametric SoftPlus segfaults on macOS 10.15 and below. (rdar://problem/66555235)", + reason="Parametric SoftPlus segfaults on macOS 10.15 and below.", ) def test_softplus(self, backend, beta, threshold): input_shape = (1, 10, 5, 15) @@ -1761,6 +1892,39 @@ def test_softsign(self, backend, shape): shape, model, backend=backend, ) + @pytest.mark.skipif( + condition=version_lt(torch, "1.7.0"), + reason="torch.nn.SiLU available only in PyTorch 1.7.0+", + ) + @pytest.mark.parametrize( + "shape, backend", + itertools.product([(1, 10), (1, 3, 4), (1, 4, 5, 6)], backends), + ) + def test_silu(self, shape, backend): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + model = ModuleWrapper(function=torch.nn.functional.silu) + self.run_compare_torch([shape], model, backend=backend) + + @pytest.mark.parametrize( + "rounding_mode, backend", + itertools.product([None, "floor", "trunc"], backends), + ) + def test_div(self, rounding_mode, backend): + model = ModuleWrapper(function=torch.div, + kwargs={"rounding_mode": rounding_mode}) + x1 = torch.from_numpy(np.array([2.3, 2.6, -3.6, -3.2], dtype=np.float32)) + x2 = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float32)) + out = torch.div(x1, x2, rounding_mode=rounding_mode) + self.run_compare_torch( + [x1, x2], + model, + backend=backend, + input_as_shape=False, + expected_results=out, + ) + class TestElementWiseUnary(TorchBaseTest): @pytest.mark.parametrize( @@ -1830,6 +1994,7 @@ def test_threshold(self, backend, shape, threshold): model = torch.nn.Threshold(threshold[0], threshold[1]).eval() self.run_compare_torch( shape, model, backend=backend, + use_cpu_for_conversion=True, # TODO: change this to False (rdar://78343191) ) @pytest.mark.parametrize( @@ -1914,6 +2079,9 @@ def test_cast_bug(self, use_cpu_for_conversion, backend): if backend == "mlprogram" and not use_cpu_for_conversion: pytest.xfail("rdar://78343191 ((MIL GPU) Core ML Tools Unit Test failures [failure to load or Seg fault])") + if backend == "mlprogram" and use_cpu_for_conversion: + pytest.xfail("numerical mismatch : rdar://78952850") + class TestModel(torch.nn.Module): def forward(self, spans, embedding): spans = spans.float().relu().int() @@ -2193,6 +2361,57 @@ def forward(self, x): input_shape, model, backend=backend, ) +class TestWhere(TorchBaseTest): + @pytest.mark.parametrize( + "backend, shape", + itertools.product( + backends, + [(2, 6), (3, 4, 5)] + ), + ) + def test_where_test1(self, backend, shape): + + class WhereModel(nn.Module): + def __init__(self): + super(WhereModel, self).__init__() + + def forward(self, x, y): + return torch.where(x > 0.5, x, y) + + input_shape = [shape, shape] + model = WhereModel() + self.run_compare_torch( + input_shape, model, backend=backend, + ) + + @pytest.mark.parametrize( + "backend, shape", + itertools.product( + backends, + [(2, 6), (3, 4, 5)] + ), + ) + def test_where_test2(self, backend, shape): + + class WhereModel(nn.Module): + def __init__(self): + super(WhereModel, self).__init__() + + def forward(self, cond, x, y): + return torch.where(cond, x, y) + + cond = torch.rand(*shape) > 0.5 + inputs = [cond, torch.rand(*shape), torch.rand(*shape)] + model = WhereModel() + expected_results = model(*inputs) + self.run_compare_torch( + inputs, + model, + backend=backend, + expected_results=expected_results, + input_as_shape=False, + ) + class TestSelect(TorchBaseTest): @pytest.mark.parametrize( "backend, dim_index", diff --git a/coremltools/converters/mil/frontend/torch/torchir_passes.py b/coremltools/converters/mil/frontend/torch/torchir_passes.py index d9fd963e2..1c41d5227 100644 --- a/coremltools/converters/mil/frontend/torch/torchir_passes.py +++ b/coremltools/converters/mil/frontend/torch/torchir_passes.py @@ -5,9 +5,6 @@ def transform_inplace_ops(graph, name_remap_dict=None): - # TODO: one recent 1P model has included the op `copy_`. This is another - # in-place op that should be fixed by this pass. - # See rdar://64267506 # As we modify ops, we'll need to remap symbols. if name_remap_dict is None: diff --git a/coremltools/converters/mil/input_types.py b/coremltools/converters/mil/input_types.py index f2b9de2be..1e575564d 100644 --- a/coremltools/converters/mil/input_types.py +++ b/coremltools/converters/mil/input_types.py @@ -26,7 +26,7 @@ def __init__( Parameters ---------- class_labels: str / list of int / list of str - If a ``list`` if given, the ``list`` maps the index of the output of a + If a ``list`` is given, the ``list`` maps the index of the output of a neural network to labels in a classifier. If a ``str`` is given, the ``str`` points to a file which maps the index diff --git a/coremltools/converters/mil/mil/ops/defs/__init__.py b/coremltools/converters/mil/mil/ops/defs/__init__.py index d25458116..bbc14cba8 100644 --- a/coremltools/converters/mil/mil/ops/defs/__init__.py +++ b/coremltools/converters/mil/mil/ops/defs/__init__.py @@ -4,6 +4,7 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause from .activation import * +from .classify import classify from .control_flow import * from .conv import * from .elementwise_binary import * diff --git a/coremltools/converters/mil/mil/ops/defs/activation.py b/coremltools/converters/mil/mil/ops/defs/activation.py index cd221589a..66a6cf03c 100644 --- a/coremltools/converters/mil/mil/ops/defs/activation.py +++ b/coremltools/converters/mil/mil/ops/defs/activation.py @@ -5,7 +5,7 @@ import numpy as np import scipy -from coremltools.converters.mil.mil import Operation, VALUE +from coremltools.converters.mil.mil import Operation, types, VALUE from coremltools.converters.mil.mil.input_type import ( DefaultInputs, FloatInputType, @@ -491,6 +491,31 @@ def value_inference(self): def type_inference(self): return self.x.sym_type +@register_op(doc_str="") +class silu(Operation): + """ + Sigmoid Linear Unit, element-wise apply the SiLU or Swish operation ``x * sigmoid(x)``. + + Parameters + ---------- + x: tensor<*, T> + + Returns + ------- + tensor<*, T> + + Attributes + ---------- + T: fp32 + """ + + input_spec = InputSpec(x=TensorInputType(),) + + def __init__(self, **kwargs): + super(silu, self).__init__(**kwargs) + + def type_inference(self): + return types.tensor(self.x.dtype, tuple(self.x.shape)) @register_op(doc_str="") class softplus(elementwise_unary): diff --git a/coremltools/converters/mil/mil/ops/defs/classify.py b/coremltools/converters/mil/mil/ops/defs/classify.py new file mode 100644 index 000000000..0428fefae --- /dev/null +++ b/coremltools/converters/mil/mil/ops/defs/classify.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +import numpy as np + +from coremltools.converters.mil.mil.ops.defs._op_reqs import * +from coremltools.converters.mil.mil.types.symbolic import any_symbolic + +@register_op(doc_str="") +class classify(Operation): + """ + Presence of this op indicates that the model is of type classifier, + which accordingly constructs the model output, that is, the predicted class label + and the output probability dictionary. The parameters of this op are set + based on the attributes set for the class "coremltools.ClassifierConfig" by the user. + The outputs of this op cannot be used by another op. + + Parameters + ---------- + * probabilities: tensor<[* , ProbT]> (Required) + * a tensor in the graph, which is used to compute the classifier output(s) + * This is the tensor whose values are mapped to the class labels and used for + * constructing the predicted class label and the output dictionary of class names + * and values + * classes: list<*, ClassT> (Required) + * list of classes + + Returns + ------- + * + * Dict[classT, probT] + + + Attributes + ---------- + ProbT: fp32 + ClassT: int64, str + """ + + input_spec = InputSpec( + probabilities=TensorInputType(), + classes=ListInputType(const=True), + ) + + def __init__(self, **kwargs): + super(classify, self).__init__(**kwargs) + + def type_inference(self): + # check the type of "classes" + if not types.is_list(self.classes.sym_type): + msg = "'classes' in the op 'classify' must be of type list. Instead it is {}." + raise ValueError(msg.format(self.classes.sym_type.__type_info__())) + + classes_elem_type = self.classes.elem_type + if classes_elem_type not in {types.str, types.int64}: + msg = "Type of elements in 'classes' in the op 'classify' must be either str or int64. Instead it is {}." + raise ValueError(msg.format(classes_elem_type.__type_info__())) + + # check that the size of "classes" is compatible with the size of "probabilities" + if not any_symbolic(self.probabilities.shape): + size = np.prod(self.probabilities.shape) + if len(self.classes.val) != size: + msg = "In op 'classify', number of classes must match the size of the tensor corresponding to 'probabilities'." + raise ValueError(msg) + + return classes_elem_type, types.dict(classes_elem_type, types.double) diff --git a/coremltools/converters/mil/mil/ops/defs/control_flow.py b/coremltools/converters/mil/mil/ops/defs/control_flow.py index 944c22764..aa51ff509 100644 --- a/coremltools/converters/mil/mil/ops/defs/control_flow.py +++ b/coremltools/converters/mil/mil/ops/defs/control_flow.py @@ -22,6 +22,7 @@ ) from coremltools.converters.mil.mil.input_type import ( BoolInputType, + BoolTensorInputType, DefaultInputs, InputSpec, InternalScalarOrTensorInputType, @@ -174,12 +175,18 @@ def _get_type_val(self, value): value = np.int32(value) elif isinstance(value, (tuple, list, np.ndarray)): value = np.array(value) - if value.dtype == np.int64: - # We use int32 by default. + + # For the int type, we use int32 by default + if value.dtype in [np.uint8, np.int8, np.uint16, np.int16, np.uint32, np.uint64, np.int64]: + if value.dtype in [np.uint64, np.int64]: + msg = "Downcast const op {} data int64 as int32".format(self.name) + logging.warning(msg) value = value.astype(np.int32) - if value.dtype == np.float64: - # We use float32 by default. + # For the float type, we use float32 by default + elif value.dtype == np.float64: + msg = "Downcast const op {} data fp64 as fp32".format(self.name) + logging.warning(msg) value = value.astype(np.float32) elif isinstance(value, mil_list): @@ -229,8 +236,8 @@ class select(Operation): Parameters ---------- - cond: tensor<[\*D1], T> (Required) - * Tensor. When ``True`` (non-zero), select element from ``x``, otherwise, ``y``. + cond: tensor<[\*D1], B> (Required) + * Tensor. When ``True``, select element from ``x``, otherwise, ``y``. a: tensor<[\*D2], T> (Optional) * Values selected at indices where ``cond`` is ``True``. @@ -251,11 +258,12 @@ class select(Operation): Attributes ---------- + B: bool T: fp32 """ - + input_spec = InputSpec( - cond=TensorInputType(), a=TensorInputType(), b=TensorInputType() + cond=BoolTensorInputType(), a=TensorInputType(), b=TensorInputType() ) def __init__(self, **kwargs): @@ -451,8 +459,8 @@ def build_nested_blocks(self): if not is_subtype(v_out.sym_type, v_in.sym_type): msg = 'Block output {}: {} is not a subtype of ' +\ 'block input {}: {} after factoring shape changes' - raise ValueError(msg.format(v_out.name. v.sym_type, - v_in.name, v_in.sym_type)) + raise ValueError(msg.format(v_out.name, v_out.sym_type.__name__, + v_in.name, v_in.sym_type.__name__)) if not while_loop._check_equal_value(v_out.sym_val, v_in.sym_val): msg = 'Block output {}: {} is not equal to ' +\ 'block input {}: {} after value changes' diff --git a/coremltools/converters/mil/mil/ops/defs/conv.py b/coremltools/converters/mil/mil/ops/defs/conv.py index 2d8e92d47..3337b87f6 100644 --- a/coremltools/converters/mil/mil/ops/defs/conv.py +++ b/coremltools/converters/mil/mil/ops/defs/conv.py @@ -62,12 +62,14 @@ class conv(Operation): ``d_out[i] = ceil(d_in[i] / strides[i])``. Specifically, for ``i = 0,..,,len(d_in)-1``, the equivalent paddings are - as follows, when dilated kernel is even (for example, ``(K[i]-1)*dilations[i]+1)``): - - * ``pad[2*i] = ceil[((K[i]-1)*dilations[i]+1)/2]``. - * ``pad[2*i+1] = floor[((K[i]-1)*dilations[i]+1)/2]``. - - Otherwise, ``pad[2*i] = pad[2*i+1] = (K[i]-1) * dilations[i] / 2``. + calculated as follows: + + * ``dilated_kernel = (K[i] - 1) * dilate[i] + 1`` + * if ``dilated_kernel`` is odd, + ``padding[2*i] = padding[2*i+1] = floor(dilated_kernel / 2)`` + * Otherwise: + ``padding[2*i] = ceil((dilated_kernel - 1) / 2)``, + ``padding[2*i+1] = floor((dilated_kernel - 1) / 2)`` pad: const tensor<[P], i32> (Optional. Default to all zeros) diff --git a/coremltools/converters/mil/mil/ops/defs/elementwise_binary.py b/coremltools/converters/mil/mil/ops/defs/elementwise_binary.py index 9f6519512..794221c89 100644 --- a/coremltools/converters/mil/mil/ops/defs/elementwise_binary.py +++ b/coremltools/converters/mil/mil/ops/defs/elementwise_binary.py @@ -597,8 +597,8 @@ class real_div(elementwise_binary): Returns ------- - tensor<\*?, bool> - * A boolean tensor with the same shape as the inputs. + tensor<\*?, T> + * A tensor with the same type and shape as the inputs. Attributes ---------- @@ -606,14 +606,18 @@ class real_div(elementwise_binary): """ def __init__(self, **kwargs): + # TODO(rdar://79925291): Allow int32 input to floor_div + from coremltools.converters.mil.mil import Builder as mb + from coremltools.converters.mil.mil import types + accepted_types = [types.fp32, types.fp16] + for input_name in ["x", "y"]: + if kwargs[input_name].dtype not in accepted_types: + kwargs[input_name] = mb.cast(x=kwargs[input_name], dtype="fp32") super(real_div, self).__init__(**kwargs) def get_operator(self): return operator.truediv - def get_dtype(self, promoted_dtype): - return types.float - @register_op(doc_str="") class pow(elementwise_binary): diff --git a/coremltools/converters/mil/mil/ops/defs/image_resizing.py b/coremltools/converters/mil/mil/ops/defs/image_resizing.py index be4ac53a5..a01abde7a 100644 --- a/coremltools/converters/mil/mil/ops/defs/image_resizing.py +++ b/coremltools/converters/mil/mil/ops/defs/image_resizing.py @@ -10,6 +10,135 @@ from coremltools.converters.mil.mil import get_new_symbol +@register_op(doc_str="") +class affine(Operation): + """ + Apply a linear affine transform to the input 2D image tensor. Value at the + (x, y), i.e., (w, h) coordinate of the output, is computed by first computing + the coordinates x’ and y’ with the following equation and then compute the + value at the coordinate (x’,y’) in the input image using either bilinear or + nearest neighbor interpolation. If the (x’, y’) point falls outside the input + image, then padding information is used to compute the value. + * x’ = a0 * x + a1 * y + a2 + * y’ = b0 * x + b1 * y + b2 + + Parameters + ---------- + x: tensor<[B, C, H1, W1], T> + * Must be rank ``4``. + transform_matrix: tensor<[D, 6], T> + * Must be rank ``2`` + * D can be either B or 1. + when D == B, for each batch, there is a separate transform matrix + when D == 1, the same matrix is used for all input batches + for each batch: [a0, a1, a2, b0, b1, b2] + output_height: const + * Target output height + output_width: const + * Target output width + sampling_mode: const + * Allowed values: "bilinear" + padding_mode: const + * Allowed values: "constant" + * Note that following illustration is 1D case for brevity, the op only support 2D image input. + * if ``padding_mode == "constant"``: + the input image is assumed to be padded with the padding_value + E.g., |1, 2, 3| -> |0, 0, 0, 1, 2, 3, 0, 0, 0| + padding_value: const + * Currently non-zero values are not supported. + * To be used only when ``padding_mode == "constant"``, ignored in other cases. + coordinates_mode: const + * allowed values: "normalized_minus_one_to_one", + * if ``coordinates_mode == "normalized_minus_one_to_one"``, in-image values are [-1, 1] + * E.g., if ``coordinates_mode == "normalized_minus_one_to_one"``, + the in range values are [-1, 1]. That is: + * (-1, -1), i.e., (w=-1, h=-1), corresponds to the top-left pixel + * (1, -1), i.e., (w=1, h=-1), corresponds to the top-right pixel + * (-1, 1), i.e., (w=-1, h=1), corresponds to the bottom-left pixel + * (1, 1), i.e., (w=1, h=1), corresponds to the bottom-right pixel + align_corners: const + * Currently align_corners=False is not supported. + * To be used only when ``coordinates_mode != unnormalized``, ignored otherwise. + * if ``align_corners == True``, the extrema coordinates are corresponding + to the center of the first and last corner pixels. + * if ``align_corners == False``, the extrema coordinates are corresponding + to the edge of the first and last corner pixels. + + Returns + ------- + tensor<[B, C, output_height, output_width], T> + + Attributes + ---------- + T: fp32 + """ + + input_spec = InputSpec( + x=TensorInputType(), + transform_matrix=TensorInputType(), + output_height=IntInputType(const=True), + output_width=IntInputType(const=True), + sampling_mode=StringInputType(const=True), + padding_mode=StringInputType(const=True), + padding_value=FloatInputType(const=True), + coordinates_mode=StringInputType(const=True), + align_corners=BoolInputType(const=True), + ) + + def __init__(self, **kwargs): + super(affine, self).__init__(**kwargs) + + def type_inference(self): + if self.x.rank != 4: + raise ValueError( + 'input "x" to the "affine" op must be a rank 4 tensor. ' + "Got rank {} tensor of shape {}".format( + self.x.rank, self.x.shape + ) + ) + if self.transform_matrix.rank != 2: + raise ValueError( + 'input "transform_matrix" to the "affine" op must be a rank 2 tensor. ' + "Got rank {} tensor of shape {}".format( + self.transform_matrix.rank, self.transform_matrix.shape + ) + ) + if self.sampling_mode.val.lower() != "bilinear": + raise NotImplementedError( + 'input "sampling_mode" to the "affine" not implemented. ' + 'Got "{}"'.format(self.sampling_mode.val) + ) + if self.coordinates_mode.val.lower() != "normalized_minus_one_to_one": + raise NotImplementedError( + 'input "coordinates_mode" to the "affine" not implemented. ' + 'Got "{}"'.format(self.coordinates_mode.val) + ) + if self.padding_mode.val.lower() != "constant" or self.padding_value.val != 0.0: + raise NotImplementedError( + 'input "padding_mode" to the "affine" not implemented. ' + 'Got "{}" with "padding_value={}"'.format( + self.padding_mode.val, self.padding_value.val + ) + ) + + input_shape = self.x.shape + transform_matrix_shape = self.transform_matrix.shape + if ( + not is_symbolic(transform_matrix_shape[-1]) + and transform_matrix_shape[-1] != 6 + ): + raise ValueError( + 'input "transform_matrix" to the "affine" op last dimension must be 6 ' + "[a0, a1, a2, b0, b1, b2], " + "Got {} for last dimension".format(transform_matrix_shape[-1]) + ) + + ret_shape = list(input_shape) + ret_shape[2] = self.output_height.val + ret_shape[3] = self.output_width.val + return types.tensor(self.x.dtype, tuple(ret_shape)) + + @register_op(doc_str="TODO") class upsample_nearest_neighbor(Operation): """ @@ -66,6 +195,187 @@ def type_inference(self): return types.tensor(self.x.dtype, ret_shape) +@register_op(doc_str="") +class resample(Operation): + """ + Resample the input image tensor ``x``, at the ``coordinates``. + input. Since the coordinates may not correspond to exact pixels in the + input image, this would require "resampling". sampling_mode determines + the algorithm used for resampling and computing the values. + + Parameters + ---------- + x: tensor<[B, C, H1, W1], T> + * Must be rank ``4``. + coordinates: tensor<[B, H2, W2, 2], U> + * Must be rank ``4``. + * Coordinates are provided in the order (x, y), i.e., (w, h). + * Value of each output location output[b, c, h, w] is calculated by + sampling, from the input image x[b, c, :, :], the pixel at the (x, y) + location corresponding to the length-2 vector: coordinates[b, h, w, :] + * Coordinate (normalized or unnormalized) should be specified according + to ``coordinates_mode`` + sampling_mode: const + * Allowed values: "bilinear" , "nearest" + padding_mode: const + * Allowed values: "constant", "border", "reflection", "symmetric" + * Note that following illustration is 1D case for brevity, the op only support 2D image input. + * if ``padding_mode == "constant"``: + the input image is assumed to be padded with the padding_value + E.g., |1, 2, 3| -> |0, 0, 0, 1, 2, 3, 0, 0, 0| + * if ``padding_mode == "border"``: + the input image is assumed to be padded with the values replicated + from the values at the edge. This is also referred to as the + "clamped" or "replication" mode, since the padded values are + clamped to the border values. + E.g., |1, 2, 3| -> |1, 1, 1, 1, 2, 3, 3, 3, 3| + * if ``padding_mode == "reflection"``: + the border values are reflected, *not* including the values at the edge/border + E.g., |1, 2, 3| -> |2, 3, 2, 1, 2, 3, 2, 1, 2| + * if ``padding_mode == "symmetric"``: + values are reflected, including the border/edge values + E.g., |1, 2, 3| -> |3, 2, 1 , 1, 2, 3, 3, 2, 1| + padding_value: const + * To be used only when ``padding_mode == "constant"``, ignored in other cases. + coordinates_mode: const + * allowed values: "unnormalized", "normalized_minus_one_to_one", + "normalized_zero_to_one" + * if ``coordinates_mode == "unnormalized"``, the coordinates input values + are interpreted to be in range [0, W - 1] / [0, H - 1] corresponds to in-image point + * if ``coordinates_mode == "normalized_minus_one_to_one"``, in-image values are [-1, 1] + * if ``coordinates_mode == "normalized_zero_to_one"``, in-image values are [0, 1] + * E.g., if ``coordinates_mode == "normalized_minus_one_to_one"``, + the in range values are [-1, 1]. That is: + * (-1, -1), i.e., (w=-1, h=-1), corresponds to the top-left pixel + * (1, -1), i.e., (w=1, h=-1), corresponds to the top-right pixel + * (-1, 1), i.e., (w=-1, h=1), corresponds to the bottom-left pixel + * (1, 1), i.e., (w=1, h=1), corresponds to the bottom-right pixel + align_corners: const + * if ``align_corners == True``, the extrema coordinates are corresponding + to the center of the first and last corner pixels. + * if ``align_corners == False``, the extrema coordinates are corresponding + to the edge of the first and last corner pixels. + + Returns + ------- + tensor<[B, C, H2, W2], T> + + Attributes + ---------- + T: fp32 + U: fp32, i32, i64 + """ + + input_spec = InputSpec( + x=TensorInputType(), + coordinates=TensorInputType(), + sampling_mode=StringInputType(const=True), + padding_mode=StringInputType(const=True), + padding_value=FloatInputType(const=True), + coordinates_mode=StringInputType(const=True), + align_corners=BoolInputType(const=True), + ) + + def __init__(self, **kwargs): + super(resample, self).__init__(**kwargs) + + def type_inference(self): + if self.x.rank != 4: + raise ValueError( + 'input "x" to the "resample" op must be a rank 4 tensor. ' + "Got rank {} tensor of shape {}".format( + self.x.rank, self.x.shape + ) + ) + if self.coordinates.rank != 4: + raise ValueError( + 'input "coordinates" to the "resample" op must be a rank 4 tensor. ' + "Got rank {} tensor of shape {}".format( + self.coordinates.rank, self.coordinates.shape + ) + ) + + input_shape = self.x.shape + coord_shape = self.coordinates.shape + if ( + not is_symbolic(input_shape[0]) + and not is_symbolic(coord_shape[0]) + and input_shape[0] != coord_shape[0] + ): + raise ValueError( + 'input "x" and "coordinates" to the "resample" must agree on ' + "dimension of batch size: {} vs. {}".format( + input_shape[0], coord_shape[0] + ) + ) + if not is_symbolic(coord_shape[-1]) and coord_shape[-1] != 2: + raise ValueError( + 'input "coordinates" to the "resample" op last dimension must be 2. ' + "Got {} for last dimension".format( + coord_shape[-1] + ) + ) + + ret_shape = list(input_shape) + ret_shape[2] = coord_shape[1] # Output height + ret_shape[3] = coord_shape[2] # Output width + return types.tensor(self.x.dtype, tuple(ret_shape)) + + +@register_op(doc_str="TODO") +class resize_nearest_neighbor(Operation): + """ + Resize the spatial (last two) dimensions to the specified target size + using nearest neighbor interpolation. Although this op is similar to + ``upsample_nearest_neighbor``, ``resize_nearest_neighbor`` works with + a target size rather than with scale factors. + + Parameters + ---------- + x: tensor<[*D, H1, W1], T> (Required) + * Must be at least rank ``3``. + target_size_height: const (Required) + * Target spatial size for the height dimension (``axis=-2``). + target_size_width: const (Required) + * Target spatial size for the width dimension (``axis=-1``). + + Notes + ----- + See ``resize_bilinear`` for examples. + + Returns + ------- + tensor<[*D, H2, W2], T> + * Tensor with same type as the input. + * ``H2`` = ``target_size_height``. + * ``W2`` = ``target_size_width``. + + Attributes + ---------- + T: fp32 + """ + + input_spec = InputSpec( + x=TensorInputType(), + target_size_height=IntInputType(const=True), + target_size_width=IntInputType(const=True), + ) + + def __init__(self, **kwargs): + super(resize_nearest_neighbor, self).__init__(**kwargs) + + def type_inference(self): + if self.x.rank < 3: + raise ValueError( + 'input to the "resize_nearest_neighbor" op must have rank at least 3' + ) + + ret_shape = list(self.x.shape) + ret_shape[-1] = int(self.target_size_width.val) + ret_shape[-2] = int(self.target_size_height.val) + return types.tensor(self.x.dtype, ret_shape) + + @register_op(doc_str="TODO") class upsample_bilinear(Operation): """ diff --git a/coremltools/converters/mil/mil/ops/defs/random.py b/coremltools/converters/mil/mil/ops/defs/random.py index 83ae9825f..3863dacde 100644 --- a/coremltools/converters/mil/mil/ops/defs/random.py +++ b/coremltools/converters/mil/mil/ops/defs/random.py @@ -14,6 +14,7 @@ class RandomDistribution(Operation): input_spec = InputSpec(shape=IntTensorInputType(),) + out_dtype = types.fp32 def __init__(self, **kwargs): super(RandomDistribution, self).__init__(**kwargs) @@ -21,14 +22,14 @@ def __init__(self, **kwargs): def type_inference(self): if any_symbolic(self.shape.shape): # We can't infer any shape if shape has variable length. - return types.tensor(types.fp32, (get_new_variadic_symbol(),)) + return types.tensor(self.out_dtype, (get_new_variadic_symbol(),)) # shape has fixed length here. if self.shape.sym_val is None: shape = tuple([get_new_symbol() for _ in range(self.shape.shape[0])]) - return types.tensor(types.fp32, shape) + return types.tensor(self.out_dtype, shape) - return types.tensor(types.fp32, tuple(self.shape.sym_val.tolist())) + return types.tensor(self.out_dtype, tuple(self.shape.sym_val.tolist())) """ @@ -88,6 +89,10 @@ def default_inputs(self): def __init__(self, **kwargs): super(random_bernoulli, self).__init__(**kwargs) + def type_inference(self): + self.out_dtype = self.prob.dtype + return super().type_inference() + @register_op(doc_str="") class random_categorical(Operation): @@ -138,8 +143,9 @@ def __init__(self, **kwargs): super(random_categorical, self).__init__(**kwargs) def type_inference(self): + self.out_dtype = self.x.dtype output_shape = self.x.shape[:-1] + (self.size.val,) - return types.tensor(types.fp32, output_shape) + return types.tensor(self.out_dtype, output_shape) @register_op(doc_str="") @@ -192,6 +198,12 @@ def default_inputs(self): def __init__(self, **kwargs): super(random_normal, self).__init__(**kwargs) + def type_inference(self): + if self.mean.dtype != self.stddev.dtype: + raise ValueError("Incompatible primitive types in random_normal operation") + self.out_dtype = self.mean.dtype + return super().type_inference() + @register_op(doc_str="") class random_uniform(RandomDistribution): @@ -251,3 +263,9 @@ def default_inputs(self): def __init__(self, **kwargs): super(random_uniform, self).__init__(**kwargs) + + def type_inference(self): + if self.low.dtype != self.high.dtype: + raise ValueError("Incompatible primitive types in random_uniform operation") + self.out_dtype = self.low.dtype + return super().type_inference() diff --git a/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py b/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py index a80fc8259..c887fa920 100644 --- a/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py +++ b/coremltools/converters/mil/mil/ops/defs/tensor_transformation.py @@ -4,6 +4,7 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause import functools +import logging import numpy as np import sympy as sm diff --git a/coremltools/converters/mil/mil/ops/registry.py b/coremltools/converters/mil/mil/ops/registry.py index e260f6a04..d15cbb948 100644 --- a/coremltools/converters/mil/mil/ops/registry.py +++ b/coremltools/converters/mil/mil/ops/registry.py @@ -29,7 +29,6 @@ def register_op(doc_str="", is_custom_op=False, namespace="core"): def class_wrapper(op_cls): op_type = op_cls.__name__ - # op_cls.__doc__ = doc_str # TODO: rdar://58622145 # Operation specific to custom op op_msg = "Custom op" if is_custom_op else "op" diff --git a/coremltools/converters/mil/mil/ops/tests/test_activation.py b/coremltools/converters/mil/mil/ops/tests/test_activation.py index e62d7274d..b9329db5e 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_activation.py +++ b/coremltools/converters/mil/mil/ops/tests/test_activation.py @@ -322,7 +322,6 @@ def build(x): ) -# TODO (rdar://59954690): Broken when there is 1 channel class TestPReLU: @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True, False], backends,) @@ -650,6 +649,40 @@ def build(x): ) +class TestSiLU: + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends,) + ) + def test_builder_to_backend_smoke(self, use_cpu_only, backend): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + x_val = np.array([-1.1, 2.2, -3.3, 4.4], dtype=np.float32).reshape((1, 2, 1, 2)) + + input_placeholder_dict = { + "x": mb.placeholder(shape=x_val.shape), + } + input_value_dict = {"x": x_val} + expected_output_type = x_val.shape + (types.fp32,) + + def build(x): + return mb.silu(x=x) + + expected_output = np.array( + [-0.2747, 1.9805, -0.1174, 4.3466], dtype=np.float32 + ).reshape(expected_output_type[:-1]) + + run_compare_builder( + build, + input_placeholder_dict, + input_value_dict, + expected_output_type, + expected_output, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + class TestSoftplus: @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True, False], backends,) @@ -688,7 +721,6 @@ def test_builder_eval(self): ) -# TODO (rdar://59954690): NN Segfaults when converting from MIL ParametricSoftplus layer # No torch test because there is no direct torch translation to this layer class TestSoftplusParametric: @pytest.mark.parametrize( diff --git a/coremltools/converters/mil/mil/ops/tests/test_const.py b/coremltools/converters/mil/mil/ops/tests/test_const.py index 821630f3d..44610cfbb 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_const.py +++ b/coremltools/converters/mil/mil/ops/tests/test_const.py @@ -16,15 +16,26 @@ class TestConst: "use_cpu_for_conversion, backend, dtype", itertools.product( [True, False], backends, - [np.float32, np.int32] + [ + np.uint8, + np.int8, + np.uint16, + np.int16, + np.uint32, + np.int32, + np.uint64, + np.int64, + np.float32, + np.float64, + ] ) ) def test_builder_to_backend_smoke(self, use_cpu_for_conversion, backend, dtype): if backend == "mlprogram" and not use_cpu_for_conversion: pytest.xfail("rdar://78343191 ((MIL GPU) Core ML Tools Unit Test failures [failure to load or Seg fault])") - t = np.random.randint(0, 100, (100, 2)).astype(np.float32) - constant = np.random.randint(0, 100, (100, 2)).astype(dtype) + t = np.random.randint(0, 5, (4, 2)).astype(np.float32) + constant = np.random.randint(0, 5, (4, 2)).astype(dtype) input_placeholders = { "x": mb.placeholder(shape=t.shape), } @@ -36,7 +47,7 @@ def build(x): z = mb.add(x=x, y=y) return mb.cast(x=z, dtype='fp32') - expected_output_types = (100, 2, types.fp32) + expected_output_types = (4, 2, types.fp32) expected_outputs = t + constant.astype(np.float32) run_compare_builder( diff --git a/coremltools/converters/mil/mil/ops/tests/test_control_flow.py b/coremltools/converters/mil/mil/ops/tests/test_control_flow.py index ea935858c..07e2e863d 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_control_flow.py +++ b/coremltools/converters/mil/mil/ops/tests/test_control_flow.py @@ -26,6 +26,8 @@ def test_builder_to_backend_smoke(self, use_cpu_only, backend): input_values = {"cond": cond_val, "a": a_val, "b": b_val} def build(cond, a, b): + if not types.is_bool(cond.dtype): + cond = mb.cast(x=cond, dtype="bool") return [mb.select(cond=cond, a=a, b=b)] expected_output_types = [(3, 3, types.fp32)] @@ -60,6 +62,8 @@ def test_builder_to_backend_smoke_broadcast(self, use_cpu_only, backend): input_values = {"cond": cond_val, "a": a_val, "b": b_val} def build(cond, a, b): + if not types.is_bool(cond.dtype): + cond = mb.cast(x=cond, dtype="bool") return [mb.select(cond=cond, a=a, b=b)] expected_output_types = [(3, 3, types.fp32)] @@ -81,7 +85,7 @@ def build(cond, a, b): @ssa_fn def test_builder_eval(self): - cond = np.random.randint(low=0, high=2, size=(6, 1, 7)) + cond = np.random.randint(low=0, high=2, size=(6, 1, 7)).astype(np.bool) a = random_gen(shape=(6, 1, 7), rand_min=-1962.0, rand_max=0.0) b = random_gen(shape=(6, 1, 7), rand_min=0.0, rand_max=1964.0) res = mb.select(cond=cond, a=a, b=b) @@ -89,7 +93,7 @@ def test_builder_eval(self): @ssa_fn def test_builder_eval_broadcast(self): - cond = np.array([[1], [0], [1]]) + cond = np.array([[True], [False], [True]]) a = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32) b = np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32) res = mb.select(cond=cond, a=a, b=b) diff --git a/coremltools/converters/mil/mil/ops/tests/test_elementwise_unary.py b/coremltools/converters/mil/mil/ops/tests/test_elementwise_unary.py index 0a46ebcc6..a2fb71950 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_elementwise_unary.py +++ b/coremltools/converters/mil/mil/ops/tests/test_elementwise_unary.py @@ -627,7 +627,7 @@ def build(x): itertools.product( [True, False], backends, - [("fp32", "int32"), ("fp16", "fp32"), ("fp32", "fp16"), ("fp16", "int32")], + [("fp16", "fp32"), ("fp32", "fp16")], ), ) def test_builder_to_backend_stress_cast( @@ -672,4 +672,4 @@ def build(x): frontend_only=False, backend=backend, use_cpu_for_conversion=use_cpu_for_conversion, - ) \ No newline at end of file + ) diff --git a/coremltools/converters/mil/mil/ops/tests/test_image_resizing.py b/coremltools/converters/mil/mil/ops/tests/test_image_resizing.py index 5638afc25..bb423e98a 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_image_resizing.py +++ b/coremltools/converters/mil/mil/ops/tests/test_image_resizing.py @@ -12,6 +12,275 @@ backends = testing_reqs.backends +class TestAffine: + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends) + ) + def test_builder_to_backend_smoke(self, use_cpu_only, backend): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + x_val = np.array([11.0, 22.0, 33.0, 44.0], dtype=np.float32).reshape( + [1, 1, 2, 2] + ) + transform_matrix_val = np.array( + [-1.0, -2.0, -3.7, -1.0, 3.5, 1.2], dtype=np.float32 + ).reshape([1, 6]) + + input_placeholder_dict = { + "x": mb.placeholder(shape=x_val.shape), + "transform_matrix": mb.placeholder(shape=transform_matrix_val.shape), + } + input_value_dict = {"x": x_val, "transform_matrix": transform_matrix_val} + + def build(x, transform_matrix): + return [ + mb.affine( + x=x, + transform_matrix=transform_matrix, + output_height=3, + output_width=3, + sampling_mode="bilinear", + padding_mode="constant", + padding_value=0.0, + coordinates_mode="normalized_minus_one_to_one", + align_corners=True, + ), + mb.affine( + x=x, + transform_matrix=transform_matrix, + output_height=2, + output_width=5, + sampling_mode="bilinear", + padding_mode="constant", + padding_value=0.0, + coordinates_mode="normalized_minus_one_to_one", + align_corners=True, + ), + ] + + expected_output_types = [ + (1, 1, 3, 3, types.fp32), + (1, 1, 2, 5, types.fp32), + ] + expected_outputs = [ + np.array( + [10.752501, 2.5025, 0.0, 1.9799997, 0.0, 0.0, 0.0, 0.0, 0.0], + dtype=np.float32, + ).reshape([1, 1, 3, 3]), + np.array( + [10.752501, 5.94, 2.5025, 0.44000006, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + dtype=np.float32, + ).reshape([1, 1, 2, 5]), + ] + + run_compare_builder( + build, + input_placeholder_dict, + input_value_dict, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + +class TestResample: + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends,) + ) + def test_builder_to_backend_smoke(self, use_cpu_only, backend): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + x_ = np.array([11.0, 22.0, 33.0, 44.0], dtype=np.float32).reshape([1, 1, 2, 2]) + coordinates_ = np.array( + [-1.0, -2.0, -3.7, -1.0, 0.0, 0.0, 3.5, 1.2], dtype=np.float32 + ).reshape([1, 2, 2, 2]) + + input_placeholder_dict = { + "x": mb.placeholder(shape=x_.shape), + "coordinates": mb.placeholder(shape=coordinates_.shape), + } + input_value_dict = {"x": x_, "coordinates": coordinates_} + expected_output_type = (1, 1, 2, 2, types.fp32) + + def build_0(x, coordinates): + return mb.resample( + x=x, + coordinates=coordinates, + sampling_mode="bilinear", + padding_mode="constant", + padding_value=6.17, + coordinates_mode="normalized_minus_one_to_one", + align_corners=True, + ) + + expected_output_0 = np.array( + [8.585, 6.17, 27.5, 6.17], dtype=np.float32 + ).reshape(expected_output_type[:-1]) + + def build_1(x, coordinates): + return mb.resample( + x=x, + coordinates=coordinates, + sampling_mode="nearest", + padding_mode="border", + padding_value=-1.0, + coordinates_mode="unnormalized", + align_corners=False, + ) + + expected_output_1 = np.array( + [11.0, 11.0, 11.0, 44.0], dtype=np.float32 + ).reshape(expected_output_type[:-1]) + + def build_2(x, coordinates): + return mb.resample( + x=x, + coordinates=coordinates, + sampling_mode="bilinear", + padding_mode="reflection", + padding_value=-1.0, + coordinates_mode="normalized_zero_to_one", + align_corners=True, + ) + + expected_output_2 = np.array( + [22.0, 36.3, 11.0, 34.1], dtype=np.float32 + ).reshape(expected_output_type[:-1]) + + def build_3(x, coordinates): + return mb.resample( + x=x, + coordinates=coordinates, + sampling_mode="nearest", + padding_mode="symmetric", + padding_value=-1.0, + coordinates_mode="normalized_zero_to_one", + align_corners=False, + ) + + expected_output_3 = np.array( + [22.0, 33.0, 11.0, 33.0], dtype=np.float32 + ).reshape(expected_output_type[:-1]) + + for build, expected_output in zip( + [build_0, build_1, build_2, build_3], + [ + expected_output_0, + expected_output_1, + expected_output_2, + expected_output_3, + ], + ): + run_compare_builder( + build, + input_placeholder_dict, + input_value_dict, + expected_output_type, + expected_output, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + +class TestResizeNearestNeighbor: + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends) + ) + def test_builder_to_backend_smoke(self, use_cpu_only, backend): + x_val = np.array([0.37, 6.17], dtype=np.float32).reshape([1, 1, 2, 1]) + input_placeholder_dict = {"x": mb.placeholder(shape=x_val.shape)} + input_value_dict = {"x": x_val} + + def build_model(x): + return [ + mb.resize_nearest_neighbor( + x=x, target_size_height=2, target_size_width=1, + ), + mb.resize_nearest_neighbor( + x=x, target_size_height=2, target_size_width=3, + ), + ] + + expected_output_types = [ + (1, 1, 2, 1, types.fp32), + (1, 1, 2, 3, types.fp32), + ] + expected_outputs = [ + x_val, + np.array([0.37, 0.37, 0.37, 6.17, 6.17, 6.17], dtype=np.float32).reshape( + [1, 1, 2, 3] + ), + ] + + run_compare_builder( + build_model, + input_placeholder_dict, + input_value_dict, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + +class TestUpsampleNearestNeighborFractionalScales: + @pytest.mark.parametrize( + "use_cpu_for_conversion, backend", itertools.product([True, False], backends) + ) + def test_builder_to_backend_smoke(self, use_cpu_for_conversion, backend): + if backend == "neuralnetwork": + pytest.xfail("nn backend not supported") + + if backend == "mlprogram" and not use_cpu_for_conversion: + pytest.xfail("rdar://78343225 ((MIL GPU) Core ML Tools Unit Test failures [numerical error])") + + x_val = np.array([1.5, -2.5, 3.5], dtype=np.float32).reshape([1, 1, 1, 3]) + input_placeholder_dict = {"x": mb.placeholder(shape=x_val.shape)} + input_value_dict = {"x": x_val} + + def build(x): + return [ + mb.upsample_nearest_neighbor( + x=x, scale_factor_height=1.0, scale_factor_width=1.0, + ), + mb.upsample_nearest_neighbor( + x=x, scale_factor_height=3.17, scale_factor_width=0.67 + ), + mb.upsample_nearest_neighbor( + x=x, scale_factor_height=2.0, scale_factor_width=1.12, + ), + ] + + expected_output_types = [ + (1, 1, 1, 3, types.fp32), + (1, 1, 3, 2, types.fp32), + (1, 1, 2, 3, types.fp32), + ] + expected_outputs = [ + x_val, + np.array([1.5, -2.5, 1.5, -2.5, 1.5, -2.5], dtype=np.float32).reshape( + [1, 1, 3, 2] + ), + np.array([1.5, -2.5, 3.5, 1.5, -2.5, 3.5], dtype=np.float32).reshape( + [1, 1, 2, 3] + ), + ] + + run_compare_builder( + build, + input_placeholder_dict, + input_value_dict, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_for_conversion, + backend=backend, + use_cpu_for_conversion=use_cpu_for_conversion, + ) + + class TestResizeBilinear: @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True, False], backends,) @@ -151,7 +420,7 @@ def build_upsample_integer(x): def build_upsample_fractional(x): return mb.upsample_bilinear( - x=x, scale_factor_height=1, scale_factor_width=2.6, align_corners=False + x=x, scale_factor_height=1.0, scale_factor_width=2.6, align_corners=False ) expected_output_type = (1, 1, 5, types.fp32) @@ -171,8 +440,7 @@ def build_upsample_fractional(x): ) - # TODO: enable GPU test: rdar://problem/60309338 - @pytest.mark.skip("Broken for mil backend rdar://problem/66964398") + @pytest.mark.xfail(reason="rdar://66964398, failing on both NNv1 and MIL", run=True) @pytest.mark.skipif(not testing_reqs._HAS_TORCH, reason="PyTorch not installed.") @pytest.mark.parametrize( "use_cpu_only, backend, input_shape, scale_factor, align_corners", @@ -180,7 +448,7 @@ def build_upsample_fractional(x): [True], backends, [(2, 5, 10, 22)], - [(3, 4), (2.5, 2), (0.5, 0.75)], + [(3, 4), (2.5, 2.0), (0.5, 0.75)], [True, False], ), ) diff --git a/coremltools/converters/mil/mil/ops/tests/test_normalization.py b/coremltools/converters/mil/mil/ops/tests/test_normalization.py index cbd6a2850..5cbe731aa 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_normalization.py +++ b/coremltools/converters/mil/mil/ops/tests/test_normalization.py @@ -316,25 +316,25 @@ class TestNormalizationLayerNorm: @staticmethod def _keras_layer_norm( x, axes, epsilon): - layer = tf.keras.layers.LayerNormalization(axis=axes, epsilon=epsilon) - data = tf.constant(x, dtype=tf.float32) - output = layer(data) - return output.numpy() + layer = tf.keras.layers.LayerNormalization(axis=axes, epsilon=epsilon) + data = tf.constant(x, dtype=tf.float32) + output = layer(data) + return output.numpy() @staticmethod def _np_layer_norm(x, axes, gamma=None, beta=None, epsilon=1e-5): - rank = len(x.shape) - axes = [axis + rank if axis < 0 else axis for axis in axes] - normalized_shape = [x.shape[i] if i in axes else 1 for i in range(rank)] - gamma = np.ones(shape=normalized_shape) if gamma is None else np.reshape(gamma, normalized_shape) - beta = np.zeros(shape=normalized_shape) if beta is None else np.reshape(beta, normalized_shape) - num = x - np.mean(x, axis=tuple(axes), keepdims=True) - dem = np.sqrt( - np.sum(np.square(num), axis=tuple(axes), keepdims=True) - / np.prod(normalized_shape) - + epsilon - ) - return num / dem * gamma + beta + rank = len(x.shape) + axes = [axis + rank if axis < 0 else axis for axis in axes] + normalized_shape = [x.shape[i] if i in axes else 1 for i in range(rank)] + gamma = np.ones(shape=normalized_shape) if gamma is None else np.reshape(gamma, normalized_shape) + beta = np.zeros(shape=normalized_shape) if beta is None else np.reshape(beta, normalized_shape) + num = x - np.mean(x, axis=tuple(axes), keepdims=True) + dem = np.sqrt( + np.sum(np.square(num), axis=tuple(axes), keepdims=True) + / np.prod(normalized_shape) + + epsilon + ) + return num / dem * gamma + beta @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True, False], backends,) @@ -343,6 +343,8 @@ def test_builder_to_backend_smoke(self, use_cpu_only, backend): x_val = np.array([[[1.0, -7.0], [5.0, -6.0], [-3.0, -5.0]]], dtype=np.float32) input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} input_values = {"x": x_val} + gamma_val = np.array([1.0, 1.0], dtype=np.float32) + beta_val = np.array([1.0, 0.0], dtype=np.float32) def build(x): return [ @@ -350,9 +352,11 @@ def build(x): mb.layer_norm(x=x, axes=[2], epsilon=1e-4), # V2->V1 lowering (op_mappings.py): else branch mb.layer_norm(x=x, axes=[-2, -1], epsilon=1e-4), + # V2->V1 lowering (op_mappings.py): if branch with scale + mb.layer_norm(x=x, axes=[2], epsilon=1e-4, gamma=gamma_val, beta=beta_val), ] - expected_output_types = [(1, 3, 2, types.fp32), (1, 3, 2, types.fp32)] + expected_output_types = [(1, 3, 2, types.fp32), (1, 3, 2, types.fp32), (1, 3, 2, types.fp32)] expected_outputs = [ np.array( [ @@ -374,6 +378,67 @@ def build(x): ], dtype=np.float32, ), + np.array( + [ + [ + [ 1.9999969, -0.9999969 ], + [ 1.99999833, -0.99999833], + [ 1.99995005, -0.99995005], + ] + ], + dtype=np.float32, + ), + ] + + run_compare_builder( + build, + input_placeholders, + input_values, + expected_output_types, + expected_outputs, + use_cpu_only=use_cpu_only, + backend=backend, + ) + + @pytest.mark.parametrize( + "use_cpu_only, backend", itertools.product([True, False], backends,) + ) + def test_builder_to_backend_smoke_rank_2(self, use_cpu_only, backend): + x_val = np.array([[1.0, -7.0], [5.0, -6.0], [-3.0, -5.0]], dtype=np.float32) + gamma_val = np.array([1.0, 1.0], dtype=np.float32) + beta_val = np.array([1.0, 0.0], dtype=np.float32) + input_placeholders = {"x": mb.placeholder(shape=x_val.shape)} + input_values = {"x": x_val} + + def build(x): + return [ + # V2->V1 lowering (op_mappings.py): if branch + mb.layer_norm(x=x, axes=[1], epsilon=1e-4), + mb.layer_norm(x=x, axes=[1], epsilon=1e-4, gamma=gamma_val, beta=beta_val) + ] + + expected_output_types = [(3, 2, types.fp32), (3, 2, types.fp32)] + expected_outputs = [ + np.array( + [ + [ + [ 0.9999969, -0.9999969 ], + [ 0.99999833, -0.99999833], + [ 0.99995005, -0.99995005], + ] + ], + dtype=np.float32, + ), + np.array( + [ + [ + [ 1.9999969, -0.9999969 ], + [ 1.99999833, -0.99999833], + [ 1.99995005, -0.99995005], + ] + ], + dtype=np.float32, + ), ] run_compare_builder( diff --git a/coremltools/converters/mil/mil/ops/tests/test_random.py b/coremltools/converters/mil/mil/ops/tests/test_random.py index 462d084a6..36f78c10e 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_random.py +++ b/coremltools/converters/mil/mil/ops/tests/test_random.py @@ -143,9 +143,6 @@ def build(x): ) - @pytest.mark.xfail( - reason="rdar://78080118 re-enable test once rdar://78079222 (one-hot) is resolved.", - ) @pytest.mark.parametrize( "use_cpu_only, backend, n_sample, n_class", itertools.product([True, False], backends, [50000], [2, 10, 20]), diff --git a/coremltools/converters/mil/mil/ops/tests/test_recurrent.py b/coremltools/converters/mil/mil/ops/tests/test_recurrent.py index 8c3cab6d6..3fd28efb7 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_recurrent.py +++ b/coremltools/converters/mil/mil/ops/tests/test_recurrent.py @@ -37,7 +37,6 @@ class TestGRU: # output(always 0) for second batch onwards [2, 32], [1, 16], - # rdar://66661491 (GRU with bias fails on NNv1 and MIL backend) [True, False], [True, False], ["forward", "reverse"], @@ -334,7 +333,6 @@ def build(x): use_cpu_only=use_cpu_only, frontend_only=False, backend=backend, - # rdar://63839623 ([GITLAB-CI] precision issue on various tests on gitlab ci) atol=1e-3, rtol=1e-3, ) @@ -379,9 +377,6 @@ def test_builder_to_backend_smoke_unilstm( direction, symbolic, ): - # TODO: [MIL] LSTM layer- Implement eval and tf register routine - # Testing 1. peephole values - # 2. clip values torch.manual_seed(50) rnn = torch.nn.LSTM(input_size, hidden_size, 1, bias=has_bias) @@ -489,8 +484,7 @@ def build(x, initial_h, initial_c): ], argvalues=itertools.product( [True], - # TODO: rdar://66768742 (BiLSTM output numerically mismatch for MIL backend) - ["neuralnetwork"], + backends, [1, 8], [1, 32], [1, 64], diff --git a/coremltools/converters/mil/mil/ops/tests/test_scatter_gather.py b/coremltools/converters/mil/mil/ops/tests/test_scatter_gather.py index eccd8cb8c..fcdcfbde5 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_scatter_gather.py +++ b/coremltools/converters/mil/mil/ops/tests/test_scatter_gather.py @@ -225,7 +225,6 @@ def build(data, indices, updates): class TestScatterNd: - # TODO: [MIL] Scatter and ScatterNd in tensoflow @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True, False], backends,) ) diff --git a/coremltools/converters/mil/mil/ops/tests/test_tensor_operation.py b/coremltools/converters/mil/mil/ops/tests/test_tensor_operation.py index 824a0f3c8..26c547209 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_tensor_operation.py +++ b/coremltools/converters/mil/mil/ops/tests/test_tensor_operation.py @@ -292,10 +292,6 @@ def build(shape): class TestNonMaximumSuppression: - @pytest.mark.xfail( - condition=backends[0] == "mlprogram", - reason="Investigate failure rdar://78630549", - ) @pytest.mark.parametrize( "use_cpu_only, backend", itertools.product([True, False], backends,) ) @@ -492,7 +488,6 @@ def _ref_non_maximum_suppression( return out1, out2, out3, out4 - @pytest.mark.xfail(reason="rdar://60390856", run=False) @pytest.mark.parametrize( ",".join( [ @@ -528,6 +523,9 @@ def test_builder_to_backend_stress( n_score, per_class_suppression, ): + if backend == "mlprogram" and iou_threshold_percentile == 0: + pytest.xfail("rdar://78080118") + n_boxes_in, n_boxes_out = n_boxes boxes_val = random_gen((n_batch, n_boxes_in, 4), 0, 100) scores_val = random_gen((n_batch, n_boxes_in, n_score), -100, 100) diff --git a/coremltools/converters/mil/mil/ops/tests/test_tensor_transformation.py b/coremltools/converters/mil/mil/ops/tests/test_tensor_transformation.py index c3269bb7a..444eab709 100644 --- a/coremltools/converters/mil/mil/ops/tests/test_tensor_transformation.py +++ b/coremltools/converters/mil/mil/ops/tests/test_tensor_transformation.py @@ -510,8 +510,6 @@ class TestSliceBySize: "use_cpu_only, backend", itertools.product([True, False], backends,) ) def test_builder_to_backend_smoke(self, use_cpu_only, backend): - if backend != "nn_poto": - pytest.xfail("TODO: activate after rdar://75290346 is fixed and is in the build. Tracked by rdar://75823380") x_val = np.array(list(range(24))).reshape((2, 3, 4)).astype(np.float32) begin_val = np.array([1, 1, 1], dtype=np.int32) input_placeholders = { diff --git a/coremltools/converters/mil/mil/passes/apply_common_pass_pipeline.py b/coremltools/converters/mil/mil/passes/apply_common_pass_pipeline.py index 75da73b36..b3e0e9a73 100644 --- a/coremltools/converters/mil/mil/passes/apply_common_pass_pipeline.py +++ b/coremltools/converters/mil/mil/passes/apply_common_pass_pipeline.py @@ -74,6 +74,7 @@ def _apply(passes, name="common"): "common::noop_elimination", "common::dedup_op_and_var_names", "common::reduce_transposes", # fuse_layernorm_or_instancenorm can potentially adding transposes + "common::topological_reorder", "common::dead_code_elimination", # always end with dce ] diff --git a/coremltools/converters/mil/mil/passes/cast_optimization.py b/coremltools/converters/mil/mil/passes/cast_optimization.py index d6b22a998..036266e04 100644 --- a/coremltools/converters/mil/mil/passes/cast_optimization.py +++ b/coremltools/converters/mil/mil/passes/cast_optimization.py @@ -119,7 +119,7 @@ def try_to_transform(root_op, cached_vars): block = root_op.enclosing_block # Scenario: Redundant cast when source and destination dtype are same. - if root_op.op_type == "cast" and root_op.x.is_tensor_of(dtype=root_op.dtype.val): + if root_op.op_type == "cast" and root_op.x.is_tensor_or_scalar_of(dtype=root_op.dtype.val): block.replace_uses_of_var_after_op( anchor_op=root_op, old_var=root_op.outputs[0], @@ -144,7 +144,7 @@ def try_to_transform(root_op, cached_vars): fused_output_var_name = cast_1.x.name + "_to_{}".format(cast_2.dtype.val) - if cast_1.x.is_tensor_of(dtype=cast_2.dtype.val): + if cast_1.x.is_tensor_or_scalar_of(dtype=cast_2.dtype.val): # when consecutive casts cancel each other # Please checkout: test_linear_consecutive_cast_ops_cancellation in test_cast_optimization.py new_output_var = cast_1.x diff --git a/coremltools/converters/mil/mil/passes/divide_to_multiply.py b/coremltools/converters/mil/mil/passes/divide_to_multiply.py index 65f441cda..1d894142b 100644 --- a/coremltools/converters/mil/mil/passes/divide_to_multiply.py +++ b/coremltools/converters/mil/mil/passes/divide_to_multiply.py @@ -8,6 +8,7 @@ from coremltools.converters.mil.mil.passes.pass_registry import register_pass from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import types as _types def divide_to_multiply_block(block): @@ -18,7 +19,12 @@ def divide_to_multiply_block(block): # This op can't be divide. continue - if op.op_type == "real_div" and op.y.val is not None: + # If real_div has integer input, the result is an integer (following TensorFlow spec). + # Hence this pass needs disabled if the input is not float, since it translates y + # to a floating point number. If x or y was originally an integer, and y becomes + # a floating point number, then the original type + # signature (with integer output) would not be preserved. + if op.op_type == "real_div" and op.y.val is not None and _types.is_float(op.x.dtype): with block: x = mb.mul( x=op.x, y=1.0 / op.y.val, name="_inversed_" + op.name, before_op=op diff --git a/coremltools/converters/mil/mil/passes/noop_elimination.py b/coremltools/converters/mil/mil/passes/noop_elimination.py index 7fd666e2a..40d65469e 100644 --- a/coremltools/converters/mil/mil/passes/noop_elimination.py +++ b/coremltools/converters/mil/mil/passes/noop_elimination.py @@ -79,6 +79,21 @@ def remove_linear(op, block): block.remove_ops([op]) return True +def remove_transpose(op, block): + perm = np.sort(op.perm.val) + if (perm != op.perm.val).any(): + return False + + input_var = op.x + input_op = input_var.op + + op.enclosing_block.replace_uses_of_var_after_op( + anchor_op=input_op, old_var=op.outputs[0], new_var=input_var + ) + + # Remove all the ops at once + block.remove_ops([op]) + return True _SUPPORTED_OPS = { "add", @@ -93,6 +108,7 @@ def remove_linear(op, block): "slice_by_size", "pad", "tile", + "transpose", "upsample_nearest_neighbor", "upsample_bilinear", "resize_bilinear", @@ -112,6 +128,7 @@ def remove_linear(op, block): "slice_by_size": remove_same_shape, "pad": remove_same_shape, "tile": remove_same_shape, + "transpose": remove_transpose, "upsample_nearest_neighbor": remove_same_shape, "upsample_bilinear": remove_same_shape, "resize_bilinear": remove_same_shape, diff --git a/coremltools/converters/mil/mil/passes/quantization_passes.py b/coremltools/converters/mil/mil/passes/quantization_passes.py index 601ac5394..3b940e918 100644 --- a/coremltools/converters/mil/mil/passes/quantization_passes.py +++ b/coremltools/converters/mil/mil/passes/quantization_passes.py @@ -151,17 +151,17 @@ def transform_op(self, op): casted_inputs[param] = list(inputs[:]) for i, var in enumerate(inputs): # Second loop, iterates over all the vars of a python list corresponding to an input parameter. - if not (var.is_tensor_of(dtype="fp32") or var.is_scalar_of(dtype="fp32")): + if not var.is_tensor_or_scalar_of(dtype="fp32"): continue inputs_modified = True with block: casted_var_name = var.name + "_to_fp16" - if len(var._child_ops) > 1 and casted_var_name in self.cache_vars: + if len(var._child_ops) > 1 and casted_var_name in self.cache_vars and (self.cache_vars[casted_var_name] in block._visible_vars_in_block()[1]): casted_inputs[param][i] = self.cache_vars[casted_var_name] else: x = mb.cast( - x=var, dtype="fp16", name=casted_var_name, before_op=op + x=var, dtype="fp16", name=casted_var_name, before_op= op ) casted_inputs[param][i] = x if len(var._child_ops) > 1: @@ -183,8 +183,8 @@ def transform_op(self, op): quant_output = [quant_output] for old_output_var, new_output_var in zip(op.outputs, quant_output): - if old_output_var.is_tensor_of(dtype="fp32") and ( - not new_output_var.is_tensor_of(dtype="fp32") + if old_output_var.is_tensor_or_scalar_of(dtype="fp32") and ( + not new_output_var.is_tensor_or_scalar_of(dtype="fp32") ): with block: x = mb.cast( diff --git a/coremltools/converters/mil/mil/passes/test_noop_elimination.py b/coremltools/converters/mil/mil/passes/test_noop_elimination.py index 4c2110254..748540b92 100644 --- a/coremltools/converters/mil/mil/passes/test_noop_elimination.py +++ b/coremltools/converters/mil/mil/passes/test_noop_elimination.py @@ -45,7 +45,14 @@ def prog(x): elif op_type in {'mul'}: if val == 1 or val == [1, 1, 1, 1]: new_program = ["relu"] - elif op_type in {'pow', 'real_div', 'floor_div'}: + elif op_type in {'real_div'}: + # TODO(rdar://79925291): Remove this branch and add `real_div` to the + # following elif once fp32 casts for `real_div` are no longer required. + original_program = ["cast"] + original_program + new_program = original_program + if pos == 'y' and (val == 1 or val == [1, 1, 1, 1]): + new_program = ["cast", "relu"] + elif op_type in {'pow', 'floor_div'}: if pos == 'y' and (val == 1 or val == [1, 1, 1, 1]): new_program = ["relu"] elif op_type in {'sub'}: @@ -379,3 +386,21 @@ def prog(x): ) +def test_transpose_elimination(): + @mb.program(input_specs=[mb.TensorSpec(shape=(2, 3, 4))]) + def prog(x): + r1 = mb.transpose(x=x, perm=[0, 1, 2]) + return mb.relu(x=r1) + + prev_prog, prev_block, block = apply_pass_and_basic_check( + prog, "common::noop_elimination" + ) + assert get_op_types_in_program(prev_prog) == ["transpose", "relu"] + assert get_op_types_in_program(prog) == ["relu"] + assert_model_is_valid( + prog, + {"x": (2, 3, 4)}, + expected_output_shapes={block.outputs[0].name: (2, 3, 4)}, + ) + + diff --git a/coremltools/converters/mil/mil/passes/test_passes.py b/coremltools/converters/mil/mil/passes/test_passes.py index c61e97c30..5da2f881f 100644 --- a/coremltools/converters/mil/mil/passes/test_passes.py +++ b/coremltools/converters/mil/mil/passes/test_passes.py @@ -23,8 +23,6 @@ validate_model = True -# TODO: rdar://58993652 (Add recursive block test cases for graph pass tests) - def test_const_elimination(): @mb.program(input_specs=[mb.TensorSpec(shape=(2, 4))]) @@ -445,3 +443,34 @@ def prog(x): expected_output_shapes={block.outputs[0].name: (3, 5, 6)}, ) +class TestTopologicalReorder: + + def test_move_sink_casts_to_the_end(self): + @mb.program(input_specs=[mb.TensorSpec(shape=(10, 20))]) + def prog(x): + x = mb.cast(x=x, dtype="fp16") + x1 = mb.square(x=x) + x2 = mb.cast(x=x1, dtype="fp32") + x3 = mb.log(x=x) + x4 = mb.cast(x=x3, dtype="fp32") + x5 = mb.relu(x=x) + x6 = mb.cast(x=x5, dtype="fp32") + x7 = mb.relu(x=x6) + return x2, x4, x7 + + assert get_op_types_in_program(prog) == ['cast', 'square', 'cast', 'log', 'cast', 'relu', 'cast', 'relu'] + + apply_pass_and_basic_check(prog, "common::topological_reorder") + _, _, block = apply_pass_and_basic_check(prog, "common::dead_code_elimination") + + assert get_op_types_in_program(prog) == ["cast", "square", "log", "relu", "cast", "cast", "cast", "relu"] + + assert_model_is_valid( + prog, + {"x": (10, 20)}, + expected_output_shapes={ + block.outputs[0].name: (10, 20), + block.outputs[1].name: (10, 20), + block.outputs[2].name: (10, 20), + }, + ) diff --git a/coremltools/converters/mil/mil/passes/topological_reorder.py b/coremltools/converters/mil/mil/passes/topological_reorder.py new file mode 100644 index 000000000..86deeda24 --- /dev/null +++ b/coremltools/converters/mil/mil/passes/topological_reorder.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2021, Apple Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-3-clause license that can be +# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause + +from collections import defaultdict +from coremltools.converters.mil.mil.passes.pass_registry import register_pass + +def _is_sink(op): + return sum(len(output._child_ops) for output in op.outputs) == 0 + +def _topological_reorder_block(block): + sink_nodes = [] + other_ordered_operations = [] + for i, op in enumerate(block.operations): + for b in op.blocks: + _topological_reorder_block(b) + + if _is_sink(op): + sink_nodes.append(op) + else: + other_ordered_operations.append(op) + + block.operations = other_ordered_operations + sink_nodes + +@register_pass(namespace="common") +def topological_reorder(prog): + """ + Topologically reorders the list of operations in a program by moving all sink nodes to the very end in that list + + Please checkout: test_move_sink_casts_to_the_end in test_passes.py::TestTopologicalReorder + """ + for f_name, f in prog.functions.items(): + _topological_reorder_block(f) + diff --git a/coremltools/converters/mil/mil/var.py b/coremltools/converters/mil/mil/var.py index c3f0a9c36..6494867cf 100644 --- a/coremltools/converters/mil/mil/var.py +++ b/coremltools/converters/mil/mil/var.py @@ -184,11 +184,8 @@ def type_str(self): def set_name(self, name): self.name = name - def is_tensor_of(self, dtype: str): - return types.is_tensor(self.sym_type) and builtin_to_string(self.dtype) == dtype - - def is_scalar_of(self, dtype: str): - return types.is_scalar(self.sym_type) and builtin_to_string(self.dtype) == dtype + def is_tensor_or_scalar_of(self, dtype: str): + return (types.is_tensor(self.sym_type) or types.is_scalar(self.sym_type)) and builtin_to_string(self.dtype) == dtype def __str__(self): return "%" + self.name + ": " + self.shape_str() + self.type_str() diff --git a/coremltools/converters/onnx/_tests/test_pytorch_model.py b/coremltools/converters/onnx/_tests/test_pytorch_model.py index f0391a3ae..0501f53f3 100644 --- a/coremltools/converters/onnx/_tests/test_pytorch_model.py +++ b/coremltools/converters/onnx/_tests/test_pytorch_model.py @@ -282,7 +282,7 @@ def forward(self, x): torch_model.train(False) _test_torch_model_single_io(torch_model, (3, 2, 3), (3, 2, 3)) # type: ignore - @pytest.mark.skip(reason="rdar://64224329") + @pytest.mark.skip(reason="") @unittest.skipIf( _macos_version() < MIN_MACOS_VERSION_10_15, "macOS 10.15+ required. Skipping test.", @@ -301,7 +301,7 @@ def forward(self, x): torch_model.train(False) _test_torch_model_single_io(torch_model, (3, 1, 256), (3, 1, 256), minimum_ios_deployment_target="13") # type: ignore - @pytest.mark.skip(reason="rdar://64224329") + @pytest.mark.skip(reason="") @unittest.skipIf( _macos_version() < MIN_MACOS_VERSION_10_15, "macOS 10.15+ required. Skipping test.", @@ -322,7 +322,7 @@ def forward(self, x): torch_model.train(False) _test_torch_model_single_io(torch_model, (3, 1, 256), (3, 1, 256), minimum_ios_deployment_target="13") # type: ignore - @pytest.mark.skip(reason="rdar://64224329") + @pytest.mark.skip(reason="") @unittest.skipIf( _macos_version() < MIN_MACOS_VERSION_10_15, "macOS 10.15+ required. Skipping test.", @@ -995,7 +995,7 @@ class TransformationTests(unittest.TestCase): _macos_version() < MIN_MACOS_VERSION_10_15, "macOS 10.15+ required. Skipping test.", ) - @pytest.mark.skip(reason="test failure: ") + @pytest.mark.skip(reason="") def test_cast_removal_transformation(self, minimum_ios_deployment_target="13"): torch_model = nn.Upsample(scale_factor=2) torch_model.train(False) diff --git a/coremltools/models/model.py b/coremltools/models/model.py index 65ddeb8cd..ebe938607 100644 --- a/coremltools/models/model.py +++ b/coremltools/models/model.py @@ -12,7 +12,7 @@ from .utils import _has_custom_layer as _has_custom_layer from .utils import load_spec as _load_spec -from .utils import _macos_version as _macos_version +from .utils import _macos_version, _is_macos from .utils import save_spec as _save_spec from ..proto import Model_pb2 as _Model_pb2 @@ -97,7 +97,7 @@ def __iter__(self): yield f.name -def _get_proxy_and_spec(filename, use_cpu_only=False): +def _get_proxy_and_spec(filename, use_cpu_only=False, skip_model_load=False): try: from ..libcoremlpython import _MLModelProxy except Exception: @@ -106,7 +106,7 @@ def _get_proxy_and_spec(filename, use_cpu_only=False): filename = _os.path.expanduser(filename) specification = _load_spec(filename) - if _MLModelProxy: + if _MLModelProxy and not skip_model_load: # check if the version is supported engine_version = _MLModelProxy.maximum_supported_specification_version() @@ -180,7 +180,11 @@ class MLModel(object): predict """ - def __init__(self, model, useCPUOnly=False, is_temp_package=False, mil_program=None): + def __init__(self, model, + useCPUOnly=False, + is_temp_package=False, + mil_program=None, + skip_model_load=False): """ Construct an MLModel from a .mlmodel @@ -203,7 +207,16 @@ def __init__(self, model, useCPUOnly=False, is_temp_package=False, mil_program=N mil_program : coremltools.converters.mil.Program Set to the mil program object, if available. It is avaiable whenever an MLModel object is constructed using - the unified converter API coremltools.convert() + the unified converter API coremltools.convert() + + skip_model_load : bool + Set to True to prevent coremltools from calling into the Core ML framework + to compile and load the model. In that case, the returned model object cannot + be used to make a prediction. This flag may be used to load a newer model + type on an older Mac, to inspect or load/save the spec. + Example: Loading ML Program model type on a macOS 11, since ML Program can only be + compiled and loaded from macOS12+. + Defaults to False. Notes ----- @@ -233,13 +246,13 @@ def __init__(self, model, useCPUOnly=False, is_temp_package=False, mil_program=N self.package_path = model self.is_temp_package = is_temp_package self.__proxy__, self._spec, self._framework_error = _get_proxy_and_spec( - model, useCPUOnly + model, useCPUOnly, skip_model_load=skip_model_load, ) elif isinstance(model, _Model_pb2.Model): filename = _tempfile.mktemp(suffix=_MLMODEL_EXTENSION) _save_spec(model, filename) self.__proxy__, self._spec, self._framework_error = _get_proxy_and_spec( - filename, useCPUOnly + filename, useCPUOnly, skip_model_load=skip_model_load, ) try: _os.remove(filename) @@ -382,6 +395,11 @@ def predict(self, data, useCPUOnly=False, **kwargs): >>> predictions = model.predict(data) """ + if self.is_package and _is_macos() and _macos_version() < (12, 0): + raise Exception( + "predict() for .mlpackage is not supported in macOS version older than 12.0." + ) + if self.__proxy__: return self.__proxy__.predict(data, useCPUOnly) else: diff --git a/coremltools/models/utils.py b/coremltools/models/utils.py index 6d06f2401..6a061d27d 100644 --- a/coremltools/models/utils.py +++ b/coremltools/models/utils.py @@ -17,16 +17,18 @@ from coremltools.proto import Model_pb2 as _Model_pb2 from .._deps import _HAS_SKLEARN +from ..libmodelpackage import ModelPackage -try: - from ..libmodelpackage import ModelPackage -except ModuleNotFoundError: - pass _MLMODEL_EXTENSION = ".mlmodel" _MLPACKAGE_EXTENSION = ".mlpackage" +try: + from ..libmodelpackage import ModelPackage +except ModuleNotFoundError: + pass + if _HAS_SKLEARN: import scipy.sparse as _sp @@ -572,6 +574,33 @@ def rename_feature( rename_outputs or (index < len(spec.pipeline.models)), ) + # Rename for mlProgram + if spec.HasField("mlProgram"): + from coremltools.converters.mil.backend.mil.helper import NameSanitizer + new_name_sanitized = NameSanitizer().sanitize_name(new_name) + if new_name != new_name_sanitized: + raise ValueError("Input/output names for ML Program must be of the format [a-zA-Z_][a-zA-Z0-9_]*. " + "That is, it must start with a letter and only contain numerals, underscore or letters. " + "Provided feature name, \"{}\" does not satisfy these requirements.".format(new_name)) + mil = spec.mlProgram + for function in mil.functions.values(): + for name_value_type in function.inputs: + if name_value_type.name == current_name: + name_value_type.name = new_name + for block in function.block_specializations.values(): + for i, out_name in enumerate(block.outputs): + if out_name == current_name: + block.outputs[i] = new_name + for op in block.operations: + for argument in op.inputs.values(): + for binding in argument.arguments: + if binding.HasField("name"): + if binding.name == current_name: + binding.name = new_name + for name_value_type in op.outputs: + if name_value_type.name == current_name: + name_value_type.name = new_name + def _sanitize_value(x): """ @@ -791,11 +820,12 @@ def _macos_version(): version comparisons. On non-Macs, it returns an empty tuple. """ if _is_macos(): - import platform - - ver_str = platform.mac_ver()[0] - return tuple([int(v) for v in ver_str.split(".")]) - + try: + import subprocess + ver_str = subprocess.run(["sw_vers", "-productVersion"], stdout=subprocess.PIPE).stdout.decode('utf-8').strip('\n') + return tuple([int(v) for v in ver_str.split(".")]) + except: + raise Exception("Unable to detemine the macOS version") return () diff --git a/coremltools/test/api/test_api_examples.py b/coremltools/test/api/test_api_examples.py index d1d47eef8..87e743add 100644 --- a/coremltools/test/api/test_api_examples.py +++ b/coremltools/test/api/test_api_examples.py @@ -23,6 +23,9 @@ from shutil import rmtree from tempfile import mkdtemp +if _HAS_TORCH: + import torch + ############################################################################### # Note: all tests are also used as examples such as in readme.md as a reference @@ -191,9 +194,8 @@ def test_freeze_and_convert_matmul_graph(): mlmodel = ct.convert(frozen_graph_file) # optionally, you can save model to disk # mlmodel.save(frozen_graph_file.replace("pb", "mlmodel")) - import shutil try: - shutil.rmtree(model_dir) + rmtree(model_dir) except: pass @@ -490,7 +492,6 @@ def test_convert_torch_vision_mobilenet_v2(tmpdir): @staticmethod def test_int64_inputs(): - import torch num_tokens = 3 embedding_size = 5 @@ -569,7 +570,6 @@ def test_fully_dynamic_inputs(): All dims of the inputs are dynamic, and write to slice to one of the inputs. """ - import torch class Model(torch.nn.Module): def __init__(self, index): @@ -598,14 +598,14 @@ def forward(self, x, y): torch_res = model(x, y) results = mlmodel.predict({"x": x.cpu().detach().numpy(), "y": y.cpu().detach().numpy()}) - np.testing.assert_allclose(torch_res[0], results['y.3']) + np.testing.assert_allclose(torch_res[0], results['y.5']) np.testing.assert_allclose(torch_res[1], results['x']) x, y = torch.rand(1, 6), torch.rand(2, 3) torch_res = model(x, y) results = mlmodel.predict({"x": x.cpu().detach().numpy(), "y": y.cpu().detach().numpy()}) - np.testing.assert_allclose(torch_res[0], results['y.3']) + np.testing.assert_allclose(torch_res[0], results['y.5']) np.testing.assert_allclose(torch_res[1], results['x']) @@ -639,6 +639,7 @@ def prog(x): ) assert len(prediction) == 1 +@pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) class TestInvalidInput: @staticmethod def test_rank0_inputs_mil(): @@ -656,7 +657,6 @@ def test_rank0_inputs_torch(): """Similar to TestPyTorchConverterExamples::test_int64_inputs but using rank-0 int input. """ - import torch num_tokens = 3 embedding_size = 5 @@ -798,7 +798,6 @@ def test_tf2keras_outofbound_range_dim(use_symbol): "use_symbol", [True, False]) @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) def test_torch_range_dim(use_symbol): - import torch num_tokens = 3 embedding_size = 5 @@ -858,8 +857,6 @@ def test_torch_range_dim_lstm(variable_length): """ This example shows how to run LSTM with previous hidden / cell states """ - import torch - import coremltools as ct input_size = 3 hidden_size = 2 @@ -953,7 +950,6 @@ def forward(self, x, hidden_state, cell_state): "use_symbol", [True, False]) @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) def test_torch_outofbound_range_dim(use_symbol): - import torch num_tokens = 3 embedding_size = 5 @@ -1059,7 +1055,6 @@ def test_tf2keras_enumerated_shapes(): @staticmethod @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) def test_torch_enumerated_shapes(): - import torch in_channels = 3 out_channels = 2 @@ -1134,7 +1129,6 @@ def test_tf2_image_enumerated_shapes(): @staticmethod @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) def test_torch_image_enumerated_shapes(): - import torch import torchvision torch_model = torchvision.models.mobilenet_v2().features torch_model.eval() @@ -1188,7 +1182,6 @@ def test_tf2keras_optional_input(): @staticmethod @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) def test_torch_optional_input(): - import torch num_tokens = 3 embedding_size = 5 @@ -1266,7 +1259,6 @@ def test_convert_tf2_keras(tmpdir): @staticmethod @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) def test_convert_torch_traced_model(tmpdir): - import torch from torch import nn class Network(nn.Module): def __init__(self): @@ -1320,21 +1312,17 @@ def test_mil_op_names_consistency(tmpdir): # compare op names of the two programs np.testing.assert_array_equal(get_op_types_in_program(mil_prog1), get_op_types_in_program(mil_prog2)) - -@pytest.mark.skipif(ct.utils._macos_version() < (10, 16), reason='Model produces specification 6.') class TestMLProgramConverterExamples: @staticmethod @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) @pytest.mark.parametrize( "convert_to", ['neuralnetwork', 'mlprogram']) def test_convert_to_argument(tmpdir, convert_to): - import torch - from torch import nn - class Network(nn.Module): + class Network(torch.nn.Module): def __init__(self): super(Network, self).__init__() - self.hidden = nn.Linear(30, 5) - self.relu = nn.ReLU() + self.hidden = torch.nn.Linear(30, 5) + self.relu = torch.nn.ReLU() def forward(self, x): x = self.hidden(x) @@ -1359,13 +1347,11 @@ def forward(self, x): @staticmethod @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) def test_deployment_target_argument(tmpdir): - import torch - from torch import nn - class Network(nn.Module): + class Network(torch.nn.Module): def __init__(self): super(Network, self).__init__() - self.hidden = nn.Linear(30, 5) - self.relu = nn.ReLU() + self.hidden = torch.nn.Linear(30, 5) + self.relu = torch.nn.ReLU() def forward(self, x): x = self.hidden(x) @@ -1456,13 +1442,11 @@ def prog(x): @staticmethod @pytest.mark.skipif(not _HAS_TORCH, reason=MSG_TORCH_NOT_FOUND) def test_get_milprogram_method(tmpdir): - import torch - from torch import nn - class Network(nn.Module): + class Network(torch.nn.Module): def __init__(self): super(Network, self).__init__() - self.hidden = nn.Linear(100, 10) - self.relu = nn.ReLU() + self.hidden = torch.nn.Linear(100, 10) + self.relu = torch.nn.ReLU() def forward(self, x): x = self.hidden(x) @@ -1480,7 +1464,55 @@ def forward(self, x): ) assert isinstance(model._get_mil_internal(), ct.converters.mil.Program) -@pytest.mark.skipif(ct.utils._macos_version() < (10, 16), reason='Model produces specification 6.') + @staticmethod + @pytest.mark.skipif(not _HAS_TORCH or ct.utils._macos_version() < (12, 0), + reason=MSG_TORCH_NOT_FOUND) + def test_classifier(): + torch_model = torch.nn.ReLU().eval() + traced_model = torch.jit.trace(torch_model, torch.rand(3,)) + model = ct.convert( + traced_model, + inputs=[ct.TensorType(shape=(3,))], + classifier_config = ct.ClassifierConfig(['a', 'b', 'c']), + convert_to='mlprogram' + ) + spec = model.get_spec() + input_name = spec.description.input[0].name + out_dict = model.predict({input_name : np.array([1.0, 2.0, 3.0])}) + assert 'classLabel' in out_dict + assert out_dict['classLabel'] == 'c' + + @pytest.mark.skipif(not ct.utils._is_macos(), reason="Platform is not Mac OS") + @pytest.mark.parametrize("skip_model_load", [True, False]) + def test_model_load_skip_flag(self, skip_model_load): + @mb.program(input_specs=[mb.TensorSpec(shape=(3,)), ]) + def prog(x): + return mb.relu(x=x, name='relu') + + if ct.utils._macos_version() < (12, 0) and not skip_model_load: + # converting to mlprogram, on macOS < 12 + # should raise a runtime error when skip_model_load is False + with pytest.warns(RuntimeWarning): + model = ct.convert(prog, convert_to='mlprogram', + skip_model_load=skip_model_load) + else: + model = ct.convert(prog, convert_to='mlprogram', + skip_model_load=skip_model_load) + + assert model is not None + if skip_model_load: + assert model.__proxy__ is None + model_dir = mkdtemp() + filename = os.path.join(model_dir, 'test.mlpackage') + model.save(filename) + assert os.path.exists(filename) + try: + rmtree(model_dir) + except: + pass + + +@pytest.mark.skipif(ct.utils._macos_version() < (12, 0), reason='Model produces specification 6.') class TestMLProgramFP16Transform: @staticmethod def test_compute_precision_api(): diff --git a/coremltools/test/api/test_api_visibilities.py b/coremltools/test/api/test_api_visibilities.py index 4b07abfd3..492c3a905 100644 --- a/coremltools/test/api/test_api_visibilities.py +++ b/coremltools/test/api/test_api_visibilities.py @@ -183,9 +183,6 @@ def test_converters_sklearn(self): def test_converters_xgboost(self): _check_visible_modules(_get_visible_items(ct.converters.xgboost), ["convert"]) - def test_converters_mil(self): - pass # TODO: [Create API visibility tests for MIL](rdar://64413959) - def test_models_neural_network_quantization_utils(self): expected = [ "AdvancedQuantizedLayerSelector", diff --git a/coremltools/test/modelpackage/test_modelpackage.py b/coremltools/test/modelpackage/test_modelpackage.py index 7ca81e5e5..072898b2f 100644 --- a/coremltools/test/modelpackage/test_modelpackage.py +++ b/coremltools/test/modelpackage/test_modelpackage.py @@ -128,7 +128,7 @@ def test_predict_api(self): model.save(package.name) loaded_model = MLModel(package.name) - if utils._macos_version() >= (10, 17): + if utils._macos_version() >= (12, 0): preds = loaded_model.predict({"feature_1": 1.0, "feature_2": 1.0}) self.assertIsNotNone(preds) self.assertEqual(preds["output"], 3.1) @@ -146,7 +146,7 @@ def test_rename_input(self): model.save(package.name) loaded_model = MLModel(package.name) - if utils._macos_version() >= (10, 17): + if utils._macos_version() >= (12, 0): preds = loaded_model.predict({"renamed_feature": 1.0, "feature_2": 1.0}) self.assertIsNotNone(preds) self.assertEqual(preds["output"], 3.1) @@ -167,7 +167,7 @@ def test_rename_input_bad(self): model.save(package.name) loaded_model = MLModel(package.name) - if utils._macos_version() >= (10, 17): + if utils._macos_version() >= (12, 0): preds = loaded_model.predict({"feature_1": 1.0, "feature_2": 1.0}) self.assertIsNotNone(preds) self.assertEqual(preds["output"], 3.1) @@ -187,7 +187,7 @@ def test_save(self): model.save(package.name) loaded_model = MLModel(package.name) - if utils._macos_version() >= (10, 17): + if utils._macos_version() >= (12, 0): preds = loaded_model.predict({"feature_1": 1.0, "feature_2": 1.0}) self.assertIsNotNone(preds) self.assertEqual(preds["output"], 3.1) @@ -241,7 +241,7 @@ def forward(self, x): # Read back the saved bundle and compile mlmodel2 = MLModel(package_path, useCPUOnly=True) - if utils._macos_version() >= (10, 17): + if utils._macos_version() >= (12, 0): result = mlmodel2.predict( {"input": example_input.cpu().detach().numpy().astype(np.float32)} ) diff --git a/coremltools/test/neural_network/test_keras2.py b/coremltools/test/neural_network/test_keras2.py index d7bfbd72c..ecddbe631 100644 --- a/coremltools/test/neural_network/test_keras2.py +++ b/coremltools/test/neural_network/test_keras2.py @@ -1374,8 +1374,6 @@ def test_updatable_model_flag_no_loss_optimizer(self): self.assertTrue(layers[1].innerProduct) self.assertTrue(layers[1].isUpdatable) - # - # when loss was specified as a string the converter had failed to work. def test_updatable_model_flag_mse_string_adam(self): """ Tests the 'respect_trainable' flag when used along with string @@ -1445,7 +1443,6 @@ def test_updatable_model_flag_mse_string_adam(self): self.assertEqual(adopt.beta2.defaultValue, 0.75) self.assertEqual(adopt.eps.defaultValue, 0.25) - # def test_updatable_model_flag_cce_string_sgd(self): """ Tests the 'respect_trainable' flag when used along with string diff --git a/coremltools/test/neural_network/test_model.py b/coremltools/test/neural_network/test_model.py index d1cd91cf3..e4ccb0da6 100644 --- a/coremltools/test/neural_network/test_model.py +++ b/coremltools/test/neural_network/test_model.py @@ -10,6 +10,7 @@ import tempfile import unittest +from coremltools._deps import _HAS_TORCH from coremltools.proto import Model_pb2 from coremltools.models.utils import ( rename_feature, @@ -23,6 +24,9 @@ from coremltools.models.neural_network import NeuralNetworkBuilder from coremltools.models.neural_network.utils import make_image_input, make_nn_classifier +if _HAS_TORCH: + import torch as _torch + class MLModelTest(unittest.TestCase): @classmethod @@ -503,6 +507,53 @@ def test_rename_image_input(self): out = mlmodel.predict({"new_input_name": pil_img}, useCPUOnly=True)['out'] np.testing.assert_equal(out, np.array([8.0, 10.0, 12.0]).reshape(3, 1, 1)) + @unittest.skipUnless( + _is_macos() and _macos_version() >= (12, 0) and _HAS_TORCH, "Only supported on macOS 12+" + ) + def test_rename_feature_mlprogram(self): + torch_model = _torch.nn.ReLU().eval() + model = coremltools.convert( + _torch.jit.trace(torch_model, _torch.rand(3, )), + inputs=[coremltools.TensorType(shape=(3,))], + convert_to='mlprogram' + ) + spec = model.get_spec() + input_name = spec.description.input[0].name + output_name = spec.description.output[0].name + + # rename input + rename_feature(spec, input_name, "new_input_name") + self.assertEqual(spec.description.input[0].name, "new_input_name") + model = coremltools.models.MLModel(spec) + out = model.predict({"new_input_name": np.array([1.0, 2.0, 3.0])})[output_name] + self.assertEqual(out[0], 1.0) + + # rename output + rename_feature(spec, output_name, "new_output_name") + self.assertEqual(spec.description.output[0].name, "new_output_name") + model = coremltools.models.MLModel(spec) + out = model.predict({"new_input_name": np.array([1.0, 2.0, 3.0])})["new_output_name"] + self.assertEqual(out[1], 2.0) + + @unittest.skipUnless( + _is_macos() and _macos_version() >= (12, 0) and _HAS_TORCH, "Only supported on macOS 12+" + ) + def test_rename_feature_classifier_mlprogram(self): + torch_model = _torch.nn.ReLU().eval() + model = coremltools.convert( + _torch.jit.trace(torch_model, _torch.rand(3, )), + inputs=[coremltools.TensorType(shape=(3,))], + classifier_config=coremltools.ClassifierConfig(['a', 'b', 'c']), + convert_to='mlprogram' + ) + spec = model.get_spec() + input_name = spec.description.input[0].name + + rename_feature(spec, 'classLabel', 'highestProbClass') + model = coremltools.models.MLModel(spec) + output_class = model.predict({input_name: np.array([1.0, 2.0, 3.0])})['highestProbClass'] + self.assertEqual(output_class, 'c') + if __name__ == "__main__": unittest.main() diff --git a/coremltools/test/neural_network/test_nn_builder.py b/coremltools/test/neural_network/test_nn_builder.py index 0bd54db1e..c4219f5d9 100644 --- a/coremltools/test/neural_network/test_nn_builder.py +++ b/coremltools/test/neural_network/test_nn_builder.py @@ -197,22 +197,21 @@ def test_linear_quant_convolution_8bit_vector_scalebias(self): expected_out = np.reshape(np.array([8, 44]), (2, 1, 1)) self.assertTrue(np.allclose(out, expected_out)) - @unittest.skip(" Investigate numerical discrepancy during quantization in CoreML") def test_linear_quant_convolution_8bit_float_scale_and_bias(self): W = np.array(([[[[1, 248], [248, 248]]]]), dtype=np.uint8) mlmodel = self.build_quant_conv_layer( W=W.flatten().tobytes(), quantization_type="linear", nbits=8, - quant_scale=[15.346457], - quant_bias=[-3913.3464], + quant_scale=[15], + quant_bias=[-3913], output_channels=1, ) data = np.ones((1, 2, 2)) data_dict = {"data": data} out = mlmodel.predict(data_dict, useCPUOnly=True)["out"] # Output should be equal to: (scale*(1+248+248+248)+(4*bias)) - expected_out = np.reshape(np.array([-4220.275]), (1, 1, 1, 1, 1)) + expected_out = np.reshape(np.array([-4477]), (1, 1, 1, 1, 1)) self.assertTrue(np.allclose(out, expected_out)) def test_lut_quant_convolution_2bit(self): @@ -362,10 +361,6 @@ def test_linear_quant_batchedmatmul_8bit(self): self.assertTrue(out.shape == expected_out.shape) self.assertTrue(np.allclose(out.flatten(), expected_out.flatten(), atol=0.1)) - @pytest.mark.xfail( - reason="rdar://78057487 (Re-enable tests after fixing regression in embedding layer)", - run=False - ) def test_lut_quant_embedding_nd_2bit(self): embed_size = 2 vocab_size = 3 @@ -408,10 +403,6 @@ def test_lut_quant_embedding_nd_2bit(self): self.assertTrue(np.allclose(out.flatten(), expected_out.flatten())) - @pytest.mark.xfail( - reason="rdar://78057487 (Re-enable tests after fixing regression in embedding layer)", - run=False - ) def test_linear_quant_embedding_7bit(self): embed_size = 2 vocab_size = 3 diff --git a/coremltools/test/neural_network/test_numpy_nn_layers.py b/coremltools/test/neural_network/test_numpy_nn_layers.py index fb3d19bfb..ed272d558 100644 --- a/coremltools/test/neural_network/test_numpy_nn_layers.py +++ b/coremltools/test/neural_network/test_numpy_nn_layers.py @@ -1938,10 +1938,6 @@ def test_floor_cpu(self, cpu_only=True): self._test_model(builder.spec, inputs, expected, useCPUOnly=cpu_only) - @pytest.mark.xfail(reason="[GitLab CI failure: test_floor_gpu](rdar://64311149)") - def test_floor_gpu(self): - self.test_floor_cpu(cpu_only=False) - def test_round_cpu(self, cpu_only=True): for rank in range(1, 6): shape = np.random.randint(low=2, high=8, size=rank) @@ -2424,9 +2420,6 @@ def test_tile_cpu(self, cpu_only=True): def test_tile_gpu(self): self.test_tile_cpu(cpu_only=False) - @pytest.mark.skip( - reason="rdar://65198011 (Re-enable Conv3dTranspose and DynamicTile unit tests)" - ) def test_dynamic_tile_cpu(self, cpu_only=True): for rank in range(1, 6): input_shape = np.random.randint(low=2, high=5, size=rank) @@ -2893,7 +2886,6 @@ def get_reference(data, output_shape, value, left_pad=False): def test_const_pad_mode2_gpu(self): self.test_const_pad_mode2_cpu(cpu_only=False) - @pytest.mark.xfail(reason="rdar://problem/59486372", run=False) def test_nms_cpu(self, cpu_only=True): def _compute_iou_matrix(boxes): # input is (N,4), in order [center_w, center_h, width, height] @@ -5006,7 +4998,6 @@ def test_reverse_sequence_cpu(self, cpu_only=True): def test_reverse_sequence_gpu(self): self.test_reverse_sequence_cpu(cpu_only=False) - @pytest.mark.skip("rdar://72018475 (Segfault in coremltools unit tests)") def test_where_nonzero_cpu(self, cpu_only=True): for rank in range(1, 6): @@ -5028,7 +5019,6 @@ def test_where_nonzero_cpu(self, cpu_only=True): expected = {"output": np.transpose(np.nonzero(x)).astype(np.float)} self._test_model(builder.spec, input, expected, useCPUOnly=cpu_only) - @pytest.mark.skip("rdar://72018475 (Segfault in coremltools unit tests)") def test_where_nonzero_gpu(self): self.test_where_nonzero_cpu(cpu_only=False) @@ -6674,6 +6664,10 @@ def upsample_pytorch_test_iter(self, scale_range, cpu_only): for scale_w in scale_range: for input_h in range(2, 6): for input_w in range(2, 6): + if not align_corners: + if scale_w - np.floor(scale_w) > .01 or scale_h - np.floor(scale_h) > .01: + # FIXME: rdar://79935318 + continue self.upsample_pytorch_test( input_h, input_w, diff --git a/coremltools/test/neural_network/test_quantization.py b/coremltools/test/neural_network/test_quantization.py index 22f017843..fa7c85a8e 100644 --- a/coremltools/test/neural_network/test_quantization.py +++ b/coremltools/test/neural_network/test_quantization.py @@ -174,10 +174,6 @@ def test_quantized_conv_dense(self): def test_quantized_tiny_conv_crop_1d_random(self): self.keras_tester.test_tiny_conv_crop_1d_random() - @pytest.mark.xfail( - reason="rdar://78057487 (Re-enable tests after fixing regression in embedding layer)", - run=False - ) def test_quantized_embedding(self): self.keras_tester.test_embedding() @@ -1003,10 +999,6 @@ def test_batched_matmul_1bit_weight_quantized(self): "Missing macOS 10.15+. Skipping tests.", ) class QuantizeWeightsAPI(unittest.TestCase): - @pytest.mark.xfail( - reason="rdar://78057487 (Re-enable tests after fixing regression in embedding layer)", - run=False - ) def test_embeddingND_quantize(self): input_features = [("data", datatypes.Array(10, 1))] output_features = [("output", None)] diff --git a/coremltools/test/xgboost_tests/test_boosted_trees_classifier_numeric.py b/coremltools/test/xgboost_tests/test_boosted_trees_classifier_numeric.py index 5f96fe4c5..32f1eedee 100644 --- a/coremltools/test/xgboost_tests/test_boosted_trees_classifier_numeric.py +++ b/coremltools/test/xgboost_tests/test_boosted_trees_classifier_numeric.py @@ -210,7 +210,7 @@ def _classifier_stress_test(self): self._train_convert_evaluate_assert(**arg) -@unittest.skipIf(_macos_version() >= (10, 16), "rdar://problem/75172473") +@unittest.skipIf(_macos_version() >= (12, 0), "rdar://problem/75172473") @unittest.skipIf(not _HAS_SKLEARN, "Missing sklearn. Skipping tests.") @unittest.skipIf(not _HAS_XGBOOST, "Skipping, no xgboost") class BoostedTreeBinaryClassificationBostonHousingXGboostNumericTest( @@ -241,7 +241,7 @@ def test_binary_classifier_stress_test(self): self._classifier_stress_test() -@unittest.skipIf(_macos_version() >= (10, 16), "rdar://problem/75172473") +@unittest.skipIf(_macos_version() >= (12, 0), "rdar://problem/75172473") @unittest.skipIf(not _HAS_SKLEARN, "Missing sklearn. Skipping tests.") @unittest.skipIf(not _HAS_XGBOOST, "Skipping, no xgboost") class BoostedTreeMultiClassClassificationBostonHousingXGboostNumericTest( diff --git a/coremltools/test/xgboost_tests/test_boosted_trees_regression_numeric.py b/coremltools/test/xgboost_tests/test_boosted_trees_regression_numeric.py index 7b4a91204..de837a9b3 100644 --- a/coremltools/test/xgboost_tests/test_boosted_trees_regression_numeric.py +++ b/coremltools/test/xgboost_tests/test_boosted_trees_regression_numeric.py @@ -88,7 +88,7 @@ def test_boston_housing_parameter_stress_test(self): self._train_convert_evaluate_assert(**arg) -@unittest.skipIf(_macos_version() >= (10, 16), "rdar://problem/75172473") +@unittest.skipIf(_macos_version() >= (12, 0), "rdar://problem/75172473") @unittest.skipIf(not _HAS_XGBOOST, "Missing xgboost. Skipping") @unittest.skipIf(not _HAS_SKLEARN, "Missing scikit-learn. Skipping tests.") class XgboostBoosterBostonHousingNumericTest(unittest.TestCase): @@ -199,7 +199,7 @@ def test_boston_housing_parameter_stress_test(self): self._train_convert_evaluate_assert(arg) -@unittest.skipIf(_macos_version() >= (10, 16), "rdar://problem/75172473") +@unittest.skipIf(_macos_version() >= (12, 0), "rdar://problem/75172473") @unittest.skipIf(not _HAS_XGBOOST, "Missing xgboost. Skipping") @unittest.skipIf(not _HAS_SKLEARN, "Missing sklearn. Skipping tests.") class XGboostRegressorBostonHousingNumericTest(unittest.TestCase): diff --git a/coremltools/version.py b/coremltools/version.py index 474b0e412..f8bf6cb1a 100644 --- a/coremltools/version.py +++ b/coremltools/version.py @@ -4,4 +4,4 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -__version__ = "5.0b1" # VERSION_STRING +__version__ = "5.0b2" # VERSION_STRING diff --git a/reqs/test.pip b/reqs/test.pip index 40231f64c..4ce664ea3 100644 --- a/reqs/test.pip +++ b/reqs/test.pip @@ -21,9 +21,9 @@ six sympy > 1.6 tensorflow==1.14.0; python_version < '3.8' torch==1.5.0; python_version == '3.5' -torch==1.7.1; python_version > '3.5' +torch==1.9.0; python_version > '3.5' torchvision==0.6.1; python_version == '3.5' -torchvision==0.8.2; python_version > '3.5' +torchvision==0.10.0; python_version > '3.5' xgboost mock wrapt diff --git a/setup.py b/setup.py index c14089aa4..a05932d63 100755 --- a/setup.py +++ b/setup.py @@ -85,6 +85,7 @@ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Topic :: Scientific/Engineering", "Topic :: Software Development", ],