Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM] Fix bug in interpolate node in notebook image-to-image #28957

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 49 additions & 29 deletions src/plugins/intel_cpu/src/nodes/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ using namespace Xbyak;

namespace ov::intel_cpu::node {

static inline bool isARM() {
bool isArm = false;
#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
isArm = true;
#endif
return isArm;
}

static inline bool isACL() {
bool isACL = false;
#if defined(OV_CPU_WITH_ACL)
isACL = true;
#endif
return isACL;
}

static inline bool isFloatCompatible(ov::element::Type prc) {
return one_of(prc, ov::element::f32, ov::element::bf16, ov::element::f16, ov::element::f64);
}
Expand Down Expand Up @@ -2138,11 +2154,13 @@ void Interpolate::initSupportedPrimitiveDescriptors() {

ov::element::Type inputPrecision = getOriginalInputPrecisionAtPort(DATA_ID);

#if defined(OV_CPU_WITH_ACL)
bool isInputPrecisionSupported = one_of(inputPrecision, ov::element::i8, ov::element::u8, ov::element::f16);
#else
bool isInputPrecisionSupported = one_of(inputPrecision, ov::element::i8, ov::element::u8, ov::element::bf16);
#endif
bool isInputPrecisionSupported = false;
if (isACL()) {
isInputPrecisionSupported = one_of(inputPrecision, ov::element::i8, ov::element::u8, ov::element::f16);
} else {
isInputPrecisionSupported = one_of(inputPrecision, ov::element::i8, ov::element::u8, ov::element::bf16);
}

if (!isInputPrecisionSupported) {
inputPrecision = ov::element::f32;
}
Expand All @@ -2162,11 +2180,11 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
outputPrecision = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(DATA_ID);
}

#if !defined(OV_CPU_WITH_ACL)
if (!mayiuse(cpu::x64::sse41)) {

if (!isACL() && !mayiuse(cpu::x64::sse41)) {
inputPrecision = outputPrecision = ov::element::f32;
}
#endif


auto targetShapeType = ov::element::i32;
auto scalesType = ov::element::f32;
Expand Down Expand Up @@ -2248,16 +2266,17 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
}
};
if (is_version11) {
#if defined(OV_CPU_WITH_ACL)
interpAttrs.hasPad = hasPad;
pushDesc(LayoutType::nspc, undef, true, true);
pushDesc(LayoutType::ncsp, undef, true, true);
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
if (canUseAclExecutor)
return;
// fallback to f32 if ref is used
inputPrecision = outputPrecision = ov::element::f32;
#endif
if (isACL()) {
interpAttrs.hasPad = hasPad;
pushDesc(LayoutType::nspc, undef, true, true);
pushDesc(LayoutType::ncsp, undef, true, true);
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
if (canUseAclExecutor) {
return;
}
// fallback to f32 if ref is used
inputPrecision = outputPrecision = ov::element::f32;
}

if (dataRank == 4) {
if (mayiuse(cpu::x64::avx512_core)) {
Expand Down Expand Up @@ -2285,16 +2304,17 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
const auto& dataMinDims = getInputShapeAtPort(DATA_ID).getMinDims();
bool isBlkApplied = dataRank > 1 && dataMinDims[1] != Shape::UNDEFINED_DIM && dataMinDims[1] > 1;

#if defined(OV_CPU_WITH_ACL)
interpAttrs.hasPad = hasPad;
pushDesc(LayoutType::nspc, undef, false, true);
pushDesc(LayoutType::ncsp, undef, false, true);
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
if (canUseAclExecutor)
return;
// fallback to f32 if ref is used
inputPrecision = outputPrecision = ov::element::f32;
#endif
if (isACL()) {
interpAttrs.hasPad = hasPad;
pushDesc(LayoutType::nspc, undef, false, true);
pushDesc(LayoutType::ncsp, undef, false, true);
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
if (canUseAclExecutor) {
return;
}
// fallback to f32 if ref is used
inputPrecision = outputPrecision = ov::element::f32;
}

if (!mayiuse(cpu::x64::sse41) || interpAttrs.mode == InterpolateMode::linear) {
pushDesc(LayoutType::ncsp, ref, false);
Expand Down Expand Up @@ -2462,7 +2482,7 @@ void Interpolate::prepareParams() {

std::vector<float> dataScales =
getScales(getPaddedInputShape(srcDims, interpAttrs.padBegin, interpAttrs.padEnd), dstDims);
if (!NCHWAsNHWC && (getOutputShapeAtPort(0).getRank() > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f))) {
if (!NCHWAsNHWC && (getOutputShapeAtPort(0).getRank() > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f)) && !isARM()) {
THROW_CPU_NODE_ERR("only supports resize on spatial dimensions(depth, height and width)");
}

Expand Down
Loading