Skip to content

Commit

Permalink
fix arm part for interpolate.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
allnes committed Feb 12, 2025
1 parent 98e04be commit 6efe824
Showing 1 changed file with 47 additions and 29 deletions.
76 changes: 47 additions & 29 deletions src/plugins/intel_cpu/src/nodes/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ using namespace Xbyak;

namespace ov::intel_cpu::node {

static inline bool isARM() {
bool isArm = false;
#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
isArm = true;
#endif
return isArm;
}

static inline bool isACL() {
bool isACL = false;
#if defined(OV_CPU_WITH_ACL)
isACL = true;
#endif
return isACL;
}

static inline bool isFloatCompatible(ov::element::Type prc) {
return one_of(prc, ov::element::f32, ov::element::bf16, ov::element::f16, ov::element::f64);
}
Expand Down Expand Up @@ -2138,11 +2154,13 @@ void Interpolate::initSupportedPrimitiveDescriptors() {

ov::element::Type inputPrecision = getOriginalInputPrecisionAtPort(DATA_ID);

#if defined(OV_CPU_WITH_ACL)
bool isInputPrecisionSupported = one_of(inputPrecision, ov::element::i8, ov::element::u8, ov::element::f16);
#else
bool isInputPrecisionSupported = one_of(inputPrecision, ov::element::i8, ov::element::u8, ov::element::bf16);
#endif
bool isInputPrecisionSupported = false;
if (isACL()) {
isInputPrecisionSupported = one_of(inputPrecision, ov::element::i8, ov::element::u8, ov::element::f16);
} else {
isInputPrecisionSupported = one_of(inputPrecision, ov::element::i8, ov::element::u8, ov::element::bf16);
}

if (!isInputPrecisionSupported) {
inputPrecision = ov::element::f32;
}
Expand All @@ -2162,11 +2180,11 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
outputPrecision = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(DATA_ID);
}

#if !defined(OV_CPU_WITH_ACL)
if (!mayiuse(cpu::x64::sse41)) {

if (!isACL() && !mayiuse(cpu::x64::sse41)) {
inputPrecision = outputPrecision = ov::element::f32;
}
#endif


auto targetShapeType = ov::element::i32;
auto scalesType = ov::element::f32;
Expand Down Expand Up @@ -2248,16 +2266,16 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
}
};
if (is_version11) {
#if defined(OV_CPU_WITH_ACL)
interpAttrs.hasPad = hasPad;
pushDesc(LayoutType::nspc, undef, true, true);
pushDesc(LayoutType::ncsp, undef, true, true);
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
if (canUseAclExecutor)
return;
// fallback to f32 if ref is used
inputPrecision = outputPrecision = ov::element::f32;
#endif
if (isACL()) {
interpAttrs.hasPad = hasPad;
pushDesc(LayoutType::nspc, undef, true, true);
pushDesc(LayoutType::ncsp, undef, true, true);
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
if (canUseAclExecutor)
return;
// fallback to f32 if ref is used
inputPrecision = outputPrecision = ov::element::f32;
}

if (dataRank == 4) {
if (mayiuse(cpu::x64::avx512_core)) {
Expand Down Expand Up @@ -2285,16 +2303,16 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
const auto& dataMinDims = getInputShapeAtPort(DATA_ID).getMinDims();
bool isBlkApplied = dataRank > 1 && dataMinDims[1] != Shape::UNDEFINED_DIM && dataMinDims[1] > 1;

#if defined(OV_CPU_WITH_ACL)
interpAttrs.hasPad = hasPad;
pushDesc(LayoutType::nspc, undef, false, true);
pushDesc(LayoutType::ncsp, undef, false, true);
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
if (canUseAclExecutor)
return;
// fallback to f32 if ref is used
inputPrecision = outputPrecision = ov::element::f32;
#endif
if (isACL()) {
interpAttrs.hasPad = hasPad;
pushDesc(LayoutType::nspc, undef, false, true);
pushDesc(LayoutType::ncsp, undef, false, true);
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
if (canUseAclExecutor)
return;
// fallback to f32 if ref is used
inputPrecision = outputPrecision = ov::element::f32;
}

if (!mayiuse(cpu::x64::sse41) || interpAttrs.mode == InterpolateMode::linear) {
pushDesc(LayoutType::ncsp, ref, false);
Expand Down Expand Up @@ -2462,7 +2480,7 @@ void Interpolate::prepareParams() {

std::vector<float> dataScales =
getScales(getPaddedInputShape(srcDims, interpAttrs.padBegin, interpAttrs.padEnd), dstDims);
if (!NCHWAsNHWC && (getOutputShapeAtPort(0).getRank() > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f))) {
if (!NCHWAsNHWC && (getOutputShapeAtPort(0).getRank() > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f)) && !isARM()) {
THROW_CPU_NODE_ERR("only supports resize on spatial dimensions(depth, height and width)");
}

Expand Down

0 comments on commit 6efe824

Please sign in to comment.