Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU] selecting v_sat_pk instruction, version 2 #123297

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

Shoreshen
Copy link
Contributor

This PR uses TRUNCATE_SSAT_U node to select v_sat_pk instruction.

Compare to previous #121124 , this PR put most of pattern match task to combiner, instead of instruction selection.

@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: None (Shoreshen)

Changes

This PR uses TRUNCATE_SSAT_U node to select v_sat_pk instruction.

Compare to previous #121124 , this PR put most of pattern match task to combiner, instead of instruction selection.


Full diff: https://github.com/llvm/llvm-project/pull/123297.diff

6 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h (+1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td (+2)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+31)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+15)
  • (modified) llvm/test/CodeGen/AMDGPU/v_sat_pk_u8_i16.ll (+92-69)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index cca9fa72d0ca53..da9fe7e15e6620 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -5498,6 +5498,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(UMIN3)
   NODE_NAME_CASE(FMED3)
   NODE_NAME_CASE(SMED3)
+  NODE_NAME_CASE(SAT_PK_CAST)
   NODE_NAME_CASE(UMED3)
   NODE_NAME_CASE(FMAXIMUM3)
   NODE_NAME_CASE(FMINIMUM3)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index c74dc7942f52c0..6df4066c0fe6bc 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -461,6 +461,7 @@ enum NodeType : unsigned {
   FMED3,
   SMED3,
   UMED3,
+  SAT_PK_CAST,
   FMAXIMUM3,
   FMINIMUM3,
   FDOT2,
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
index bec294a945d2fe..2c4c9025134015 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
@@ -332,6 +332,8 @@ def AMDGPUumed3 : SDNode<"AMDGPUISD::UMED3", AMDGPUDTIntTernaryOp,
   []
 >;
 
+def AMDGPUsat_pk_cast : SDNode<"AMDGPUISD::SAT_PK_CAST", SDTUnaryOp, []>;
+
 def AMDGPUfmed3_impl : SDNode<"AMDGPUISD::FMED3", SDTFPTernaryOp, []>;
 
 def AMDGPUfdot2_impl : SDNode<"AMDGPUISD::FDOT2",
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index e068b5f0b8769b..58361b1e633039 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -865,6 +865,15 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
       setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::v2f16, Legal);
   }
 
+  // special dealing for v_sat_pk instruction
+  if (AMDGPU::isGFX9(STI) || AMDGPU::isGFX11(STI) || AMDGPU::isGFX12(STI)) {
+    // In foldToSaturated during DAG combine
+    // 1. isOperationLegalOrCustom(Opc, SrcVT) getOperationAction(Op, SrcVT) == Custom
+    // 2. isTypeDesirableForOp checks regclass for v2i8 (hooked now checking DstVT == v2i8)
+    // In CustomLowerNode during legalizing, checks getOperationAction(Op, DstVT) == Custom
+    setOperationAction(ISD::TRUNCATE_SSAT_U, {MVT::v2i16, MVT::v2i8}, Custom);
+  }
+  
   setOperationAction(ISD::INTRINSIC_WO_CHAIN,
                      {MVT::Other, MVT::f32, MVT::v4f32, MVT::i16, MVT::f16,
                       MVT::bf16, MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::i128,
@@ -1974,6 +1983,12 @@ bool SITargetLowering::isTypeDesirableForOp(unsigned Op, EVT VT) const {
   // create setcc with i1 operands.  We don't have instructions for i1 setcc.
   if (VT == MVT::i1 && Op == ISD::SETCC)
     return false;
+  
+  // Avoiding legality check for reg type of v2i8 
+  // (do not need to addRegisterClass for v2i8)
+  // VT is result type, ensure the result type is v2i8
+  if (VT == MVT::v2i8 && Op == ISD::TRUNCATE_SSAT_U)
+    return true;
 
   return TargetLowering::isTypeDesirableForOp(Op, VT);
 }
@@ -6605,6 +6620,12 @@ void SITargetLowering::ReplaceNodeResults(SDNode *N,
     Results.push_back(lowerFSQRTF16(SDValue(N, 0), DAG));
     break;
   }
+  case ISD::TRUNCATE_SSAT_U: {
+    SDLoc SL(N);
+    SDValue Op = DAG.getNode(AMDGPUISD::SAT_PK_CAST, SL, MVT::i16, N->getOperand(0));
+    Results.push_back(Op);
+    break;
+  }
   default:
     AMDGPUTargetLowering::ReplaceNodeResults(N, Results, DAG);
     break;
@@ -15184,6 +15205,16 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
       return Widened;
     [[fallthrough]];
   }
+  case ISD::BITCAST: {
+    // This is possible beause for (i16 bitcase (v2i8 trunc ...))
+    // It may be replaced bu (i16 bitcase (v2i8 truncssat_u ...))
+    // And then (i16 bitcase (i16 AMDGPUsat_pk_cast ...))
+    // There is no instruction of casting to the same type
+    SDValue Src = N->getOperand(0);
+    if (N->getValueType(0) == Src.getValueType()) {
+      return Src;
+    }
+  }
   default: {
     if (!DCI.isBeforeLegalize()) {
       if (MemSDNode *MemNode = dyn_cast<MemSDNode>(N))
diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td
index 1abbf4c217a697..7f098e37b893bf 100644
--- a/llvm/lib/Target/AMDGPU/SIInstructions.td
+++ b/llvm/lib/Target/AMDGPU/SIInstructions.td
@@ -3309,6 +3309,21 @@ def : GCNPat <
   (v2i16 (V_LSHL_OR_B32_e64 $src1, (i32 16), (i32 (V_AND_B32_e64 (i32 (V_MOV_B32_e32 (i32 0xffff))), $src0))))
 >;
 
+multiclass V_SAT_PK_Pat<Instruction inst> {
+  def : GCNPat<
+    (i16 (AMDGPUsat_pk_cast v2i16:$src)),
+    (inst VRegSrc_32:$src)
+  >;
+}
+
+let OtherPredicates = [NotHasTrue16BitInsts] in {
+  defm : V_SAT_PK_Pat<V_SAT_PK_U8_I16_e64>;
+} // End OtherPredicates = [NotHasTrue16BitInsts]
+
+let True16Predicate = UseFakeTrue16Insts in {
+  defm : V_SAT_PK_Pat<V_SAT_PK_U8_I16_fake16_e64>;
+} // End True16Predicate = UseFakeTrue16Insts
+
 // With multiple uses of the shift, this will duplicate the shift and
 // increase register pressure.
 def : GCNPat <
diff --git a/llvm/test/CodeGen/AMDGPU/v_sat_pk_u8_i16.ll b/llvm/test/CodeGen/AMDGPU/v_sat_pk_u8_i16.ll
index 2d84e877229515..695c8e1c680eef 100644
--- a/llvm/test/CodeGen/AMDGPU/v_sat_pk_u8_i16.ll
+++ b/llvm/test/CodeGen/AMDGPU/v_sat_pk_u8_i16.ll
@@ -1,12 +1,12 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=amdgcn -mcpu=fiji -verify-machineinstrs < %s | FileCheck -check-prefixes=SDAG-VI %s
 ; RUN: llc -mtriple=amdgcn -mcpu=gfx900 -verify-machineinstrs < %s | FileCheck -check-prefixes=SDAG-GFX9 %s
-; RUN: llc -mtriple=amdgcn -mcpu=gfx1101 -verify-machineinstrs < %s | FileCheck -check-prefixes=GFX11,SDAG-GFX11 %s
+; RUN: llc -mtriple=amdgcn -mcpu=gfx1101 -verify-machineinstrs < %s | FileCheck -check-prefixes=SDAG-GFX11 %s
 ; RUN: llc -mtriple=amdgcn -mcpu=gfx1200 -verify-machineinstrs < %s | FileCheck -check-prefixes=SDAG-GFX12 %s
 
 ; RUN: llc -mtriple=amdgcn -mcpu=fiji -verify-machineinstrs -global-isel < %s | FileCheck -check-prefixes=GISEL-VI %s
 ; RUN: llc -mtriple=amdgcn -mcpu=gfx900 -verify-machineinstrs -global-isel < %s | FileCheck -check-prefixes=GISEL-GFX9 %s
-; RUN: llc -mtriple=amdgcn -mcpu=gfx1101 -verify-machineinstrs -global-isel < %s | FileCheck -check-prefixes=GFX11,GISEL-GFX11 %s
+; RUN: llc -mtriple=amdgcn -mcpu=gfx1101 -verify-machineinstrs -global-isel < %s | FileCheck -check-prefixes=GISEL-GFX11 %s
 ; RUN: llc -mtriple=amdgcn -mcpu=gfx1200 -verify-machineinstrs -global-isel < %s | FileCheck -check-prefixes=GISEL-GFX12 %s
 
 ; <GFX9 has no V_SAT_PK, GFX9+ has V_SAT_PK, GFX11 has V_SAT_PK with t16
@@ -815,15 +815,15 @@ define i16 @basic_smax_smin_bit_or(i16 %src0, i16 %src1) {
 ; SDAG-GFX9-NEXT:    v_or_b32_e32 v0, v0, v1
 ; SDAG-GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
-; GFX11-LABEL: basic_smax_smin_bit_or:
-; GFX11:       ; %bb.0:
-; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX11-NEXT:    v_med3_i16 v1, v1, 0, 0xff
-; GFX11-NEXT:    v_med3_i16 v0, v0, 0, 0xff
-; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
-; GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
-; GFX11-NEXT:    s_setpc_b64 s[30:31]
+; SDAG-GFX11-LABEL: basic_smax_smin_bit_or:
+; SDAG-GFX11:       ; %bb.0:
+; SDAG-GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; SDAG-GFX11-NEXT:    v_med3_i16 v1, v1, 0, 0xff
+; SDAG-GFX11-NEXT:    v_med3_i16 v0, v0, 0, 0xff
+; SDAG-GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; SDAG-GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
+; SDAG-GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; SDAG-GFX11-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; SDAG-GFX12-LABEL: basic_smax_smin_bit_or:
 ; SDAG-GFX12:       ; %bb.0:
@@ -860,6 +860,16 @@ define i16 @basic_smax_smin_bit_or(i16 %src0, i16 %src1) {
 ; GISEL-GFX9-NEXT:    v_or_b32_e32 v0, v0, v1
 ; GISEL-GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
+; GISEL-GFX11-LABEL: basic_smax_smin_bit_or:
+; GISEL-GFX11:       ; %bb.0:
+; GISEL-GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GISEL-GFX11-NEXT:    v_med3_i16 v1, v1, 0, 0xff
+; GISEL-GFX11-NEXT:    v_med3_i16 v0, v0, 0, 0xff
+; GISEL-GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GISEL-GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
+; GISEL-GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; GISEL-GFX11-NEXT:    s_setpc_b64 s[30:31]
+;
 ; GISEL-GFX12-LABEL: basic_smax_smin_bit_or:
 ; GISEL-GFX12:       ; %bb.0:
 ; GISEL-GFX12-NEXT:    s_wait_loadcnt_dscnt 0x0
@@ -873,6 +883,15 @@ define i16 @basic_smax_smin_bit_or(i16 %src0, i16 %src1) {
 ; GISEL-GFX12-NEXT:    v_lshlrev_b16 v1, 8, v1
 ; GISEL-GFX12-NEXT:    v_or_b32_e32 v0, v0, v1
 ; GISEL-GFX12-NEXT:    s_setpc_b64 s[30:31]
+; GFX11-LABEL: basic_smax_smin_bit_or:
+; GFX11:       ; %bb.0:
+; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX11-NEXT:    v_med3_i16 v1, v1, 0, 0xff
+; GFX11-NEXT:    v_med3_i16 v0, v0, 0, 0xff
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
+; GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; GFX11-NEXT:    s_setpc_b64 s[30:31]
 
   %src0.max = call i16 @llvm.smax.i16(i16 %src0, i16 0)
   %src0.clamp = call i16 @llvm.smin.i16(i16 %src0.max, i16 255)
@@ -902,15 +921,15 @@ define i16 @basic_umax_umin_bit_or(i16 %src0, i16 %src1) {
 ; SDAG-GFX9-NEXT:    v_or_b32_e32 v0, v0, v1
 ; SDAG-GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
-; GFX11-LABEL: basic_umax_umin_bit_or:
-; GFX11:       ; %bb.0:
-; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX11-NEXT:    v_min_u16 v1, 0xff, v1
-; GFX11-NEXT:    v_min_u16 v0, 0xff, v0
-; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
-; GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
-; GFX11-NEXT:    s_setpc_b64 s[30:31]
+; SDAG-GFX11-LABEL: basic_umax_umin_bit_or:
+; SDAG-GFX11:       ; %bb.0:
+; SDAG-GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; SDAG-GFX11-NEXT:    v_min_u16 v1, 0xff, v1
+; SDAG-GFX11-NEXT:    v_min_u16 v0, 0xff, v0
+; SDAG-GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; SDAG-GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
+; SDAG-GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; SDAG-GFX11-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; SDAG-GFX12-LABEL: basic_umax_umin_bit_or:
 ; SDAG-GFX12:       ; %bb.0:
@@ -944,6 +963,16 @@ define i16 @basic_umax_umin_bit_or(i16 %src0, i16 %src1) {
 ; GISEL-GFX9-NEXT:    v_or_b32_e32 v0, v0, v1
 ; GISEL-GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
+; GISEL-GFX11-LABEL: basic_umax_umin_bit_or:
+; GISEL-GFX11:       ; %bb.0:
+; GISEL-GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GISEL-GFX11-NEXT:    v_min_u16 v1, 0xff, v1
+; GISEL-GFX11-NEXT:    v_min_u16 v0, 0xff, v0
+; GISEL-GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GISEL-GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
+; GISEL-GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; GISEL-GFX11-NEXT:    s_setpc_b64 s[30:31]
+;
 ; GISEL-GFX12-LABEL: basic_umax_umin_bit_or:
 ; GISEL-GFX12:       ; %bb.0:
 ; GISEL-GFX12-NEXT:    s_wait_loadcnt_dscnt 0x0
@@ -957,6 +986,15 @@ define i16 @basic_umax_umin_bit_or(i16 %src0, i16 %src1) {
 ; GISEL-GFX12-NEXT:    v_lshlrev_b16 v1, 8, v1
 ; GISEL-GFX12-NEXT:    v_or_b32_e32 v0, v0, v1
 ; GISEL-GFX12-NEXT:    s_setpc_b64 s[30:31]
+; GFX11-LABEL: basic_umax_umin_bit_or:
+; GFX11:       ; %bb.0:
+; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX11-NEXT:    v_min_u16 v1, 0xff, v1
+; GFX11-NEXT:    v_min_u16 v0, 0xff, v0
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
+; GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; GFX11-NEXT:    s_setpc_b64 s[30:31]
 
   %src0.max = call i16 @llvm.umax.i16(i16 %src0, i16 0)
   %src0.clamp = call i16 @llvm.umin.i16(i16 %src0.max, i16 255)
@@ -1093,15 +1131,15 @@ define i16 @basic_smax_smin_bit_shl(i16 %src0, i16 %src1) {
 ; SDAG-GFX9-NEXT:    v_or_b32_e32 v0, v0, v1
 ; SDAG-GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
-; GFX11-LABEL: basic_smax_smin_bit_shl:
-; GFX11:       ; %bb.0:
-; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX11-NEXT:    v_max_i16 v1, v1, 0
-; GFX11-NEXT:    v_med3_i16 v0, v0, 0, 0xff
-; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
-; GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
-; GFX11-NEXT:    s_setpc_b64 s[30:31]
+; SDAG-GFX11-LABEL: basic_smax_smin_bit_shl:
+; SDAG-GFX11:       ; %bb.0:
+; SDAG-GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; SDAG-GFX11-NEXT:    v_max_i16 v1, v1, 0
+; SDAG-GFX11-NEXT:    v_med3_i16 v0, v0, 0, 0xff
+; SDAG-GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; SDAG-GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
+; SDAG-GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; SDAG-GFX11-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; SDAG-GFX12-LABEL: basic_smax_smin_bit_shl:
 ; SDAG-GFX12:       ; %bb.0:
@@ -1137,6 +1175,16 @@ define i16 @basic_smax_smin_bit_shl(i16 %src0, i16 %src1) {
 ; GISEL-GFX9-NEXT:    v_or_b32_e32 v0, v0, v1
 ; GISEL-GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
+; GISEL-GFX11-LABEL: basic_smax_smin_bit_shl:
+; GISEL-GFX11:       ; %bb.0:
+; GISEL-GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GISEL-GFX11-NEXT:    v_max_i16 v1, v1, 0
+; GISEL-GFX11-NEXT:    v_med3_i16 v0, v0, 0, 0xff
+; GISEL-GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GISEL-GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
+; GISEL-GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; GISEL-GFX11-NEXT:    s_setpc_b64 s[30:31]
+;
 ; GISEL-GFX12-LABEL: basic_smax_smin_bit_shl:
 ; GISEL-GFX12:       ; %bb.0:
 ; GISEL-GFX12-NEXT:    s_wait_loadcnt_dscnt 0x0
@@ -1150,6 +1198,15 @@ define i16 @basic_smax_smin_bit_shl(i16 %src0, i16 %src1) {
 ; GISEL-GFX12-NEXT:    v_lshlrev_b16 v1, 8, v1
 ; GISEL-GFX12-NEXT:    v_or_b32_e32 v0, v0, v1
 ; GISEL-GFX12-NEXT:    s_setpc_b64 s[30:31]
+; GFX11-LABEL: basic_smax_smin_bit_shl:
+; GFX11:       ; %bb.0:
+; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX11-NEXT:    v_max_i16 v1, v1, 0
+; GFX11-NEXT:    v_med3_i16 v0, v0, 0, 0xff
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
+; GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; GFX11-NEXT:    s_setpc_b64 s[30:31]
 
   %src0.max = call i16 @llvm.smax.i16(i16 %src0, i16 0)
   %src0.clamp = call i16 @llvm.smin.i16(i16 %src0.max, i16 255)
@@ -1174,24 +1231,13 @@ define i16 @basic_smax_smin_vec_input(<2 x i16> %src) {
 ; SDAG-GFX9-LABEL: basic_smax_smin_vec_input:
 ; SDAG-GFX9:       ; %bb.0:
 ; SDAG-GFX9-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; SDAG-GFX9-NEXT:    s_movk_i32 s4, 0xff
-; SDAG-GFX9-NEXT:    v_pk_min_i16 v0, v0, s4 op_sel_hi:[1,0]
-; SDAG-GFX9-NEXT:    v_pk_max_i16 v0, v0, 0
-; SDAG-GFX9-NEXT:    v_lshrrev_b32_e32 v1, 16, v0
-; SDAG-GFX9-NEXT:    v_lshlrev_b16_e32 v1, 8, v1
-; SDAG-GFX9-NEXT:    v_or_b32_e32 v0, v0, v1
+; SDAG-GFX9-NEXT:    v_sat_pk_u8_i16_e32 v0, v0
 ; SDAG-GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; SDAG-GFX11-LABEL: basic_smax_smin_vec_input:
 ; SDAG-GFX11:       ; %bb.0:
 ; SDAG-GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; SDAG-GFX11-NEXT:    v_pk_min_i16 v0, 0xff, v0 op_sel_hi:[0,1]
-; SDAG-GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; SDAG-GFX11-NEXT:    v_pk_max_i16 v0, v0, 0
-; SDAG-GFX11-NEXT:    v_lshrrev_b32_e32 v1, 16, v0
-; SDAG-GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; SDAG-GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
-; SDAG-GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; SDAG-GFX11-NEXT:    v_sat_pk_u8_i16_e32 v0, v0
 ; SDAG-GFX11-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; SDAG-GFX12-LABEL: basic_smax_smin_vec_input:
@@ -1201,13 +1247,7 @@ define i16 @basic_smax_smin_vec_input(<2 x i16> %src) {
 ; SDAG-GFX12-NEXT:    s_wait_samplecnt 0x0
 ; SDAG-GFX12-NEXT:    s_wait_bvhcnt 0x0
 ; SDAG-GFX12-NEXT:    s_wait_kmcnt 0x0
-; SDAG-GFX12-NEXT:    v_pk_min_i16 v0, 0xff, v0 op_sel_hi:[0,1]
-; SDAG-GFX12-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; SDAG-GFX12-NEXT:    v_pk_max_i16 v0, v0, 0
-; SDAG-GFX12-NEXT:    v_lshrrev_b32_e32 v1, 16, v0
-; SDAG-GFX12-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; SDAG-GFX12-NEXT:    v_lshlrev_b16 v1, 8, v1
-; SDAG-GFX12-NEXT:    v_or_b32_e32 v0, v0, v1
+; SDAG-GFX12-NEXT:    v_sat_pk_u8_i16_e32 v0, v0
 ; SDAG-GFX12-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GISEL-VI-LABEL: basic_smax_smin_vec_input:
@@ -1290,24 +1330,13 @@ define i16 @basic_smax_smin_vec_input_rev(<2 x i16> %src) {
 ; SDAG-GFX9-LABEL: basic_smax_smin_vec_input_rev:
 ; SDAG-GFX9:       ; %bb.0:
 ; SDAG-GFX9-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; SDAG-GFX9-NEXT:    v_pk_max_i16 v0, v0, 0
-; SDAG-GFX9-NEXT:    s_movk_i32 s4, 0xff
-; SDAG-GFX9-NEXT:    v_pk_min_i16 v0, v0, s4 op_sel_hi:[1,0]
-; SDAG-GFX9-NEXT:    v_lshrrev_b32_e32 v1, 16, v0
-; SDAG-GFX9-NEXT:    v_lshlrev_b16_e32 v1, 8, v1
-; SDAG-GFX9-NEXT:    v_or_b32_e32 v0, v0, v1
+; SDAG-GFX9-NEXT:    v_sat_pk_u8_i16_e32 v0, v0
 ; SDAG-GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; SDAG-GFX11-LABEL: basic_smax_smin_vec_input_rev:
 ; SDAG-GFX11:       ; %bb.0:
 ; SDAG-GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; SDAG-GFX11-NEXT:    v_pk_max_i16 v0, v0, 0
-; SDAG-GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; SDAG-GFX11-NEXT:    v_pk_min_i16 v0, 0xff, v0 op_sel_hi:[0,1]
-; SDAG-GFX11-NEXT:    v_lshrrev_b32_e32 v1, 16, v0
-; SDAG-GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; SDAG-GFX11-NEXT:    v_lshlrev_b16 v1, 8, v1
-; SDAG-GFX11-NEXT:    v_or_b32_e32 v0, v0, v1
+; SDAG-GFX11-NEXT:    v_sat_pk_u8_i16_e32 v0, v0
 ; SDAG-GFX11-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; SDAG-GFX12-LABEL: basic_smax_smin_vec_input_rev:
@@ -1317,13 +1346,7 @@ define i16 @basic_smax_smin_vec_input_rev(<2 x i16> %src) {
 ; SDAG-GFX12-NEXT:    s_wait_samplecnt 0x0
 ; SDAG-GFX12-NEXT:    s_wait_bvhcnt 0x0
 ; SDAG-GFX12-NEXT:    s_wait_kmcnt 0x0
-; SDAG-GFX12-NEXT:    v_pk_max_i16 v0, v0, 0
-; SDAG-GFX12-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; SDAG-GFX12-NEXT:    v_pk_min_i16 v0, 0xff, v0 op_sel_hi:[0,1]
-; SDAG-GFX12-NEXT:    v_lshrrev_b32_e32 v1, 16, v0
-; SDAG-GFX12-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; SDAG-GFX12-NEXT:    v_lshlrev_b16 v1, 8, v1
-; SDAG-GFX12-NEXT:    v_or_b32_e32 v0, v0, v1
+; SDAG-GFX12-NEXT:    v_sat_pk_u8_i16_e32 v0, v0
 ; SDAG-GFX12-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GISEL-VI-LABEL: basic_smax_smin_vec_input_rev:

Copy link

github-actions bot commented Jan 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Pierre-vh Pierre-vh requested review from arsenm and Pierre-vh January 17, 2025 08:05
Copy link
Contributor

@Pierre-vh Pierre-vh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the new approach, seems more stable :)

Please add the [AMDGPU] prefix to the PR title

@@ -865,6 +865,15 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::v2f16, Legal);
}

// special dealing for v_sat_pk instruction
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Special case for v_sat_pk" fits more

// In foldToSaturated during DAG combine
// 1. isOperationLegalOrCustom(Opc, SrcVT) getOperationAction(Op, SrcVT) == Custom
// 2. isTypeDesirableForOp checks regclass for v2i8 (hooked now checking DstVT == v2i8)
// In CustomLowerNode during legalizing, checks getOperationAction(Op, DstVT) == Custom
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover debug comments?

@@ -5498,6 +5498,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(UMIN3)
NODE_NAME_CASE(FMED3)
NODE_NAME_CASE(SMED3)
NODE_NAME_CASE(SAT_PK_CAST)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this extra node, can't we just select from TRUNC_SSAT_U directly?
Is it because it gets transformed/lost otherwise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for printing and dumping, without this the debug dump will show unknown node

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the type signature is different. This is forcing the pack to use a legal integer type instead of v2i8

Comment on lines 1987 to 1989
// Avoiding legality check for reg type of v2i8
// (do not need to addRegisterClass for v2i8)
// VT is result type, ensure the result type is v2i8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can shorten into something like "v2i8 is illegal and only allowed in specific cases" ?

Comment on lines 15214 to 15216
if (N->getValueType(0) == Src.getValueType()) {
return Src;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (N->getValueType(0) == Src.getValueType()) {
return Src;
}
if (N->getValueType(0) == Src.getValueType())
return Src;

// This is possible beause for (i16 bitcase (v2i8 trunc ...))
// It may be replaced bu (i16 bitcase (v2i8 truncssat_u ...))
// And then (i16 bitcase (i16 AMDGPUsat_pk_cast ...))
// There is no instruction of casting to the same type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand that comment, can you elaborate?

@Shoreshen Shoreshen changed the title selecting v_sat_pk instruction, version 2 [AMDGPU] selecting v_sat_pk instruction, version 2 Jan 17, 2025
@@ -865,6 +865,19 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
setOperationAction({ISD::FMAXIMUM, ISD::FMINIMUM}, MVT::v2f16, Legal);
}

// special case for v_sat_pk
if (AMDGPU::isGFX9(STI) || AMDGPU::isGFX11(STI) || AMDGPU::isGFX12(STI)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not do generation checks, and especially don't do them directly here. This should be a subtarget feature (especially since it was removed in gfx10 and re-added based on this logic?)

Comment on lines 871 to 877
// 1. In foldToSaturated during DAG combine
// a. isOperationLegalOrCustom(Opc, SrcVT)
// will check getOperationAction(Op, SrcVT) == Custom
// b. isTypeDesirableForOp checks regclass for v2i8
// (hooked now checking DstVT == v2i8)
// 2. In CustomLowerNode during legalizing, checks
// getOperationAction(Op, DstVT) == Custom
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is too specifically describing implementation details in the combiner instead of why these cases are relevant

SDLoc SL(N);
SDValue Op =
DAG.getNode(AMDGPUISD::SAT_PK_CAST, SL, MVT::i16, N->getOperand(0));
Results.push_back(Op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to insert a cast to the original type, does this not assert as-is?

DL, MVT::i8)); // swz
Ops.push_back(M0Val.getValue(0)); // Chain
Ops.push_back(M0Val.getValue(1)); // Glue
DL, MVT::i8)); // swz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated formatting change

@@ -15184,6 +15208,20 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
return Widened;
[[fallthrough]];
}
case ISD::BITCAST: {
// If src.VT == dst.VT, there is no instruction can be select
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of this should be necessary, you shouldn't have to touch anything related to bitcast handling

@@ -3309,6 +3309,21 @@ def : GCNPat <
(v2i16 (V_LSHL_OR_B32_e64 $src1, (i32 16), (i32 (V_AND_B32_e64 (i32 (V_MOV_B32_e32 (i32 0xffff))), $src0))))
>;

multiclass V_SAT_PK_Pat<Instruction inst> {
def : GCNPat<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern isn't doing much, you should be able to pass the node to the SDNodeOperator argument to the instruction definition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , could you be more specific? Should I use other type of pattern?

Copy link
Contributor

@arsenm arsenm Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The node is basically the same as the instruction definition, so you should be able to use the built-in pattern attached to the instruction def.
Something like

in VOP1Instructions.td:
defm V_SAT_PK_U8_I16 : VOP1Inst_t16<"v_sat_pk_u8_i16", VOP_I16_I32, AMDGPUsat_pk_cast>;

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , by adding the node I got the following:

def V_SAT_PK_U8_I16_e64: list<dag> Pattern = [(set i16:$vdst, (AMDGPUsat_pk_cast (i32 (VOP3Mods0 i32:$src0))))];
def V_SAT_PK_U8_I16_fake16_e64: list<dag> Pattern = [(set i16:$vdst, (AMDGPUsat_pk_cast (i32 (VOP3Mods0 i32:$src0))))];
def V_SAT_PK_U8_I16_t16_e64: list<dag> Pattern = [(set i16:$vdst, (AMDGPUsat_pk_cast (i32 (VOP3OpSelMods i32:$src0, i32:$src0_modifiers))))];

I think there are 2 problems:

  1. The source is i32, instead of v2i16
  2. It requires the operand of AMDGPUsat_pk_cast be complex pattern of VOP3Mods0 or VOP3OpSelMods

If the instruction cannot cover any type of (i16 (AMDGPUsat_pk_cast v2i8)), we gain risk of failing in selection.

I also tried to create a new VOP_I16_V2I16 type, but it makes V_SAT_PK_U8_I16_e64 and V_SAT_PK_U8_I16_fake16_e64 4 operands instructions (with modifier, clamp and opsel)

I think in order to make the passing node work, I need to modify related complex pattern functions and replace (v2i8 (truncssat_u v2i16)) with some patterns that can fit the complex pattern functions

@@ -332,6 +332,8 @@ def AMDGPUumed3 : SDNode<"AMDGPUISD::UMED3", AMDGPUDTIntTernaryOp,
[]
>;

def AMDGPUsat_pk_cast : SDNode<"AMDGPUISD::SAT_PK_CAST", SDTUnaryOp, []>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to document what this is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain what the node is, not just to avoid v2i8. It's to pack a v2i18 into i16

@@ -816,6 +816,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
{MVT::v4f32, MVT::v8f32, MVT::v16f32, MVT::v32f32},
Custom);
}

// true 16 currently unsupported
if (!Subtarget->hasTrue16BitInsts() || (!Subtarget->useRealTrue16Insts() ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a single predicate for the instruction, and not related to the true16 support. This is likely a new subtarget feature

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , the true 16 is not supported since it has 4 operands. If a sub target uses true 16 and we convert the truncssat_u, the selection will fail since no pattern will fit it.

@Shoreshen Shoreshen requested review from Pierre-vh and arsenm January 20, 2025 15:58
@@ -816,6 +816,13 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
{MVT::v4f32, MVT::v8f32, MVT::v16f32, MVT::v32f32},
Custom);
}

// Avoid true 16 instruction
if (!Subtarget->hasTrue16BitInsts() || !Subtarget->useRealTrue16Insts()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs predicate for has the instruction

@@ -1975,6 +1982,10 @@ bool SITargetLowering::isTypeDesirableForOp(unsigned Op, EVT VT) const {
if (VT == MVT::i1 && Op == ISD::SETCC)
return false;

// v2i8 is illegal and only allowed in specific cases
if (VT == MVT::v2i8 && Op == ISD::TRUNCATE_SSAT_U)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't expect to need this, but this also should check the operation is actually available if it is actually needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , I think it is needed, because the dst type has to be v2i8, otherwise we do not have instructions to select (since this is specially made for v_sat_pk)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combiner checks isOperationLegalOrCustom though

Plus this wouldn't be a check specific to this vector size. We would want wider vectors split into pieces

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , could you be more specific??

Do you mean that we should remove the check here, and try to handle different dst type of truncate_ssat_u latter in ReplaceNodeResults??

Or should I accept any vNi8, and:

  1. try to pack then into pairs to select v_sat_pk
  2. extract elements from the result
  3. build_vector back to vNi8

while N is any even integer??

SDLoc SL(N);
SDValue Op =
DAG.getNode(AMDGPUISD::SAT_PK_CAST, SL, MVT::i16, N->getOperand(0));
Op = DAG.getNode(ISD::BITCAST, SL, MVT::v2i8, Op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't hardcode the v2i8, use the type from N

; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_1)
; GFX11-NEXT: v_lshlrev_b16 v1, 8, v1
; GFX11-NEXT: v_or_b32_e32 v0, v0, v1
; GFX11-NEXT: s_setpc_b64 s[30:31]

%src0.max = call i16 @llvm.umax.i16(i16 %src0, i16 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should test these patterns in more vector types, at least 2 x, 3 x and 4 x cases

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , currently the pattern will only work for v2i8 result type.

Or maybe for N x i8 case we could make it to N/2 of v2i8??

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should still have baseline tests with the multiples. Yes, ideally we would take multiples and split them

// MVT::vNi8 for dst type check in CustomLowerNode
setOperationAction(ISD::TRUNCATE_SSAT_U,
{
MVT::v2i16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't have to override every single type that could decompose. Ideally the combiner should be able to figure it out based on the legalizer rules

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , for truncate_ssat_u to be folded TLI.isOperationLegalOrCustom function has pass.

We didn't hook this function, so it goes default and will check getOperationAction(Op, SrcVT) == Custom, which will look up OpActions[(unsigned)VT.getSimpleVT().SimpleTy][Op], and this is set here.

If we do not set every vNi16 (source type), the related truncat_ssat_u will not be created.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For step 1 I wouldn't do this. The fix for this kind of issue is in the combiner forming them, not the legalizer rules for a specific operation

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , I agree, but changing the combiner will relate to changing llvm side code. currently I'm not planning on changing llvm side's code in this PR.

So I think we may either do this, or hook the TLI.isOperationLegalOrCustom function??

But I think hooking TLI.isOperationLegalOrCustom will make it strange, since the input variable is op code and SrcVT....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle the basic case in the first step, and don't worry about handling every vector perfectly right away.

@@ -1982,8 +2004,10 @@ bool SITargetLowering::isTypeDesirableForOp(unsigned Op, EVT VT) const {
if (VT == MVT::i1 && Op == ISD::SETCC)
return false;

// v2i8 is illegal and only allowed in specific cases
if (VT == MVT::v2i8 && Op == ISD::TRUNCATE_SSAT_U)
// Special case for vNi8 handling where N is even
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't think you should need anything in isTypeDesirableForOp

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , this function checks the destination type, while TLI.isOperationLegalOrCustom checks the source type.

The Dst type are vNi8, if we didn't return true here, it goes to the default function to check isTypeLegal(DstVT)

The backend haven't add register class for vNi8. We maybe can add the relevant register class, but makeing vNi8 legal for register class may cause unpredictable result.

Personally I think we could add the register class for relevant type when it is formally legal in the backend. So I decide to handle it here for special case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , this function checks the destination type, while TLI.isOperationLegalOrCustom checks the source type.

This just sounds buggy. The interpretation of which type is the one that matters for the opcode needs to be globally consistent

The backend haven't add register class for vNi8. We maybe can add the relevant register class,

Please no, this is a huge amount of work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , yeah it is kind of weird logic. And what made it more strange is that to get into the ReplaceNodeResults function, it will check TLI.getOperationAction(Opc, DstVT) == Custom........

But if we want to change this, I think we also need to modify AArch64 backend....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, bugs cause other bugs and all the use points need to be fixed

@@ -16,6 +16,10 @@ declare i16 @llvm.smax.i16(i16, i16)

declare <2 x i16> @llvm.smin.v2i16(<2 x i16>, <2 x i16>)
declare <2 x i16> @llvm.smax.v2i16(<2 x i16>, <2 x i16>)
declare <4 x i16> @llvm.smin.v4i16(<4 x i16>, <4 x i16>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test the 3 x case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , the case is added~

But currently the v2i16 case will not use saturation pack, only even number of elements are supported. And for all MVT, only v3i16 is odd vector.

@Shoreshen Shoreshen requested a review from arsenm January 23, 2025 01:00

if (EleNo == 2) {
SDValue Op =
DAG.getNode(AMDGPUISD::SAT_PK_CAST, SL, MVT::i16, N->getOperand(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Src from above

EVT DstEleVT = DstVT.getVectorElementType();
EVT SrcPairVT = EVT::getVectorVT(*DAG.getContext(), SrcEleVT, 2);
EVT DstPairVT = EVT::getVectorVT(*DAG.getContext(), DstEleVT, 2);
for (unsigned i = 0; i + 1 < EleNo; i = i + 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (unsigned i = 0; i + 1 < EleNo; i = i + 2) {
for (unsigned I = 0; I != EleNo; I += 2) {

EVT SrcVT = Src.getValueType();
EVT DstVT = N->getValueType(0);

assert(SrcVT.isVector() && DstVT.isVector());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also should assert the element type is i8


// True 16 instruction is current not supported
// FIXME: Add support for true 16 when supported
if (!(Subtarget->hasTrue16BitInsts() && Subtarget->useRealTrue16Insts())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how this is checking for the existence of the underlying instruction. Plus the negation should be pushed through the condition

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants