Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pairwise add and extmul for RISCV #323

Merged
merged 1 commit into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/jit/ByteCodeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
#define OTOp3DotAddV128 OTOp3V128

#elif (defined SLJIT_CONFIG_RISCV && SLJIT_CONFIG_RISCV)

#define OPERAND_TYPE_LIST_SIMD_ARCH \
OL2(OTOp1V128CB, /* SD */ V128 | NOTMP, V128 | NOTMP) \
OL3(OTOp2V128, /* SSD */ V128 | TMP, V128 | TMP, V128 | TMP | S0 | S1) \
Expand Down
116 changes: 79 additions & 37 deletions src/jit/SimdRiscvInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ enum TypeOpcode : uint32_t {
vand_vv = InstructionType::opivv | OPCODE(0x9),
vcompress_vm = InstructionType::opmvv | OPCODE(0x17),
#if defined(__riscv_zvbb)
vcpop_v = InstructionType::opmvv | OPCODE(0x12) | (0xE << 15),
vcpop_v = InstructionType::opmvv | OPCODE(0x12) | (0xe << 15),
#endif
vfadd_vf = InstructionType::opfvf | OPCODE(0x0),
vfadd_vv = InstructionType::opfvv | OPCODE(0x0),
Expand All @@ -56,7 +56,7 @@ enum TypeOpcode : uint32_t {
vfmul_vv = InstructionType::opfvv | OPCODE(0x24),
vfsgnj_vv = InstructionType::opfvv | OPCODE(0x8),
vfsgnjn_vv = InstructionType::opfvv | OPCODE(0x9),
vfsgnjx_vv = InstructionType::opfvv | OPCODE(0xA),
vfsgnjx_vv = InstructionType::opfvv | OPCODE(0xa),
vfsqrt_v = InstructionType::opfvv | OPCODE(0x13),
vfsub_vv = InstructionType::opfvv | OPCODE(0x2),
vmax_vv = InstructionType::opivv | OPCODE(0x7),
Expand All @@ -65,16 +65,16 @@ enum TypeOpcode : uint32_t {
vmerge_vv = (InstructionType::opivv ^ InstructionType::vm) | OPCODE(0x17),
vmfeq_vv = InstructionType::opfvv | OPCODE(0x18),
vmfle_vv = InstructionType::opfvv | OPCODE(0x19),
vmflt_vv = InstructionType::opfvv | OPCODE(0x1B),
vmfne_vv = InstructionType::opfvv | OPCODE(0x1C),
vmflt_vv = InstructionType::opfvv | OPCODE(0x1b),
vmfne_vv = InstructionType::opfvv | OPCODE(0x1c),
vmin_vv = InstructionType::opivv | OPCODE(0x5),
vminu_vv = InstructionType::opivv | OPCODE(0x4),
vmseq_vv = InstructionType::opivv | OPCODE(0x18),
vmsle_vv = InstructionType::opivv | OPCODE(0x1D),
vmsleu_vv = InstructionType::opivv | OPCODE(0x1C),
vmslt_vv = InstructionType::opivv | OPCODE(0x1B),
vmslt_vx = InstructionType::opivx | OPCODE(0x1B),
vmsltu_vv = InstructionType::opivv | OPCODE(0x1A),
vmsle_vv = InstructionType::opivv | OPCODE(0x1d),
vmsleu_vv = InstructionType::opivv | OPCODE(0x1c),
vmslt_vv = InstructionType::opivv | OPCODE(0x1b),
vmslt_vx = InstructionType::opivx | OPCODE(0x1b),
vmsltu_vv = InstructionType::opivv | OPCODE(0x1a),
vmsne_vi = InstructionType::opivi | OPCODE(0x19),
vmsne_vv = InstructionType::opivv | OPCODE(0x19),
vmul_vv = InstructionType::opmvv | OPCODE(0x25),
Expand All @@ -83,16 +83,16 @@ enum TypeOpcode : uint32_t {
vmv_vv = InstructionType::opivv | OPCODE(0x17),
vmv_vx = InstructionType::opivx | OPCODE(0x17),
vmv_xs = InstructionType::opmvv | OPCODE(0x10),
vor_vv = InstructionType::opivv | OPCODE(0xA),
vor_vv = InstructionType::opivv | OPCODE(0xa),
vredmaxu_vs = InstructionType::opmvv | OPCODE(0x6),
vredminu_vs = InstructionType::opmvv | OPCODE(0x4),
vredsum_vs = InstructionType::opmvv | OPCODE(0x0),
vrgather_vv = InstructionType::opivv | OPCODE(0xC),
vrgather_vv = InstructionType::opivv | OPCODE(0xc),
vrsub_vi = InstructionType::opivi | OPCODE(0x3),
vsadd_vv = InstructionType::opivv | OPCODE(0x21),
vsaddu_vv = InstructionType::opivv | OPCODE(0x20),
vsext_vf2 = InstructionType::opmvv | OPCODE(0x12) | (0x7 << 15),
vslidedown_vi = InstructionType::opivi | OPCODE(0xF),
vslidedown_vi = InstructionType::opivi | OPCODE(0xf),
vsll_vi = InstructionType::opivi | OPCODE(0x25),
vsll_vx = InstructionType::opivx | OPCODE(0x25),
vsra_vi = InstructionType::opivi | OPCODE(0x29),
Expand All @@ -102,9 +102,10 @@ enum TypeOpcode : uint32_t {
vssub_vv = InstructionType::opivv | OPCODE(0x23),
vssubu_vv = InstructionType::opivv | OPCODE(0x22),
vsub_vv = InstructionType::opivv | OPCODE(0x2),
vwmul_vv = InstructionType::opmvv | OPCODE(0x3B),
vxor_vi = InstructionType::opivi | OPCODE(0xB),
vxor_vv = InstructionType::opivv | OPCODE(0xB),
vwmul_vv = InstructionType::opmvv | OPCODE(0x3b),
vwmulu_vv = InstructionType::opmvv | OPCODE(0x38),
vxor_vi = InstructionType::opivi | OPCODE(0xb),
vxor_vv = InstructionType::opivv | OPCODE(0xb),
vzext_vf2 = InstructionType::opmvv | OPCODE(0x12) | (0x6 << 15),
};

Expand All @@ -115,9 +116,20 @@ enum OperandTypes : uint32_t {
rmIsGpr = 1 << 4,
rdIsGpr = 1 << 5
};

enum VectorLengthMultiplyTypes : uint32_t {
vlMul1 = 0,
vlMul2 = 1,
vlMul4 = 2,
vlMul8 = 3,
vlMulF2 = 7,
vlMulF4 = 6,
vlMulF8 = 5,
};

}; // namespace SimdOp

static void simdEmitVsetivli(struct sljit_compiler* compiler, sljit_s32 type, sljit_ins vlmul)
static void simdEmitVsetivli(struct sljit_compiler* compiler, sljit_s32 type, uint32_t vlmul)
{
uint32_t elem_size = (uint32_t)(((type) >> 18) & 0x3f);
uint32_t avl = (uint32_t)1 << (4 - elem_size);
Expand Down Expand Up @@ -151,7 +163,7 @@ static void simdEmitOp(sljit_compiler* compiler, uint32_t opcode, sljit_s32 rd,
sljit_emit_op_custom(compiler, &opcode, sizeof(uint32_t));
}

static void simdEmitTypedOp(sljit_compiler* compiler, sljit_s32 type, uint32_t opcode, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm, uint32_t optype = 0, sljit_ins vlmul = 0)
static void simdEmitTypedOp(sljit_compiler* compiler, sljit_s32 type, uint32_t opcode, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm, uint32_t optype = 0, uint32_t vlmul = 0)
{
simdEmitVsetivli(compiler, type, vlmul);
simdEmitOp(compiler, opcode, rd, rn, rm, optype);
Expand Down Expand Up @@ -186,6 +198,17 @@ static void simdEmitAbs(sljit_compiler* compiler, sljit_s32 type, sljit_s32 rd,
}
}

static void simdEmitPairwiseAdd(sljit_compiler* compiler, sljit_s32 type, bool isSigned, sljit_s32 rd, sljit_s32 rn)
{
sljit_s32 tmp = SLJIT_TMP_DEST_VREG;
sljit_s32 shift = (type == SLJIT_SIMD_ELEM_16) ? 8 : 16;

simdEmitTypedOp(compiler, type, isSigned ? SimdOp::vsra_vi : SimdOp::vsrl_vi, tmp, rn, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsll_vi, rd, rn, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, isSigned ? SimdOp::vsra_vi : SimdOp::vsrl_vi, rd, rd, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vadd_vv, rd, rd, tmp);
}

static void simdEmitAllTrue(sljit_compiler* compiler, sljit_s32 type, sljit_s32 rd, sljit_s32 rn)
{
sljit_s32 tmp = SLJIT_TMP_DEST_VREG;
Expand Down Expand Up @@ -530,8 +553,12 @@ static void emitUnarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitPopcnt(compiler, srcType, dst, args[0].arg, instr->requiredReg(1));
break;
case ByteCode::I16X8ExtaddPairwiseI8X16SOpcode:
case ByteCode::I32X4ExtaddPairwiseI16X8SOpcode:
simdEmitPairwiseAdd(compiler, dstType, true, dst, args[0].arg);
break;
case ByteCode::I16X8ExtaddPairwiseI8X16UOpcode:
case ByteCode::I32X4ExtaddPairwiseI16X8UOpcode:
simdEmitPairwiseAdd(compiler, dstType, false, dst, args[0].arg);
break;
case ByteCode::I16X8ExtendLowI8X16SOpcode:
case ByteCode::I32X4ExtendLowI16X8SOpcode:
Expand All @@ -553,10 +580,6 @@ static void emitUnarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::I64X2ExtendHighI32X4UOpcode:
simdEmitExtend(compiler, srcType, false, false, dst, args[0].arg);
break;
case ByteCode::I32X4ExtaddPairwiseI16X8SOpcode:
break;
case ByteCode::I32X4ExtaddPairwiseI16X8UOpcode:
break;
case ByteCode::I32X4TruncSatF32X4SOpcode:
simdEmitTruncSat(compiler, srcType, SimdOp::vfcvt_rtz_x_f_v, dst, args[0].arg);
break;
Expand Down Expand Up @@ -666,6 +689,29 @@ static bool emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
return false;
}

static void simdEmitExtmul(sljit_compiler* compiler, sljit_s32 type, uint32_t opcode, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm)
{
sljit_s32 tmp = SLJIT_TMP_DEST_VREG;
bool useTmp = (rd == rn || rd == rm);

simdEmitTypedOp(compiler, type, opcode, useTmp ? tmp : rd, rn, rm, 0, SimdOp::vlMulF2);

if (useTmp) {
simdEmitTypedOp(compiler, type, SimdOp::vmv_vv, rd, 0, tmp, SimdOp::rnIsImm);
}
}

static void simdEmitExtmulHigh(sljit_compiler* compiler, sljit_s32 type, uint32_t opcode, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm)
{
sljit_s32 tmp1 = SLJIT_TMP_DEST_VREG;
sljit_s32 tmp2 = SLJIT_VR0;

simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_8, SimdOp::vslidedown_vi, tmp1, rn, 8, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vslidedown_vi, tmp2, rm, 8, SimdOp::rmIsImm);

simdEmitTypedOp(compiler, type, opcode, rd, tmp1, tmp2, 0, SimdOp::vlMulF2);
}

static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
{
Operand* operands = instr->operands();
Expand Down Expand Up @@ -959,27 +1005,31 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitTypedOp(compiler, srcType, SimdOp::vmul_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8ExtmulLowI8X16SOpcode:
case ByteCode::I32X4ExtmulLowI16X8SOpcode:
case ByteCode::I64X2ExtmulLowI32X4SOpcode:
simdEmitExtmul(compiler, srcType, SimdOp::vwmul_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8ExtmulHighI8X16SOpcode:
case ByteCode::I32X4ExtmulHighI16X8SOpcode:
case ByteCode::I64X2ExtmulHighI32X4SOpcode:
simdEmitExtmulHigh(compiler, srcType, SimdOp::vwmul_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8ExtmulLowI8X16UOpcode:
case ByteCode::I32X4ExtmulLowI16X8UOpcode:
case ByteCode::I64X2ExtmulLowI32X4UOpcode:
simdEmitExtmul(compiler, srcType, SimdOp::vwmulu_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8ExtmulHighI8X16UOpcode:
case ByteCode::I32X4ExtmulHighI16X8UOpcode:
case ByteCode::I64X2ExtmulHighI32X4UOpcode:
simdEmitExtmulHigh(compiler, srcType, SimdOp::vwmulu_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I16X8NarrowI32X4SOpcode:
break;
case ByteCode::I16X8NarrowI32X4UOpcode:
break;
case ByteCode::I16X8Q15mulrSatSOpcode:
break;
case ByteCode::I32X4ExtmulLowI16X8SOpcode:
break;
case ByteCode::I32X4ExtmulHighI16X8SOpcode:
break;
case ByteCode::I32X4ExtmulLowI16X8UOpcode:
break;
case ByteCode::I32X4ExtmulHighI16X8UOpcode:
break;
case ByteCode::I32X4DotI16X8SOpcode:
simdEmitI32x4DotI16x8(compiler, srcType, dst, args[0].arg, args[1].arg);
break;
Expand Down Expand Up @@ -1039,14 +1089,6 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::F64X2GeOpcode:
simdEmitCompare(compiler, srcType, SimdOp::vmfle_vv, dst, args[1].arg, args[0].arg);
break;
case ByteCode::I64X2ExtmulLowI32X4SOpcode:
break;
case ByteCode::I64X2ExtmulHighI32X4SOpcode:
break;
case ByteCode::I64X2ExtmulLowI32X4UOpcode:
break;
case ByteCode::I64X2ExtmulHighI32X4UOpcode:
break;
case ByteCode::V128AndOpcode:
simdEmitTypedOp(compiler, srcType, SimdOp::vand_vv, dst, args[0].arg, args[1].arg);
break;
Expand Down
Loading