Skip to content

Commit

Permalink
Add support for Neon mixed-sign dot product instructions
Browse files Browse the repository at this point in the history
Add support for Neon vector usdot, indexed-vector usdot and sudot instructions,
available in the I8MM architecture extension.
  • Loading branch information
mmc28a committed Oct 1, 2021
1 parent 876aace commit 09925d4
Show file tree
Hide file tree
Showing 12 changed files with 258 additions and 124 deletions.
35 changes: 35 additions & 0 deletions src/aarch64/assembler-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3852,6 +3852,15 @@ void Assembler::udot(const VRegister& vd,
Emit(VFormat(vd) | NEON_UDOT | Rm(vm) | Rn(vn) | Rd(vd));
}

void Assembler::usdot(const VRegister& vd,
const VRegister& vn,
const VRegister& vm) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON, CPUFeatures::kI8MM));
VIXL_ASSERT(AreSameFormat(vn, vm));
VIXL_ASSERT((vd.Is2S() && vn.Is8B()) || (vd.Is4S() && vn.Is16B()));

Emit(VFormat(vd) | 0x0e809c00 | Rm(vm) | Rn(vn) | Rd(vd));
}

void Assembler::faddp(const VRegister& vd, const VRegister& vn) {
VIXL_ASSERT(CPUHas(CPUFeatures::kFP, CPUFeatures::kNEON));
Expand Down Expand Up @@ -4166,6 +4175,32 @@ void Assembler::udot(const VRegister& vd,
ImmNEONHLM(vm_index, index_num_bits) | Rm(vm) | Rn(vn) | Rd(vd));
}

void Assembler::sudot(const VRegister& vd,
const VRegister& vn,
const VRegister& vm,
int vm_index) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON, CPUFeatures::kI8MM));
VIXL_ASSERT((vd.Is2S() && vn.Is8B() && vm.Is1S4B()) ||
(vd.Is4S() && vn.Is16B() && vm.Is1S4B()));
int q = vd.Is4S() ? (1U << NEONQ_offset) : 0;
int index_num_bits = 2;
Emit(q | 0x0f00f000 | ImmNEONHLM(vm_index, index_num_bits) | Rm(vm) | Rn(vn) |
Rd(vd));
}


void Assembler::usdot(const VRegister& vd,
const VRegister& vn,
const VRegister& vm,
int vm_index) {
VIXL_ASSERT(CPUHas(CPUFeatures::kNEON, CPUFeatures::kI8MM));
VIXL_ASSERT((vd.Is2S() && vn.Is8B() && vm.Is1S4B()) ||
(vd.Is4S() && vn.Is16B() && vm.Is1S4B()));
int q = vd.Is4S() ? (1U << NEONQ_offset) : 0;
int index_num_bits = 2;
Emit(q | 0x0f80f000 | ImmNEONHLM(vm_index, index_num_bits) | Rm(vm) | Rn(vn) |
Rd(vd));
}

// clang-format off
#define NEON_BYELEMENT_LIST(V) \
Expand Down
15 changes: 15 additions & 0 deletions src/aarch64/assembler-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -3367,6 +3367,21 @@ class Assembler : public vixl::internal::AssemblerBase {
// Unsigned dot product [Armv8.2].
void udot(const VRegister& vd, const VRegister& vn, const VRegister& vm);

// Dot Product with unsigned and signed integers (vector).
void usdot(const VRegister& vd, const VRegister& vn, const VRegister& vm);

// Dot product with signed and unsigned integers (vector, by element).
void sudot(const VRegister& vd,
const VRegister& vn,
const VRegister& vm,
int vm_index);

// Dot product with unsigned and signed integers (vector, by element).
void usdot(const VRegister& vd,
const VRegister& vn,
const VRegister& vm,
int vm_index);

// Signed saturating rounding doubling multiply subtract returning high half
// [Armv8.1].
void sqrdmlsh(const VRegister& vd, const VRegister& vn, const VRegister& vm);
Expand Down
6 changes: 6 additions & 0 deletions src/aarch64/cpu-features-auditor-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,12 @@ void CPUFeaturesAuditor::Visit(Metadata* metadata, const Instruction* instr) {
CPUFeatures(CPUFeatures::kSVE, CPUFeatures::kSVEF64MM)},
{"ld1roh_z_p_br_contiguous",
CPUFeatures(CPUFeatures::kSVE, CPUFeatures::kSVEF64MM)},
{"usdot_asimdsame2_d",
CPUFeatures(CPUFeatures::kNEON, CPUFeatures::kI8MM)},
{"sudot_asimdelem_d",
CPUFeatures(CPUFeatures::kNEON, CPUFeatures::kI8MM)},
{"usdot_asimdelem_d",
CPUFeatures(CPUFeatures::kNEON, CPUFeatures::kI8MM)},
{"usdot_z_zzz_s",
CPUFeatures(CPUFeatures::kSVE, CPUFeatures::kSVEI8MM)},
{"usdot_z_zzzi_s",
Expand Down
3 changes: 0 additions & 3 deletions src/aarch64/decoder-visitor-map-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -2861,12 +2861,9 @@
{"subg_64_addsub_immtags", &VISITORCLASS::VisitUnimplemented}, \
{"subps_64s_dp_2src", &VISITORCLASS::VisitUnimplemented}, \
{"subp_64s_dp_2src", &VISITORCLASS::VisitUnimplemented}, \
{"sudot_asimdelem_d", &VISITORCLASS::VisitUnimplemented}, \
{"tcancel_ex_exception", &VISITORCLASS::VisitUnimplemented}, \
{"tstart_br_systemresult", &VISITORCLASS::VisitUnimplemented}, \
{"ttest_br_systemresult", &VISITORCLASS::VisitUnimplemented}, \
{"usdot_asimdelem_d", &VISITORCLASS::VisitUnimplemented}, \
{"usdot_asimdsame2_d", &VISITORCLASS::VisitUnimplemented}, \
{"wfet_only_systeminstrswithreg", &VISITORCLASS::VisitUnimplemented}, \
{"wfit_only_systeminstrswithreg", &VISITORCLASS::VisitUnimplemented}, \
{"xar_vvv2_crypto3_imm6", &VISITORCLASS::VisitUnimplemented}, \
Expand Down
68 changes: 25 additions & 43 deletions src/aarch64/disasm-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ Disassembler::FormToVisitorFnMap Disassembler::form_to_visitor_ = {
{"sqdmlsl_asimdelem_l", &Disassembler::DisassembleNEONMulByElementLong},
{"sdot_asimdelem_d", &Disassembler::DisassembleNEONDotProdByElement},
{"udot_asimdelem_d", &Disassembler::DisassembleNEONDotProdByElement},
{"usdot_asimdelem_d", &Disassembler::DisassembleNEONDotProdByElement},
{"sudot_asimdelem_d", &Disassembler::DisassembleNEONDotProdByElement},
{"fmlal2_asimdelem_lh", &Disassembler::DisassembleNEONFPMulByElementLong},
{"fmlal_asimdelem_lh", &Disassembler::DisassembleNEONFPMulByElementLong},
{"fmlsl2_asimdelem_lh", &Disassembler::DisassembleNEONFPMulByElementLong},
Expand Down Expand Up @@ -376,6 +378,7 @@ Disassembler::FormToVisitorFnMap Disassembler::form_to_visitor_ = {
&Disassembler::VisitSVELoadAndBroadcastQOWord_ScalarPlusScalar},
{"usdot_z_zzzi_s", &Disassembler::VisitSVEMulIndex},
{"sudot_z_zzzi_s", &Disassembler::VisitSVEMulIndex},
{"usdot_asimdsame2_d", &Disassembler::VisitNEON3SameExtra},
};

Disassembler::Disassembler() {
Expand Down Expand Up @@ -3309,40 +3312,32 @@ void Disassembler::VisitNEON3SameFP16(const Instruction *instr) {
void Disassembler::VisitNEON3SameExtra(const Instruction *instr) {
static const NEONFormatMap map_usdot = {{30}, {NF_8B, NF_16B}};

const char *mnemonic = "unallocated";
const char *form = "(NEON3SameExtra)";
const char *mnemonic = mnemonic_.c_str();
const char *form = "'Vd.%s, 'Vn.%s, 'Vm.%s";
const char *suffix = NULL;

NEONFormatDecoder nfd(instr);

if (instr->Mask(NEON3SameExtraFCMLAMask) == NEON_FCMLA) {
mnemonic = "fcmla";
form = "'Vd.%s, 'Vn.%s, 'Vm.%s, 'IVFCNM";
} else if (instr->Mask(NEON3SameExtraFCADDMask) == NEON_FCADD) {
mnemonic = "fcadd";
form = "'Vd.%s, 'Vn.%s, 'Vm.%s, 'IVFCNA";
} else {
form = "'Vd.%s, 'Vn.%s, 'Vm.%s";
switch (instr->Mask(NEON3SameExtraMask)) {
case NEON_SDOT:
mnemonic = "sdot";
nfd.SetFormatMap(1, &map_usdot);
nfd.SetFormatMap(2, &map_usdot);
break;
case NEON_SQRDMLAH:
mnemonic = "sqrdmlah";
break;
case NEON_UDOT:
mnemonic = "udot";
nfd.SetFormatMap(1, &map_usdot);
nfd.SetFormatMap(2, &map_usdot);
break;
case NEON_SQRDMLSH:
mnemonic = "sqrdmlsh";
break;
}
switch (form_hash_) {
case Hash("fcmla_asimdsame2_c"):
suffix = ", #'u1211*90";
break;
case Hash("fcadd_asimdsame2_c"):
// Bit 10 is always set, so this gives 90 * 1 or 3.
suffix = ", #'u1212:1010*90";
break;
case Hash("sdot_asimdsame2_d"):
case Hash("udot_asimdsame2_d"):
case Hash("usdot_asimdsame2_d"):
nfd.SetFormatMap(1, &map_usdot);
nfd.SetFormatMap(2, &map_usdot);
break;
default:
// sqrdml[as]h - nothing to do.
break;
}

Format(instr, mnemonic, nfd.Substitute(form));
Format(instr, mnemonic, nfd.Substitute(form), suffix);
}


Expand Down Expand Up @@ -3566,7 +3561,7 @@ void Disassembler::DisassembleNEONMulByElementLong(const Instruction *instr) {

void Disassembler::DisassembleNEONDotProdByElement(const Instruction *instr) {
const char *form = instr->ExtractBit(30) ? "'Vd.4s, 'Vn.16" : "'Vd.2s, 'Vn.8";
const char *suffix = "b, 'Ve.4b['IVByElemIndex]";
const char *suffix = "b, 'Vm.4b['u1111:2121]";
Format(instr, mnemonic_.c_str(), form, suffix);
}

Expand Down Expand Up @@ -10790,19 +10785,6 @@ int Disassembler::SubstituteImmediateField(const Instruction *instr,
}
case 'V': { // Immediate Vector.
switch (format[2]) {
case 'F': {
switch (format[5]) {
// Convert 'rot' bit encodings into equivalent angle rotation
case 'A':
AppendToOutput("#%" PRId32,
instr->GetImmRotFcadd() == 1 ? 270 : 90);
break;
case 'M':
AppendToOutput("#%" PRId32, instr->GetImmRotFcmlaVec() * 90);
break;
}
return strlen("IVFCN") + 1;
}
case 'E': { // IVExtract.
AppendToOutput("#%" PRId32, instr->GetImmNEONExt());
return 9;
Expand Down
34 changes: 0 additions & 34 deletions src/aarch64/logic-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -859,23 +859,6 @@ LogicVRegister Simulator::sqrdmulh(VectorFormat vform,
}


LogicVRegister Simulator::sdot(VectorFormat vform,
LogicVRegister dst,
const LogicVRegister& src1,
const LogicVRegister& src2,
int index) {
SimVRegister temp;
// NEON indexed `dot` allows the index value exceed the register size.
// Promote the format to Q-sized vector format before the duplication.
dup_elements_to_segments(IsSVEFormat(vform) ? vform
: VectorFormatFillQ(vform),
temp,
src2,
index);
return sdot(vform, dst, src1, temp);
}


LogicVRegister Simulator::sqrdmlah(VectorFormat vform,
LogicVRegister dst,
const LogicVRegister& src1,
Expand All @@ -887,23 +870,6 @@ LogicVRegister Simulator::sqrdmlah(VectorFormat vform,
}


LogicVRegister Simulator::udot(VectorFormat vform,
LogicVRegister dst,
const LogicVRegister& src1,
const LogicVRegister& src2,
int index) {
SimVRegister temp;
// NEON indexed `dot` allows the index value exceed the register size.
// Promote the format to Q-sized vector format before the duplication.
dup_elements_to_segments(IsSVEFormat(vform) ? vform
: VectorFormatFillQ(vform),
temp,
src2,
index);
return udot(vform, dst, src1, temp);
}


LogicVRegister Simulator::sqrdmlsh(VectorFormat vform,
LogicVRegister dst,
const LogicVRegister& src1,
Expand Down
8 changes: 6 additions & 2 deletions src/aarch64/macro-assembler-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -2819,7 +2819,8 @@ class MacroAssembler : public Assembler, public MacroAssemblerInterface {
V(zip2, Zip2) \
V(smmla, Smmla) \
V(ummla, Ummla) \
V(usmmla, Usmmla)
V(usmmla, Usmmla) \
V(usdot, Usdot)

#define DEFINE_MACRO_ASM_FUNC(ASM, MASM) \
void MASM(const VRegister& vd, const VRegister& vn, const VRegister& vm) { \
Expand Down Expand Up @@ -2971,7 +2972,10 @@ class MacroAssembler : public Assembler, public MacroAssemblerInterface {
V(umlal, Umlal) \
V(umlal2, Umlal2) \
V(umlsl, Umlsl) \
V(umlsl2, Umlsl2)
V(umlsl2, Umlsl2) \
V(sudot, Sudot) \
V(usdot, Usdot)


#define DEFINE_MACRO_ASM_FUNC(ASM, MASM) \
void MASM(const VRegister& vd, \
Expand Down
Loading

0 comments on commit 09925d4

Please sign in to comment.