Skip to content

Commit

Permalink
Merge pull request #418 from Xilinx/bump_to_49c5cebb
Browse files Browse the repository at this point in the history
[AutoBump] Merge with fixes of 49c5ceb (Sep 16) (4)
  • Loading branch information
jorickert authored Dec 13, 2024
2 parents 1700d25 + 9191944 commit 9554f29
Show file tree
Hide file tree
Showing 9 changed files with 568 additions and 113 deletions.
233 changes: 142 additions & 91 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29830,6 +29830,144 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
}
}

// Constant ISD::SRA/SRL/SHL can be performed efficiently on vXi8 vectors by
// using vXi16 vector operations.
if (ConstantAmt &&
(VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
(VT == MVT::v64i8 && Subtarget.hasBWI())) &&
!Subtarget.hasXOP()) {
int NumElts = VT.getVectorNumElements();
MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
// We can do this extra fast if each pair of i8 elements is shifted by the
// same amount by doing this SWAR style: use a shift to move the valid bits
// to the right position, mask out any bits which crossed from one element
// to the other.
APInt UndefElts;
SmallVector<APInt, 64> AmtBits;
// This optimized lowering is only valid if the elements in a pair can
// be treated identically.
bool SameShifts = true;
SmallVector<APInt, 32> AmtBits16(NumElts / 2);
APInt UndefElts16 = APInt::getZero(AmtBits16.size());
if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
AmtBits, /*AllowWholeUndefs=*/true,
/*AllowPartialUndefs=*/false)) {
// Collect information to construct the BUILD_VECTOR for the i16 version
// of the shift. Conceptually, this is equivalent to:
// 1. Making sure the shift amounts are the same for both the low i8 and
// high i8 corresponding to the i16 lane.
// 2. Extending that shift amount to i16 for a build vector operation.
//
// We want to handle undef shift amounts which requires a little more
// logic (e.g. if one is undef and the other is not, grab the other shift
// amount).
for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
unsigned DstI = SrcI / 2;
// Both elements are undef? Make a note and keep going.
if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
AmtBits16[DstI] = APInt::getZero(16);
UndefElts16.setBit(DstI);
continue;
}
// Even element is undef? We will shift it by the same shift amount as
// the odd element.
if (UndefElts[SrcI]) {
AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
continue;
}
// Odd element is undef? We will shift it by the same shift amount as
// the even element.
if (UndefElts[SrcI + 1]) {
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
continue;
}
// Both elements are equal.
if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
continue;
}
// One of the provisional i16 elements will not have the same shift
// amount. Let's bail.
SameShifts = false;
break;
}
}
// We are only dealing with identical pairs.
if (SameShifts) {
// Cast the operand to vXi16.
SDValue R16 = DAG.getBitcast(VT16, R);
// Create our new vector of shift amounts.
SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
// Perform the actual shift.
unsigned LogicalOpc = Opc == ISD::SRA ? ISD::SRL : Opc;
SDValue ShiftedR = DAG.getNode(LogicalOpc, dl, VT16, R16, Amt16);
// Now we need to construct a mask which will "drop" bits that get
// shifted past the LSB/MSB. For a logical shift left, it will look
// like:
// MaskLowBits = (0xff << Amt16) & 0xff;
// MaskHighBits = MaskLowBits << 8;
// Mask = MaskLowBits | MaskHighBits;
//
// This masking ensures that bits cannot migrate from one i8 to
// another. The construction of this mask will be constant folded.
// The mask for a logical right shift is nearly identical, the only
// difference is that 0xff is shifted right instead of left.
SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
// The mask for the low bits is most simply expressed as an 8-bit
// field of all ones which is shifted in the exact same way the data
// is shifted but masked with 0xff.
SDValue MaskLowBits = DAG.getNode(LogicalOpc, dl, VT16, Splat255, Amt16);
MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
// The mask for the high bits is the same as the mask for the low bits but
// shifted up by 8.
SDValue MaskHighBits =
DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
SDValue Mask = DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
// Finally, we mask the shifted vector with the SWAR mask.
SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
Masked = DAG.getBitcast(VT, Masked);
if (Opc != ISD::SRA) {
// Logical shifts are complete at this point.
return Masked;
}
// At this point, we have done a *logical* shift right. We now need to
// sign extend the result so that we get behavior equivalent to an
// arithmetic shift right. Post-shifting by Amt16, our i8 elements are
// `8-Amt16` bits wide.
//
// To convert our `8-Amt16` bit unsigned numbers to 8-bit signed numbers,
// we need to replicate the bit at position `7-Amt16` into the MSBs of
// each i8.
// We can use the following trick to accomplish this:
// SignBitMask = 1 << (7-Amt16)
// (Masked ^ SignBitMask) - SignBitMask
//
// When the sign bit is already clear, this will compute:
// Masked + SignBitMask - SignBitMask
//
// This is equal to Masked which is what we want: the sign bit was clear
// so sign extending should be a no-op.
//
// When the sign bit is set, this will compute:
// Masked - SignBitmask - SignBitMask
//
// This is equal to Masked - 2*SignBitMask which will correctly sign
// extend our result.
SDValue CstHighBit = DAG.getConstant(0x80, dl, MVT::i8);
SDValue SplatHighBit = DAG.getSplat(VT, dl, CstHighBit);
// This does not induce recursion, all operands are constants.
SDValue SignBitMask = DAG.getNode(LogicalOpc, dl, VT, SplatHighBit, Amt);
SDValue FlippedSignBit =
DAG.getNode(ISD::XOR, dl, VT, Masked, SignBitMask);
SDValue Subtraction =
DAG.getNode(ISD::SUB, dl, VT, FlippedSignBit, SignBitMask);
return Subtraction;
}
}

// If possible, lower this packed shift into a vector multiply instead of
// expanding it into a sequence of scalar shifts.
// For v32i8 cases, it might be quicker to split/extend to vXi16 shifts.
Expand Down Expand Up @@ -29950,105 +30088,18 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
DAG.getNode(Opc, dl, ExtVT, R, Amt));
}

// Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors by using
// vXi16 vector operations.
// Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors as we
// extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI.
if (ConstantAmt && (Opc == ISD::SRA || Opc == ISD::SRL) &&
(VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
(VT == MVT::v64i8 && Subtarget.hasBWI())) &&
!Subtarget.hasXOP()) {
int NumElts = VT.getVectorNumElements();
MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
// We can do this extra fast if each pair of i8 elements is shifted by the
// same amount by doing this SWAR style: use a shift to move the valid bits
// to the right position, mask out any bits which crossed from one element
// to the other.
if (Opc == ISD::SRL || Opc == ISD::SHL) {
APInt UndefElts;
SmallVector<APInt, 64> AmtBits;
if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
AmtBits, /*AllowWholeUndefs=*/true,
/*AllowPartialUndefs=*/false)) {
// This optimized lowering is only valid if the elements in a pair can
// be treated identically.
bool SameShifts = true;
SmallVector<APInt, 32> AmtBits16(NumElts / 2);
APInt UndefElts16 = APInt::getZero(AmtBits16.size());
for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
unsigned DstI = SrcI / 2;
// Both elements are undef? Make a note and keep going.
if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
AmtBits16[DstI] = APInt::getZero(16);
UndefElts16.setBit(DstI);
continue;
}
// Even element is undef? We will shift it by the same shift amount as
// the odd element.
if (UndefElts[SrcI]) {
AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
continue;
}
// Odd element is undef? We will shift it by the same shift amount as
// the even element.
if (UndefElts[SrcI + 1]) {
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
continue;
}
// Both elements are equal.
if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
continue;
}
// One of the provisional i16 elements will not have the same shift
// amount. Let's bail.
SameShifts = false;
break;
}

// We are only dealing with identical pairs and the operation is a
// logical shift.
if (SameShifts) {
// Cast the operand to vXi16.
SDValue R16 = DAG.getBitcast(VT16, R);
// Create our new vector of shift amounts.
SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
// Perform the actual shift.
SDValue ShiftedR = DAG.getNode(Opc, dl, VT16, R16, Amt16);
// Now we need to construct a mask which will "drop" bits that get
// shifted past the LSB/MSB. For a logical shift left, it will look
// like:
// MaskLowBits = (0xff << Amt16) & 0xff;
// MaskHighBits = MaskLowBits << 8;
// Mask = MaskLowBits | MaskHighBits;
//
// This masking ensures that bits cannot migrate from one i8 to
// another. The construction of this mask will be constant folded.
// The mask for a logical right shift is nearly identical, the only
// difference is that 0xff is shifted right instead of left.
SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
// The mask for the low bits is most simply expressed as an 8-bit
// field of all ones which is shifted in the exact same way the data
// is shifted but masked with 0xff.
SDValue MaskLowBits = DAG.getNode(Opc, dl, VT16, Splat255, Amt16);
MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
// Thie mask for the high bits is the same as the mask for the low
// bits but shifted up by 8.
SDValue MaskHighBits =
DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
SDValue Mask =
DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
// Finally, we mask the shifted vector with the SWAR mask.
SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
return DAG.getBitcast(VT, Masked);
}
}
}
SDValue Cst8 = DAG.getTargetConstant(8, dl, MVT::i8);

// Extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI (it
// doesn't matter if the type isn't legal).
// Extend constant shift amount to vXi16 (it doesn't matter if the type
// isn't legal).
MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts);
Amt = DAG.getZExtOrTrunc(Amt, dl, ExVT);
Amt = DAG.getNode(ISD::SUB, dl, ExVT, DAG.getConstant(8, dl, ExVT), Amt);
Expand Down
93 changes: 93 additions & 0 deletions llvm/test/CodeGen/X86/vector-shift-ashr-128.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,99 @@ define <16 x i8> @constant_shift_v16i8(<16 x i8> %a) nounwind {
ret <16 x i8> %shift
}

define <16 x i8> @constant_shift_v16i8_pairs(<16 x i8> %a) nounwind {
; SSE2-LABEL: constant_shift_v16i8_pairs:
; SSE2: # %bb.0:
; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [65535,65535,65535,65535,65535,0,65535,65535]
; SSE2-NEXT: pandn %xmm0, %xmm1
; SSE2-NEXT: pmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
; SSE2-NEXT: por %xmm1, %xmm0
; SSE2-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
; SSE2-NEXT: pxor %xmm1, %xmm0
; SSE2-NEXT: psubb %xmm1, %xmm0
; SSE2-NEXT: retq
;
; SSE41-LABEL: constant_shift_v16i8_pairs:
; SSE41: # %bb.0:
; SSE41-NEXT: movdqa {{.*#+}} xmm1 = [32768,4096,512,8192,16384,u,2048,1024]
; SSE41-NEXT: pmulhuw %xmm0, %xmm1
; SSE41-NEXT: pblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
; SSE41-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
; SSE41-NEXT: movdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
; SSE41-NEXT: pxor %xmm1, %xmm0
; SSE41-NEXT: psubb %xmm1, %xmm0
; SSE41-NEXT: retq
;
; AVX-LABEL: constant_shift_v16i8_pairs:
; AVX: # %bb.0:
; AVX-NEXT: vpmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 # [32768,4096,512,8192,16384,u,2048,1024]
; AVX-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; AVX-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
; AVX-NEXT: vpxor %xmm1, %xmm0, %xmm0
; AVX-NEXT: vpsubb %xmm1, %xmm0, %xmm0
; AVX-NEXT: retq
;
; XOP-LABEL: constant_shift_v16i8_pairs:
; XOP: # %bb.0:
; XOP-NEXT: vpshab {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; XOP-NEXT: retq
;
; AVX512DQ-LABEL: constant_shift_v16i8_pairs:
; AVX512DQ: # %bb.0:
; AVX512DQ-NEXT: vpmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 # [32768,4096,512,8192,16384,u,2048,1024]
; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
; AVX512DQ-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
; AVX512DQ-NEXT: vpxor %xmm1, %xmm0, %xmm0
; AVX512DQ-NEXT: vpsubb %xmm1, %xmm0, %xmm0
; AVX512DQ-NEXT: retq
;
; AVX512BW-LABEL: constant_shift_v16i8_pairs:
; AVX512BW: # %bb.0:
; AVX512BW-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
; AVX512BW-NEXT: vpmovsxbw {{.*#+}} xmm1 = [1,4,7,3,2,0,5,6]
; AVX512BW-NEXT: vpsrlvw %zmm1, %zmm0, %zmm0
; AVX512BW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; AVX512BW-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
; AVX512BW-NEXT: vpxor %xmm1, %xmm0, %xmm0
; AVX512BW-NEXT: vpsubb %xmm1, %xmm0, %xmm0
; AVX512BW-NEXT: vzeroupper
; AVX512BW-NEXT: retq
;
; AVX512DQVL-LABEL: constant_shift_v16i8_pairs:
; AVX512DQVL: # %bb.0:
; AVX512DQVL-NEXT: vpmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 # [32768,4096,512,8192,16384,u,2048,1024]
; AVX512DQVL-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
; AVX512DQVL-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
; AVX512DQVL-NEXT: vpternlogq $108, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
; AVX512DQVL-NEXT: vpsubb %xmm1, %xmm0, %xmm0
; AVX512DQVL-NEXT: retq
;
; AVX512BWVL-LABEL: constant_shift_v16i8_pairs:
; AVX512BWVL: # %bb.0:
; AVX512BWVL-NEXT: vpsrlvw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
; AVX512BWVL-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
; AVX512BWVL-NEXT: vpternlogq $108, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
; AVX512BWVL-NEXT: vpsubb %xmm1, %xmm0, %xmm0
; AVX512BWVL-NEXT: retq
;
; X86-SSE-LABEL: constant_shift_v16i8_pairs:
; X86-SSE: # %bb.0:
; X86-SSE-NEXT: movdqa {{.*#+}} xmm1 = [65535,65535,65535,65535,65535,0,65535,65535]
; X86-SSE-NEXT: pandn %xmm0, %xmm1
; X86-SSE-NEXT: pmulhuw {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
; X86-SSE-NEXT: por %xmm1, %xmm0
; X86-SSE-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
; X86-SSE-NEXT: movdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
; X86-SSE-NEXT: pxor %xmm1, %xmm0
; X86-SSE-NEXT: psubb %xmm1, %xmm0
; X86-SSE-NEXT: retl
%shift = ashr <16 x i8> %a, <i8 1, i8 1, i8 4, i8 4, i8 7, i8 7, i8 3, i8 3, i8 2, i8 2, i8 0, i8 0, i8 5, i8 5, i8 6, i8 6>
ret <16 x i8> %shift
}

;
; Uniform Constant Shifts
;
Expand Down
Loading

0 comments on commit 9554f29

Please sign in to comment.