Skip to content

Commit

Permalink
Allow combining jump and select with simd conditional
Browse files Browse the repository at this point in the history
Signed-off-by: Zoltan Herczeg [email protected]
  • Loading branch information
Zoltan Herczeg authored and ksh8281 committed Nov 15, 2023
1 parent df1d9f4 commit a26da93
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 19 deletions.
10 changes: 9 additions & 1 deletion src/jit/Analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,16 @@ void JITCompiler::buildParamDependencies(uint32_t requiredStackSize, size_t next
end = operand + instr->resultCount();

// Only certain instructions are collected
if (instr->group() != Instruction::Immediate && instr->group() != Instruction::Compare && instr->group() != Instruction::CompareFloat) {
switch (instr->group()) {
case Instruction::Immediate:
case Instruction::Compare:
case Instruction::CompareFloat:
case Instruction::UnaryCondSIMD:
break;

default:
instr = nullptr;
break;
}

while (operand < end) {
Expand Down
4 changes: 3 additions & 1 deletion src/jit/Backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,9 @@ void JITCompiler::compile()
break;
}
case Instruction::UnaryCondSIMD: {
emitUnaryCondSIMD(m_compiler, item->asInstruction());
if (emitUnaryCondSIMD(m_compiler, item->asInstruction())) {
item = item->next();
}
break;
}
case Instruction::LoadLaneSIMD: {
Expand Down
35 changes: 34 additions & 1 deletion src/jit/SimdArm32Inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ static void emitUnarySIMD(sljit_compiler* compiler, Instruction* instr)
}
}

static void emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
static bool emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
{
Operand* operands = instr->operands();
JITArg args[2];
Expand Down Expand Up @@ -825,11 +825,44 @@ static void emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)

sljit_emit_simd_lane_mov(compiler, SLJIT_SIMD_STORE | SLJIT_SIMD_REG_128 | SLJIT_SIMD_LANE_SIGNED | SLJIT_32 | srcType, tmpReg, 0, dst, 0);
sljit_emit_op2u(compiler, SLJIT_SUB | SLJIT_SET_Z, dst, 0, SLJIT_IMM, 0);

ASSERT(instr->next() != nullptr);

if (instr->next()->isInstruction()) {
Instruction* nextInstr = instr->next()->asInstruction();

switch (nextInstr->opcode()) {
case ByteCode::JumpIfTrueOpcode:
case ByteCode::JumpIfFalseOpcode:
if (nextInstr->getParam(0)->item == instr) {
sljit_s32 type = SLJIT_NOT_EQUAL;

if (nextInstr->opcode() == ByteCode::JumpIfFalseOpcode) {
type ^= 0x1;
}

nextInstr->asExtended()->value().targetLabel->jumpFrom(sljit_emit_jump(compiler, type));
return true;
}
break;
case ByteCode::SelectOpcode:
if (nextInstr->getParam(2)->item == instr) {
emitSelect(compiler, nextInstr, SLJIT_NOT_EQUAL);
return true;
}
break;
default:
break;
}
}

sljit_emit_op_flags(compiler, SLJIT_MOV, dst, 0, SLJIT_NOT_EQUAL);

if (SLJIT_IS_MEM(args[1].arg)) {
sljit_emit_op1(compiler, SLJIT_MOV32, args[1].arg, args[1].argw, dst, 0);
}

return false;
}

static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
Expand Down
63 changes: 50 additions & 13 deletions src/jit/SimdArm64Inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,10 @@ static void emitUnarySIMD(sljit_compiler* compiler, Instruction* instr)

static void simdEmitAllTrue(sljit_compiler* compiler, sljit_s32 rd, sljit_s32 rn, SimdOp::IntSizeType size)
{
sljit_s32 tmpReg = SLJIT_FR1;

ASSERT(size != SimdOp::D2);
simdEmitOp(compiler, SimdOp::uminv | size, rn, rn, 0);
simdEmitOp(compiler, SimdOp::uminv | size, tmpReg, rn, 0);

auto type = SLJIT_SIMD_ELEM_8;

Expand All @@ -418,33 +420,35 @@ static void simdEmitAllTrue(sljit_compiler* compiler, sljit_s32 rd, sljit_s32 rn
type = SLJIT_SIMD_ELEM_32;
}

sljit_emit_simd_lane_mov(compiler, SLJIT_SIMD_REG_128 | SLJIT_SIMD_STORE | type, rd, 0, rn, 0);
sljit_emit_simd_lane_mov(compiler, SLJIT_SIMD_REG_128 | SLJIT_SIMD_STORE | type, tmpReg, 0, rd, 0);
sljit_emit_op2u(compiler, SLJIT_SUB | SLJIT_SET_Z, rd, 0, SLJIT_IMM, 0);
sljit_emit_op_flags(compiler, SLJIT_MOV, rd, 0, SLJIT_NOT_EQUAL);
}

static void simdEmitI64x2AllTrue(sljit_compiler* compiler, sljit_s32 rd, sljit_s32 rn)
static void simdEmitI64x2AllTrue(sljit_compiler* compiler, sljit_s32 rn)
{
simdEmitOp(compiler, SimdOp::cmeqz | SimdOp::D2, rn, rn, rn);
simdEmitOp(compiler, SimdOp::addp | SimdOp::D2, rn, rn, rn);
sljit_emit_fop1(compiler, SLJIT_CMP_F64 | SLJIT_SET_ORDERED_EQUAL, rn, 0, rn, 0);
sljit_emit_op_flags(compiler, SLJIT_MOV, rd, 0, SLJIT_ORDERED_EQUAL);
sljit_s32 tmpReg = SLJIT_FR1;

simdEmitOp(compiler, SimdOp::cmeqz | SimdOp::D2, tmpReg, rn, rn);
simdEmitOp(compiler, SimdOp::addp | SimdOp::D2, tmpReg, tmpReg, tmpReg);
sljit_emit_fop1(compiler, SLJIT_CMP_F64 | SLJIT_SET_ORDERED_EQUAL, tmpReg, 0, tmpReg, 0);
}

static void simdEmitV128AnyTrue(sljit_compiler* compiler, sljit_s32 rd, sljit_s32 rn)
{
simdEmitOp(compiler, SimdOp::umaxp | SimdOp::S4, rn, rn, rn);
sljit_emit_simd_lane_mov(compiler, SLJIT_SIMD_REG_128 | SLJIT_SIMD_ELEM_64 | SLJIT_SIMD_STORE, rd, 1, rn, 0);
sljit_s32 tmpReg = SLJIT_FR1;

simdEmitOp(compiler, SimdOp::umaxp | SimdOp::S4, tmpReg, rn, rn);
sljit_emit_simd_lane_mov(compiler, SLJIT_SIMD_REG_128 | SLJIT_SIMD_ELEM_64 | SLJIT_SIMD_STORE, tmpReg, 0, rd, 0);
sljit_emit_op2u(compiler, SLJIT_SUB | SLJIT_SET_Z, rd, 0, SLJIT_IMM, 0);
sljit_emit_op_flags(compiler, SLJIT_MOV, rd, 0, SLJIT_NOT_EQUAL);
}

static void emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
static bool emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
{
Operand* operands = instr->operands();
JITArg args[2];

sljit_s32 srcType = SLJIT_SIMD_ELEM_128;
sljit_s32 type = SLJIT_NOT_EQUAL;

switch (instr->opcode()) {
case ByteCode::I8X16AllTrueOpcode:
Expand All @@ -458,6 +462,7 @@ static void emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
break;
case ByteCode::I64X2AllTrueOpcode:
srcType = SLJIT_SIMD_ELEM_64;
type = SLJIT_ORDERED_EQUAL;
break;
default:
ASSERT(instr->opcode() == ByteCode::V128AnyTrueOpcode);
Expand All @@ -481,17 +486,49 @@ static void emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitAllTrue(compiler, dst, args[0].arg, SimdOp::S4);
break;
case ByteCode::I64X2AllTrueOpcode:
simdEmitI64x2AllTrue(compiler, dst, args[0].arg);
simdEmitI64x2AllTrue(compiler, args[0].arg);
break;
default:
ASSERT(instr->opcode() == ByteCode::V128AnyTrueOpcode);
simdEmitV128AnyTrue(compiler, dst, args[0].arg);
break;
}

ASSERT(instr->next() != nullptr);

if (instr->next()->isInstruction()) {
Instruction* nextInstr = instr->next()->asInstruction();

switch (nextInstr->opcode()) {
case ByteCode::JumpIfTrueOpcode:
case ByteCode::JumpIfFalseOpcode:
if (nextInstr->getParam(0)->item == instr) {
if (nextInstr->opcode() == ByteCode::JumpIfFalseOpcode) {
type ^= 0x1;
}

nextInstr->asExtended()->value().targetLabel->jumpFrom(sljit_emit_jump(compiler, type));
return true;
}
break;
case ByteCode::SelectOpcode:
if (nextInstr->getParam(2)->item == instr) {
emitSelect(compiler, nextInstr, type);
return true;
}
break;
default:
break;
}
}

sljit_emit_op_flags(compiler, SLJIT_MOV32, dst, 0, type);

if (SLJIT_IS_MEM(args[1].arg)) {
sljit_emit_op1(compiler, SLJIT_MOV32, args[1].arg, args[1].argw, dst, 0);
}

return false;
}

static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
Expand Down
38 changes: 35 additions & 3 deletions src/jit/SimdX86Inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -891,15 +891,15 @@ static void simdEmitAllTrue(sljit_compiler* compiler, uint32_t opcode, sljit_s32
simdEmitSSEOp(compiler, opcode, tmp, rn);
simdEmitSSEOp(compiler, SimdOp::ptest, tmp, tmp);
sljit_set_current_flags(compiler, SLJIT_SET_Z);
sljit_emit_op_flags(compiler, SLJIT_MOV32, rd, 0, SLJIT_ZERO);
}

static void emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
static bool emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
{
Operand* operands = instr->operands();
JITArg args[2];

sljit_s32 srcType = SLJIT_SIMD_ELEM_128;
sljit_s32 type = SLJIT_ZERO;

switch (instr->opcode()) {
case ByteCode::I8X16AllTrueOpcode:
Expand All @@ -917,6 +917,7 @@ static void emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
default:
ASSERT(instr->opcode() == ByteCode::V128AnyTrueOpcode);
srcType = SLJIT_SIMD_ELEM_128;
type = SLJIT_NOT_ZERO;
break;
}

Expand All @@ -942,13 +943,44 @@ static void emitUnaryCondSIMD(sljit_compiler* compiler, Instruction* instr)
ASSERT(instr->opcode() == ByteCode::V128AnyTrueOpcode);
simdEmitSSEOp(compiler, SimdOp::ptest, args[0].arg, args[0].arg);
sljit_set_current_flags(compiler, SLJIT_SET_Z);
sljit_emit_op_flags(compiler, SLJIT_MOV32, dst, 0, SLJIT_NOT_ZERO);
break;
}

ASSERT(instr->next() != nullptr);

if (instr->next()->isInstruction()) {
Instruction* nextInstr = instr->next()->asInstruction();

switch (nextInstr->opcode()) {
case ByteCode::JumpIfTrueOpcode:
case ByteCode::JumpIfFalseOpcode:
if (nextInstr->getParam(0)->item == instr) {
if (nextInstr->opcode() == ByteCode::JumpIfFalseOpcode) {
type ^= 0x1;
}

nextInstr->asExtended()->value().targetLabel->jumpFrom(sljit_emit_jump(compiler, type));
return true;
}
break;
case ByteCode::SelectOpcode:
if (nextInstr->getParam(2)->item == instr) {
emitSelect(compiler, nextInstr, type);
return true;
}
break;
default:
break;
}
}

sljit_emit_op_flags(compiler, SLJIT_MOV32, dst, 0, type);

if (SLJIT_IS_MEM(args[1].arg)) {
sljit_emit_op1(compiler, SLJIT_MOV32, args[1].arg, args[1].argw, dst, 0);
}

return false;
}

static void simdEmitPMinMax(sljit_compiler* compiler, uint32_t operation, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm)
Expand Down
100 changes: 100 additions & 0 deletions test/jit/compare-simd.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
(module
(func (export "test1") (param v128) (result i32 i32 i32 i32 i32)
block (result i32)
i32.const 1
local.get 0
v128.any_true
br_if 0
drop
i32.const 0
end

block (result i32)
i32.const 1
local.get 0
i8x16.all_true
br_if 0
drop
i32.const 0
end

block (result i32)
i32.const 1
local.get 0
i16x8.all_true
br_if 0
drop
i32.const 0
end

block (result i32)
i32.const 1
local.get 0
i32x4.all_true
br_if 0
drop
i32.const 0
end

block (result i32)
i32.const 1
local.get 0
i64x2.all_true
br_if 0
drop
i32.const 0
end
)

(func (export "test2") (param v128) (result i32 i32 i32 i32 i32)
block (result i32)
i32.const 1
i32.const 0
local.get 0
v128.any_true
select
end

block (result i32)
i32.const 1
i32.const 0
local.get 0
i8x16.all_true
select
end

block (result i32)
i32.const 1
i32.const 0
local.get 0
i16x8.all_true
select
end

block (result i32)
i32.const 1
i32.const 0
local.get 0
i32x4.all_true
select
end

block (result i32)
i32.const 1
i32.const 0
local.get 0
i64x2.all_true
select
end
)
)

(assert_return (invoke "test1" (v128.const i64x2 -1 -1))
(i32.const 1) (i32.const 1) (i32.const 1) (i32.const 1) (i32.const 1))
(assert_return (invoke "test1" (v128.const i64x2 0 0))
(i32.const 0) (i32.const 0) (i32.const 0) (i32.const 0) (i32.const 0))

(assert_return (invoke "test2" (v128.const i64x2 -1 -1))
(i32.const 1) (i32.const 1) (i32.const 1) (i32.const 1) (i32.const 1))
(assert_return (invoke "test2" (v128.const i64x2 0 0))
(i32.const 0) (i32.const 0) (i32.const 0) (i32.const 0) (i32.const 0))

0 comments on commit a26da93

Please sign in to comment.