diff --git a/internal/engine/wazevo/backend/backend_test.go b/internal/engine/wazevo/backend/backend_test.go index 80e18ac8d1..7153e24256 100644 --- a/internal/engine/wazevo/backend/backend_test.go +++ b/internal/engine/wazevo/backend/backend_test.go @@ -54,11 +54,11 @@ func TestE2E(t *testing.T) { { name: "empty", m: testcases.Empty.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! add sp, sp, #0x10 @@ -69,7 +69,7 @@ L1 (SSA Block: blk0): { name: "selects", m: testcases.Selects.Module, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! subs xzr, x4, x5 @@ -89,7 +89,7 @@ L1 (SSA Block: blk0): { name: "consts", m: testcases.Constants.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): ldr d133?, #8; b 16; data.f64 64.000000 mov v1.8b, v133?.8b ldr s132?, #8; b 8; data.f32 32.000000 @@ -101,7 +101,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! ldr d1, #8; b 16; data.f64 64.000000 @@ -116,7 +116,7 @@ L1 (SSA Block: blk0): { name: "add sub params return", m: testcases.AddSubParamsReturn.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 mov x131?, x3 add w132?, w130?, w131? @@ -125,7 +125,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! add w8, w2, w3 @@ -138,7 +138,7 @@ L1 (SSA Block: blk0): { name: "locals params", m: testcases.LocalsParams.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 mov v131?.8b, v0.8b mov v132?.8b, v1.8b @@ -162,7 +162,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! add x8, x2, x2 @@ -187,7 +187,7 @@ L1 (SSA Block: blk0): { name: "local_param_return", m: testcases.LocalParamReturn.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 mov x131?, xzr mov x1, x131? @@ -195,7 +195,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! mov x1, xzr @@ -208,7 +208,7 @@ L1 (SSA Block: blk0): { name: "swap_param_and_return", m: testcases.SwapParamAndReturn.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 mov x131?, x3 mov x1, x130? @@ -216,7 +216,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! mov x1, x2 @@ -229,19 +229,19 @@ L1 (SSA Block: blk0): { name: "swap_params_and_return", m: testcases.SwapParamsAndReturn.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 mov x131?, x3 -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): mov x1, x130? mov x0, x131? ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): mov x1, x2 mov x0, x3 add sp, sp, #0x10 @@ -252,15 +252,15 @@ L2 (SSA Block: blk1): { name: "block_br", m: testcases.BlockBr.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): -L2 (SSA Block: blk1): +L0 (SSA Block: blk0): +L1 (SSA Block: blk1): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): add sp, sp, #0x10 ldr x30, [sp], #0x10 ret @@ -269,11 +269,11 @@ L2 (SSA Block: blk1): { name: "block_br_if", m: testcases.BlockBrIf.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x128?, x0 mov x131?, xzr - cbnz w131?, L2 -L3 (SSA Block: blk2): + cbnz w131?, L1 +L2 (SSA Block: blk2): movz x132?, #0x3, lsl 0 str w132?, [x128?] mov x133?, sp @@ -281,16 +281,16 @@ L3 (SSA Block: blk2): adr x134?, #0x0 str x134?, [x128?, #0x30] exit_sequence x128? -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! mov x8, xzr - cbnz w8, #0x34 (L2) -L3 (SSA Block: blk2): + cbnz w8, #0x34 (L1) +L2 (SSA Block: blk2): movz x8, #0x3, lsl 0 str w8, [x0] mov x8, sp @@ -298,7 +298,7 @@ L3 (SSA Block: blk2): adr x8, #0x0 str x8, [x0, #0x30] exit_sequence x0 -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): add sp, sp, #0x10 ldr x30, [sp], #0x10 ret @@ -307,44 +307,44 @@ L2 (SSA Block: blk1): { name: "loop_br", m: testcases.LoopBr.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): -L2 (SSA Block: blk1): - b L2 +L0 (SSA Block: blk0): +L1 (SSA Block: blk1): + b L1 `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! -L2 (SSA Block: blk1): - b #0x0 (L2) +L1 (SSA Block: blk1): + b #0x0 (L1) `, }, { name: "loop_with_param_results", m: testcases.LoopBrWithParamResults.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): orr w133?, wzr, #0x1 cbz w133?, (L3) L4 (SSA Block: blk4): - b L2 + b L1 L3 (SSA Block: blk3): -L5 (SSA Block: blk2): +L2 (SSA Block: blk2): mov x0, x130? ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): orr w8, wzr, #0x1 cbz w8, #0x8 L3 L4 (SSA Block: blk4): - b #-0x8 (L2) + b #-0x8 (L1) L3 (SSA Block: blk3): -L5 (SSA Block: blk2): +L2 (SSA Block: blk2): mov x0, x2 add sp, sp, #0x10 ldr x30, [sp], #0x10 @@ -354,24 +354,24 @@ L5 (SSA Block: blk2): { name: "loop_br_if", m: testcases.LoopBrIf.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): -L2 (SSA Block: blk1): +L0 (SSA Block: blk0): +L1 (SSA Block: blk1): orr w131?, wzr, #0x1 cbz w131?, (L3) L4 (SSA Block: blk4): - b L2 + b L1 L3 (SSA Block: blk3): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): orr w8, wzr, #0x1 cbz w8, #0x8 L3 L4 (SSA Block: blk4): - b #-0x8 (L2) + b #-0x8 (L1) L3 (SSA Block: blk3): add sp, sp, #0x10 ldr x30, [sp], #0x10 @@ -381,15 +381,15 @@ L3 (SSA Block: blk3): { name: "block_block_br", m: testcases.BlockBlockBr.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): -L2 (SSA Block: blk1): +L0 (SSA Block: blk0): +L1 (SSA Block: blk1): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): add sp, sp, #0x10 ldr x30, [sp], #0x10 ret @@ -401,25 +401,25 @@ L2 (SSA Block: blk1): // So we cannot have the general optimization on this kind of redundant branch elimination before register allocations. // Instead, we can do it during the code generation phase where we actually resolve the label offsets. afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x131?, xzr cbz w131?, (L2) -L3 (SSA Block: blk1): - b L4 +L1 (SSA Block: blk1): + b L3 L2 (SSA Block: blk2): -L4 (SSA Block: blk3): +L3 (SSA Block: blk3): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! mov x8, xzr cbz w8, #0x8 L2 -L3 (SSA Block: blk1): - b #0x4 (L4) +L1 (SSA Block: blk1): + b #0x4 (L3) L2 (SSA Block: blk2): -L4 (SSA Block: blk3): +L3 (SSA Block: blk3): add sp, sp, #0x10 ldr x30, [sp], #0x10 ret @@ -428,23 +428,23 @@ L4 (SSA Block: blk3): { name: "if_else", m: testcases.IfElse.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x131?, xzr cbz w131?, (L2) -L3 (SSA Block: blk1): -L4 (SSA Block: blk3): +L1 (SSA Block: blk1): +L3 (SSA Block: blk3): ret L2 (SSA Block: blk2): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! mov x8, xzr cbz w8, #0x10 L2 -L3 (SSA Block: blk1): -L4 (SSA Block: blk3): +L1 (SSA Block: blk1): +L3 (SSA Block: blk3): add sp, sp, #0x10 ldr x30, [sp], #0x10 ret @@ -457,32 +457,32 @@ L2 (SSA Block: blk2): { name: "single_predecessor_local_refs", m: testcases.SinglePredecessorLocalRefs.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x131?, xzr cbz w131?, (L2) -L3 (SSA Block: blk1): +L1 (SSA Block: blk1): mov x130?, xzr mov x0, x130? ret L2 (SSA Block: blk2): -L4 (SSA Block: blk3): +L3 (SSA Block: blk3): mov x130?, xzr mov x0, x130? ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! mov x8, xzr cbz w8, #0x14 L2 -L3 (SSA Block: blk1): +L1 (SSA Block: blk1): mov x0, xzr add sp, sp, #0x10 ldr x30, [sp], #0x10 ret L2 (SSA Block: blk2): -L4 (SSA Block: blk3): +L3 (SSA Block: blk3): mov x0, xzr add sp, sp, #0x10 ldr x30, [sp], #0x10 @@ -492,29 +492,29 @@ L4 (SSA Block: blk3): { name: "multi_predecessor_local_ref", m: testcases.MultiPredecessorLocalRef.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 mov x131?, x3 cbz w130?, (L2) -L3 (SSA Block: blk1): +L1 (SSA Block: blk1): mov x132?, x130? - b L4 + b L3 L2 (SSA Block: blk2): mov x132?, x131? -L4 (SSA Block: blk3): +L3 (SSA Block: blk3): mov x0, x132? ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! cbz w2, #0x8 L2 -L3 (SSA Block: blk1): - b #0x8 (L4) +L1 (SSA Block: blk1): + b #0x8 (L3) L2 (SSA Block: blk2): mov x2, x3 -L4 (SSA Block: blk3): +L3 (SSA Block: blk3): mov x0, x2 add sp, sp, #0x10 ldr x30, [sp], #0x10 @@ -524,17 +524,17 @@ L4 (SSA Block: blk3): { name: "reference_value_from_unsealed_block", m: testcases.ReferenceValueFromUnsealedBlock.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): mov x0, x130? ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! -L2 (SSA Block: blk1): +L1 (SSA Block: blk1): mov x0, x2 add sp, sp, #0x10 ldr x30, [sp], #0x10 @@ -544,30 +544,30 @@ L2 (SSA Block: blk1): { name: "reference_value_from_unsealed_block2", m: testcases.ReferenceValueFromUnsealedBlock2.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 -L2 (SSA Block: blk1): - cbz w130?, (L3) -L4 (SSA Block: blk5): - b L2 -L3 (SSA Block: blk4): -L5 (SSA Block: blk3): -L6 (SSA Block: blk2): +L1 (SSA Block: blk1): + cbz w130?, (L4) +L5 (SSA Block: blk5): + b L1 +L4 (SSA Block: blk4): +L3 (SSA Block: blk3): +L2 (SSA Block: blk2): mov x131?, xzr mov x0, x131? ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! -L2 (SSA Block: blk1): - cbz w2, #0x8 L3 -L4 (SSA Block: blk5): - b #-0x4 (L2) -L3 (SSA Block: blk4): -L5 (SSA Block: blk3): -L6 (SSA Block: blk2): +L1 (SSA Block: blk1): + cbz w2, #0x8 L4 +L5 (SSA Block: blk5): + b #-0x4 (L1) +L4 (SSA Block: blk4): +L3 (SSA Block: blk3): +L2 (SSA Block: blk2): mov x0, xzr add sp, sp, #0x10 ldr x30, [sp], #0x10 @@ -578,41 +578,41 @@ L6 (SSA Block: blk2): name: "reference_value_from_unsealed_block3", m: testcases.ReferenceValueFromUnsealedBlock3.Module, // TODO: we should be able to invert cbnz in so that L2 can end with fallthrough. investigate builder.LayoutBlocks function. afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 mov x131?, x130? -L2 (SSA Block: blk1): - cbnz w131?, L4 - b L3 -L4 (SSA Block: blk5): +L1 (SSA Block: blk1): + cbnz w131?, L5 + b L4 +L5 (SSA Block: blk5): ret -L3 (SSA Block: blk4): -L5 (SSA Block: blk3): +L4 (SSA Block: blk4): +L3 (SSA Block: blk3): load_const_block_arg x131?, 0x1 - b L2 + b L1 `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! -L2 (SSA Block: blk1): - cbnz w2, #0x8 (L4) - b #0x10 (L3) -L4 (SSA Block: blk5): +L1 (SSA Block: blk1): + cbnz w2, #0x8 (L5) + b #0x10 (L4) +L5 (SSA Block: blk5): add sp, sp, #0x10 ldr x30, [sp], #0x10 ret -L3 (SSA Block: blk4): -L5 (SSA Block: blk3): +L4 (SSA Block: blk4): +L3 (SSA Block: blk3): load_const_block_arg x2, 0x1 orr w2, wzr, #0x1 - b #-0x18 (L2) + b #-0x18 (L1) `, }, { name: "call", m: testcases.Call.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x128?, x0 mov x129?, x1 mov x0, x128? @@ -637,7 +637,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! sub sp, sp, #0x10 orr x27, xzr, #0x10 @@ -667,7 +667,7 @@ L1 (SSA Block: blk0): { name: "call_many_params", m: testcases.CallManyParams.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x128?, x0 mov x129?, x1 mov x130?, x2 @@ -720,7 +720,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! sub sp, sp, #0x20 orr x27, xzr, #0x20 @@ -779,7 +779,7 @@ L1 (SSA Block: blk0): { name: "call_many_returns", m: testcases.CallManyReturns.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x128?, x0 mov x129?, x1 mov x130?, x2 @@ -876,7 +876,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): orr x27, xzr, #0xc0 sub sp, sp, x27 stp x30, x27, [sp, #-0x10]! @@ -949,7 +949,7 @@ L1 (SSA Block: blk0): name: "integer_extensions", m: testcases.IntegerExtensions.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 mov x131?, x3 sxtw x132?, w130? @@ -969,7 +969,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! sxtw x0, w2 @@ -989,7 +989,7 @@ L1 (SSA Block: blk0): { name: "integer bit counts", m: testcases.IntegerBitCounts.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 mov x131?, x3 clz w132?, w130? @@ -1015,7 +1015,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! clz w0, w2 @@ -1042,7 +1042,7 @@ L1 (SSA Block: blk0): name: "float_comparisons", m: testcases.FloatComparisons.Module, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): orr x27, xzr, #0x20 sub sp, sp, x27 stp x30, x27, [sp, #-0x10]! @@ -1085,7 +1085,7 @@ L1 (SSA Block: blk0): name: "float_conversions", m: testcases.FloatConversions.Module, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! msr fpsr, xzr @@ -1307,7 +1307,7 @@ L3: name: "nontrapping_float_conversions", m: testcases.NonTrappingFloatConversions.Module, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! fcvtzs x8, d0 @@ -1328,7 +1328,7 @@ L1 (SSA Block: blk0): name: "many_middle_values", m: testcases.ManyMiddleValues.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 mov v131?.8b, v0.8b orr w289?, wzr, #0x1 @@ -1454,7 +1454,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str x19, [sp, #-0x10]! str x20, [sp, #-0x10]! @@ -1620,7 +1620,7 @@ L1 (SSA Block: blk0): { name: "imported_function_call", m: testcases.ImportedFunctionCall.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x128?, x0 mov x129?, x1 mov x130?, x2 @@ -1637,7 +1637,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! sub sp, sp, #0x10 orr x27, xzr, #0x10 @@ -1658,7 +1658,7 @@ L1 (SSA Block: blk0): { name: "memory_load_basic", m: testcases.MemoryLoadBasic.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x128?, x0 mov x129?, x1 mov x130?, x2 @@ -1683,7 +1683,7 @@ L2: ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! uxtw x8, w2 @@ -1710,7 +1710,7 @@ L2: { name: "memory_stores", m: testcases.MemoryStores.Module, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! mov x8, xzr @@ -1867,7 +1867,7 @@ L2: name: "globals_mutable[0]", m: testcases.GlobalsMutable.Module, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x128?, x0 mov x129?, x1 ldr w130?, [x129?, #0x10] @@ -1892,7 +1892,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! sub sp, sp, #0x20 orr x27, xzr, #0x20 @@ -1927,7 +1927,7 @@ L1 (SSA Block: blk0): m: testcases.GlobalsMutable.Module, targetIndex: 1, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x129?, x1 orr w137?, wzr, #0x1 str w137?, [x129?, #0x10] @@ -1940,7 +1940,7 @@ L1 (SSA Block: blk0): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! orr w8, wzr, #0x1 @@ -1961,87 +1961,87 @@ L1 (SSA Block: blk0): m: testcases.BrTable.Module, targetIndex: 0, afterLoweringARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): mov x130?, x2 orr w137?, wzr, #0x6 subs wzr, w130?, w137? csel w138?, w137?, w130?, hs br_table_sequence x138?, table_index=0 -L2 (SSA Block: blk7): - b L9 -L3 (SSA Block: blk8): -L10 (SSA Block: blk5): +L7 (SSA Block: blk7): + b L6 +L8 (SSA Block: blk8): +L5 (SSA Block: blk5): orr w131?, wzr, #0xc mov x0, x131? ret -L4 (SSA Block: blk9): -L11 (SSA Block: blk4): +L9 (SSA Block: blk9): +L4 (SSA Block: blk4): movz w132?, #0xd, lsl 0 mov x0, x132? ret -L5 (SSA Block: blk10): -L12 (SSA Block: blk3): +L10 (SSA Block: blk10): +L3 (SSA Block: blk3): orr w133?, wzr, #0xe mov x0, x133? ret -L6 (SSA Block: blk11): -L13 (SSA Block: blk2): +L11 (SSA Block: blk11): +L2 (SSA Block: blk2): orr w134?, wzr, #0xf mov x0, x134? ret -L7 (SSA Block: blk12): -L14 (SSA Block: blk1): +L12 (SSA Block: blk12): +L1 (SSA Block: blk1): orr w135?, wzr, #0x10 mov x0, x135? ret -L8 (SSA Block: blk13): -L9 (SSA Block: blk6): +L13 (SSA Block: blk13): +L6 (SSA Block: blk6): movz w136?, #0xb, lsl 0 mov x0, x136? ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! orr w8, wzr, #0x6 subs wzr, w2, w8 csel w8, w8, w2, hs br_table_sequence x8, table_index=0 -L2 (SSA Block: blk7): - b #0x54 (L9) -L3 (SSA Block: blk8): -L10 (SSA Block: blk5): +L7 (SSA Block: blk7): + b #0x54 (L6) +L8 (SSA Block: blk8): +L5 (SSA Block: blk5): orr w0, wzr, #0xc add sp, sp, #0x10 ldr x30, [sp], #0x10 ret -L4 (SSA Block: blk9): -L11 (SSA Block: blk4): +L9 (SSA Block: blk9): +L4 (SSA Block: blk4): movz w0, #0xd, lsl 0 add sp, sp, #0x10 ldr x30, [sp], #0x10 ret -L5 (SSA Block: blk10): -L12 (SSA Block: blk3): +L10 (SSA Block: blk10): +L3 (SSA Block: blk3): orr w0, wzr, #0xe add sp, sp, #0x10 ldr x30, [sp], #0x10 ret -L6 (SSA Block: blk11): -L13 (SSA Block: blk2): +L11 (SSA Block: blk11): +L2 (SSA Block: blk2): orr w0, wzr, #0xf add sp, sp, #0x10 ldr x30, [sp], #0x10 ret -L7 (SSA Block: blk12): -L14 (SSA Block: blk1): +L12 (SSA Block: blk12): +L1 (SSA Block: blk1): orr w0, wzr, #0x10 add sp, sp, #0x10 ldr x30, [sp], #0x10 ret -L8 (SSA Block: blk13): -L9 (SSA Block: blk6): +L13 (SSA Block: blk13): +L6 (SSA Block: blk6): movz w0, #0xb, lsl 0 add sp, sp, #0x10 ldr x30, [sp], #0x10 @@ -2052,7 +2052,7 @@ L9 (SSA Block: blk6): name: "VecShuffle", m: testcases.VecShuffle.Module, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str q29, [sp, #-0x10]! str q30, [sp, #-0x10]! @@ -2073,7 +2073,7 @@ L1 (SSA Block: blk0): name: "AtomicRmwAdd", m: testcases.AtomicRmwAdd.Module, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): orr x27, xzr, #0x10 sub sp, sp, x27 stp x30, x27, [sp, #-0x10]! @@ -2272,12 +2272,12 @@ L2: name: "icmp_and_zero", m: testcases.IcmpAndZero.Module, afterFinalizeAMD64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): pushq %rbp movq %rsp, %rbp testl %edi, %ecx jnz L2 -L3 (SSA Block: blk1): +L1 (SSA Block: blk1): movl $1, %eax movq %rbp, %rsp popq %rbp @@ -2289,12 +2289,12 @@ L2 (SSA Block: blk2): ret `, afterFinalizeARM64: ` -L1 (SSA Block: blk0): +L0 (SSA Block: blk0): stp x30, xzr, [sp, #-0x10]! str xzr, [sp, #-0x10]! ands wzr, w2, w3 b.ne #0x14, (L2) -L3 (SSA Block: blk1): +L1 (SSA Block: blk1): orr w0, wzr, #0x1 add sp, sp, #0x10 ldr x30, [sp], #0x10 diff --git a/internal/engine/wazevo/backend/compiler_lower.go b/internal/engine/wazevo/backend/compiler_lower.go index a58e71dd9d..9a9414aeaa 100644 --- a/internal/engine/wazevo/backend/compiler_lower.go +++ b/internal/engine/wazevo/backend/compiler_lower.go @@ -9,7 +9,7 @@ import ( func (c *compiler) Lower() { c.assignVirtualRegisters() c.mach.SetCurrentABI(c.GetFunctionABI(c.ssaBuilder.Signature())) - c.mach.ExecutableContext().StartLoweringFunction(c.ssaBuilder.BlockIDMax()) + c.mach.StartLoweringFunction(c.ssaBuilder.BlockIDMax()) c.lowerBlocks() } @@ -20,12 +20,11 @@ func (c *compiler) lowerBlocks() { c.lowerBlock(blk) } - ectx := c.mach.ExecutableContext() // After lowering all blocks, we need to link adjacent blocks to layout one single instruction list. var prev ssa.BasicBlock for next := builder.BlockIteratorReversePostOrderBegin(); next != nil; next = builder.BlockIteratorReversePostOrderNext() { if prev != nil { - ectx.LinkAdjacentBlocks(prev, next) + c.mach.LinkAdjacentBlocks(prev, next) } prev = next } @@ -33,8 +32,7 @@ func (c *compiler) lowerBlocks() { func (c *compiler) lowerBlock(blk ssa.BasicBlock) { mach := c.mach - ectx := mach.ExecutableContext() - ectx.StartBlock(blk) + mach.StartBlock(blk) // We traverse the instructions in reverse order because we might want to lower multiple // instructions together. @@ -76,7 +74,7 @@ func (c *compiler) lowerBlock(blk ssa.BasicBlock) { default: mach.LowerInstr(cur) } - ectx.FlushPendingInstructions() + mach.FlushPendingInstructions() } // Finally, if this is the entry block, we have to insert copies of arguments from the real location to the VReg. @@ -84,7 +82,7 @@ func (c *compiler) lowerBlock(blk ssa.BasicBlock) { c.lowerFunctionArguments(blk) } - ectx.EndBlock() + mach.EndBlock() } // lowerBranches is called right after StartBlock and before any LowerInstr call if @@ -93,15 +91,15 @@ func (c *compiler) lowerBlock(blk ssa.BasicBlock) { // // See ssa.Instruction IsBranching, and the comment on ssa.BasicBlock. func (c *compiler) lowerBranches(br0, br1 *ssa.Instruction) { - ectx := c.mach.ExecutableContext() + mach := c.mach c.setCurrentGroupID(br0.GroupID()) c.mach.LowerSingleBranch(br0) - ectx.FlushPendingInstructions() + mach.FlushPendingInstructions() if br1 != nil { c.setCurrentGroupID(br1.GroupID()) c.mach.LowerConditionalBranch(br1) - ectx.FlushPendingInstructions() + mach.FlushPendingInstructions() } if br0.Opcode() == ssa.OpcodeJump { @@ -119,11 +117,11 @@ func (c *compiler) lowerBranches(br0, br1 *ssa.Instruction) { c.lowerBlockArguments(args, target) } } - ectx.FlushPendingInstructions() + mach.FlushPendingInstructions() } func (c *compiler) lowerFunctionArguments(entry ssa.BasicBlock) { - ectx := c.mach.ExecutableContext() + mach := c.mach c.tmpVals = c.tmpVals[:0] for i := 0; i < entry.Params(); i++ { @@ -135,8 +133,8 @@ func (c *compiler) lowerFunctionArguments(entry ssa.BasicBlock) { c.tmpVals = append(c.tmpVals, ssa.ValueInvalid) } } - c.mach.LowerParams(c.tmpVals) - ectx.FlushPendingInstructions() + mach.LowerParams(c.tmpVals) + mach.FlushPendingInstructions() } // lowerBlockArguments lowers how to pass arguments to the given successor block. diff --git a/internal/engine/wazevo/backend/executable_context.go b/internal/engine/wazevo/backend/executable_context.go deleted file mode 100644 index 8e9571b202..0000000000 --- a/internal/engine/wazevo/backend/executable_context.go +++ /dev/null @@ -1,221 +0,0 @@ -package backend - -import ( - "fmt" - "math" - - "github.com/tetratelabs/wazero/internal/engine/wazevo/ssa" - "github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi" -) - -type ExecutableContext interface { - // StartLoweringFunction is called when the lowering of the given function is started. - // maximumBlockID is the maximum value of ssa.BasicBlockID existing in the function. - StartLoweringFunction(maximumBlockID ssa.BasicBlockID) - - // LinkAdjacentBlocks is called after finished lowering all blocks in order to create one single instruction list. - LinkAdjacentBlocks(prev, next ssa.BasicBlock) - - // StartBlock is called when the compilation of the given block is started. - // The order of this being called is the reverse post order of the ssa.BasicBlock(s) as we iterate with - // ssa.Builder BlockIteratorReversePostOrderBegin and BlockIteratorReversePostOrderEnd. - StartBlock(ssa.BasicBlock) - - // EndBlock is called when the compilation of the current block is finished. - EndBlock() - - // FlushPendingInstructions flushes the pending instructions to the buffer. - // This will be called after the lowering of each SSA Instruction. - FlushPendingInstructions() -} - -type ExecutableContextT[Instr any] struct { - CurrentSSABlk ssa.BasicBlock - - // InstrPool is the InstructionPool of instructions. - InstructionPool wazevoapi.Pool[Instr] - asNop func(*Instr) - setNext func(*Instr, *Instr) - setPrev func(*Instr, *Instr) - - // RootInstr is the root instruction of the executable. - RootInstr *Instr - labelPositionPool wazevoapi.Pool[LabelPosition[Instr]] - NextLabel Label - // LabelPositions maps a label to the instructions of the region which the label represents. - LabelPositions []*LabelPosition[Instr] - OrderedBlockLabels []*LabelPosition[Instr] - - // PerBlockHead and PerBlockEnd are the head and tail of the instruction list per currently-compiled ssa.BasicBlock. - PerBlockHead, PerBlockEnd *Instr - // PendingInstructions are the instructions which are not yet emitted into the instruction list. - PendingInstructions []*Instr - - // SsaBlockIDToLabels maps an SSA block ID to the label. - SsaBlockIDToLabels []Label -} - -func NewExecutableContextT[Instr any]( - resetInstruction func(*Instr), - setNext func(*Instr, *Instr), - setPrev func(*Instr, *Instr), - asNop func(*Instr), -) *ExecutableContextT[Instr] { - return &ExecutableContextT[Instr]{ - InstructionPool: wazevoapi.NewPool[Instr](resetInstruction), - asNop: asNop, - setNext: setNext, - setPrev: setPrev, - labelPositionPool: wazevoapi.NewPool[LabelPosition[Instr]](resetLabelPosition[Instr]), - NextLabel: LabelInvalid, - } -} - -func resetLabelPosition[T any](l *LabelPosition[T]) { - *l = LabelPosition[T]{} -} - -// StartLoweringFunction implements ExecutableContext. -func (e *ExecutableContextT[Instr]) StartLoweringFunction(max ssa.BasicBlockID) { - imax := int(max) - if len(e.SsaBlockIDToLabels) <= imax { - // Eagerly allocate labels for the blocks since the underlying slice will be used for the next iteration. - e.SsaBlockIDToLabels = append(e.SsaBlockIDToLabels, make([]Label, imax+1)...) - } -} - -func (e *ExecutableContextT[Instr]) StartBlock(blk ssa.BasicBlock) { - e.CurrentSSABlk = blk - - l := e.SsaBlockIDToLabels[e.CurrentSSABlk.ID()] - if l == LabelInvalid { - l = e.AllocateLabel() - e.SsaBlockIDToLabels[blk.ID()] = l - } - - end := e.allocateNop0() - e.PerBlockHead, e.PerBlockEnd = end, end - - labelPos := e.GetOrAllocateLabelPosition(l) - e.OrderedBlockLabels = append(e.OrderedBlockLabels, labelPos) - labelPos.Begin, labelPos.End = end, end - labelPos.SB = blk -} - -// EndBlock implements ExecutableContext. -func (e *ExecutableContextT[T]) EndBlock() { - // Insert nop0 as the head of the block for convenience to simplify the logic of inserting instructions. - e.insertAtPerBlockHead(e.allocateNop0()) - - l := e.SsaBlockIDToLabels[e.CurrentSSABlk.ID()] - e.LabelPositions[l].Begin = e.PerBlockHead - - if e.CurrentSSABlk.EntryBlock() { - e.RootInstr = e.PerBlockHead - } -} - -func (e *ExecutableContextT[T]) insertAtPerBlockHead(i *T) { - if e.PerBlockHead == nil { - e.PerBlockHead = i - e.PerBlockEnd = i - return - } - e.setNext(i, e.PerBlockHead) - e.setPrev(e.PerBlockHead, i) - e.PerBlockHead = i -} - -// FlushPendingInstructions implements ExecutableContext. -func (e *ExecutableContextT[T]) FlushPendingInstructions() { - l := len(e.PendingInstructions) - if l == 0 { - return - } - for i := l - 1; i >= 0; i-- { // reverse because we lower instructions in reverse order. - e.insertAtPerBlockHead(e.PendingInstructions[i]) - } - e.PendingInstructions = e.PendingInstructions[:0] -} - -func (e *ExecutableContextT[T]) Reset() { - e.labelPositionPool.Reset() - e.InstructionPool.Reset() - for i := range e.LabelPositions { - e.LabelPositions[i] = nil - } - e.PendingInstructions = e.PendingInstructions[:0] - e.OrderedBlockLabels = e.OrderedBlockLabels[:0] - e.RootInstr = nil - e.SsaBlockIDToLabels = e.SsaBlockIDToLabels[:0] - e.PerBlockHead, e.PerBlockEnd = nil, nil - e.NextLabel = LabelInvalid -} - -// AllocateLabel allocates an unused label. -func (e *ExecutableContextT[T]) AllocateLabel() Label { - e.NextLabel++ - return e.NextLabel -} - -func (e *ExecutableContextT[T]) GetOrAllocateLabelPosition(l Label) *LabelPosition[T] { - if len(e.LabelPositions) <= int(l) { - e.LabelPositions = append(e.LabelPositions, make([]*LabelPosition[T], int(l)+1-len(e.LabelPositions))...) - } - ret := e.LabelPositions[l] - if ret == nil { - ret = e.labelPositionPool.Allocate() - ret.L = l - e.LabelPositions[l] = ret - } - return ret -} - -func (e *ExecutableContextT[T]) GetOrAllocateSSABlockLabel(blk ssa.BasicBlock) Label { - if blk.ReturnBlock() { - return LabelReturn - } - l := e.SsaBlockIDToLabels[blk.ID()] - if l == LabelInvalid { - l = e.AllocateLabel() - e.SsaBlockIDToLabels[blk.ID()] = l - } - return l -} - -func (e *ExecutableContextT[T]) allocateNop0() *T { - i := e.InstructionPool.Allocate() - e.asNop(i) - return i -} - -// LinkAdjacentBlocks implements backend.Machine. -func (e *ExecutableContextT[T]) LinkAdjacentBlocks(prev, next ssa.BasicBlock) { - prevLabelPos := e.LabelPositions[e.GetOrAllocateSSABlockLabel(prev)] - nextLabelPos := e.LabelPositions[e.GetOrAllocateSSABlockLabel(next)] - e.setNext(prevLabelPos.End, nextLabelPos.Begin) -} - -// LabelPosition represents the regions of the generated code which the label represents. -type LabelPosition[Instr any] struct { - SB ssa.BasicBlock - L Label - Begin, End *Instr - BinaryOffset int64 -} - -// Label represents a position in the generated code which is either -// a real instruction or the constant InstructionPool (e.g. jump tables). -// -// This is exactly the same as the traditional "label" in assembly code. -type Label uint32 - -const ( - LabelInvalid Label = 0 - LabelReturn Label = math.MaxUint32 -) - -// String implements backend.Machine. -func (l Label) String() string { - return fmt.Sprintf("L%d", l) -} diff --git a/internal/engine/wazevo/backend/isa/amd64/abi_entry_preamble_test.go b/internal/engine/wazevo/backend/isa/amd64/abi_entry_preamble_test.go index 2dea21eea4..baf6d1561c 100644 --- a/internal/engine/wazevo/backend/isa/amd64/abi_entry_preamble_test.go +++ b/internal/engine/wazevo/backend/isa/amd64/abi_entry_preamble_test.go @@ -307,7 +307,7 @@ func TestMachineCompileEntryPreamble(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { _, _, m := newSetupWithMockContext() - m.ectx.RootInstr = m.compileEntryPreamble(tc.sig) + m.rootInstr = m.compileEntryPreamble(tc.sig) require.Equal(t, tc.exp, m.Format()) }) } diff --git a/internal/engine/wazevo/backend/isa/amd64/abi_go_call.go b/internal/engine/wazevo/backend/isa/amd64/abi_go_call.go index 751050aff0..96f035e582 100644 --- a/internal/engine/wazevo/backend/isa/amd64/abi_go_call.go +++ b/internal/engine/wazevo/backend/isa/amd64/abi_go_call.go @@ -14,7 +14,6 @@ var calleeSavedVRegs = []regalloc.VReg{ // CompileGoFunctionTrampoline implements backend.Machine. func (m *machine) CompileGoFunctionTrampoline(exitCode wazevoapi.ExitCode, sig *ssa.Signature, needModuleContextPtr bool) []byte { - ectx := m.ectx argBegin := 1 // Skips exec context by default. if needModuleContextPtr { argBegin++ @@ -25,7 +24,7 @@ func (m *machine) CompileGoFunctionTrampoline(exitCode wazevoapi.ExitCode, sig * m.currentABI = abi cur := m.allocateNop() - ectx.RootInstr = cur + m.rootInstr = cur // Execution context is always the first argument. execCtrPtr := raxVReg @@ -272,7 +271,7 @@ func (m *machine) CompileGoFunctionTrampoline(exitCode wazevoapi.ExitCode, sig * cur = m.revertRBPRSP(cur) linkInstr(cur, m.allocateInstr().asRet()) - m.encodeWithoutSSA(ectx.RootInstr) + m.encodeWithoutSSA(m.rootInstr) return m.c.Buf() } @@ -347,10 +346,8 @@ var stackGrowSaveVRegs = []regalloc.VReg{ // CompileStackGrowCallSequence implements backend.Machine. func (m *machine) CompileStackGrowCallSequence() []byte { - ectx := m.ectx - cur := m.allocateNop() - ectx.RootInstr = cur + m.rootInstr = cur cur = m.setupRBPRSP(cur) @@ -379,7 +376,7 @@ func (m *machine) CompileStackGrowCallSequence() []byte { cur = m.revertRBPRSP(cur) linkInstr(cur, m.allocateInstr().asRet()) - m.encodeWithoutSSA(ectx.RootInstr) + m.encodeWithoutSSA(m.rootInstr) return m.c.Buf() } diff --git a/internal/engine/wazevo/backend/isa/amd64/abi_go_call_test.go b/internal/engine/wazevo/backend/isa/amd64/abi_go_call_test.go index 78c0c721a4..581a1b9b09 100644 --- a/internal/engine/wazevo/backend/isa/amd64/abi_go_call_test.go +++ b/internal/engine/wazevo/backend/isa/amd64/abi_go_call_test.go @@ -369,6 +369,7 @@ L3: } { t.Run(tc.name, func(t *testing.T) { _, _, m := newSetupWithMockContext() + m.nextLabel = 1 m.CompileGoFunctionTrampoline(tc.exitCode, tc.sig, tc.needModuleContextPtr) require.Equal(t, tc.exp, m.Format()) err := m.Encode(context.Background()) @@ -396,6 +397,8 @@ func Test_stackGrowSaveVRegs(t *testing.T) { func TestMachine_CompileStackGrowCallSequence(t *testing.T) { _, _, m := newSetupWithMockContext() + m.nextLabel = 1 + _ = m.CompileStackGrowCallSequence() require.Equal(t, ` @@ -518,8 +521,9 @@ L2: tc := tc t.Run(tc.exp, func(t *testing.T) { _, _, m := newSetupWithMockContext() - m.ectx.RootInstr = m.allocateNop() - m.insertStackBoundsCheck(tc.requiredStackSize, m.ectx.RootInstr) + m.nextLabel = 1 + m.rootInstr = m.allocateNop() + m.insertStackBoundsCheck(tc.requiredStackSize, m.rootInstr) err := m.Encode(context.Background()) require.NoError(t, err) require.Equal(t, tc.exp, m.Format()) diff --git a/internal/engine/wazevo/backend/isa/amd64/instr.go b/internal/engine/wazevo/backend/isa/amd64/instr.go index d27e79c0e5..9635049035 100644 --- a/internal/engine/wazevo/backend/isa/amd64/instr.go +++ b/internal/engine/wazevo/backend/isa/amd64/instr.go @@ -17,16 +17,6 @@ type instruction struct { kind instructionKind } -// Next implements regalloc.Instr. -func (i *instruction) Next() regalloc.Instr { - return i.next -} - -// Prev implements regalloc.Instr. -func (i *instruction) Prev() regalloc.Instr { - return i.prev -} - // IsCall implements regalloc.Instr. func (i *instruction) IsCall() bool { return i.kind == call } @@ -651,26 +641,14 @@ func resetInstruction(i *instruction) { *i = instruction{} } -func setNext(i *instruction, next *instruction) { - i.next = next -} - -func setPrev(i *instruction, prev *instruction) { - i.prev = prev -} - -func asNop(i *instruction) { - i.kind = nop0 -} - -func (i *instruction) asNop0WithLabel(label backend.Label) *instruction { //nolint +func (i *instruction) asNop0WithLabel(label label) *instruction { //nolint i.kind = nop0 i.u1 = uint64(label) return i } -func (i *instruction) nop0Label() backend.Label { - return backend.Label(i.u1) +func (i *instruction) nop0Label() label { + return label(i.u1) } type instructionKind byte @@ -1161,7 +1139,7 @@ func (i *instruction) asJmp(target operand) *instruction { return i } -func (i *instruction) jmpLabel() backend.Label { +func (i *instruction) jmpLabel() label { switch i.kind { case jmp, jmpIf, lea, xmmUnaryRmR: return i.op1.label() diff --git a/internal/engine/wazevo/backend/isa/amd64/instr_encoding_test.go b/internal/engine/wazevo/backend/isa/amd64/instr_encoding_test.go index d8bb29d6ea..c02973a912 100644 --- a/internal/engine/wazevo/backend/isa/amd64/instr_encoding_test.go +++ b/internal/engine/wazevo/backend/isa/amd64/instr_encoding_test.go @@ -5,7 +5,6 @@ import ( "runtime" "testing" - "github.com/tetratelabs/wazero/internal/engine/wazevo/backend" "github.com/tetratelabs/wazero/internal/testing/require" ) @@ -2652,7 +2651,7 @@ func TestInstruction_format_encode(t *testing.T) { }, { setup: func(i *instruction) { - i.asLEA(newOperandLabel(backend.Label(1234)), r11VReg) + i.asLEA(newOperandLabel(label(1234)), r11VReg) }, want: "4c8d1dffffffff", wantFormat: "lea L1234, %r11", diff --git a/internal/engine/wazevo/backend/isa/amd64/machine.go b/internal/engine/wazevo/backend/isa/amd64/machine.go index 1d118d1578..ea64642e4b 100644 --- a/internal/engine/wazevo/backend/isa/amd64/machine.go +++ b/internal/engine/wazevo/backend/isa/amd64/machine.go @@ -16,18 +16,13 @@ import ( // NewBackend returns a new backend for arm64. func NewBackend() backend.Machine { - ectx := backend.NewExecutableContextT[instruction]( - resetInstruction, - setNext, - setPrev, - asNop, - ) - return &machine{ - ectx: ectx, + m := &machine{ cpuFeatures: platform.CpuFeatures, - regAlloc: regalloc.NewAllocator(regInfo), + regAlloc: regalloc.NewAllocator[*instruction, *labelPosition](regInfo), spillSlots: map[regalloc.VRegID]int64{}, amodePool: wazevoapi.NewPool[amode](nil), + labelPositionPool: wazevoapi.NewIDedPool[labelPosition](resetLabelPosition), + instrPool: wazevoapi.NewPool[instruction](resetInstruction), constSwizzleMaskConstIndex: -1, constSqmulRoundSatIndex: -1, constI8x16SHLMaskTableIndex: -1, @@ -41,23 +36,46 @@ func NewBackend() backend.Machine { constExtAddPairwiseI16x8uMask1Index: -1, constExtAddPairwiseI16x8uMask2Index: -1, } + m.regAllocFn.m = m + return m } type ( // machine implements backend.Machine for amd64. machine struct { c backend.Compiler - ectx *backend.ExecutableContextT[instruction] stackBoundsCheckDisabled bool + instrPool wazevoapi.Pool[instruction] amodePool wazevoapi.Pool[amode] cpuFeatures platform.CpuFeatureFlags - regAlloc regalloc.Allocator - regAllocFn *backend.RegAllocFunction[*instruction, *machine] + regAlloc regalloc.Allocator[*instruction, *labelPosition] + regAllocFn regAllocFn regAllocStarted bool + // labelPositionPool is the pool of labelPosition. The id is the label where + // if the label is less than the maxSSABlockID, it's the ssa.BasicBlockID. + labelPositionPool wazevoapi.IDedPool[labelPosition] + // nextLabel is the next label to be allocated. The first free label comes after maxSSABlockID + // so that we can have an identical label for the SSA block ID, which is useful for debugging. + nextLabel label + // rootInstr is the first instruction of the function. + rootInstr *instruction + // currentLabelPos is the currently-compiled ssa.BasicBlock's labelPosition. + currentLabelPos *labelPosition + // orderedSSABlockLabelPos is the ordered list of labelPosition in the generated code for each ssa.BasicBlock. + orderedSSABlockLabelPos []*labelPosition + // returnLabelPos is the labelPosition for the return block. + returnLabelPos labelPosition + // perBlockHead and perBlockEnd are the head and tail of the instruction list per currently-compiled ssa.BasicBlock. + perBlockHead, perBlockEnd *instruction + // pendingInstructions are the instructions which are not yet emitted into the instruction list. + pendingInstructions []*instruction + // maxSSABlockID is the maximum ssa.BasicBlockID in the current function. + maxSSABlockID label + spillSlotSize int64 spillSlots map[regalloc.VRegID]int64 currentABI *backend.FunctionABI @@ -82,9 +100,10 @@ type ( } _const struct { - lo, hi uint64 - _var []byte - label *labelPosition + lo, hi uint64 + _var []byte + label label + labelPos *labelPosition } labelResolutionPend struct { @@ -93,22 +112,73 @@ type ( // imm32Offset is the offset of the last 4 bytes of the instruction. imm32Offset int64 } +) - labelPosition = backend.LabelPosition[instruction] +type ( + // label represents a position in the generated code which is either + // a real instruction or the constant InstructionPool (e.g. jump tables). + // + // This is exactly the same as the traditional "label" in assembly code. + label uint32 + + // labelPosition represents the regions of the generated code which the label represents. + // This implements regalloc.Block. + labelPosition struct { + // sb is not nil if this corresponds to a ssa.BasicBlock. + sb ssa.BasicBlock + // cur is used to walk through the instructions in the block during the register allocation. + cur, + // begin and end are the first and last instructions of the block. + begin, end *instruction + // binaryOffset is the offset in the binary where the label is located. + binaryOffset int64 + } ) -func (m *machine) getOrAllocateConstLabel(i *int, _var []byte) backend.Label { +// String implements backend.Machine. +func (l label) String() string { + return fmt.Sprintf("L%d", l) +} + +func resetLabelPosition(l *labelPosition) { + *l = labelPosition{} +} + +const labelReturn = math.MaxUint32 + +func ssaBlockLabel(sb ssa.BasicBlock) label { + if sb.ReturnBlock() { + return labelReturn + } + return label(sb.ID()) +} + +// getOrAllocateSSABlockLabelPosition returns the labelPosition for the given basic block. +func (m *machine) getOrAllocateSSABlockLabelPosition(sb ssa.BasicBlock) *labelPosition { + if sb.ReturnBlock() { + m.returnLabelPos.sb = sb + return &m.returnLabelPos + } + + l := ssaBlockLabel(sb) + pos := m.labelPositionPool.GetOrAllocate(int(l)) + pos.sb = sb + return pos +} + +func (m *machine) getOrAllocateConstLabel(i *int, _var []byte) label { index := *i if index == -1 { - label := m.allocateLabel() + l, pos := m.allocateLabel() index = len(m.consts) m.consts = append(m.consts, _const{ - _var: _var, - label: label, + _var: _var, + label: l, + labelPos: pos, }) *i = index } - return m.consts[index].label.L + return m.consts[index].label } // Reset implements backend.Machine. @@ -123,15 +193,17 @@ func (m *machine) Reset() { } m.stackBoundsCheckDisabled = false - m.ectx.Reset() - - m.regAllocFn.Reset() m.regAlloc.Reset() + m.labelPositionPool.Reset() + m.instrPool.Reset() m.regAllocStarted = false m.clobberedRegs = m.clobberedRegs[:0] m.spillSlotSize = 0 m.maxRequiredStackSizeForCalls = 0 + m.perBlockHead, m.perBlockEnd, m.rootInstr = nil, nil, nil + m.pendingInstructions = m.pendingInstructions[:0] + m.orderedSSABlockLabelPos = m.orderedSSABlockLabelPos[:0] m.amodePool.Reset() m.jmpTableTargetsNext = 0 @@ -149,8 +221,63 @@ func (m *machine) Reset() { m.constExtAddPairwiseI16x8uMask2Index = -1 } -// ExecutableContext implements backend.Machine. -func (m *machine) ExecutableContext() backend.ExecutableContext { return m.ectx } +// StartLoweringFunction implements backend.Machine StartLoweringFunction. +func (m *machine) StartLoweringFunction(maxBlockID ssa.BasicBlockID) { + m.maxSSABlockID = label(maxBlockID) + m.nextLabel = label(maxBlockID) + 1 +} + +// LinkAdjacentBlocks implements backend.Machine. +func (m *machine) LinkAdjacentBlocks(prev, next ssa.BasicBlock) { + prevPos, nextPos := m.getOrAllocateSSABlockLabelPosition(prev), m.getOrAllocateSSABlockLabelPosition(next) + prevPos.end.next = nextPos.begin +} + +// StartBlock implements backend.Machine. +func (m *machine) StartBlock(blk ssa.BasicBlock) { + m.currentLabelPos = m.getOrAllocateSSABlockLabelPosition(blk) + labelPos := m.currentLabelPos + end := m.allocateNop() + m.perBlockHead, m.perBlockEnd = end, end + labelPos.begin, labelPos.end = end, end + m.orderedSSABlockLabelPos = append(m.orderedSSABlockLabelPos, labelPos) +} + +// EndBlock implements ExecutableContext. +func (m *machine) EndBlock() { + // Insert nop0 as the head of the block for convenience to simplify the logic of inserting instructions. + m.insertAtPerBlockHead(m.allocateNop()) + + m.currentLabelPos.begin = m.perBlockHead + + if m.currentLabelPos.sb.EntryBlock() { + m.rootInstr = m.perBlockHead + } +} + +func (m *machine) insertAtPerBlockHead(i *instruction) { + if m.perBlockHead == nil { + m.perBlockHead = i + m.perBlockEnd = i + return + } + + i.next = m.perBlockHead + m.perBlockHead.prev = i + m.perBlockHead = i +} + +// FlushPendingInstructions implements backend.Machine. +func (m *machine) FlushPendingInstructions() { + l := len(m.pendingInstructions) + if l == 0 { + return + } + for i := l - 1; i >= 0; i-- { // reverse because we lower instructions in reverse order. + m.insertAtPerBlockHead(m.pendingInstructions[i]) + } + m.pendingInstructions = m.pendingInstructions[:0] +} // DisableStackCheck implements backend.Machine. func (m *machine) DisableStackCheck() { m.stackBoundsCheckDisabled = true } @@ -158,23 +285,17 @@ func (m *machine) DisableStackCheck() { m.stackBoundsCheckDisabled = true } // SetCompiler implements backend.Machine. func (m *machine) SetCompiler(c backend.Compiler) { m.c = c - m.regAllocFn = backend.NewRegAllocFunction[*instruction, *machine](m, c.SSABuilder(), c) + m.regAllocFn.ssaB = c.SSABuilder() } // SetCurrentABI implements backend.Machine. -func (m *machine) SetCurrentABI(abi *backend.FunctionABI) { - m.currentABI = abi -} +func (m *machine) SetCurrentABI(abi *backend.FunctionABI) { m.currentABI = abi } // RegAlloc implements backend.Machine. func (m *machine) RegAlloc() { rf := m.regAllocFn - for _, pos := range m.ectx.OrderedBlockLabels { - rf.AddBlock(pos.SB, pos.L, pos.Begin, pos.End) - } - m.regAllocStarted = true - m.regAlloc.DoAllocation(rf) + m.regAlloc.DoAllocation(&rf) // Now that we know the final spill slot size, we must align spillSlotSize to 16 bytes. m.spillSlotSize = (m.spillSlotSize + 15) &^ 15 } @@ -187,7 +308,6 @@ func (m *machine) InsertReturn() { // LowerSingleBranch implements backend.Machine. func (m *machine) LowerSingleBranch(b *ssa.Instruction) { - ectx := m.ectx switch b.Opcode() { case ssa.OpcodeJump: _, _, targetBlkID := b.BranchData() @@ -195,8 +315,8 @@ func (m *machine) LowerSingleBranch(b *ssa.Instruction) { return } jmp := m.allocateInstr() - target := ectx.GetOrAllocateSSABlockLabel(m.c.SSABuilder().BasicBlock(targetBlkID)) - if target == backend.LabelReturn { + target := ssaBlockLabel(m.c.SSABuilder().BasicBlock(targetBlkID)) + if target == labelReturn { jmp.asRet() } else { jmp.asJmp(newOperandLabel(target)) @@ -220,8 +340,7 @@ func (m *machine) addJmpTableTarget(targets ssa.Values) (index int) { m.jmpTableTargets[index] = m.jmpTableTargets[index][:0] for _, targetBlockID := range targets.View() { target := m.c.SSABuilder().BasicBlock(ssa.BasicBlockID(targetBlockID)) - m.jmpTableTargets[index] = append(m.jmpTableTargets[index], - uint32(m.ectx.GetOrAllocateSSABlockLabel(target))) + m.jmpTableTargets[index] = append(m.jmpTableTargets[index], uint32(ssaBlockLabel(target))) } return } @@ -271,17 +390,16 @@ func (m *machine) lowerBrTable(index ssa.Value, targets ssa.Values) { // LowerConditionalBranch implements backend.Machine. func (m *machine) LowerConditionalBranch(b *ssa.Instruction) { - exctx := m.ectx cval, args, targetBlkID := b.BranchData() if len(args) > 0 { panic(fmt.Sprintf( "conditional branch shouldn't have args; likely a bug in critical edge splitting: from %s to %s", - exctx.CurrentSSABlk, + m.currentLabelPos.sb, targetBlkID, )) } - target := exctx.GetOrAllocateSSABlockLabel(m.c.SSABuilder().BasicBlock(targetBlkID)) + target := ssaBlockLabel(m.c.SSABuilder().BasicBlock(targetBlkID)) cvalDef := m.c.ValueDefinition(cval) switch m.c.MatchInstrOneOf(cvalDef, condBranchMatches[:]) { @@ -1282,9 +1400,9 @@ func (m *machine) lowerVconst(dst regalloc.VReg, lo, hi uint64) { } load := m.allocateInstr() - constLabel := m.allocateLabel() - m.consts = append(m.consts, _const{label: constLabel, lo: lo, hi: hi}) - load.asXmmUnaryRmR(sseOpcodeMovdqu, newOperandMem(m.newAmodeRipRel(constLabel.L)), dst) + l, pos := m.allocateLabel() + m.consts = append(m.consts, _const{label: l, labelPos: pos, lo: lo, hi: hi}) + load.asXmmUnaryRmR(sseOpcodeMovdqu, newOperandMem(m.newAmodeRipRel(l)), dst) m.insert(load) } @@ -1532,7 +1650,7 @@ func (m *machine) allocateExitInstructions(execCtx, exitCodeReg regalloc.VReg) ( return } -func (m *machine) lowerExitWithCode(execCtx regalloc.VReg, code wazevoapi.ExitCode) (afterLabel backend.Label) { +func (m *machine) lowerExitWithCode(execCtx regalloc.VReg, code wazevoapi.ExitCode) (afterLabel label) { exitCodeReg := rbpVReg saveRsp, saveRbp, setExitCode := m.allocateExitInstructions(execCtx, exitCodeReg) @@ -1914,25 +2032,20 @@ func (m *machine) InsertMove(dst, src regalloc.VReg, typ ssa.Type) { // Format implements backend.Machine. func (m *machine) Format() string { - ectx := m.ectx - begins := map[*instruction]backend.Label{} - for _, pos := range ectx.LabelPositions { + begins := map[*instruction]label{} + for l := label(0); l < m.nextLabel; l++ { + pos := m.labelPositionPool.Get(int(l)) if pos != nil { - begins[pos.Begin] = pos.L + begins[pos.begin] = l } } - irBlocks := map[backend.Label]ssa.BasicBlockID{} - for i, l := range ectx.SsaBlockIDToLabels { - irBlocks[l] = ssa.BasicBlockID(i) - } - var lines []string - for cur := ectx.RootInstr; cur != nil; cur = cur.next { + for cur := m.rootInstr; cur != nil; cur = cur.next { if l, ok := begins[cur]; ok { var labelStr string - if blkID, ok := irBlocks[l]; ok { - labelStr = fmt.Sprintf("%s (SSA Block: %s):", l, blkID) + if l <= m.maxSSABlockID { + labelStr = fmt.Sprintf("%s (SSA Block: blk%d):", l, l) } else { labelStr = fmt.Sprintf("%s:", l) } @@ -1945,9 +2058,9 @@ func (m *machine) Format() string { } for _, vc := range m.consts { if vc._var == nil { - lines = append(lines, fmt.Sprintf("%s: const [%d %d]", vc.label.L, vc.lo, vc.hi)) + lines = append(lines, fmt.Sprintf("%s: const [%d %d]", vc.label, vc.lo, vc.hi)) } else { - lines = append(lines, fmt.Sprintf("%s: const %#x", vc.label.L, vc._var)) + lines = append(lines, fmt.Sprintf("%s: const %#x", vc.label, vc._var)) } } return "\n" + strings.Join(lines, "\n") + "\n" @@ -1955,18 +2068,14 @@ func (m *machine) Format() string { func (m *machine) encodeWithoutSSA(root *instruction) { m.labelResolutionPends = m.labelResolutionPends[:0] - ectx := m.ectx - bufPtr := m.c.BufPtr() for cur := root; cur != nil; cur = cur.next { offset := int64(len(*bufPtr)) if cur.kind == nop0 { l := cur.nop0Label() - if int(l) >= len(ectx.LabelPositions) { - continue - } - if pos := ectx.LabelPositions[l]; pos != nil { - pos.BinaryOffset = offset + pos := m.labelPositionPool.Get(int(l)) + if pos != nil { + pos.binaryOffset = offset } } @@ -1983,7 +2092,7 @@ func (m *machine) encodeWithoutSSA(root *instruction) { switch p.instr.kind { case jmp, jmpIf, lea: target := p.instr.jmpLabel() - targetOffset := ectx.LabelPositions[target].BinaryOffset + targetOffset := m.labelPositionPool.Get(int(target)).binaryOffset imm32Offset := p.imm32Offset jmpOffset := int32(targetOffset - (p.imm32Offset + 4)) // +4 because RIP points to the next instruction. binary.LittleEndian.PutUint32((*bufPtr)[imm32Offset:], uint32(jmpOffset)) @@ -1995,33 +2104,33 @@ func (m *machine) encodeWithoutSSA(root *instruction) { // Encode implements backend.Machine Encode. func (m *machine) Encode(ctx context.Context) (err error) { - ectx := m.ectx bufPtr := m.c.BufPtr() var fn string var fnIndex int - var labelToSSABlockID map[backend.Label]ssa.BasicBlockID + var labelPosToLabel map[*labelPosition]label if wazevoapi.PerfMapEnabled { fn = wazevoapi.GetCurrentFunctionName(ctx) - labelToSSABlockID = make(map[backend.Label]ssa.BasicBlockID) - for i, l := range ectx.SsaBlockIDToLabels { - labelToSSABlockID[l] = ssa.BasicBlockID(i) + labelPosToLabel = make(map[*labelPosition]label) + for i := 0; i <= m.labelPositionPool.MaxIDEncountered(); i++ { + pos := m.labelPositionPool.Get(i) + labelPosToLabel[pos] = label(i) } fnIndex = wazevoapi.GetCurrentFunctionIndex(ctx) } m.labelResolutionPends = m.labelResolutionPends[:0] - for _, pos := range ectx.OrderedBlockLabels { + for _, pos := range m.orderedSSABlockLabelPos { offset := int64(len(*bufPtr)) - pos.BinaryOffset = offset - for cur := pos.Begin; cur != pos.End.next; cur = cur.next { + pos.binaryOffset = offset + for cur := pos.begin; cur != pos.end.next; cur = cur.next { offset := int64(len(*bufPtr)) switch cur.kind { case nop0: l := cur.nop0Label() - if pos := ectx.LabelPositions[l]; pos != nil { - pos.BinaryOffset = offset + if pos := m.labelPositionPool.Get(int(l)); pos != nil { + pos.binaryOffset = offset } case sourceOffsetInfo: m.c.AddSourceOffsetInfo(offset, cur.sourceOffsetInfo()) @@ -2036,22 +2145,16 @@ func (m *machine) Encode(ctx context.Context) (err error) { } if wazevoapi.PerfMapEnabled { - l := pos.L - var labelStr string - if blkID, ok := labelToSSABlockID[l]; ok { - labelStr = fmt.Sprintf("%s::SSA_Block[%s]", l, blkID) - } else { - labelStr = l.String() - } + l := labelPosToLabel[pos] size := int64(len(*bufPtr)) - offset - wazevoapi.PerfMap.AddModuleEntry(fnIndex, offset, uint64(size), fmt.Sprintf("%s:::::%s", fn, labelStr)) + wazevoapi.PerfMap.AddModuleEntry(fnIndex, offset, uint64(size), fmt.Sprintf("%s:::::%s", fn, l)) } } for i := range m.consts { offset := int64(len(*bufPtr)) vc := &m.consts[i] - vc.label.BinaryOffset = offset + vc.labelPos.binaryOffset = offset if vc._var == nil { lo, hi := vc.lo, vc.hi m.c.Emit8Bytes(lo) @@ -2069,7 +2172,7 @@ func (m *machine) Encode(ctx context.Context) (err error) { switch p.instr.kind { case jmp, jmpIf, lea, xmmUnaryRmR: target := p.instr.jmpLabel() - targetOffset := ectx.LabelPositions[target].BinaryOffset + targetOffset := m.labelPositionPool.Get(int(target)).binaryOffset imm32Offset := p.imm32Offset jmpOffset := int32(targetOffset - (p.imm32Offset + 4)) // +4 because RIP points to the next instruction. binary.LittleEndian.PutUint32(buf[imm32Offset:], uint32(jmpOffset)) @@ -2078,7 +2181,7 @@ func (m *machine) Encode(ctx context.Context) (err error) { // Each entry is the offset from the beginning of the jmpTableIsland instruction in 8 bytes. targets := m.jmpTableTargets[p.instr.u1] for i, l := range targets { - targetOffset := ectx.LabelPositions[backend.Label(l)].BinaryOffset + targetOffset := m.labelPositionPool.Get(int(l)).binaryOffset jmpOffset := targetOffset - tableBegin binary.LittleEndian.PutUint64(buf[tableBegin+int64(i)*8:], uint64(jmpOffset)) } @@ -2150,7 +2253,7 @@ func (m *machine) lowerFcmpToFlags(instr *ssa.Instruction) (f1, f2 cond, and boo // allocateInstr allocates an instruction. func (m *machine) allocateInstr() *instruction { - instr := m.ectx.InstructionPool.Allocate() + instr := m.instrPool.Allocate() if !m.regAllocStarted { instr.addedBeforeRegAlloc = true } @@ -2164,24 +2267,22 @@ func (m *machine) allocateNop() *instruction { } func (m *machine) insert(i *instruction) { - ectx := m.ectx - ectx.PendingInstructions = append(ectx.PendingInstructions, i) + m.pendingInstructions = append(m.pendingInstructions, i) } -func (m *machine) allocateBrTarget() (nop *instruction, l backend.Label) { //nolint - pos := m.allocateLabel() - l = pos.L +func (m *machine) allocateBrTarget() (nop *instruction, l label) { //nolint + l, pos := m.allocateLabel() nop = m.allocateInstr() nop.asNop0WithLabel(l) - pos.Begin, pos.End = nop, nop + pos.begin, pos.end = nop, nop return } -func (m *machine) allocateLabel() *labelPosition { - ectx := m.ectx - l := ectx.AllocateLabel() - pos := ectx.GetOrAllocateLabelPosition(l) - return pos +func (m *machine) allocateLabel() (label, *labelPosition) { + l := m.nextLabel + pos := m.labelPositionPool.GetOrAllocate(int(l)) + m.nextLabel++ + return l, pos } func (m *machine) getVRegSpillSlotOffsetFromSP(id regalloc.VRegID, size byte) int64 { @@ -3195,22 +3296,22 @@ func (m *machine) lowerShuffle(x, y ssa.Value, lo, hi uint64, ret ssa.Value) { } } - xmaskLabel := m.allocateLabel() - m.consts = append(m.consts, _const{lo: xMask[0], hi: xMask[1], label: xmaskLabel}) - ymaskLabel := m.allocateLabel() - m.consts = append(m.consts, _const{lo: yMask[0], hi: yMask[1], label: ymaskLabel}) + xl, xmaskPos := m.allocateLabel() + m.consts = append(m.consts, _const{lo: xMask[0], hi: xMask[1], label: xl, labelPos: xmaskPos}) + yl, ymaskPos := m.allocateLabel() + m.consts = append(m.consts, _const{lo: yMask[0], hi: yMask[1], label: yl, labelPos: ymaskPos}) xx, yy := m.getOperand_Reg(m.c.ValueDefinition(x)), m.getOperand_Reg(m.c.ValueDefinition(y)) tmpX, tmpY := m.copyToTmp(xx.reg()), m.copyToTmp(yy.reg()) // Apply mask to X. tmp := m.c.AllocateVReg(ssa.TypeV128) - loadMaskLo := m.allocateInstr().asXmmUnaryRmR(sseOpcodeMovdqu, newOperandMem(m.newAmodeRipRel(xmaskLabel.L)), tmp) + loadMaskLo := m.allocateInstr().asXmmUnaryRmR(sseOpcodeMovdqu, newOperandMem(m.newAmodeRipRel(xl)), tmp) m.insert(loadMaskLo) m.insert(m.allocateInstr().asXmmRmR(sseOpcodePshufb, newOperandReg(tmp), tmpX)) // Apply mask to Y. - loadMaskHi := m.allocateInstr().asXmmUnaryRmR(sseOpcodeMovdqu, newOperandMem(m.newAmodeRipRel(ymaskLabel.L)), tmp) + loadMaskHi := m.allocateInstr().asXmmUnaryRmR(sseOpcodeMovdqu, newOperandMem(m.newAmodeRipRel(yl)), tmp) m.insert(loadMaskHi) m.insert(m.allocateInstr().asXmmRmR(sseOpcodePshufb, newOperandReg(tmp), tmpY)) diff --git a/internal/engine/wazevo/backend/isa/amd64/machine_pro_epi_logue.go b/internal/engine/wazevo/backend/isa/amd64/machine_pro_epi_logue.go index 8fa974c661..e53729860d 100644 --- a/internal/engine/wazevo/backend/isa/amd64/machine_pro_epi_logue.go +++ b/internal/engine/wazevo/backend/isa/amd64/machine_pro_epi_logue.go @@ -12,7 +12,7 @@ func (m *machine) PostRegAlloc() { } func (m *machine) setupPrologue() { - cur := m.ectx.RootInstr + cur := m.rootInstr prevInitInst := cur.next // At this point, we have the stack layout as follows: @@ -130,14 +130,13 @@ func (m *machine) setupPrologue() { // 3. Inserts the dec/inc RSP instruction right before/after the call instruction. // 4. Lowering that is supposed to be done after regalloc. func (m *machine) postRegAlloc() { - ectx := m.ectx - for cur := ectx.RootInstr; cur != nil; cur = cur.next { + for cur := m.rootInstr; cur != nil; cur = cur.next { switch k := cur.kind; k { case ret: m.setupEpilogueAfter(cur.prev) continue case fcvtToSintSequence, fcvtToUintSequence: - m.ectx.PendingInstructions = m.ectx.PendingInstructions[:0] + m.pendingInstructions = m.pendingInstructions[:0] if k == fcvtToSintSequence { m.lowerFcvtToSintSequenceAfterRegalloc(cur) } else { @@ -146,29 +145,29 @@ func (m *machine) postRegAlloc() { prev := cur.prev next := cur.next cur := prev - for _, instr := range m.ectx.PendingInstructions { + for _, instr := range m.pendingInstructions { cur = linkInstr(cur, instr) } linkInstr(cur, next) continue case xmmCMov: - m.ectx.PendingInstructions = m.ectx.PendingInstructions[:0] + m.pendingInstructions = m.pendingInstructions[:0] m.lowerXmmCmovAfterRegAlloc(cur) prev := cur.prev next := cur.next cur := prev - for _, instr := range m.ectx.PendingInstructions { + for _, instr := range m.pendingInstructions { cur = linkInstr(cur, instr) } linkInstr(cur, next) continue case idivRemSequence: - m.ectx.PendingInstructions = m.ectx.PendingInstructions[:0] + m.pendingInstructions = m.pendingInstructions[:0] m.lowerIDivRemSequenceAfterRegAlloc(cur) prev := cur.prev next := cur.next cur := prev - for _, instr := range m.ectx.PendingInstructions { + for _, instr := range m.pendingInstructions { cur = linkInstr(cur, instr) } linkInstr(cur, next) diff --git a/internal/engine/wazevo/backend/isa/amd64/machine_pro_epi_logue_test.go b/internal/engine/wazevo/backend/isa/amd64/machine_pro_epi_logue_test.go index 1a86a44589..9155cf2f96 100644 --- a/internal/engine/wazevo/backend/isa/amd64/machine_pro_epi_logue_test.go +++ b/internal/engine/wazevo/backend/isa/amd64/machine_pro_epi_logue_test.go @@ -69,14 +69,14 @@ func TestMachine_setupPrologue(t *testing.T) { m.currentABI = &tc.abi root := m.allocateNop() - m.ectx.RootInstr = root + m.rootInstr = root udf := m.allocateInstr() udf.asUD2() root.next = udf udf.prev = root m.setupPrologue() - require.Equal(t, root, m.ectx.RootInstr) + require.Equal(t, root, m.rootInstr) err := m.Encode(context.Background()) require.NoError(t, err) require.Equal(t, tc.exp, m.Format()) @@ -144,14 +144,14 @@ func TestMachine_postRegAlloc(t *testing.T) { m.currentABI = &tc.abi root := m.allocateNop() - m.ectx.RootInstr = root + m.rootInstr = root ret := m.allocateInstr() ret.asRet() root.next = ret ret.prev = root m.postRegAlloc() - require.Equal(t, root, m.ectx.RootInstr) + require.Equal(t, root, m.rootInstr) err := m.Encode(context.Background()) require.NoError(t, err) require.Equal(t, tc.exp, m.Format()) diff --git a/internal/engine/wazevo/backend/isa/amd64/machine_regalloc.go b/internal/engine/wazevo/backend/isa/amd64/machine_regalloc.go index 0bb28ee9e7..1750c89281 100644 --- a/internal/engine/wazevo/backend/isa/amd64/machine_regalloc.go +++ b/internal/engine/wazevo/backend/isa/amd64/machine_regalloc.go @@ -1,13 +1,226 @@ package amd64 import ( - "github.com/tetratelabs/wazero/internal/engine/wazevo/backend" "github.com/tetratelabs/wazero/internal/engine/wazevo/backend/regalloc" "github.com/tetratelabs/wazero/internal/engine/wazevo/ssa" ) -// InsertMoveBefore implements backend.RegAllocFunctionMachine. -func (m *machine) InsertMoveBefore(dst, src regalloc.VReg, instr *instruction) { +// regAllocFn implements regalloc.Function. +type regAllocFn struct { + ssaB ssa.Builder + m *machine + loopNestingForestRoots []ssa.BasicBlock + blockIter int +} + +// PostOrderBlockIteratorBegin implements regalloc.Function. +func (f *regAllocFn) PostOrderBlockIteratorBegin() *labelPosition { + f.blockIter = len(f.m.orderedSSABlockLabelPos) - 1 + return f.PostOrderBlockIteratorNext() +} + +// PostOrderBlockIteratorNext implements regalloc.Function. +func (f *regAllocFn) PostOrderBlockIteratorNext() *labelPosition { + if f.blockIter < 0 { + return nil + } + b := f.m.orderedSSABlockLabelPos[f.blockIter] + f.blockIter-- + return b +} + +// ReversePostOrderBlockIteratorBegin implements regalloc.Function. +func (f *regAllocFn) ReversePostOrderBlockIteratorBegin() *labelPosition { + f.blockIter = 0 + return f.ReversePostOrderBlockIteratorNext() +} + +// ReversePostOrderBlockIteratorNext implements regalloc.Function. +func (f *regAllocFn) ReversePostOrderBlockIteratorNext() *labelPosition { + if f.blockIter >= len(f.m.orderedSSABlockLabelPos) { + return nil + } + b := f.m.orderedSSABlockLabelPos[f.blockIter] + f.blockIter++ + return b +} + +// ClobberedRegisters implements regalloc.Function. +func (f *regAllocFn) ClobberedRegisters(regs []regalloc.VReg) { + f.m.clobberedRegs = append(f.m.clobberedRegs[:0], regs...) +} + +// LoopNestingForestRoots implements regalloc.Function. +func (f *regAllocFn) LoopNestingForestRoots() int { + f.loopNestingForestRoots = f.ssaB.LoopNestingForestRoots() + return len(f.loopNestingForestRoots) +} + +// LoopNestingForestRoot implements regalloc.Function. +func (f *regAllocFn) LoopNestingForestRoot(i int) *labelPosition { + root := f.loopNestingForestRoots[i] + pos := f.m.getOrAllocateSSABlockLabelPosition(root) + return pos +} + +// LowestCommonAncestor implements regalloc.Function. +func (f *regAllocFn) LowestCommonAncestor(blk1, blk2 *labelPosition) *labelPosition { + sb := f.ssaB.LowestCommonAncestor(blk1.sb, blk2.sb) + pos := f.m.getOrAllocateSSABlockLabelPosition(sb) + return pos +} + +// Idom implements regalloc.Function. +func (f *regAllocFn) Idom(blk *labelPosition) *labelPosition { + sb := f.ssaB.Idom(blk.sb) + pos := f.m.getOrAllocateSSABlockLabelPosition(sb) + return pos +} + +// SwapBefore implements regalloc.Function. +func (f *regAllocFn) SwapBefore(x1, x2, tmp regalloc.VReg, instr *instruction) { + f.m.swap(instr.prev, x1, x2, tmp) +} + +// StoreRegisterBefore implements regalloc.Function. +func (f *regAllocFn) StoreRegisterBefore(v regalloc.VReg, instr *instruction) { + m := f.m + m.insertStoreRegisterAt(v, instr, false) +} + +// StoreRegisterAfter implements regalloc.Function. +func (f *regAllocFn) StoreRegisterAfter(v regalloc.VReg, instr *instruction) { + m := f.m + m.insertStoreRegisterAt(v, instr, true) +} + +// ReloadRegisterBefore implements regalloc.Function. +func (f *regAllocFn) ReloadRegisterBefore(v regalloc.VReg, instr *instruction) { + m := f.m + m.insertReloadRegisterAt(v, instr, false) +} + +// ReloadRegisterAfter implements regalloc.Function. +func (f *regAllocFn) ReloadRegisterAfter(v regalloc.VReg, instr *instruction) { + m := f.m + m.insertReloadRegisterAt(v, instr, true) +} + +// InsertMoveBefore implements regalloc.Function. +func (f *regAllocFn) InsertMoveBefore(dst, src regalloc.VReg, instr *instruction) { + f.m.insertMoveBefore(dst, src, instr) +} + +// LoopNestingForestChild implements regalloc.Function. +func (f *regAllocFn) LoopNestingForestChild(pos *labelPosition, i int) *labelPosition { + childSB := pos.sb.LoopNestingForestChildren()[i] + return f.m.getOrAllocateSSABlockLabelPosition(childSB) +} + +// Succ implements regalloc.Block. +func (f *regAllocFn) Succ(pos *labelPosition, i int) *labelPosition { + succSB := pos.sb.Succ(i) + if succSB.ReturnBlock() { + return nil + } + return f.m.getOrAllocateSSABlockLabelPosition(succSB) +} + +// Pred implements regalloc.Block. +func (f *regAllocFn) Pred(pos *labelPosition, i int) *labelPosition { + predSB := pos.sb.Pred(i) + return f.m.getOrAllocateSSABlockLabelPosition(predSB) +} + +// BlockParams implements regalloc.Function. +func (f *regAllocFn) BlockParams(pos *labelPosition, regs *[]regalloc.VReg) []regalloc.VReg { + c := f.m.c + *regs = (*regs)[:0] + for i := 0; i < pos.sb.Params(); i++ { + v := c.VRegOf(pos.sb.Param(i)) + *regs = append(*regs, v) + } + return *regs +} + +// ID implements regalloc.Block. +func (pos *labelPosition) ID() int32 { + return int32(pos.sb.ID()) +} + +// InstrIteratorBegin implements regalloc.Block. +func (pos *labelPosition) InstrIteratorBegin() *instruction { + ret := pos.begin + pos.cur = ret + return ret +} + +// InstrIteratorNext implements regalloc.Block. +func (pos *labelPosition) InstrIteratorNext() *instruction { + for { + if pos.cur == pos.end { + return nil + } + instr := pos.cur.next + pos.cur = instr + if instr == nil { + return nil + } else if instr.AddedBeforeRegAlloc() { + // Only concerned about the instruction added before regalloc. + return instr + } + } +} + +// InstrRevIteratorBegin implements regalloc.Block. +func (pos *labelPosition) InstrRevIteratorBegin() *instruction { + pos.cur = pos.end + return pos.cur +} + +// InstrRevIteratorNext implements regalloc.Block. +func (pos *labelPosition) InstrRevIteratorNext() *instruction { + for { + if pos.cur == pos.begin { + return nil + } + instr := pos.cur.prev + pos.cur = instr + if instr == nil { + return nil + } else if instr.AddedBeforeRegAlloc() { + // Only concerned about the instruction added before regalloc. + return instr + } + } +} + +// FirstInstr implements regalloc.Block. +func (pos *labelPosition) FirstInstr() *instruction { return pos.begin } + +// LastInstrForInsertion implements regalloc.Block. +func (pos *labelPosition) LastInstrForInsertion() *instruction { + return lastInstrForInsertion(pos.begin, pos.end) +} + +// Preds implements regalloc.Block. +func (pos *labelPosition) Preds() int { return pos.sb.Preds() } + +// Entry implements regalloc.Block. +func (pos *labelPosition) Entry() bool { return pos.sb.EntryBlock() } + +// Succs implements regalloc.Block. +func (pos *labelPosition) Succs() int { return pos.sb.Succs() } + +// LoopHeader implements regalloc.Block. +func (pos *labelPosition) LoopHeader() bool { return pos.sb.LoopHeader() } + +// LoopNestingForestChildren implements regalloc.Block. +func (pos *labelPosition) LoopNestingForestChildren() int { + return len(pos.sb.LoopNestingForestChildren()) +} + +func (m *machine) insertMoveBefore(dst, src regalloc.VReg, instr *instruction) { typ := src.RegType() if typ != dst.RegType() { panic("BUG: src and dst must have the same type") @@ -26,8 +239,7 @@ func (m *machine) InsertMoveBefore(dst, src regalloc.VReg, instr *instruction) { linkInstr(cur, prevNext) } -// InsertStoreRegisterAt implements backend.RegAllocFunctionMachine. -func (m *machine) InsertStoreRegisterAt(v regalloc.VReg, instr *instruction, after bool) *instruction { +func (m *machine) insertStoreRegisterAt(v regalloc.VReg, instr *instruction, after bool) *instruction { if !v.IsRealReg() { panic("BUG: VReg must be backed by real reg to be stored") } @@ -61,8 +273,7 @@ func (m *machine) InsertStoreRegisterAt(v regalloc.VReg, instr *instruction, aft return linkInstr(cur, prevNext) } -// InsertReloadRegisterAt implements backend.RegAllocFunctionMachine. -func (m *machine) InsertReloadRegisterAt(v regalloc.VReg, instr *instruction, after bool) *instruction { +func (m *machine) insertReloadRegisterAt(v regalloc.VReg, instr *instruction, after bool) *instruction { if !v.IsRealReg() { panic("BUG: VReg must be backed by real reg to be stored") } @@ -98,13 +309,7 @@ func (m *machine) InsertReloadRegisterAt(v regalloc.VReg, instr *instruction, af return linkInstr(cur, prevNext) } -// ClobberedRegisters implements backend.RegAllocFunctionMachine. -func (m *machine) ClobberedRegisters(regs []regalloc.VReg) { - m.clobberedRegs = append(m.clobberedRegs[:0], regs...) -} - -// Swap implements backend.RegAllocFunctionMachine. -func (m *machine) Swap(cur *instruction, x1, x2, tmp regalloc.VReg) { +func (m *machine) swap(cur *instruction, x1, x2, tmp regalloc.VReg) { if x1.RegType() == regalloc.RegTypeInt { prevNext := cur.next xc := m.allocateInstr().asXCHG(x1, newOperandReg(x2), 8) @@ -113,25 +318,24 @@ func (m *machine) Swap(cur *instruction, x1, x2, tmp regalloc.VReg) { } else { if tmp.Valid() { prevNext := cur.next - m.InsertMoveBefore(tmp, x1, prevNext) - m.InsertMoveBefore(x1, x2, prevNext) - m.InsertMoveBefore(x2, tmp, prevNext) + m.insertMoveBefore(tmp, x1, prevNext) + m.insertMoveBefore(x1, x2, prevNext) + m.insertMoveBefore(x2, tmp, prevNext) } else { prevNext := cur.next r2 := x2.RealReg() // Temporarily spill x1 to stack. - cur = m.InsertStoreRegisterAt(x1, cur, true).prev + cur = m.insertStoreRegisterAt(x1, cur, true).prev // Then move x2 to x1. cur = linkInstr(cur, m.allocateInstr().asXmmUnaryRmR(sseOpcodeMovdqa, newOperandReg(x2), x1)) linkInstr(cur, prevNext) // Then reload the original value on x1 from stack to r2. - m.InsertReloadRegisterAt(x1.SetRealReg(r2), cur, true) + m.insertReloadRegisterAt(x1.SetRealReg(r2), cur, true) } } } -// LastInstrForInsertion implements backend.RegAllocFunctionMachine. -func (m *machine) LastInstrForInsertion(begin, end *instruction) *instruction { +func lastInstrForInsertion(begin, end *instruction) *instruction { cur := end for cur.kind == nop0 { cur = cur.prev @@ -146,8 +350,3 @@ func (m *machine) LastInstrForInsertion(begin, end *instruction) *instruction { return end } } - -// SSABlockLabel implements backend.RegAllocFunctionMachine. -func (m *machine) SSABlockLabel(id ssa.BasicBlockID) backend.Label { - return m.ectx.SsaBlockIDToLabels[id] -} diff --git a/internal/engine/wazevo/backend/isa/amd64/machine_regalloc_test.go b/internal/engine/wazevo/backend/isa/amd64/machine_regalloc_test.go index 929e14c3b1..f1834960f5 100644 --- a/internal/engine/wazevo/backend/isa/amd64/machine_regalloc_test.go +++ b/internal/engine/wazevo/backend/isa/amd64/machine_regalloc_test.go @@ -8,7 +8,7 @@ import ( "github.com/tetratelabs/wazero/internal/testing/require" ) -func TestMachine_InsertStoreRegisterAt(t *testing.T) { +func TestMachine_insertStoreRegisterAt(t *testing.T) { for _, tc := range []struct { spillSlotSize int64 expected string @@ -52,13 +52,13 @@ func TestMachine_InsertStoreRegisterAt(t *testing.T) { i2.prev = i1 if after { - m.InsertStoreRegisterAt(raxVReg, i1, after) - m.InsertStoreRegisterAt(xmm1VReg, i1, after) + m.insertStoreRegisterAt(raxVReg, i1, after) + m.insertStoreRegisterAt(xmm1VReg, i1, after) } else { - m.InsertStoreRegisterAt(xmm1VReg, i2, after) - m.InsertStoreRegisterAt(raxVReg, i2, after) + m.insertStoreRegisterAt(xmm1VReg, i2, after) + m.insertStoreRegisterAt(raxVReg, i2, after) } - m.ectx.RootInstr = i1 + m.rootInstr = i1 require.Equal(t, tc.expected, m.Format()) }) } @@ -66,7 +66,7 @@ func TestMachine_InsertStoreRegisterAt(t *testing.T) { } } -func TestMachine_InsertReloadRegisterAt(t *testing.T) { +func TestMachine_insertReloadRegisterAt(t *testing.T) { for _, tc := range []struct { spillSlotSize int64 expected string @@ -110,13 +110,13 @@ func TestMachine_InsertReloadRegisterAt(t *testing.T) { i2.prev = i1 if after { - m.InsertReloadRegisterAt(xmm1VReg, i1, after) - m.InsertReloadRegisterAt(raxVReg, i1, after) + m.insertReloadRegisterAt(xmm1VReg, i1, after) + m.insertReloadRegisterAt(raxVReg, i1, after) } else { - m.InsertReloadRegisterAt(raxVReg, i2, after) - m.InsertReloadRegisterAt(xmm1VReg, i2, after) + m.insertReloadRegisterAt(raxVReg, i2, after) + m.insertReloadRegisterAt(xmm1VReg, i2, after) } - m.ectx.RootInstr = i1 + m.rootInstr = i1 require.Equal(t, tc.expected, m.Format()) }) } @@ -154,8 +154,8 @@ func TestMachine_InsertMoveBefore(t *testing.T) { i1.next = i2 i2.prev = i1 - m.InsertMoveBefore(tc.dst, tc.src, i2) - m.ectx.RootInstr = i1 + m.insertMoveBefore(tc.dst, tc.src, i2) + m.rootInstr = i1 require.Equal(t, tc.expected, m.Format()) }) } @@ -221,8 +221,8 @@ func TestMachineSwap(t *testing.T) { cur.next = i2 i2.prev = cur - m.Swap(cur, tc.x1, tc.x2, tc.tmp) - m.ectx.RootInstr = cur + m.swap(cur, tc.x1, tc.x2, tc.tmp) + m.rootInstr = cur require.Equal(t, tc.expected, m.Format()) diff --git a/internal/engine/wazevo/backend/isa/amd64/machine_test.go b/internal/engine/wazevo/backend/isa/amd64/machine_test.go index b870ae4b6d..5614b45cbf 100644 --- a/internal/engine/wazevo/backend/isa/amd64/machine_test.go +++ b/internal/engine/wazevo/backend/isa/amd64/machine_test.go @@ -279,10 +279,11 @@ func Test_machine_getOperand_Mem_Imm32_Reg(t *testing.T) { func TestMachine_lowerExitWithCode(t *testing.T) { _, _, m := newSetupWithMockContext() + m.nextLabel = 1 m.lowerExitWithCode(r15VReg, wazevoapi.ExitCodeUnreachable) m.insert(m.allocateInstr().asUD2()) - m.ectx.FlushPendingInstructions() - m.ectx.RootInstr = m.ectx.PerBlockHead + m.FlushPendingInstructions() + m.rootInstr = m.perBlockHead require.Equal(t, ` mov.q %rsp, 56(%r15) mov.q %rbp, 1152(%r15) @@ -356,6 +357,7 @@ L2: } { t.Run(tc.name, func(t *testing.T) { ctx, b, m := newSetupWithMockContext() + m.nextLabel = 1 p := b.CurrentBlock().AddParam(b, tc.typ) m.cpuFeatures = tc.cpuFlags @@ -364,8 +366,8 @@ L2: instr := &ssa.Instruction{} instr.AsClz(p) m.lowerClz(instr) - m.ectx.FlushPendingInstructions() - m.ectx.RootInstr = m.ectx.PerBlockHead + m.FlushPendingInstructions() + m.rootInstr = m.perBlockHead require.Equal(t, tc.exp, m.Format()) }) } @@ -428,6 +430,7 @@ L2: } { t.Run(tc.name, func(t *testing.T) { ctx, b, m := newSetupWithMockContext() + m.nextLabel = 1 p := b.CurrentBlock().AddParam(b, tc.typ) m.cpuFeatures = tc.cpuFlags @@ -436,8 +439,8 @@ L2: instr := &ssa.Instruction{} instr.AsCtz(p) m.lowerCtz(instr) - m.ectx.FlushPendingInstructions() - m.ectx.RootInstr = m.ectx.PerBlockHead + m.FlushPendingInstructions() + m.rootInstr = m.perBlockHead require.Equal(t, tc.exp, m.Format()) }) } diff --git a/internal/engine/wazevo/backend/isa/amd64/operands.go b/internal/engine/wazevo/backend/isa/amd64/operands.go index c6fcb86731..10b727e7c9 100644 --- a/internal/engine/wazevo/backend/isa/amd64/operands.go +++ b/internal/engine/wazevo/backend/isa/amd64/operands.go @@ -59,7 +59,7 @@ func (o *operand) format(_64 bool) string { case operandKindImm32: return fmt.Sprintf("$%d", int32(o.imm32())) case operandKindLabel: - return backend.Label(o.imm32()).String() + return label(o.imm32()).String() default: panic(fmt.Sprintf("BUG: invalid operand: %s", o.kind)) } @@ -85,22 +85,22 @@ func (o *operand) imm32() uint32 { return uint32(o.data) } -func (o *operand) label() backend.Label { +func (o *operand) label() label { switch o.kind { case operandKindLabel: - return backend.Label(o.data) + return label(o.data) case operandKindMem: mem := o.addressMode() if mem.kind() != amodeRipRel { panic("BUG: invalid label") } - return backend.Label(mem.imm32) + return label(mem.imm32) default: panic("BUG: invalid operand kind") } } -func newOperandLabel(label backend.Label) operand { +func newOperandLabel(label label) operand { return operand{kind: operandKindLabel, data: uint64(label)} } @@ -221,7 +221,7 @@ func (m *machine) newAmodeRegRegShift(imm32 uint32, base, index regalloc.VReg, s return ret } -func (m *machine) newAmodeRipRel(label backend.Label) *amode { +func (m *machine) newAmodeRipRel(label label) *amode { ret := m.amodePool.Allocate() *ret = amode{kindWithShift: uint32(amodeRipRel), imm32: uint32(label)} return ret @@ -246,7 +246,7 @@ func (a *amode) String() string { "%d(%s,%s,%d)", int32(a.imm32), formatVRegSized(a.base, true), formatVRegSized(a.index, true), shift) case amodeRipRel: - return fmt.Sprintf("%s(%%rip)", backend.Label(a.imm32)) + return fmt.Sprintf("%s(%%rip)", label(a.imm32)) default: panic("BUG: invalid amode kind") } diff --git a/internal/engine/wazevo/backend/isa/amd64/util_test.go b/internal/engine/wazevo/backend/isa/amd64/util_test.go index 29b2c6c15e..5344974a58 100644 --- a/internal/engine/wazevo/backend/isa/amd64/util_test.go +++ b/internal/engine/wazevo/backend/isa/amd64/util_test.go @@ -138,9 +138,9 @@ func (m *mockCompiler) Compile(context.Context) (_ []byte, _ []backend.Relocatio } func formatEmittedInstructionsInCurrentBlock(m *machine) string { - m.ectx.FlushPendingInstructions() + m.FlushPendingInstructions() var strs []string - for cur := m.ectx.PerBlockHead; cur != nil; cur = cur.next { + for cur := m.perBlockHead; cur != nil; cur = cur.next { strs = append(strs, cur.String()) } return strings.Join(strs, "\n") diff --git a/internal/engine/wazevo/backend/isa/arm64/abi.go b/internal/engine/wazevo/backend/isa/arm64/abi.go index 4eaa13ce1c..c5e24a5773 100644 --- a/internal/engine/wazevo/backend/isa/arm64/abi.go +++ b/internal/engine/wazevo/backend/isa/arm64/abi.go @@ -228,10 +228,9 @@ func (m *machine) callerGenFunctionReturnVReg(a *backend.FunctionABI, retIndex i } func (m *machine) resolveAddressModeForOffsetAndInsert(cur *instruction, offset int64, dstBits byte, rn regalloc.VReg, allowTmpRegUse bool) (*instruction, *addressMode) { - exct := m.executableContext - exct.PendingInstructions = exct.PendingInstructions[:0] + m.pendingInstructions = m.pendingInstructions[:0] mode := m.resolveAddressModeForOffset(offset, dstBits, rn, allowTmpRegUse) - for _, instr := range exct.PendingInstructions { + for _, instr := range m.pendingInstructions { cur = linkInstr(cur, instr) } return cur, mode diff --git a/internal/engine/wazevo/backend/isa/arm64/abi_entry_preamble_test.go b/internal/engine/wazevo/backend/isa/arm64/abi_entry_preamble_test.go index beca3f0290..3dd6811cd4 100644 --- a/internal/engine/wazevo/backend/isa/arm64/abi_entry_preamble_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/abi_entry_preamble_test.go @@ -397,7 +397,7 @@ func TestAbiImpl_constructEntryPreamble(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { _, _, m := newSetupWithMockContext() - m.executableContext.RootInstr = m.constructEntryPreamble(tc.sig) + m.rootInstr = m.constructEntryPreamble(tc.sig) require.Equal(t, tc.exp, m.Format()) }) } @@ -532,7 +532,7 @@ func TestMachine_goEntryPreamblePassArg(t *testing.T) { t.Run(tc.exp, func(t *testing.T) { _, _, m := newSetupWithMockContext() cur := m.allocateNop() - m.executableContext.RootInstr = cur + m.rootInstr = cur m.goEntryPreamblePassArg(cur, paramSlicePtr, &tc.arg, tc.argSlotBeginOffsetFromSP) require.Equal(t, tc.exp, m.Format()) err := m.Encode(context.Background()) @@ -686,7 +686,7 @@ func TestMachine_goEntryPreamblePassResult(t *testing.T) { t.Run(tc.exp, func(t *testing.T) { _, _, m := newSetupWithMockContext() cur := m.allocateNop() - m.executableContext.RootInstr = cur + m.rootInstr = cur m.goEntryPreamblePassResult(cur, paramSlicePtr, &tc.arg, tc.retStart) require.Equal(t, tc.exp, m.Format()) err := m.Encode(context.Background()) diff --git a/internal/engine/wazevo/backend/isa/arm64/abi_go_call.go b/internal/engine/wazevo/backend/isa/arm64/abi_go_call.go index 99e6bb482d..06f8a4a053 100644 --- a/internal/engine/wazevo/backend/isa/arm64/abi_go_call.go +++ b/internal/engine/wazevo/backend/isa/arm64/abi_go_call.go @@ -14,7 +14,6 @@ var calleeSavedRegistersSorted = []regalloc.VReg{ // CompileGoFunctionTrampoline implements backend.Machine. func (m *machine) CompileGoFunctionTrampoline(exitCode wazevoapi.ExitCode, sig *ssa.Signature, needModuleContextPtr bool) []byte { - exct := m.executableContext argBegin := 1 // Skips exec context by default. if needModuleContextPtr { argBegin++ @@ -26,7 +25,7 @@ func (m *machine) CompileGoFunctionTrampoline(exitCode wazevoapi.ExitCode, sig * cur := m.allocateInstr() cur.asNop0() - exct.RootInstr = cur + m.rootInstr = cur // Execution context is always the first argument. execCtrPtr := x0VReg @@ -244,7 +243,7 @@ func (m *machine) CompileGoFunctionTrampoline(exitCode wazevoapi.ExitCode, sig * ret.asRet() linkInstr(cur, ret) - m.encode(m.executableContext.RootInstr) + m.encode(m.rootInstr) return m.compiler.Buf() } @@ -302,20 +301,18 @@ func (m *machine) restoreRegistersInExecutionContext(cur *instruction, regs []re } func (m *machine) lowerConstantI64AndInsert(cur *instruction, dst regalloc.VReg, v int64) *instruction { - exct := m.executableContext - exct.PendingInstructions = exct.PendingInstructions[:0] + m.pendingInstructions = m.pendingInstructions[:0] m.lowerConstantI64(dst, v) - for _, instr := range exct.PendingInstructions { + for _, instr := range m.pendingInstructions { cur = linkInstr(cur, instr) } return cur } func (m *machine) lowerConstantI32AndInsert(cur *instruction, dst regalloc.VReg, v int32) *instruction { - exct := m.executableContext - exct.PendingInstructions = exct.PendingInstructions[:0] + m.pendingInstructions = m.pendingInstructions[:0] m.lowerConstantI32(dst, v) - for _, instr := range exct.PendingInstructions { + for _, instr := range m.pendingInstructions { cur = linkInstr(cur, instr) } return cur diff --git a/internal/engine/wazevo/backend/isa/arm64/abi_go_call_test.go b/internal/engine/wazevo/backend/isa/arm64/abi_go_call_test.go index 5fef0cfbe5..ad428d70a8 100644 --- a/internal/engine/wazevo/backend/isa/arm64/abi_go_call_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/abi_go_call_test.go @@ -517,7 +517,7 @@ func Test_goFunctionCallLoadStackArg(t *testing.T) { _, result := m.goFunctionCallLoadStackArg(nop, originalArg0Reg, tc.arg, intVReg, floatVReg) require.Equal(t, tc.expResultReg, result) - m.executableContext.RootInstr = nop + m.rootInstr = nop require.Equal(t, tc.exp, m.Format()) err := m.Encode(context.Background()) @@ -584,7 +584,7 @@ func Test_goFunctionCallStoreStackResult(t *testing.T) { nop := m.allocateNop() m.goFunctionCallStoreStackResult(nop, spVReg, tc.result, tc.resultReg) - m.executableContext.RootInstr = nop + m.rootInstr = nop require.Equal(t, tc.exp, m.Format()) err := m.Encode(context.Background()) diff --git a/internal/engine/wazevo/backend/isa/arm64/instr.go b/internal/engine/wazevo/backend/isa/arm64/instr.go index 7121cb5382..08fb84d037 100644 --- a/internal/engine/wazevo/backend/isa/arm64/instr.go +++ b/internal/engine/wazevo/backend/isa/arm64/instr.go @@ -36,18 +36,6 @@ type ( instructionKind byte ) -func asNop0(i *instruction) { - i.kind = nop0 -} - -func setNext(i, next *instruction) { - i.next = next -} - -func setPrev(i, prev *instruction) { - i.prev = prev -} - // IsCall implements regalloc.Instr IsCall. func (i *instruction) IsCall() bool { return i.kind == call @@ -63,16 +51,6 @@ func (i *instruction) IsReturn() bool { return i.kind == ret } -// Next implements regalloc.Instr Next. -func (i *instruction) Next() regalloc.Instr { - return i.next -} - -// Prev implements regalloc.Instr Prev. -func (i *instruction) Prev() regalloc.Instr { - return i.prev -} - // AddedBeforeRegAlloc implements regalloc.Instr AddedBeforeRegAlloc. func (i *instruction) AddedBeforeRegAlloc() bool { return i.addedBeforeRegAlloc diff --git a/internal/engine/wazevo/backend/isa/arm64/instr_encoding.go b/internal/engine/wazevo/backend/isa/arm64/instr_encoding.go index f0ede2d6aa..21be9b71e7 100644 --- a/internal/engine/wazevo/backend/isa/arm64/instr_encoding.go +++ b/internal/engine/wazevo/backend/isa/arm64/instr_encoding.go @@ -12,7 +12,7 @@ import ( // Encode implements backend.Machine Encode. func (m *machine) Encode(ctx context.Context) error { m.resolveRelativeAddresses(ctx) - m.encode(m.executableContext.RootInstr) + m.encode(m.rootInstr) if l := len(m.compiler.Buf()); l > maxFunctionExecutableSize { return fmt.Errorf("function size exceeds the limit: %d > %d", l, maxFunctionExecutableSize) } diff --git a/internal/engine/wazevo/backend/isa/arm64/lower_instr.go b/internal/engine/wazevo/backend/isa/arm64/lower_instr.go index 87652569ae..f9df356c0e 100644 --- a/internal/engine/wazevo/backend/isa/arm64/lower_instr.go +++ b/internal/engine/wazevo/backend/isa/arm64/lower_instr.go @@ -17,7 +17,6 @@ import ( // LowerSingleBranch implements backend.Machine. func (m *machine) LowerSingleBranch(br *ssa.Instruction) { - ectx := m.executableContext switch br.Opcode() { case ssa.OpcodeJump: _, _, targetBlkID := br.BranchData() @@ -26,11 +25,10 @@ func (m *machine) LowerSingleBranch(br *ssa.Instruction) { } b := m.allocateInstr() targetBlk := m.compiler.SSABuilder().BasicBlock(targetBlkID) - target := ectx.GetOrAllocateSSABlockLabel(targetBlk) - if target == labelReturn { + if targetBlk.ReturnBlock() { b.asRet() } else { - b.asBr(target) + b.asBr(ssaBlockLabel(targetBlk)) } m.insert(b) case ssa.OpcodeBrTable: @@ -70,18 +68,17 @@ func (m *machine) lowerBrTable(i *ssa.Instruction) { // LowerConditionalBranch implements backend.Machine. func (m *machine) LowerConditionalBranch(b *ssa.Instruction) { - exctx := m.executableContext cval, args, targetBlkID := b.BranchData() if len(args) > 0 { panic(fmt.Sprintf( "conditional branch shouldn't have args; likely a bug in critical edge splitting: from %s to %s", - exctx.CurrentSSABlk, + m.currentLabelPos.sb, targetBlkID, )) } targetBlk := m.compiler.SSABuilder().BasicBlock(targetBlkID) - target := exctx.GetOrAllocateSSABlockLabel(targetBlk) + target := ssaBlockLabel(targetBlk) cvalDef := m.compiler.ValueDefinition(cval) switch { @@ -794,7 +791,7 @@ func (m *machine) LowerInstr(instr *ssa.Instruction) { default: panic("TODO: lowering " + op.String()) } - m.executableContext.FlushPendingInstructions() + m.FlushPendingInstructions() } func (m *machine) lowerShuffle(rd regalloc.VReg, rn, rm operand, lane1, lane2 uint64) { diff --git a/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go b/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go index 6a367944e4..524811d780 100644 --- a/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/lower_instr_test.go @@ -18,7 +18,6 @@ func TestMachine_LowerConditionalBranch(t *testing.T) { brz bool, intCond ssa.IntegerCmpCond, floatCond ssa.FloatCmpCond, ctx *mockCompiler, builder ssa.Builder, m *machine, ) (instr *ssa.Instruction, verify func(t *testing.T)) { - m.executableContext.StartLoweringFunction(10) entry := builder.CurrentBlock() isInt := intCond != ssa.IntegerCmpCondInvalid @@ -63,7 +62,6 @@ func TestMachine_LowerConditionalBranch(t *testing.T) { } icmpInSameGroupFromParamAndImm12 := func(brz bool, ctx *mockCompiler, builder ssa.Builder, m *machine) (instr *ssa.Instruction, verify func(t *testing.T)) { - m.executableContext.StartLoweringFunction(10) entry := builder.CurrentBlock() v1 := entry.AddParam(builder, ssa.TypeI32) @@ -72,9 +70,6 @@ func TestMachine_LowerConditionalBranch(t *testing.T) { builder.InsertInstruction(iconst) v2 := iconst.Return() - // Constant can be referenced from different groups because we inline it. - builder.SetCurrentBlock(builder.AllocateBasicBlock()) - icmp := builder.AllocateInstruction() icmp.AsIcmp(v1, v2, ssa.IntegerCmpCondEqual) builder.InsertInstruction(icmp) @@ -103,7 +98,6 @@ func TestMachine_LowerConditionalBranch(t *testing.T) { { name: "icmp in different group", setup: func(ctx *mockCompiler, builder ssa.Builder, m *machine) (instr *ssa.Instruction, verify func(t *testing.T)) { - m.executableContext.StartLoweringFunction(10) entry := builder.CurrentBlock() v1, v2 := entry.AddParam(builder, ssa.TypeI64), entry.AddParam(builder, ssa.TypeI64) @@ -218,7 +212,6 @@ func TestMachine_LowerSingleBranch(t *testing.T) { { name: "b", setup: func(ctx *mockCompiler, builder ssa.Builder, m *machine) (instr *ssa.Instruction) { - m.executableContext.StartLoweringFunction(10) jump := builder.AllocateInstruction() jump.AsJump(ssa.ValuesNil, builder.AllocateBasicBlock()) builder.InsertInstruction(jump) @@ -229,7 +222,6 @@ func TestMachine_LowerSingleBranch(t *testing.T) { { name: "ret", setup: func(ctx *mockCompiler, builder ssa.Builder, m *machine) (instr *ssa.Instruction) { - m.executableContext.StartLoweringFunction(10) jump := builder.AllocateInstruction() jump.AsJump(ssa.ValuesNil, builder.ReturnBlock()) builder.InsertInstruction(jump) @@ -379,6 +371,7 @@ L2: regalloc.VReg(2).SetRegType(regalloc.RegTypeInt), regalloc.VReg(3).SetRegType(regalloc.RegTypeInt) mc, _, m := newSetupWithMockContext() + m.maxSSABlockID, m.nextLabel = 1, 1 mc.typeOf = map[regalloc.VRegID]ssa.Type{execCtx.ID(): ssa.TypeI64, 2: ssa.TypeI64, 3: ssa.TypeI64} m.lowerIDiv(execCtx, rd, operandNR(rn), operandNR(rm), tc._64bit, tc.signed) require.Equal(t, tc.exp, "\n"+formatEmittedInstructionsInCurrentBlock(m)+"\n") @@ -389,8 +382,8 @@ L2: func TestMachine_exitWithCode(t *testing.T) { _, _, m := newSetupWithMockContext() m.lowerExitWithCode(x1VReg, wazevoapi.ExitCodeGrowStack) - m.executableContext.FlushPendingInstructions() - m.encode(m.executableContext.PerBlockHead) + m.FlushPendingInstructions() + m.encode(m.perBlockHead) require.Equal(t, ` movz x1?, #0x1, lsl 0 str w1?, [x1] @@ -450,12 +443,13 @@ fcvtzu w1, s2 } { t.Run(tc.name, func(t *testing.T) { mc, _, m := newSetupWithMockContext() + m.maxSSABlockID, m.nextLabel = 1, 1 mc.typeOf = map[regalloc.VRegID]ssa.Type{v2VReg.ID(): ssa.TypeI64, x15VReg.ID(): ssa.TypeI64} m.lowerFpuToInt(x1VReg, operandNR(v2VReg), x15VReg, false, false, false, tc.nontrapping) require.Equal(t, tc.expectedAsm, "\n"+formatEmittedInstructionsInCurrentBlock(m)+"\n") - m.executableContext.FlushPendingInstructions() - m.encode(m.executableContext.PerBlockHead) + m.FlushPendingInstructions() + m.encode(m.perBlockHead) }) } } @@ -536,8 +530,8 @@ mul x1.4s, x2.4s, x15.4s m.lowerVIMul(x1VReg, operandNR(x2VReg), operandNR(x15VReg), tc.arrangement) require.Equal(t, tc.expectedAsm, "\n"+formatEmittedInstructionsInCurrentBlock(m)+"\n") - m.executableContext.FlushPendingInstructions() - m.encode(m.executableContext.PerBlockHead) + m.FlushPendingInstructions() + m.encode(m.perBlockHead) buf := m.compiler.Buf() require.Equal(t, tc.expectedBytes, hex.EncodeToString(buf)) }) @@ -641,8 +635,8 @@ cset x15, ne m.lowerVcheckTrue(tc.op, operandNR(x1VReg), x15VReg, tc.arrangement) require.Equal(t, tc.expectedAsm, "\n"+formatEmittedInstructionsInCurrentBlock(m)+"\n") - m.executableContext.FlushPendingInstructions() - m.encode(m.executableContext.PerBlockHead) + m.FlushPendingInstructions() + m.encode(m.perBlockHead) buf := m.compiler.Buf() require.Equal(t, tc.expectedBytes, hex.EncodeToString(buf)) }) @@ -726,8 +720,8 @@ add w15, w15, w1?, lsl #1 m.lowerVhighBits(operandNR(x1VReg), x15VReg, tc.arrangement) require.Equal(t, tc.expectedAsm, "\n"+formatEmittedInstructionsInCurrentBlock(m)+"\n") - m.executableContext.FlushPendingInstructions() - m.encode(m.executableContext.PerBlockHead) + m.FlushPendingInstructions() + m.encode(m.perBlockHead) buf := m.compiler.Buf() require.Equal(t, tc.expectedBytes, hex.EncodeToString(buf)) }) @@ -775,8 +769,8 @@ tbl x1.16b, { v29.16b, v30.16b }, v1?.16b m.lowerShuffle(x1VReg, operandNR(x2VReg), operandNR(x15VReg), lane1, lane2) require.Equal(t, tc.expectedAsm, "\n"+formatEmittedInstructionsInCurrentBlock(m)+"\n") - m.executableContext.FlushPendingInstructions() - m.encode(m.executableContext.PerBlockHead) + m.FlushPendingInstructions() + m.encode(m.perBlockHead) buf := m.compiler.Buf() require.Equal(t, tc.expectedBytes, hex.EncodeToString(buf)) }) @@ -832,8 +826,8 @@ ushl x1.16b, x2.16b, v2?.16b m.lowerVShift(tc.op, x1VReg, operandNR(x2VReg), operandNR(x15VReg), tc.arrangement) require.Equal(t, tc.expectedAsm, "\n"+formatEmittedInstructionsInCurrentBlock(m)+"\n") - m.executableContext.FlushPendingInstructions() - m.encode(m.executableContext.PerBlockHead) + m.FlushPendingInstructions() + m.encode(m.perBlockHead) buf := m.compiler.Buf() require.Equal(t, tc.expectedBytes, hex.EncodeToString(buf)) }) diff --git a/internal/engine/wazevo/backend/isa/arm64/machine.go b/internal/engine/wazevo/backend/isa/arm64/machine.go index 506c263936..1806afd263 100644 --- a/internal/engine/wazevo/backend/isa/arm64/machine.go +++ b/internal/engine/wazevo/backend/isa/arm64/machine.go @@ -3,6 +3,7 @@ package arm64 import ( "context" "fmt" + "math" "strings" "github.com/tetratelabs/wazero/internal/engine/wazevo/backend" @@ -14,12 +15,33 @@ import ( type ( // machine implements backend.Machine. machine struct { - compiler backend.Compiler - executableContext *backend.ExecutableContextT[instruction] - currentABI *backend.FunctionABI - - regAlloc regalloc.Allocator - regAllocFn *backend.RegAllocFunction[*instruction, *machine] + compiler backend.Compiler + currentABI *backend.FunctionABI + instrPool wazevoapi.Pool[instruction] + // labelPositionPool is the pool of labelPosition. The id is the label where + // if the label is less than the maxSSABlockID, it's the ssa.BasicBlockID. + labelPositionPool wazevoapi.IDedPool[labelPosition] + + // nextLabel is the next label to be allocated. The first free label comes after maxSSABlockID + // so that we can have an identical label for the SSA block ID, which is useful for debugging. + nextLabel label + // rootInstr is the first instruction of the function. + rootInstr *instruction + // currentLabelPos is the currently-compiled ssa.BasicBlock's labelPosition. + currentLabelPos *labelPosition + // orderedSSABlockLabelPos is the ordered list of labelPosition in the generated code for each ssa.BasicBlock. + orderedSSABlockLabelPos []*labelPosition + // returnLabelPos is the labelPosition for the return block. + returnLabelPos labelPosition + // perBlockHead and perBlockEnd are the head and tail of the instruction list per currently-compiled ssa.BasicBlock. + perBlockHead, perBlockEnd *instruction + // pendingInstructions are the instructions which are not yet emitted into the instruction list. + pendingInstructions []*instruction + // maxSSABlockID is the maximum ssa.BasicBlockID in the current function. + maxSSABlockID label + + regAlloc regalloc.Allocator[*instruction, *labelPosition] + regAllocFn regAllocFn amodePool wazevoapi.Pool[addressMode] @@ -93,45 +115,132 @@ type ( nextLabel label offset int64 } +) - labelPosition = backend.LabelPosition[instruction] - label = backend.Label +type ( + // label represents a position in the generated code which is either + // a real instruction or the constant InstructionPool (e.g. jump tables). + // + // This is exactly the same as the traditional "label" in assembly code. + label uint32 + + // labelPosition represents the regions of the generated code which the label represents. + // This implements regalloc.Block. + labelPosition struct { + // sb is not nil if this corresponds to a ssa.BasicBlock. + sb ssa.BasicBlock + // cur is used to walk through the instructions in the block during the register allocation. + cur, + // begin and end are the first and last instructions of the block. + begin, end *instruction + // binaryOffset is the offset in the binary where the label is located. + binaryOffset int64 + } ) const ( - labelReturn = backend.LabelReturn - labelInvalid = backend.LabelInvalid + labelReturn label = math.MaxUint32 + labelInvalid = labelReturn - 1 ) +// String implements backend.Machine. +func (l label) String() string { + return fmt.Sprintf("L%d", l) +} + +func resetLabelPosition(l *labelPosition) { + *l = labelPosition{} +} + // NewBackend returns a new backend for arm64. func NewBackend() backend.Machine { m := &machine{ spillSlots: make(map[regalloc.VRegID]int64), - executableContext: newExecutableContext(), - regAlloc: regalloc.NewAllocator(regInfo), + regAlloc: regalloc.NewAllocator[*instruction, *labelPosition](regInfo), amodePool: wazevoapi.NewPool[addressMode](resetAddressMode), + instrPool: wazevoapi.NewPool[instruction](resetInstruction), + labelPositionPool: wazevoapi.NewIDedPool[labelPosition](resetLabelPosition), } + m.regAllocFn.m = m return m } -func newExecutableContext() *backend.ExecutableContextT[instruction] { - return backend.NewExecutableContextT[instruction](resetInstruction, setNext, setPrev, asNop0) +func ssaBlockLabel(sb ssa.BasicBlock) label { + if sb.ReturnBlock() { + return labelReturn + } + return label(sb.ID()) } -// ExecutableContext implements backend.Machine. -func (m *machine) ExecutableContext() backend.ExecutableContext { - return m.executableContext +// getOrAllocateSSABlockLabelPosition returns the labelPosition for the given basic block. +func (m *machine) getOrAllocateSSABlockLabelPosition(sb ssa.BasicBlock) *labelPosition { + if sb.ReturnBlock() { + m.returnLabelPos.sb = sb + return &m.returnLabelPos + } + + l := ssaBlockLabel(sb) + pos := m.labelPositionPool.GetOrAllocate(int(l)) + pos.sb = sb + return pos } -// RegAlloc implements backend.Machine Function. -func (m *machine) RegAlloc() { - rf := m.regAllocFn - for _, pos := range m.executableContext.OrderedBlockLabels { - rf.AddBlock(pos.SB, pos.L, pos.Begin, pos.End) +// LinkAdjacentBlocks implements backend.Machine. +func (m *machine) LinkAdjacentBlocks(prev, next ssa.BasicBlock) { + prevPos, nextPos := m.getOrAllocateSSABlockLabelPosition(prev), m.getOrAllocateSSABlockLabelPosition(next) + prevPos.end.next = nextPos.begin +} + +// StartBlock implements backend.Machine. +func (m *machine) StartBlock(blk ssa.BasicBlock) { + m.currentLabelPos = m.getOrAllocateSSABlockLabelPosition(blk) + labelPos := m.currentLabelPos + end := m.allocateNop() + m.perBlockHead, m.perBlockEnd = end, end + labelPos.begin, labelPos.end = end, end + m.orderedSSABlockLabelPos = append(m.orderedSSABlockLabelPos, labelPos) +} + +// EndBlock implements ExecutableContext. +func (m *machine) EndBlock() { + // Insert nop0 as the head of the block for convenience to simplify the logic of inserting instructions. + m.insertAtPerBlockHead(m.allocateNop()) + + m.currentLabelPos.begin = m.perBlockHead + + if m.currentLabelPos.sb.EntryBlock() { + m.rootInstr = m.perBlockHead + } +} + +func (m *machine) insertAtPerBlockHead(i *instruction) { + if m.perBlockHead == nil { + m.perBlockHead = i + m.perBlockEnd = i + return + } + + i.next = m.perBlockHead + m.perBlockHead.prev = i + m.perBlockHead = i +} + +// FlushPendingInstructions implements backend.Machine. +func (m *machine) FlushPendingInstructions() { + l := len(m.pendingInstructions) + if l == 0 { + return } + for i := l - 1; i >= 0; i-- { // reverse because we lower instructions in reverse order. + m.insertAtPerBlockHead(m.pendingInstructions[i]) + } + m.pendingInstructions = m.pendingInstructions[:0] +} +// RegAlloc implements backend.Machine Function. +func (m *machine) RegAlloc() { m.regAllocStarted = true - m.regAlloc.DoAllocation(rf) + m.regAlloc.DoAllocation(&m.regAllocFn) // Now that we know the final spill slot size, we must align spillSlotSize to 16 bytes. m.spillSlotSize = (m.spillSlotSize + 15) &^ 15 } @@ -148,13 +257,22 @@ func (m *machine) Reset() { m.clobberedRegs = m.clobberedRegs[:0] m.regAllocStarted = false m.regAlloc.Reset() - m.regAllocFn.Reset() m.spillSlotSize = 0 m.unresolvedAddressModes = m.unresolvedAddressModes[:0] m.maxRequiredStackSizeForCalls = 0 - m.executableContext.Reset() m.jmpTableTargetsNext = 0 m.amodePool.Reset() + m.instrPool.Reset() + m.labelPositionPool.Reset() + m.pendingInstructions = m.pendingInstructions[:0] + m.perBlockHead, m.perBlockEnd, m.rootInstr = nil, nil, nil + m.orderedSSABlockLabelPos = m.orderedSSABlockLabelPos[:0] +} + +// StartLoweringFunction implements backend.Machine StartLoweringFunction. +func (m *machine) StartLoweringFunction(maxBlockID ssa.BasicBlockID) { + m.maxSSABlockID = label(maxBlockID) + m.nextLabel = label(maxBlockID) + 1 } // SetCurrentABI implements backend.Machine SetCurrentABI. @@ -170,12 +288,11 @@ func (m *machine) DisableStackCheck() { // SetCompiler implements backend.Machine. func (m *machine) SetCompiler(ctx backend.Compiler) { m.compiler = ctx - m.regAllocFn = backend.NewRegAllocFunction[*instruction, *machine](m, ctx.SSABuilder(), ctx) + m.regAllocFn.ssaB = ctx.SSABuilder() } func (m *machine) insert(i *instruction) { - ectx := m.executableContext - ectx.PendingInstructions = append(ectx.PendingInstructions, i) + m.pendingInstructions = append(m.pendingInstructions, i) } func (m *machine) insertBrTargetLabel() label { @@ -185,18 +302,18 @@ func (m *machine) insertBrTargetLabel() label { } func (m *machine) allocateBrTarget() (nop *instruction, l label) { - ectx := m.executableContext - l = ectx.AllocateLabel() + l = m.nextLabel + m.nextLabel++ nop = m.allocateInstr() nop.asNop0WithLabel(l) - pos := ectx.GetOrAllocateLabelPosition(l) - pos.Begin, pos.End = nop, nop + pos := m.labelPositionPool.GetOrAllocate(int(l)) + pos.begin, pos.end = nop, nop return } // allocateInstr allocates an instruction. func (m *machine) allocateInstr() *instruction { - instr := m.executableContext.InstructionPool.Allocate() + instr := m.instrPool.Allocate() if !m.regAllocStarted { instr.addedBeforeRegAlloc = true } @@ -253,7 +370,6 @@ func (m *machine) resolveAddressingMode(arg0offset, ret0offset int64, i *instruc // resolveRelativeAddresses resolves the relative addresses before encoding. func (m *machine) resolveRelativeAddresses(ctx context.Context) { - ectx := m.executableContext for { if len(m.unresolvedAddressModes) > 0 { arg0offset, ret0offset := m.arg0OffsetFromSP(), m.ret0OffsetFromSP() @@ -267,35 +383,36 @@ func (m *machine) resolveRelativeAddresses(ctx context.Context) { var fn string var fnIndex int - var labelToSSABlockID map[label]ssa.BasicBlockID + var labelPosToLabel map[*labelPosition]label if wazevoapi.PerfMapEnabled { - fn = wazevoapi.GetCurrentFunctionName(ctx) - labelToSSABlockID = make(map[label]ssa.BasicBlockID) - for i, l := range ectx.SsaBlockIDToLabels { - labelToSSABlockID[l] = ssa.BasicBlockID(i) + labelPosToLabel = make(map[*labelPosition]label) + for i := 0; i <= m.labelPositionPool.MaxIDEncountered(); i++ { + labelPosToLabel[m.labelPositionPool.Get(i)] = label(i) } + + fn = wazevoapi.GetCurrentFunctionName(ctx) fnIndex = wazevoapi.GetCurrentFunctionIndex(ctx) } // Next, in order to determine the offsets of relative jumps, we have to calculate the size of each label. var offset int64 - for i, pos := range ectx.OrderedBlockLabels { - pos.BinaryOffset = offset + for i, pos := range m.orderedSSABlockLabelPos { + pos.binaryOffset = offset var size int64 - for cur := pos.Begin; ; cur = cur.next { + for cur := pos.begin; ; cur = cur.next { switch cur.kind { case nop0: l := cur.nop0Label() - if pos := ectx.LabelPositions[l]; pos != nil { - pos.BinaryOffset = offset + size + if pos := m.labelPositionPool.Get(int(l)); pos != nil { + pos.binaryOffset = offset + size } case condBr: if !cur.condBrOffsetResolved() { var nextLabel label - if i < len(ectx.OrderedBlockLabels)-1 { + if i < len(m.orderedSSABlockLabelPos)-1 { // Note: this is only used when the block ends with fallthrough, // therefore can be safely assumed that the next block exists when it's needed. - nextLabel = ectx.OrderedBlockLabels[i+1].L + nextLabel = ssaBlockLabel(m.orderedSSABlockLabelPos[i+1].sb) } m.condBrRelocs = append(m.condBrRelocs, condBrReloc{ cbr: cur, currentLabelPos: pos, offset: offset + size, @@ -304,21 +421,14 @@ func (m *machine) resolveRelativeAddresses(ctx context.Context) { } } size += cur.size() - if cur == pos.End { + if cur == pos.end { break } } if wazevoapi.PerfMapEnabled { if size > 0 { - l := pos.L - var labelStr string - if blkID, ok := labelToSSABlockID[l]; ok { - labelStr = fmt.Sprintf("%s::SSA_Block[%s]", l, blkID) - } else { - labelStr = l.String() - } - wazevoapi.PerfMap.AddModuleEntry(fnIndex, offset, uint64(size), fmt.Sprintf("%s:::::%s", fn, labelStr)) + wazevoapi.PerfMap.AddModuleEntry(fnIndex, offset, uint64(size), fmt.Sprintf("%s:::::%s", fn, labelPosToLabel[pos])) } } offset += size @@ -332,7 +442,7 @@ func (m *machine) resolveRelativeAddresses(ctx context.Context) { offset := reloc.offset target := cbr.condBrLabel() - offsetOfTarget := ectx.LabelPositions[target].BinaryOffset + offsetOfTarget := m.labelPositionPool.Get(int(target)).binaryOffset diff := offsetOfTarget - offset if divided := diff >> 2; divided < minSignedInt19 || divided > maxSignedInt19 { // This case the conditional branch is too huge. We place the trampoline instructions at the end of the current block, @@ -353,11 +463,11 @@ func (m *machine) resolveRelativeAddresses(ctx context.Context) { } var currentOffset int64 - for cur := ectx.RootInstr; cur != nil; cur = cur.next { + for cur := m.rootInstr; cur != nil; cur = cur.next { switch cur.kind { case br: target := cur.brLabel() - offsetOfTarget := ectx.LabelPositions[target].BinaryOffset + offsetOfTarget := m.labelPositionPool.Get(int(target)).binaryOffset diff := offsetOfTarget - currentOffset divided := diff >> 2 if divided < minSignedInt26 || divided > maxSignedInt26 { @@ -368,7 +478,7 @@ func (m *machine) resolveRelativeAddresses(ctx context.Context) { case condBr: if !cur.condBrOffsetResolved() { target := cur.condBrLabel() - offsetOfTarget := ectx.LabelPositions[target].BinaryOffset + offsetOfTarget := m.labelPositionPool.Get(int(target)).binaryOffset diff := offsetOfTarget - currentOffset if divided := diff >> 2; divided < minSignedInt19 || divided > maxSignedInt19 { panic("BUG: branch relocation for large conditional branch larger than 19-bit range must be handled properly") @@ -380,7 +490,7 @@ func (m *machine) resolveRelativeAddresses(ctx context.Context) { targets := m.jmpTableTargets[tableIndex] for i := range targets { l := label(targets[i]) - offsetOfTarget := ectx.LabelPositions[l].BinaryOffset + offsetOfTarget := m.labelPositionPool.Get(int(l)).binaryOffset diff := offsetOfTarget - (currentOffset + brTableSequenceOffsetTableBegin) targets[i] = uint32(diff) } @@ -401,7 +511,7 @@ const ( ) func (m *machine) insertConditionalJumpTrampoline(cbr *instruction, currentBlk *labelPosition, nextLabel label) { - cur := currentBlk.End + cur := currentBlk.end originalTarget := cbr.condBrLabel() endNext := cur.next @@ -424,32 +534,27 @@ func (m *machine) insertConditionalJumpTrampoline(cbr *instruction, currentBlk * cur = linkInstr(cur, br) // Update the end of the current block. - currentBlk.End = cur + currentBlk.end = cur linkInstr(cur, endNext) } // Format implements backend.Machine. func (m *machine) Format() string { - ectx := m.executableContext begins := map[*instruction]label{} - for _, pos := range ectx.LabelPositions { + for l := label(0); l < m.nextLabel; l++ { + pos := m.labelPositionPool.Get(int(l)) if pos != nil { - begins[pos.Begin] = pos.L + begins[pos.begin] = l } } - irBlocks := map[label]ssa.BasicBlockID{} - for i, l := range ectx.SsaBlockIDToLabels { - irBlocks[l] = ssa.BasicBlockID(i) - } - var lines []string - for cur := ectx.RootInstr; cur != nil; cur = cur.next { + for cur := m.rootInstr; cur != nil; cur = cur.next { if l, ok := begins[cur]; ok { var labelStr string - if blkID, ok := irBlocks[l]; ok { - labelStr = fmt.Sprintf("%s (SSA Block: %s):", l, blkID) + if l <= m.maxSSABlockID { + labelStr = fmt.Sprintf("%s (SSA Block: blk%d):", l, int(l)) } else { labelStr = fmt.Sprintf("%s:", l) } @@ -520,8 +625,7 @@ func (m *machine) addJmpTableTarget(targets ssa.Values) (index int) { m.jmpTableTargets[index] = m.jmpTableTargets[index][:0] for _, targetBlockID := range targets.View() { target := m.compiler.SSABuilder().BasicBlock(ssa.BasicBlockID(targetBlockID)) - m.jmpTableTargets[index] = append(m.jmpTableTargets[index], - uint32(m.executableContext.GetOrAllocateSSABlockLabel(target))) + m.jmpTableTargets[index] = append(m.jmpTableTargets[index], uint32(target.ID())) } return } diff --git a/internal/engine/wazevo/backend/isa/arm64/machine_pro_epi_logue.go b/internal/engine/wazevo/backend/isa/arm64/machine_pro_epi_logue.go index d9032f9218..c646a8fab0 100644 --- a/internal/engine/wazevo/backend/isa/arm64/machine_pro_epi_logue.go +++ b/internal/engine/wazevo/backend/isa/arm64/machine_pro_epi_logue.go @@ -15,9 +15,7 @@ func (m *machine) PostRegAlloc() { // setupPrologue initializes the prologue of the function. func (m *machine) setupPrologue() { - ectx := m.executableContext - - cur := ectx.RootInstr + cur := m.rootInstr prevInitInst := cur.next // @@ -196,21 +194,20 @@ func (m *machine) createFrameSizeSlot(cur *instruction, s int64) *instruction { // 1. Removes the redundant copy instruction. // 2. Inserts the epilogue. func (m *machine) postRegAlloc() { - ectx := m.executableContext - for cur := ectx.RootInstr; cur != nil; cur = cur.next { + for cur := m.rootInstr; cur != nil; cur = cur.next { switch cur.kind { case ret: m.setupEpilogueAfter(cur.prev) case loadConstBlockArg: lc := cur next := lc.next - m.executableContext.PendingInstructions = m.executableContext.PendingInstructions[:0] + m.pendingInstructions = m.pendingInstructions[:0] m.lowerLoadConstantBlockArgAfterRegAlloc(lc) - for _, instr := range m.executableContext.PendingInstructions { + for _, instr := range m.pendingInstructions { cur = linkInstr(cur, instr) } linkInstr(cur, next) - m.executableContext.PendingInstructions = m.executableContext.PendingInstructions[:0] + m.pendingInstructions = m.pendingInstructions[:0] default: // Removes the redundant copy instruction. if cur.IsCopy() && cur.rn.realReg() == cur.rd.RealReg() { @@ -432,11 +429,9 @@ func (m *machine) insertStackBoundsCheck(requiredStackSize int64, cur *instructi // CompileStackGrowCallSequence implements backend.Machine. func (m *machine) CompileStackGrowCallSequence() []byte { - ectx := m.executableContext - cur := m.allocateInstr() cur.asNop0() - ectx.RootInstr = cur + m.rootInstr = cur // Save the callee saved and argument registers. cur = m.saveRegistersInExecutionContext(cur, saveRequiredRegs) @@ -458,16 +453,14 @@ func (m *machine) CompileStackGrowCallSequence() []byte { ret.asRet() linkInstr(cur, ret) - m.encode(ectx.RootInstr) + m.encode(m.rootInstr) return m.compiler.Buf() } func (m *machine) addsAddOrSubStackPointer(cur *instruction, rd regalloc.VReg, diff int64, add bool) *instruction { - ectx := m.executableContext - - ectx.PendingInstructions = ectx.PendingInstructions[:0] + m.pendingInstructions = m.pendingInstructions[:0] m.insertAddOrSubStackPointer(rd, diff, add) - for _, inserted := range ectx.PendingInstructions { + for _, inserted := range m.pendingInstructions { cur = linkInstr(cur, inserted) } return cur diff --git a/internal/engine/wazevo/backend/isa/arm64/machine_pro_epi_logue_test.go b/internal/engine/wazevo/backend/isa/arm64/machine_pro_epi_logue_test.go index 919260110d..f5bcd8ebf2 100644 --- a/internal/engine/wazevo/backend/isa/arm64/machine_pro_epi_logue_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/machine_pro_epi_logue_test.go @@ -102,14 +102,14 @@ func TestMachine_setupPrologue(t *testing.T) { m.currentABI = &tc.abi root := m.allocateNop() - m.executableContext.RootInstr = root + m.rootInstr = root udf := m.allocateInstr() udf.asUDF() root.next = udf udf.prev = root m.setupPrologue() - require.Equal(t, root, m.executableContext.RootInstr) + require.Equal(t, root, m.rootInstr) err := m.Encode(context.Background()) require.NoError(t, err) require.Equal(t, tc.exp, m.Format()) @@ -216,14 +216,14 @@ func TestMachine_postRegAlloc(t *testing.T) { m.currentABI = &tc.abi root := m.allocateNop() - m.executableContext.RootInstr = root + m.rootInstr = root ret := m.allocateInstr() ret.asRet() root.next = ret ret.prev = root m.postRegAlloc() - require.Equal(t, root, m.executableContext.RootInstr) + require.Equal(t, root, m.rootInstr) err := m.Encode(context.Background()) require.NoError(t, err) require.Equal(t, tc.exp, m.Format()) @@ -267,9 +267,9 @@ func TestMachine_insertStackBoundsCheck(t *testing.T) { tc := tc t.Run(tc.exp, func(t *testing.T) { _, _, m := newSetupWithMockContext() - m.executableContext.RootInstr = m.allocateInstr() - m.executableContext.RootInstr.asNop0() - m.insertStackBoundsCheck(tc.requiredStackSize, m.executableContext.RootInstr) + m.rootInstr = m.allocateInstr() + m.rootInstr.asNop0() + m.insertStackBoundsCheck(tc.requiredStackSize, m.rootInstr) err := m.Encode(context.Background()) require.NoError(t, err) require.Equal(t, tc.exp, m.Format()) diff --git a/internal/engine/wazevo/backend/isa/arm64/machine_regalloc.go b/internal/engine/wazevo/backend/isa/arm64/machine_regalloc.go index c7eb92cc20..c82abbdc0b 100644 --- a/internal/engine/wazevo/backend/isa/arm64/machine_regalloc.go +++ b/internal/engine/wazevo/backend/isa/arm64/machine_regalloc.go @@ -3,18 +3,226 @@ package arm64 // This file implements the interfaces required for register allocations. See backend.RegAllocFunctionMachine. import ( - "github.com/tetratelabs/wazero/internal/engine/wazevo/backend" "github.com/tetratelabs/wazero/internal/engine/wazevo/backend/regalloc" "github.com/tetratelabs/wazero/internal/engine/wazevo/ssa" ) -// ClobberedRegisters implements backend.RegAllocFunctionMachine. -func (m *machine) ClobberedRegisters(regs []regalloc.VReg) { - m.clobberedRegs = append(m.clobberedRegs[:0], regs...) +// regAllocFn implements regalloc.Function. +type regAllocFn struct { + ssaB ssa.Builder + m *machine + loopNestingForestRoots []ssa.BasicBlock + blockIter int } -// Swap implements backend.RegAllocFunctionMachine. -func (m *machine) Swap(cur *instruction, x1, x2, tmp regalloc.VReg) { +// PostOrderBlockIteratorBegin implements regalloc.Function. +func (f *regAllocFn) PostOrderBlockIteratorBegin() *labelPosition { + f.blockIter = len(f.m.orderedSSABlockLabelPos) - 1 + return f.PostOrderBlockIteratorNext() +} + +// PostOrderBlockIteratorNext implements regalloc.Function. +func (f *regAllocFn) PostOrderBlockIteratorNext() *labelPosition { + if f.blockIter < 0 { + return nil + } + b := f.m.orderedSSABlockLabelPos[f.blockIter] + f.blockIter-- + return b +} + +// ReversePostOrderBlockIteratorBegin implements regalloc.Function. +func (f *regAllocFn) ReversePostOrderBlockIteratorBegin() *labelPosition { + f.blockIter = 0 + return f.ReversePostOrderBlockIteratorNext() +} + +// ReversePostOrderBlockIteratorNext implements regalloc.Function. +func (f *regAllocFn) ReversePostOrderBlockIteratorNext() *labelPosition { + if f.blockIter >= len(f.m.orderedSSABlockLabelPos) { + return nil + } + b := f.m.orderedSSABlockLabelPos[f.blockIter] + f.blockIter++ + return b +} + +// ClobberedRegisters implements regalloc.Function. +func (f *regAllocFn) ClobberedRegisters(regs []regalloc.VReg) { + f.m.clobberedRegs = append(f.m.clobberedRegs[:0], regs...) +} + +// LoopNestingForestRoots implements regalloc.Function. +func (f *regAllocFn) LoopNestingForestRoots() int { + f.loopNestingForestRoots = f.ssaB.LoopNestingForestRoots() + return len(f.loopNestingForestRoots) +} + +// LoopNestingForestRoot implements regalloc.Function. +func (f *regAllocFn) LoopNestingForestRoot(i int) *labelPosition { + root := f.loopNestingForestRoots[i] + pos := f.m.getOrAllocateSSABlockLabelPosition(root) + return pos +} + +// LowestCommonAncestor implements regalloc.Function. +func (f *regAllocFn) LowestCommonAncestor(blk1, blk2 *labelPosition) *labelPosition { + sb := f.ssaB.LowestCommonAncestor(blk1.sb, blk2.sb) + pos := f.m.getOrAllocateSSABlockLabelPosition(sb) + return pos +} + +// Idom implements regalloc.Function. +func (f *regAllocFn) Idom(blk *labelPosition) *labelPosition { + sb := f.ssaB.Idom(blk.sb) + pos := f.m.getOrAllocateSSABlockLabelPosition(sb) + return pos +} + +// SwapBefore implements regalloc.Function. +func (f *regAllocFn) SwapBefore(x1, x2, tmp regalloc.VReg, instr *instruction) { + f.m.swap(instr.prev, x1, x2, tmp) +} + +// StoreRegisterBefore implements regalloc.Function. +func (f *regAllocFn) StoreRegisterBefore(v regalloc.VReg, instr *instruction) { + m := f.m + m.insertStoreRegisterAt(v, instr, false) +} + +// StoreRegisterAfter implements regalloc.Function. +func (f *regAllocFn) StoreRegisterAfter(v regalloc.VReg, instr *instruction) { + m := f.m + m.insertStoreRegisterAt(v, instr, true) +} + +// ReloadRegisterBefore implements regalloc.Function. +func (f *regAllocFn) ReloadRegisterBefore(v regalloc.VReg, instr *instruction) { + m := f.m + m.insertReloadRegisterAt(v, instr, false) +} + +// ReloadRegisterAfter implements regalloc.Function. +func (f *regAllocFn) ReloadRegisterAfter(v regalloc.VReg, instr *instruction) { + m := f.m + m.insertReloadRegisterAt(v, instr, true) +} + +// InsertMoveBefore implements regalloc.Function. +func (f *regAllocFn) InsertMoveBefore(dst, src regalloc.VReg, instr *instruction) { + f.m.insertMoveBefore(dst, src, instr) +} + +// LoopNestingForestChild implements regalloc.Function. +func (f *regAllocFn) LoopNestingForestChild(pos *labelPosition, i int) *labelPosition { + childSB := pos.sb.LoopNestingForestChildren()[i] + return f.m.getOrAllocateSSABlockLabelPosition(childSB) +} + +// Succ implements regalloc.Block. +func (f *regAllocFn) Succ(pos *labelPosition, i int) *labelPosition { + succSB := pos.sb.Succ(i) + if succSB.ReturnBlock() { + return nil + } + return f.m.getOrAllocateSSABlockLabelPosition(succSB) +} + +// Pred implements regalloc.Block. +func (f *regAllocFn) Pred(pos *labelPosition, i int) *labelPosition { + predSB := pos.sb.Pred(i) + return f.m.getOrAllocateSSABlockLabelPosition(predSB) +} + +// BlockParams implements regalloc.Function. +func (f *regAllocFn) BlockParams(pos *labelPosition, regs *[]regalloc.VReg) []regalloc.VReg { + c := f.m.compiler + *regs = (*regs)[:0] + for i := 0; i < pos.sb.Params(); i++ { + v := c.VRegOf(pos.sb.Param(i)) + *regs = append(*regs, v) + } + return *regs +} + +// ID implements regalloc.Block. +func (pos *labelPosition) ID() int32 { + return int32(pos.sb.ID()) +} + +// InstrIteratorBegin implements regalloc.Block. +func (pos *labelPosition) InstrIteratorBegin() *instruction { + ret := pos.begin + pos.cur = ret + return ret +} + +// InstrIteratorNext implements regalloc.Block. +func (pos *labelPosition) InstrIteratorNext() *instruction { + for { + if pos.cur == pos.end { + return nil + } + instr := pos.cur.next + pos.cur = instr + if instr == nil { + return nil + } else if instr.AddedBeforeRegAlloc() { + // Only concerned about the instruction added before regalloc. + return instr + } + } +} + +// InstrRevIteratorBegin implements regalloc.Block. +func (pos *labelPosition) InstrRevIteratorBegin() *instruction { + pos.cur = pos.end + return pos.cur +} + +// InstrRevIteratorNext implements regalloc.Block. +func (pos *labelPosition) InstrRevIteratorNext() *instruction { + for { + if pos.cur == pos.begin { + return nil + } + instr := pos.cur.prev + pos.cur = instr + if instr == nil { + return nil + } else if instr.AddedBeforeRegAlloc() { + // Only concerned about the instruction added before regalloc. + return instr + } + } +} + +// FirstInstr implements regalloc.Block. +func (pos *labelPosition) FirstInstr() *instruction { return pos.begin } + +// LastInstrForInsertion implements regalloc.Block. +func (pos *labelPosition) LastInstrForInsertion() *instruction { + return lastInstrForInsertion(pos.begin, pos.end) +} + +// Preds implements regalloc.Block. +func (pos *labelPosition) Preds() int { return pos.sb.Preds() } + +// Entry implements regalloc.Block. +func (pos *labelPosition) Entry() bool { return pos.sb.EntryBlock() } + +// Succs implements regalloc.Block. +func (pos *labelPosition) Succs() int { return pos.sb.Succs() } + +// LoopHeader implements regalloc.Block. +func (pos *labelPosition) LoopHeader() bool { return pos.sb.LoopHeader() } + +// LoopNestingForestChildren implements regalloc.Block. +func (pos *labelPosition) LoopNestingForestChildren() int { + return len(pos.sb.LoopNestingForestChildren()) +} + +func (m *machine) swap(cur *instruction, x1, x2, tmp regalloc.VReg) { prevNext := cur.next var mov1, mov2, mov3 *instruction if x1.RegType() == regalloc.RegTypeInt { @@ -32,12 +240,12 @@ func (m *machine) Swap(cur *instruction, x1, x2, tmp regalloc.VReg) { if !tmp.Valid() { r2 := x2.RealReg() // Temporarily spill x1 to stack. - cur = m.InsertStoreRegisterAt(x1, cur, true).prev + cur = m.insertStoreRegisterAt(x1, cur, true).prev // Then move x2 to x1. cur = linkInstr(cur, m.allocateInstr().asFpuMov128(x1, x2)) linkInstr(cur, prevNext) // Then reload the original value on x1 from stack to r2. - m.InsertReloadRegisterAt(x1.SetRealReg(r2), cur, true) + m.insertReloadRegisterAt(x1.SetRealReg(r2), cur, true) } else { mov1 = m.allocateInstr().asFpuMov128(tmp, x1) mov2 = m.allocateInstr().asFpuMov128(x1, x2) @@ -50,8 +258,7 @@ func (m *machine) Swap(cur *instruction, x1, x2, tmp regalloc.VReg) { } } -// InsertMoveBefore implements backend.RegAllocFunctionMachine. -func (m *machine) InsertMoveBefore(dst, src regalloc.VReg, instr *instruction) { +func (m *machine) insertMoveBefore(dst, src regalloc.VReg, instr *instruction) { typ := src.RegType() if typ != dst.RegType() { panic("BUG: src and dst must have the same type") @@ -70,13 +277,7 @@ func (m *machine) InsertMoveBefore(dst, src regalloc.VReg, instr *instruction) { linkInstr(cur, prevNext) } -// SSABlockLabel implements backend.RegAllocFunctionMachine. -func (m *machine) SSABlockLabel(id ssa.BasicBlockID) backend.Label { - return m.executableContext.SsaBlockIDToLabels[id] -} - -// InsertStoreRegisterAt implements backend.RegAllocFunctionMachine. -func (m *machine) InsertStoreRegisterAt(v regalloc.VReg, instr *instruction, after bool) *instruction { +func (m *machine) insertStoreRegisterAt(v regalloc.VReg, instr *instruction, after bool) *instruction { if !v.IsRealReg() { panic("BUG: VReg must be backed by real reg to be stored") } @@ -100,8 +301,7 @@ func (m *machine) InsertStoreRegisterAt(v regalloc.VReg, instr *instruction, aft return linkInstr(cur, prevNext) } -// InsertReloadRegisterAt implements backend.RegAllocFunctionMachine. -func (m *machine) InsertReloadRegisterAt(v regalloc.VReg, instr *instruction, after bool) *instruction { +func (m *machine) insertReloadRegisterAt(v regalloc.VReg, instr *instruction, after bool) *instruction { if !v.IsRealReg() { panic("BUG: VReg must be backed by real reg to be stored") } @@ -134,8 +334,7 @@ func (m *machine) InsertReloadRegisterAt(v regalloc.VReg, instr *instruction, af return linkInstr(cur, prevNext) } -// LastInstrForInsertion implements backend.RegAllocFunctionMachine. -func (m *machine) LastInstrForInsertion(begin, end *instruction) *instruction { +func lastInstrForInsertion(begin, end *instruction) *instruction { cur := end for cur.kind == nop0 { cur = cur.prev diff --git a/internal/engine/wazevo/backend/isa/arm64/machine_regalloc_test.go b/internal/engine/wazevo/backend/isa/arm64/machine_regalloc_test.go index 91c0768d5c..9398876803 100644 --- a/internal/engine/wazevo/backend/isa/arm64/machine_regalloc_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/machine_regalloc_test.go @@ -16,8 +16,8 @@ func TestRegAllocFunctionImpl_ReloadRegisterAfter(t *testing.T) { i1.next = i2 i2.prev = i1 - m.InsertReloadRegisterAt(x1VReg, i1, true) - m.InsertReloadRegisterAt(v1VReg, i1, true) + m.insertReloadRegisterAt(x1VReg, i1, true) + m.insertReloadRegisterAt(v1VReg, i1, true) require.NotEqual(t, i1, i2.prev) require.NotEqual(t, i1.next, i2) @@ -30,7 +30,7 @@ func TestRegAllocFunctionImpl_ReloadRegisterAfter(t *testing.T) { require.Equal(t, iload.kind, uLoad64) require.Equal(t, fload.kind, fpuLoad64) - m.executableContext.RootInstr = i1 + m.rootInstr = i1 require.Equal(t, ` ldr d1, [sp, #0x18] ldr x1, [sp, #0x10] @@ -45,8 +45,8 @@ func TestRegAllocFunctionImpl_StoreRegisterBefore(t *testing.T) { i1.next = i2 i2.prev = i1 - m.InsertStoreRegisterAt(x1VReg, i2, false) - m.InsertStoreRegisterAt(v1VReg, i2, false) + m.insertStoreRegisterAt(x1VReg, i2, false) + m.insertStoreRegisterAt(v1VReg, i2, false) require.NotEqual(t, i1, i2.prev) require.NotEqual(t, i1.next, i2) @@ -59,7 +59,7 @@ func TestRegAllocFunctionImpl_StoreRegisterBefore(t *testing.T) { require.Equal(t, iload.kind, store64) require.Equal(t, fload.kind, fpuStore64) - m.executableContext.RootInstr = i1 + m.rootInstr = i1 require.Equal(t, ` str x1, [sp, #0x10] str d1, [sp, #0x18] @@ -125,13 +125,13 @@ func TestMachine_insertStoreRegisterAt(t *testing.T) { i2.prev = i1 if after { - m.InsertStoreRegisterAt(v1VReg, i1, after) - m.InsertStoreRegisterAt(x1VReg, i1, after) + m.insertStoreRegisterAt(v1VReg, i1, after) + m.insertStoreRegisterAt(x1VReg, i1, after) } else { - m.InsertStoreRegisterAt(x1VReg, i2, after) - m.InsertStoreRegisterAt(v1VReg, i2, after) + m.insertStoreRegisterAt(x1VReg, i2, after) + m.insertStoreRegisterAt(v1VReg, i2, after) } - m.executableContext.RootInstr = i1 + m.rootInstr = i1 require.Equal(t, tc.expected, m.Format()) }) } @@ -198,13 +198,13 @@ func TestMachine_insertReloadRegisterAt(t *testing.T) { i2.prev = i1 if after { - m.InsertReloadRegisterAt(v1VReg, i1, after) - m.InsertReloadRegisterAt(x1VReg, i1, after) + m.insertReloadRegisterAt(v1VReg, i1, after) + m.insertReloadRegisterAt(x1VReg, i1, after) } else { - m.InsertReloadRegisterAt(x1VReg, i2, after) - m.InsertReloadRegisterAt(v1VReg, i2, after) + m.insertReloadRegisterAt(x1VReg, i2, after) + m.insertReloadRegisterAt(v1VReg, i2, after) } - m.executableContext.RootInstr = i1 + m.rootInstr = i1 require.Equal(t, tc.expected, m.Format()) }) @@ -215,7 +215,7 @@ func TestMachine_insertReloadRegisterAt(t *testing.T) { func TestRegMachine_ClobberedRegisters(t *testing.T) { _, _, m := newSetupWithMockContext() - m.ClobberedRegisters([]regalloc.VReg{v19VReg, v19VReg, v19VReg, v19VReg}) + m.regAllocFn.ClobberedRegisters([]regalloc.VReg{v19VReg, v19VReg, v19VReg, v19VReg}) require.Equal(t, []regalloc.VReg{v19VReg, v19VReg, v19VReg, v19VReg}, m.clobberedRegs) } @@ -284,8 +284,8 @@ func TestMachineMachineswap(t *testing.T) { cur.next = i2 i2.prev = cur - m.Swap(cur, tc.x1, tc.x2, tc.tmp) - m.executableContext.RootInstr = cur + m.swap(cur, tc.x1, tc.x2, tc.tmp) + m.rootInstr = cur require.Equal(t, tc.expected, m.Format()) }) diff --git a/internal/engine/wazevo/backend/isa/arm64/machine_test.go b/internal/engine/wazevo/backend/isa/arm64/machine_test.go index 943acff83e..b3849f22f2 100644 --- a/internal/engine/wazevo/backend/isa/arm64/machine_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/machine_test.go @@ -48,7 +48,7 @@ func TestMachine_resolveAddressingMode(t *testing.T) { i.asULoad(x17VReg, amode, 64) m.resolveAddressingMode(0, 0x40000001, i) - m.executableContext.RootInstr = root + m.rootInstr = root require.Equal(t, ` udf movz x27, #0x1, lsl 0 @@ -160,6 +160,7 @@ L200: t.Run(name, func(t *testing.T) { m := NewBackend().(*machine) + m.maxSSABlockID, m.nextLabel = 0, 10000000 const ( originLabel = 100 originLabelNext = 200 @@ -179,28 +180,22 @@ L200: originalEndNext := m.allocateInstr() originalEndNext.asExitSequence(x0VReg) - ectx := m.executableContext + originLabelPos := m.labelPositionPool.GetOrAllocate(originLabel) + originLabelPos.begin = cbr + originLabelPos.end = linkInstr(cbr, end) + originNextLabelPos := m.labelPositionPool.GetOrAllocate(originLabelNext) + originNextLabelPos.begin = originalEndNext + linkInstr(originLabelPos.end, originalEndNext) - originLabelPos := ectx.GetOrAllocateLabelPosition(originLabel) - originLabelPos.Begin = cbr - originLabelPos.End = linkInstr(cbr, end) - originNextLabelPos := ectx.GetOrAllocateLabelPosition(originLabelNext) - originNextLabelPos.Begin = originalEndNext - linkInstr(originLabelPos.End, originalEndNext) - - ectx.LabelPositions[originLabel] = originLabelPos - ectx.LabelPositions[originLabelNext] = originNextLabelPos - - ectx.RootInstr = cbr + m.rootInstr = cbr require.Equal(t, tc.expBefore, m.Format()) - ectx.NextLabel = 9999999 m.insertConditionalJumpTrampoline(cbr, originLabelPos, originLabelNext) require.Equal(t, tc.expAfter, m.Format()) // The original label position should be updated to the unconditional jump to the original target destination. - require.Equal(t, "b L12345", originLabelPos.End.String()) + require.Equal(t, "b L12345", originLabelPos.end.String()) }) } } diff --git a/internal/engine/wazevo/backend/isa/arm64/util_test.go b/internal/engine/wazevo/backend/isa/arm64/util_test.go index 1ab27aa6f9..bac898854e 100644 --- a/internal/engine/wazevo/backend/isa/arm64/util_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/util_test.go @@ -10,13 +10,13 @@ import ( ) func getPendingInstr(m *machine) *instruction { - return m.executableContext.PendingInstructions[0] + return m.pendingInstructions[0] } func formatEmittedInstructionsInCurrentBlock(m *machine) string { - m.executableContext.FlushPendingInstructions() + m.FlushPendingInstructions() var strs []string - for cur := m.executableContext.PerBlockHead; cur != nil; cur = cur.next { + for cur := m.perBlockHead; cur != nil; cur = cur.next { strs = append(strs, cur.String()) } return strings.Join(strs, "\n") diff --git a/internal/engine/wazevo/backend/machine.go b/internal/engine/wazevo/backend/machine.go index 54ce89e468..9044a9e4bc 100644 --- a/internal/engine/wazevo/backend/machine.go +++ b/internal/engine/wazevo/backend/machine.go @@ -11,7 +11,24 @@ import ( type ( // Machine is a backend for a specific ISA machine. Machine interface { - ExecutableContext() ExecutableContext + // StartLoweringFunction is called when the compilation of the given function is started. + // The maxBlockID is the maximum ssa.BasicBlockID in the function. + StartLoweringFunction(maxBlockID ssa.BasicBlockID) + + // LinkAdjacentBlocks is called after finished lowering all blocks in order to create one single instruction list. + LinkAdjacentBlocks(prev, next ssa.BasicBlock) + + // StartBlock is called when the compilation of the given block is started. + // The order of this being called is the reverse post order of the ssa.BasicBlock(s) as we iterate with + // ssa.Builder BlockIteratorReversePostOrderBegin and BlockIteratorReversePostOrderEnd. + StartBlock(ssa.BasicBlock) + + // EndBlock is called when the compilation of the current block is finished. + EndBlock() + + // FlushPendingInstructions flushes the pending instructions to the buffer. + // This will be called after the lowering of each SSA Instruction. + FlushPendingInstructions() // DisableStackCheck disables the stack check for the current compilation for debugging/testing. DisableStackCheck() diff --git a/internal/engine/wazevo/backend/machine_test.go b/internal/engine/wazevo/backend/machine_test.go index 2f378d2a06..aaf14d0848 100644 --- a/internal/engine/wazevo/backend/machine_test.go +++ b/internal/engine/wazevo/backend/machine_test.go @@ -11,7 +11,6 @@ import ( // mockMachine implements Machine for testing. type mockMachine struct { argResultInts, argResultFloats []regalloc.RealReg - startLoweringFunction func(id ssa.BasicBlockID) startBlock func(block ssa.BasicBlock) lowerSingleBranch func(b *ssa.Instruction) lowerConditionalBranch func(b *ssa.Instruction) @@ -25,6 +24,8 @@ type mockMachine struct { linkAdjacentBlocks func(prev, next ssa.BasicBlock) } +func (m mockMachine) StartLoweringFunction(maxBlockID ssa.BasicBlockID) { panic("implement me") } + func (m mockMachine) CallTrampolineIslandInfo(_ int) (_, _ int, _ error) { panic("implement me") } func (m mockMachine) ArgsResultsRegs() (argResultInts, argResultFloats []regalloc.RealReg) { @@ -37,8 +38,6 @@ func (m mockMachine) LowerParams(params []ssa.Value) { panic("implement me") } func (m mockMachine) LowerReturns(returns []ssa.Value) { panic("implement me") } -func (m mockMachine) ExecutableContext() ExecutableContext { panic("implement me") } - func (m mockMachine) CompileEntryPreamble(signature *ssa.Signature) []byte { panic("TODO") } @@ -73,11 +72,6 @@ func (m mockMachine) SetCurrentABI(*FunctionABI) {} // SetCompiler implements Machine.SetCompiler. func (m mockMachine) SetCompiler(Compiler) {} -// StartLoweringFunction implements Machine.StartLoweringFunction. -func (m mockMachine) StartLoweringFunction(id ssa.BasicBlockID) { - m.startLoweringFunction(id) -} - // StartBlock implements Machine.StartBlock. func (m mockMachine) StartBlock(block ssa.BasicBlock) { m.startBlock(block) diff --git a/internal/engine/wazevo/backend/regalloc.go b/internal/engine/wazevo/backend/regalloc.go deleted file mode 100644 index 6553707860..0000000000 --- a/internal/engine/wazevo/backend/regalloc.go +++ /dev/null @@ -1,321 +0,0 @@ -package backend - -import ( - "github.com/tetratelabs/wazero/internal/engine/wazevo/backend/regalloc" - "github.com/tetratelabs/wazero/internal/engine/wazevo/ssa" -) - -// RegAllocFunctionMachine is the interface for the machine specific logic that will be used in RegAllocFunction. -type RegAllocFunctionMachine[I regalloc.InstrConstraint] interface { - // InsertMoveBefore inserts the move instruction from src to dst before the given instruction. - InsertMoveBefore(dst, src regalloc.VReg, instr I) - // InsertStoreRegisterAt inserts the instruction(s) to store the given virtual register at the given instruction. - // If after is true, the instruction(s) will be inserted after the given instruction, otherwise before. - InsertStoreRegisterAt(v regalloc.VReg, instr I, after bool) I - // InsertReloadRegisterAt inserts the instruction(s) to reload the given virtual register at the given instruction. - // If after is true, the instruction(s) will be inserted after the given instruction, otherwise before. - InsertReloadRegisterAt(v regalloc.VReg, instr I, after bool) I - // ClobberedRegisters is called when the register allocation is done and the clobbered registers are known. - ClobberedRegisters(regs []regalloc.VReg) - // Swap swaps the two virtual registers after the given instruction. - Swap(cur I, x1, x2, tmp regalloc.VReg) - // LastInstrForInsertion implements LastInstrForInsertion of regalloc.Function. See its comment for details. - LastInstrForInsertion(begin, end I) I - // SSABlockLabel returns the label of the given ssa.BasicBlockID. - SSABlockLabel(id ssa.BasicBlockID) Label -} - -type ( - // RegAllocFunction implements regalloc.Function. - RegAllocFunction[I regalloc.InstrConstraint, m RegAllocFunctionMachine[I]] struct { - m m - ssb ssa.Builder - c Compiler - // iter is the iterator for reversePostOrderBlocks - iter int - reversePostOrderBlocks []RegAllocBlock[I, m] - // labelToRegAllocBlockIndex maps label to the index of reversePostOrderBlocks. - labelToRegAllocBlockIndex [] /* Label to */ int - loopNestingForestRoots []ssa.BasicBlock - } - - // RegAllocBlock implements regalloc.Block. - RegAllocBlock[I regalloc.InstrConstraint, m RegAllocFunctionMachine[I]] struct { - // f is the function this instruction belongs to. Used to reuse the regAllocFunctionImpl.predsSlice slice for Defs() and Uses(). - f *RegAllocFunction[I, m] - sb ssa.BasicBlock - l Label - begin, end I - loopNestingForestChildren []ssa.BasicBlock - cur I - id int - cachedLastInstrForInsertion I - } -) - -// NewRegAllocFunction returns a new RegAllocFunction. -func NewRegAllocFunction[I regalloc.InstrConstraint, M RegAllocFunctionMachine[I]](m M, ssb ssa.Builder, c Compiler) *RegAllocFunction[I, M] { - return &RegAllocFunction[I, M]{ - m: m, - ssb: ssb, - c: c, - } -} - -// AddBlock adds a new block to the function. -func (f *RegAllocFunction[I, M]) AddBlock(sb ssa.BasicBlock, l Label, begin, end I) { - i := len(f.reversePostOrderBlocks) - f.reversePostOrderBlocks = append(f.reversePostOrderBlocks, RegAllocBlock[I, M]{ - f: f, - sb: sb, - l: l, - begin: begin, - end: end, - id: int(sb.ID()), - }) - if len(f.labelToRegAllocBlockIndex) <= int(l) { - f.labelToRegAllocBlockIndex = append(f.labelToRegAllocBlockIndex, make([]int, int(l)-len(f.labelToRegAllocBlockIndex)+1)...) - } - f.labelToRegAllocBlockIndex[l] = i -} - -// Reset resets the function for the next compilation. -func (f *RegAllocFunction[I, M]) Reset() { - f.reversePostOrderBlocks = f.reversePostOrderBlocks[:0] - f.iter = 0 -} - -// StoreRegisterAfter implements regalloc.Function StoreRegisterAfter. -func (f *RegAllocFunction[I, M]) StoreRegisterAfter(v regalloc.VReg, instr regalloc.Instr) { - m := f.m - m.InsertStoreRegisterAt(v, instr.(I), true) -} - -// ReloadRegisterBefore implements regalloc.Function ReloadRegisterBefore. -func (f *RegAllocFunction[I, M]) ReloadRegisterBefore(v regalloc.VReg, instr regalloc.Instr) { - m := f.m - m.InsertReloadRegisterAt(v, instr.(I), false) -} - -// ReloadRegisterAfter implements regalloc.Function ReloadRegisterAfter. -func (f *RegAllocFunction[I, M]) ReloadRegisterAfter(v regalloc.VReg, instr regalloc.Instr) { - m := f.m - m.InsertReloadRegisterAt(v, instr.(I), true) -} - -// StoreRegisterBefore implements regalloc.Function StoreRegisterBefore. -func (f *RegAllocFunction[I, M]) StoreRegisterBefore(v regalloc.VReg, instr regalloc.Instr) { - m := f.m - m.InsertStoreRegisterAt(v, instr.(I), false) -} - -// ClobberedRegisters implements regalloc.Function ClobberedRegisters. -func (f *RegAllocFunction[I, M]) ClobberedRegisters(regs []regalloc.VReg) { - f.m.ClobberedRegisters(regs) -} - -// SwapBefore implements regalloc.Function SwapBefore. -func (f *RegAllocFunction[I, M]) SwapBefore(x1, x2, tmp regalloc.VReg, instr regalloc.Instr) { - f.m.Swap(instr.Prev().(I), x1, x2, tmp) -} - -// PostOrderBlockIteratorBegin implements regalloc.Function PostOrderBlockIteratorBegin. -func (f *RegAllocFunction[I, M]) PostOrderBlockIteratorBegin() regalloc.Block { - f.iter = len(f.reversePostOrderBlocks) - 1 - return f.PostOrderBlockIteratorNext() -} - -// PostOrderBlockIteratorNext implements regalloc.Function PostOrderBlockIteratorNext. -func (f *RegAllocFunction[I, M]) PostOrderBlockIteratorNext() regalloc.Block { - if f.iter < 0 { - return nil - } - b := &f.reversePostOrderBlocks[f.iter] - f.iter-- - return b -} - -// ReversePostOrderBlockIteratorBegin implements regalloc.Function ReversePostOrderBlockIteratorBegin. -func (f *RegAllocFunction[I, M]) ReversePostOrderBlockIteratorBegin() regalloc.Block { - f.iter = 0 - return f.ReversePostOrderBlockIteratorNext() -} - -// ReversePostOrderBlockIteratorNext implements regalloc.Function ReversePostOrderBlockIteratorNext. -func (f *RegAllocFunction[I, M]) ReversePostOrderBlockIteratorNext() regalloc.Block { - if f.iter >= len(f.reversePostOrderBlocks) { - return nil - } - b := &f.reversePostOrderBlocks[f.iter] - f.iter++ - return b -} - -// LoopNestingForestRoots implements regalloc.Function LoopNestingForestRoots. -func (f *RegAllocFunction[I, M]) LoopNestingForestRoots() int { - f.loopNestingForestRoots = f.ssb.LoopNestingForestRoots() - return len(f.loopNestingForestRoots) -} - -// LoopNestingForestRoot implements regalloc.Function LoopNestingForestRoot. -func (f *RegAllocFunction[I, M]) LoopNestingForestRoot(i int) regalloc.Block { - blk := f.loopNestingForestRoots[i] - l := f.m.SSABlockLabel(blk.ID()) - index := f.labelToRegAllocBlockIndex[l] - return &f.reversePostOrderBlocks[index] -} - -// InsertMoveBefore implements regalloc.Function InsertMoveBefore. -func (f *RegAllocFunction[I, M]) InsertMoveBefore(dst, src regalloc.VReg, instr regalloc.Instr) { - f.m.InsertMoveBefore(dst, src, instr.(I)) -} - -// LowestCommonAncestor implements regalloc.Function LowestCommonAncestor. -func (f *RegAllocFunction[I, M]) LowestCommonAncestor(blk1, blk2 regalloc.Block) regalloc.Block { - ret := f.ssb.LowestCommonAncestor(blk1.(*RegAllocBlock[I, M]).sb, blk2.(*RegAllocBlock[I, M]).sb) - l := f.m.SSABlockLabel(ret.ID()) - index := f.labelToRegAllocBlockIndex[l] - return &f.reversePostOrderBlocks[index] -} - -// Idom implements regalloc.Function Idom. -func (f *RegAllocFunction[I, M]) Idom(blk regalloc.Block) regalloc.Block { - builder := f.ssb - idom := builder.Idom(blk.(*RegAllocBlock[I, M]).sb) - if idom == nil { - panic("BUG: idom must not be nil") - } - l := f.m.SSABlockLabel(idom.ID()) - index := f.labelToRegAllocBlockIndex[l] - return &f.reversePostOrderBlocks[index] -} - -// ID implements regalloc.Block. -func (r *RegAllocBlock[I, m]) ID() int32 { return int32(r.id) } - -// BlockParams implements regalloc.Block. -func (r *RegAllocBlock[I, m]) BlockParams(regs *[]regalloc.VReg) []regalloc.VReg { - c := r.f.c - *regs = (*regs)[:0] - for i := 0; i < r.sb.Params(); i++ { - v := c.VRegOf(r.sb.Param(i)) - *regs = append(*regs, v) - } - return *regs -} - -// InstrIteratorBegin implements regalloc.Block. -func (r *RegAllocBlock[I, m]) InstrIteratorBegin() regalloc.Instr { - r.cur = r.begin - return r.cur -} - -// InstrIteratorNext implements regalloc.Block. -func (r *RegAllocBlock[I, m]) InstrIteratorNext() regalloc.Instr { - for { - if r.cur == r.end { - return nil - } - instr := r.cur.Next() - r.cur = instr.(I) - if instr == nil { - return nil - } else if instr.AddedBeforeRegAlloc() { - // Only concerned about the instruction added before regalloc. - return instr - } - } -} - -// InstrRevIteratorBegin implements regalloc.Block. -func (r *RegAllocBlock[I, m]) InstrRevIteratorBegin() regalloc.Instr { - r.cur = r.end - return r.cur -} - -// InstrRevIteratorNext implements regalloc.Block. -func (r *RegAllocBlock[I, m]) InstrRevIteratorNext() regalloc.Instr { - for { - if r.cur == r.begin { - return nil - } - instr := r.cur.Prev() - r.cur = instr.(I) - if instr == nil { - return nil - } else if instr.AddedBeforeRegAlloc() { - // Only concerned about the instruction added before regalloc. - return instr - } - } -} - -// FirstInstr implements regalloc.Block. -func (r *RegAllocBlock[I, m]) FirstInstr() regalloc.Instr { - return r.begin -} - -// EndInstr implements regalloc.Block. -func (r *RegAllocBlock[I, m]) EndInstr() regalloc.Instr { - return r.end -} - -// LastInstrForInsertion implements regalloc.Block. -func (r *RegAllocBlock[I, m]) LastInstrForInsertion() regalloc.Instr { - var nil I - if r.cachedLastInstrForInsertion == nil { - r.cachedLastInstrForInsertion = r.f.m.LastInstrForInsertion(r.begin, r.end) - } - return r.cachedLastInstrForInsertion -} - -// Preds implements regalloc.Block. -func (r *RegAllocBlock[I, m]) Preds() int { return r.sb.Preds() } - -// Pred implements regalloc.Block. -func (r *RegAllocBlock[I, m]) Pred(i int) regalloc.Block { - sb := r.sb - pred := sb.Pred(i) - l := r.f.m.SSABlockLabel(pred.ID()) - index := r.f.labelToRegAllocBlockIndex[l] - return &r.f.reversePostOrderBlocks[index] -} - -// Entry implements regalloc.Block. -func (r *RegAllocBlock[I, m]) Entry() bool { return r.sb.EntryBlock() } - -// Succs implements regalloc.Block. -func (r *RegAllocBlock[I, m]) Succs() int { - return r.sb.Succs() -} - -// Succ implements regalloc.Block. -func (r *RegAllocBlock[I, m]) Succ(i int) regalloc.Block { - sb := r.sb - succ := sb.Succ(i) - if succ.ReturnBlock() { - return nil - } - l := r.f.m.SSABlockLabel(succ.ID()) - index := r.f.labelToRegAllocBlockIndex[l] - return &r.f.reversePostOrderBlocks[index] -} - -// LoopHeader implements regalloc.Block. -func (r *RegAllocBlock[I, m]) LoopHeader() bool { - return r.sb.LoopHeader() -} - -// LoopNestingForestChildren implements regalloc.Block. -func (r *RegAllocBlock[I, m]) LoopNestingForestChildren() int { - r.loopNestingForestChildren = r.sb.LoopNestingForestChildren() - return len(r.loopNestingForestChildren) -} - -// LoopNestingForestChild implements regalloc.Block. -func (r *RegAllocBlock[I, m]) LoopNestingForestChild(i int) regalloc.Block { - blk := r.loopNestingForestChildren[i] - l := r.f.m.SSABlockLabel(blk.ID()) - index := r.f.labelToRegAllocBlockIndex[l] - return &r.f.reversePostOrderBlocks[index] -} diff --git a/internal/engine/wazevo/backend/regalloc/api.go b/internal/engine/wazevo/backend/regalloc/api.go index 23157b4782..5d961614bb 100644 --- a/internal/engine/wazevo/backend/regalloc/api.go +++ b/internal/engine/wazevo/backend/regalloc/api.go @@ -4,104 +4,100 @@ import "fmt" // These interfaces are implemented by ISA-specific backends to abstract away the details, and allow the register // allocators to work on any ISA. -// -// TODO: the interfaces are not stabilized yet, especially x64 will need some changes. E.g. x64 has an addressing mode -// where index can be in memory. That kind of info will be useful to reduce the register pressure, and should be leveraged -// by the register allocators, like https://docs.rs/regalloc2/latest/regalloc2/enum.OperandConstraint.html type ( // Function is the top-level interface to do register allocation, which corresponds to a CFG containing // Blocks(s). - Function interface { + // + // I is the type of the instruction, and B is the type of the basic block. + Function[I Instr, B Block[I]] interface { // PostOrderBlockIteratorBegin returns the first block in the post-order traversal of the CFG. // In other words, the last blocks in the CFG will be returned first. - PostOrderBlockIteratorBegin() Block + PostOrderBlockIteratorBegin() B // PostOrderBlockIteratorNext returns the next block in the post-order traversal of the CFG. - PostOrderBlockIteratorNext() Block + PostOrderBlockIteratorNext() B // ReversePostOrderBlockIteratorBegin returns the first block in the reverse post-order traversal of the CFG. // In other words, the first blocks in the CFG will be returned first. - ReversePostOrderBlockIteratorBegin() Block + ReversePostOrderBlockIteratorBegin() B // ReversePostOrderBlockIteratorNext returns the next block in the reverse post-order traversal of the CFG. - ReversePostOrderBlockIteratorNext() Block + ReversePostOrderBlockIteratorNext() B // ClobberedRegisters tell the clobbered registers by this function. ClobberedRegisters([]VReg) // LoopNestingForestRoots returns the number of roots of the loop nesting forest in a function. LoopNestingForestRoots() int // LoopNestingForestRoot returns the i-th root of the loop nesting forest in a function. - LoopNestingForestRoot(i int) Block + LoopNestingForestRoot(i int) B // LowestCommonAncestor returns the lowest common ancestor of two blocks in the dominator tree. - LowestCommonAncestor(blk1, blk2 Block) Block + LowestCommonAncestor(blk1, blk2 B) B // Idom returns the immediate dominator of the given block. - Idom(blk Block) Block + Idom(blk B) B + + // LoopNestingForestChild returns the i-th child of the block in the loop nesting forest. + LoopNestingForestChild(b B, i int) B + // Pred returns the i-th predecessor of the block in the CFG. + Pred(b B, i int) B + // Succ returns the i-th successor of the block in the CFG. + Succ(b B, i int) B + // BlockParams returns the virtual registers used as the parameters of this block. + BlockParams(B, *[]VReg) []VReg // Followings are for rewriting the function. - // SwapAtEndOfBlock swaps the two virtual registers at the end of the given block. - SwapBefore(x1, x2, tmp VReg, instr Instr) + // SwapBefore swaps the two virtual registers at the end of the given block. + SwapBefore(x1, x2, tmp VReg, instr I) // StoreRegisterBefore inserts store instruction(s) before the given instruction for the given virtual register. - StoreRegisterBefore(v VReg, instr Instr) + StoreRegisterBefore(v VReg, instr I) // StoreRegisterAfter inserts store instruction(s) after the given instruction for the given virtual register. - StoreRegisterAfter(v VReg, instr Instr) + StoreRegisterAfter(v VReg, instr I) // ReloadRegisterBefore inserts reload instruction(s) before the given instruction for the given virtual register. - ReloadRegisterBefore(v VReg, instr Instr) + ReloadRegisterBefore(v VReg, instr I) // ReloadRegisterAfter inserts reload instruction(s) after the given instruction for the given virtual register. - ReloadRegisterAfter(v VReg, instr Instr) + ReloadRegisterAfter(v VReg, instr I) // InsertMoveBefore inserts move instruction(s) before the given instruction for the given virtual registers. - InsertMoveBefore(dst, src VReg, instr Instr) + InsertMoveBefore(dst, src VReg, instr I) } // Block is a basic block in the CFG of a function, and it consists of multiple instructions, and predecessor Block(s). - Block interface { + // Right now, this corresponds to a ssa.BasicBlock lowered to the machine level. + Block[I Instr] interface { + comparable // ID returns the unique identifier of this block which is ordered in the reverse post-order traversal of the CFG. ID() int32 - // BlockParams returns the virtual registers used as the parameters of this block. - BlockParams(*[]VReg) []VReg // InstrIteratorBegin returns the first instruction in this block. Instructions added after lowering must be skipped. // Note: multiple Instr(s) will not be held at the same time, so it's safe to use the same impl for the return Instr. - InstrIteratorBegin() Instr + InstrIteratorBegin() I // InstrIteratorNext returns the next instruction in this block. Instructions added after lowering must be skipped. // Note: multiple Instr(s) will not be held at the same time, so it's safe to use the same impl for the return Instr. - InstrIteratorNext() Instr + InstrIteratorNext() I // InstrRevIteratorBegin is the same as InstrIteratorBegin, but in the reverse order. - InstrRevIteratorBegin() Instr + InstrRevIteratorBegin() I // InstrRevIteratorNext is the same as InstrIteratorNext, but in the reverse order. - InstrRevIteratorNext() Instr + InstrRevIteratorNext() I // FirstInstr returns the fist instruction in this block where instructions will be inserted after it. - FirstInstr() Instr - // EndInstr returns the end instruction in this block. - EndInstr() Instr + FirstInstr() I // LastInstrForInsertion returns the last instruction in this block where instructions will be inserted before it. // Such insertions only happen when we need to insert spill/reload instructions to adjust the merge edges. // At the time of register allocation, all the critical edges are already split, so there is no need // to worry about the case where branching instruction has multiple successors. // Therefore, usually, it is the nop instruction, but if the block ends with an unconditional branching, then it returns // the unconditional branch, not the nop. In other words it is either nop or unconditional branch. - LastInstrForInsertion() Instr + LastInstrForInsertion() I // Preds returns the number of predecessors of this block in the CFG. Preds() int - // Pred returns the i-th predecessor of this block in the CFG. - Pred(i int) Block // Entry returns true if the block is for the entry block. Entry() bool // Succs returns the number of successors of this block in the CFG. Succs() int - // Succ returns the i-th successor of this block in the CFG. - Succ(i int) Block // LoopHeader returns true if this block is a loop header. LoopHeader() bool // LoopNestingForestChildren returns the number of children of this block in the loop nesting forest. LoopNestingForestChildren() int - // LoopNestingForestChild returns the i-th child of this block in the loop nesting forest. - LoopNestingForestChild(i int) Block } // Instr is an instruction in a block, abstracting away the underlying ISA. Instr interface { + comparable fmt.Stringer - // Next returns the next instruction in the same block. - Next() Instr - // Prev returns the previous instruction in the same block. - Prev() Instr // Defs returns the virtual registers defined by this instruction. Defs(*[]VReg) []VReg // Uses returns the virtual registers used by this instruction. @@ -127,10 +123,4 @@ type ( // AddedBeforeRegAlloc returns true if this instruction is added before register allocation. AddedBeforeRegAlloc() bool } - - // InstrConstraint is an interface for arch-specific instruction constraints. - InstrConstraint interface { - comparable - Instr - } ) diff --git a/internal/engine/wazevo/backend/regalloc/api_test.go b/internal/engine/wazevo/backend/regalloc/api_test.go index 12ad1bf6d1..d3d9b662f0 100644 --- a/internal/engine/wazevo/backend/regalloc/api_test.go +++ b/internal/engine/wazevo/backend/regalloc/api_test.go @@ -18,7 +18,7 @@ type ( storeOrReloadInfo struct { reload bool v VReg - instr Instr + instr *mockInstr } // mockBlock implements Block. @@ -26,7 +26,7 @@ type ( id int32 instructions []*mockInstr preds, succs []*mockBlock - _preds, _succs []Block + _preds, _succs []*mockBlock iter int _entry bool _loop bool @@ -42,13 +42,13 @@ type ( } ) -func (m *mockFunction) LowestCommonAncestor(blk1, blk2 Block) Block { panic("TODO") } +func (m *mockFunction) LowestCommonAncestor(blk1, blk2 *mockBlock) *mockBlock { panic("TODO") } -func (m *mockFunction) Idom(blk Block) Block { panic("TODO") } +func (m *mockFunction) Idom(blk *mockBlock) *mockBlock { panic("TODO") } -func (m *mockFunction) SwapBefore(x1, x2, tmp VReg, instr Instr) { panic("TODO") } +func (m *mockFunction) SwapBefore(x1, x2, tmp VReg, instr *mockInstr) { panic("TODO") } -func (m *mockFunction) InsertMoveBefore(dst, src VReg, instr Instr) { panic("TODO") } +func (m *mockFunction) InsertMoveBefore(dst, src VReg, instr *mockInstr) { panic("TODO") } func newMockFunction(blocks ...*mockBlock) *mockFunction { return &mockFunction{blocks: blocks} @@ -141,22 +141,22 @@ func (m *mockInstr) asIndirectCall() *mockInstr { //nolint:unused } // StoreRegisterAfter implements Function.StoreRegisterAfter. -func (m *mockFunction) StoreRegisterAfter(v VReg, instr Instr) { +func (m *mockFunction) StoreRegisterAfter(v VReg, instr *mockInstr) { m.afters = append(m.afters, storeOrReloadInfo{false, v, instr}) } // ReloadRegisterBefore implements Function.ReloadRegisterBefore. -func (m *mockFunction) ReloadRegisterBefore(v VReg, instr Instr) { +func (m *mockFunction) ReloadRegisterBefore(v VReg, instr *mockInstr) { m.befores = append(m.befores, storeOrReloadInfo{true, v, instr}) } // StoreRegisterBefore implements Function.StoreRegisterBefore. -func (m *mockFunction) StoreRegisterBefore(v VReg, instr Instr) { +func (m *mockFunction) StoreRegisterBefore(v VReg, instr *mockInstr) { m.befores = append(m.befores, storeOrReloadInfo{false, v, instr}) } // ReloadRegisterAfter implements Function.ReloadRegisterAfter. -func (m *mockFunction) ReloadRegisterAfter(v VReg, instr Instr) { +func (m *mockFunction) ReloadRegisterAfter(v VReg, instr *mockInstr) { m.afters = append(m.afters, storeOrReloadInfo{true, v, instr}) } @@ -170,14 +170,14 @@ func (m *mockFunction) ClobberedRegisters(regs []VReg) { func (m *mockFunction) Done() {} // PostOrderBlockIteratorBegin implements Block. -func (m *mockFunction) PostOrderBlockIteratorBegin() Block { +func (m *mockFunction) PostOrderBlockIteratorBegin() *mockBlock { m.iter = 1 l := len(m.blocks) return m.blocks[l-1] } // PostOrderBlockIteratorNext implements Block. -func (m *mockFunction) PostOrderBlockIteratorNext() Block { +func (m *mockFunction) PostOrderBlockIteratorNext() *mockBlock { if m.iter == len(m.blocks) { return nil } @@ -188,13 +188,13 @@ func (m *mockFunction) PostOrderBlockIteratorNext() Block { } // ReversePostOrderBlockIteratorBegin implements Block. -func (m *mockFunction) ReversePostOrderBlockIteratorBegin() Block { +func (m *mockFunction) ReversePostOrderBlockIteratorBegin() *mockBlock { m.iter = 1 return m.blocks[0] } // ReversePostOrderBlockIteratorNext implements Block. -func (m *mockFunction) ReversePostOrderBlockIteratorNext() Block { +func (m *mockFunction) ReversePostOrderBlockIteratorNext() *mockBlock { if m.iter == len(m.blocks) { return nil } @@ -209,7 +209,7 @@ func (m *mockBlock) ID() int32 { } // InstrIteratorBegin implements Block. -func (m *mockBlock) InstrIteratorBegin() Instr { +func (m *mockBlock) InstrIteratorBegin() *mockInstr { if len(m.instructions) == 0 { return nil } @@ -218,7 +218,7 @@ func (m *mockBlock) InstrIteratorBegin() Instr { } // InstrIteratorNext implements Block. -func (m *mockBlock) InstrIteratorNext() Instr { +func (m *mockBlock) InstrIteratorNext() *mockInstr { if m.iter == len(m.instructions) { return nil } @@ -228,7 +228,7 @@ func (m *mockBlock) InstrIteratorNext() Instr { } // InstrRevIteratorBegin implements Block. -func (m *mockBlock) InstrRevIteratorBegin() Instr { +func (m *mockBlock) InstrRevIteratorBegin() *mockInstr { if len(m.instructions) == 0 { return nil } @@ -237,7 +237,7 @@ func (m *mockBlock) InstrRevIteratorBegin() Instr { } // InstrRevIteratorNext implements Block. -func (m *mockBlock) InstrRevIteratorNext() Instr { +func (m *mockBlock) InstrRevIteratorNext() *mockInstr { m.iter-- if m.iter < 0 { return nil @@ -251,8 +251,8 @@ func (m *mockBlock) Preds() int { } // BlockParams implements Block. -func (m *mockBlock) BlockParams(ret *[]VReg) []VReg { - *ret = append((*ret)[:0], m.blockParams...) +func (m *mockFunction) BlockParams(blk *mockBlock, ret *[]VReg) []VReg { + *ret = append((*ret)[:0], blk.blockParams...) return *ret } @@ -260,9 +260,6 @@ func (m *mockBlock) blockParam(v VReg) { m.blockParams = append(m.blockParams, v) } -// Pred implements Instr. -func (m *mockBlock) Pred(i int) Block { return m._preds[i] } - // Defs implements Instr. func (m *mockInstr) Defs(ret *[]VReg) []VReg { *ret = append((*ret)[:0], m.defs...) @@ -291,10 +288,10 @@ func (m *mockInstr) IsIndirectCall() bool { return m.isIndirect } func (m *mockInstr) IsReturn() bool { return false } // Next implements Instr. -func (m *mockInstr) Next() Instr { return m.next } +func (m *mockInstr) Next() *mockInstr { return m.next } // Prev implements Instr. -func (m *mockInstr) Prev() Instr { return m.prev } +func (m *mockInstr) Prev() *mockInstr { return m.prev } // Entry implements Entry. func (m *mockBlock) Entry() bool { return m._entry } @@ -312,20 +309,30 @@ func (m *mockInstr) AssignDef(reg VReg) { m.defs = []VReg{reg} } -var ( - _ Function = (*mockFunction)(nil) - _ Block = (*mockBlock)(nil) - _ Instr = (*mockInstr)(nil) -) +var _ Function[*mockInstr, *mockBlock] = (*mockFunction)(nil) func (m *mockFunction) LoopNestingForestRoots() int { return len(m.lnfRoots) } -func (m *mockFunction) LoopNestingForestRoot(i int) Block { +// LoopNestingForestRoot implements Function. +func (m *mockFunction) LoopNestingForestRoot(i int) *mockBlock { return m.lnfRoots[i] } +// LoopNestingForestChild implements Function. +func (m *mockFunction) LoopNestingForestChild(b *mockBlock, i int) *mockBlock { + return b.lnfChildren[i] +} + +// Succ implements Function. +func (m *mockFunction) Succ(b *mockBlock, i int) *mockBlock { + return b.succs[i] +} + +// Pred implements Function. +func (m *mockFunction) Pred(b *mockBlock, i int) *mockBlock { return b._preds[i] } + func (m *mockBlock) LoopHeader() bool { return m._loop } @@ -334,39 +341,17 @@ func (m *mockBlock) Succs() int { return len(m.succs) } -func (m *mockBlock) Succ(i int) Block { - return m.succs[i] -} - func (m *mockBlock) LoopNestingForestChildren() int { return len(m.lnfChildren) } -func (m *mockBlock) LoopNestingForestChild(i int) Block { - return m.lnfChildren[i] -} - -func (m *mockBlock) BeginInstr() Instr { - if len(m.instructions) == 0 { - return nil - } - return m.instructions[0] -} - -func (m *mockBlock) EndInstr() Instr { - if len(m.instructions) == 0 { - return nil - } - return m.instructions[len(m.instructions)-1] -} - -func (m *mockBlock) LastInstrForInsertion() Instr { +func (m *mockBlock) LastInstrForInsertion() *mockInstr { if len(m.instructions) == 0 { return nil } return m.instructions[len(m.instructions)-1] } -func (m *mockBlock) FirstInstr() Instr { +func (m *mockBlock) FirstInstr() *mockInstr { return m.instructions[0] } diff --git a/internal/engine/wazevo/backend/regalloc/regalloc.go b/internal/engine/wazevo/backend/regalloc/regalloc.go index a13fad8ef9..d02961a8a2 100644 --- a/internal/engine/wazevo/backend/regalloc/regalloc.go +++ b/internal/engine/wazevo/backend/regalloc/regalloc.go @@ -18,13 +18,13 @@ import ( ) // NewAllocator returns a new Allocator. -func NewAllocator(allocatableRegs *RegisterInfo) Allocator { - a := Allocator{ +func NewAllocator[I Instr, B Block[I]](allocatableRegs *RegisterInfo) Allocator[I, B] { + a := Allocator[I, B]{ regInfo: allocatableRegs, - phiDefInstListPool: wazevoapi.NewPool[phiDefInstList](resetPhiDefInstList), - blockStates: wazevoapi.NewIDedPool[blockState](resetBlockState), + phiDefInstListPool: wazevoapi.NewPool[phiDefInstList[I]](resetPhiDefInstList[I]), + blockStates: wazevoapi.NewIDedPool[blockState[I, B]](resetBlockState[I, B]), } - a.state.vrStates = wazevoapi.NewIDedPool[vrState](resetVrState) + a.state.vrStates = wazevoapi.NewIDedPool[vrState[I, B]](resetVrState[I, B]) a.state.reset() for _, regs := range allocatableRegs.AllocatableRegisters { for _, r := range regs { @@ -49,39 +49,39 @@ type ( } // Allocator is a register allocator. - Allocator struct { + Allocator[I Instr, B Block[I]] struct { // regInfo is static per ABI/ISA, and is initialized by the machine during Machine.PrepareRegisterAllocator. regInfo *RegisterInfo // allocatableSet is a set of allocatable RealReg derived from regInfo. Static per ABI/ISA. allocatableSet RegSet allocatedCalleeSavedRegs []VReg vs []VReg - ss []*vrState - copies []_copy - phiDefInstListPool wazevoapi.Pool[phiDefInstList] + ss []*vrState[I, B] + copies []_copy[I, B] + phiDefInstListPool wazevoapi.Pool[phiDefInstList[I]] // Followings are re-used during various places. - blks []Block + blks []B reals []RealReg // Following two fields are updated while iterating the blocks in the reverse postorder. - state state - blockStates wazevoapi.IDedPool[blockState] + state state[I, B] + blockStates wazevoapi.IDedPool[blockState[I, B]] } // _copy represents a source and destination pair of a copy instruction. - _copy struct { - src *vrState + _copy[I Instr, B Block[I]] struct { + src *vrState[I, B] dstID VRegID } // programCounter represents an opaque index into the program which is used to represents a LiveInterval of a VReg. programCounter int32 - state struct { + state[I Instr, B Block[I]] struct { argRealRegs []VReg - regsInUse regInUseSet - vrStates wazevoapi.IDedPool[vrState] + regsInUse regInUseSet[I, B] + vrStates wazevoapi.IDedPool[vrState[I, B]] currentBlockID int32 @@ -89,30 +89,30 @@ type ( allocatedRegSet RegSet } - blockState struct { + blockState[I Instr, B Block[I]] struct { // liveIns is a list of VReg that are live at the beginning of the block. - liveIns []*vrState + liveIns []*vrState[I, B] // seen is true if the block is visited during the liveness analysis. seen bool // visited is true if the block is visited during the allocation phase. visited bool startFromPredIndex int // startRegs is a list of RealReg that are used at the beginning of the block. This is used to fix the merge edges. - startRegs regInUseSet + startRegs regInUseSet[I, B] // endRegs is a list of RealReg that are used at the end of the block. This is used to fix the merge edges. - endRegs regInUseSet + endRegs regInUseSet[I, B] } - vrState struct { + vrState[I Instr, B Block[I]] struct { v VReg r RealReg // defInstr is the instruction that defines this value. If this is the phi value and not the entry block, this is nil. - defInstr Instr + defInstr I // defBlk is the block that defines this value. If this is the phi value, this is the block whose arguments contain this value. - defBlk Block + defBlk B // lca = lowest common ancestor. This is the block that is the lowest common ancestor of all the blocks that // reloads this value. This is used to determine the spill location. Only valid if spilled=true. - lca Block + lca B // lastUse is the program counter of the last use of this value. This changes while iterating the block, and // should not be used across the blocks as it becomes invalid. To check the validity, use lastUseUpdatedAtBlockID. lastUse programCounter @@ -127,14 +127,14 @@ type ( desiredLoc desiredLoc // phiDefInstList is a list of instructions that defines this phi value. // This is used to determine the spill location, and only valid if isPhi=true. - *phiDefInstList + *phiDefInstList[I] } // phiDefInstList is a linked list of instructions that defines a phi value. - phiDefInstList struct { - instr Instr + phiDefInstList[I Instr] struct { + instr I v VReg - next *phiDefInstList + next *phiDefInstList[I] } // desiredLoc represents a desired location for a VReg. @@ -166,13 +166,14 @@ func (d desiredLoc) stack() bool { return d&3 == desiredLoc(desiredLocKindStack) } -func resetPhiDefInstList(l *phiDefInstList) { - l.instr = nil +func resetPhiDefInstList[I Instr](l *phiDefInstList[I]) { + var nilInstr I + l.instr = nilInstr l.next = nil l.v = VRegInvalid } -func (s *state) dump(info *RegisterInfo) { //nolint:unused +func (s *state[I, B]) dump(info *RegisterInfo) { //nolint:unused fmt.Println("\t\tstate:") fmt.Println("\t\t\targRealRegs:", s.argRealRegs) fmt.Println("\t\t\tregsInUse", s.regsInUse.format(info)) @@ -191,7 +192,7 @@ func (s *state) dump(info *RegisterInfo) { //nolint:unused fmt.Println("\t\t\tvrStates:", strings.Join(strs, ", ")) } -func (s *state) reset() { +func (s *state[I, B]) reset() { s.argRealRegs = s.argRealRegs[:0] s.vrStates.Reset() s.allocatedRegSet = RegSet(0) @@ -199,21 +200,23 @@ func (s *state) reset() { s.currentBlockID = -1 } -func resetVrState(vs *vrState) { +func resetVrState[I Instr, B Block[I]](vs *vrState[I, B]) { vs.v = VRegInvalid vs.r = RealRegInvalid - vs.defInstr = nil - vs.defBlk = nil + var nilInstr I + vs.defInstr = nilInstr + var nilBlk B + vs.defBlk = nilBlk vs.spilled = false vs.lastUse = -1 vs.lastUseUpdatedAtBlockID = -1 - vs.lca = nil + vs.lca = nilBlk vs.isPhi = false vs.phiDefInstList = nil vs.desiredLoc = desiredLocUnspecified } -func (s *state) getOrAllocateVRegState(v VReg) *vrState { +func (s *state[I, B]) getOrAllocateVRegState(v VReg) *vrState[I, B] { st := s.vrStates.GetOrAllocate(int(v.ID())) if st.v == VRegInvalid { st.v = v @@ -221,11 +224,11 @@ func (s *state) getOrAllocateVRegState(v VReg) *vrState { return st } -func (s *state) getVRegState(v VRegID) *vrState { +func (s *state[I, B]) getVRegState(v VRegID) *vrState[I, B] { return s.vrStates.Get(int(v)) } -func (s *state) useRealReg(r RealReg, vr *vrState) { +func (s *state[I, B]) useRealReg(r RealReg, vr *vrState[I, B]) { if s.regsInUse.has(r) { panic("BUG: useRealReg: the given real register is already used") } @@ -234,7 +237,7 @@ func (s *state) useRealReg(r RealReg, vr *vrState) { s.allocatedRegSet = s.allocatedRegSet.add(r) } -func (s *state) releaseRealReg(r RealReg) { +func (s *state[I, B]) releaseRealReg(r RealReg) { current := s.regsInUse.get(r) if current != nil { s.regsInUse.remove(r) @@ -244,9 +247,10 @@ func (s *state) releaseRealReg(r RealReg) { // recordReload records that the given VReg is reloaded in the given block. // This is used to determine the spill location by tracking the lowest common ancestor of all the blocks that reloads the value. -func (vs *vrState) recordReload(f Function, blk Block) { +func (vs *vrState[I, B]) recordReload(f Function[I, B], blk B) { vs.spilled = true - if vs.lca == nil { + var nilBlk B + if vs.lca == nilBlk { if wazevoapi.RegAllocLoggingEnabled { fmt.Printf("\t\tv%d is reloaded in blk%d,\n", vs.v.ID(), blk.ID()) } @@ -262,7 +266,7 @@ func (vs *vrState) recordReload(f Function, blk Block) { } } -func (s *state) findOrSpillAllocatable(a *Allocator, allocatable []RealReg, forbiddenMask RegSet, preferred RealReg) (r RealReg) { +func (a *Allocator[I, B]) findOrSpillAllocatable(s *state[I, B], allocatable []RealReg, forbiddenMask RegSet, preferred RealReg) (r RealReg) { r = RealRegInvalid // First, check if the preferredMask has any allocatable register. if preferred != RealRegInvalid && !forbiddenMask.has(preferred) && !s.regsInUse.has(preferred) { @@ -320,7 +324,7 @@ func (s *state) findOrSpillAllocatable(a *Allocator, allocatable []RealReg, forb return r } -func (s *state) findAllocatable(allocatable []RealReg, forbiddenMask RegSet) RealReg { +func (s *state[I, B]) findAllocatable(allocatable []RealReg, forbiddenMask RegSet) RealReg { for _, r := range allocatable { if !s.regsInUse.has(r) && !forbiddenMask.has(r) { return r @@ -329,12 +333,12 @@ func (s *state) findAllocatable(allocatable []RealReg, forbiddenMask RegSet) Rea return RealRegInvalid } -func (s *state) resetAt(bs *blockState) { - s.regsInUse.range_(func(_ RealReg, vs *vrState) { +func (s *state[I, B]) resetAt(bs *blockState[I, B]) { + s.regsInUse.range_(func(_ RealReg, vs *vrState[I, B]) { vs.r = RealRegInvalid }) s.regsInUse.reset() - bs.endRegs.range_(func(r RealReg, vs *vrState) { + bs.endRegs.range_(func(r RealReg, vs *vrState[I, B]) { if vs.lastUseUpdatedAtBlockID == s.currentBlockID && vs.lastUse == programCounterLiveIn { s.regsInUse.add(r, vs) vs.r = r @@ -342,7 +346,7 @@ func (s *state) resetAt(bs *blockState) { }) } -func resetBlockState(b *blockState) { +func resetBlockState[I Instr, B Block[I]](b *blockState[I, B]) { b.seen = false b.visited = false b.endRegs.reset() @@ -351,7 +355,7 @@ func resetBlockState(b *blockState) { b.liveIns = b.liveIns[:0] } -func (b *blockState) dump(a *RegisterInfo) { +func (b *blockState[I, B]) dump(a *RegisterInfo) { fmt.Println("\t\tblockState:") fmt.Println("\t\t\tstartRegs:", b.startRegs.format(a)) fmt.Println("\t\t\tendRegs:", b.endRegs.format(a)) @@ -360,13 +364,13 @@ func (b *blockState) dump(a *RegisterInfo) { } // DoAllocation performs register allocation on the given Function. -func (a *Allocator) DoAllocation(f Function) { +func (a *Allocator[I, B]) DoAllocation(f Function[I, B]) { a.livenessAnalysis(f) a.alloc(f) a.determineCalleeSavedRealRegs(f) } -func (a *Allocator) determineCalleeSavedRealRegs(f Function) { +func (a *Allocator[I, B]) determineCalleeSavedRealRegs(f Function[I, B]) { a.allocatedCalleeSavedRegs = a.allocatedCalleeSavedRegs[:0] a.state.allocatedRegSet.Range(func(allocatedRealReg RealReg) { if a.regInfo.CalleeSavedRegisters.has(allocatedRealReg) { @@ -376,16 +380,17 @@ func (a *Allocator) determineCalleeSavedRealRegs(f Function) { f.ClobberedRegisters(a.allocatedCalleeSavedRegs) } -func (a *Allocator) getOrAllocateBlockState(blockID int32) *blockState { +func (a *Allocator[I, B]) getOrAllocateBlockState(blockID int32) *blockState[I, B] { return a.blockStates.GetOrAllocate(int(blockID)) } // phiBlk returns the block that defines the given phi value, nil otherwise. -func (vs *vrState) phiBlk() Block { +func (vs *vrState[I, B]) phiBlk() B { if vs.isPhi { return vs.defBlk } - return nil + var nilBlk B + return nilBlk } const ( @@ -395,16 +400,18 @@ const ( // liveAnalysis constructs Allocator.blockLivenessData. // The algorithm here is described in https://pfalcon.github.io/ssabook/latest/book-full.pdf Chapter 9.2. -func (a *Allocator) livenessAnalysis(f Function) { +func (a *Allocator[I, B]) livenessAnalysis(f Function[I, B]) { s := &a.state for i := VRegID(0); i < vRegIDReservedForRealNum; i++ { s.getOrAllocateVRegState(VReg(i).SetRealReg(RealReg(i))) } - for blk := f.PostOrderBlockIteratorBegin(); blk != nil; blk = f.PostOrderBlockIteratorNext() { + var nilBlk B + var nilInstr I + for blk := f.PostOrderBlockIteratorBegin(); blk != nilBlk; blk = f.PostOrderBlockIteratorNext() { // We should gather phi value data. - for _, p := range blk.BlockParams(&a.vs) { + for _, p := range f.BlockParams(blk, &a.vs) { vs := s.getOrAllocateVRegState(p) vs.isPhi = true vs.defBlk = blk @@ -420,8 +427,8 @@ func (a *Allocator) livenessAnalysis(f Function) { ) ns := blk.Succs() for i := 0; i < ns; i++ { - succ := blk.Succ(i) - if succ == nil { + succ := f.Succ(blk, i) + if succ == nilBlk { continue } @@ -440,7 +447,7 @@ func (a *Allocator) livenessAnalysis(f Function) { } } - for instr := blk.InstrRevIteratorBegin(); instr != nil; instr = blk.InstrRevIteratorNext() { + for instr := blk.InstrRevIteratorBegin(); instr != nilInstr; instr = blk.InstrRevIteratorNext() { var use, def VReg var defIsPhi bool @@ -485,12 +492,12 @@ func (a *Allocator) livenessAnalysis(f Function) { nrs := f.LoopNestingForestRoots() for i := 0; i < nrs; i++ { root := f.LoopNestingForestRoot(i) - a.loopTreeDFS(root) + a.loopTreeDFS(f, root) } } // loopTreeDFS implements the Algorithm 9.3 in the book in an iterative way. -func (a *Allocator) loopTreeDFS(entry Block) { +func (a *Allocator[I, B]) loopTreeDFS(f Function[I, B], entry B) { a.blks = a.blks[:0] a.blks = append(a.blks, entry) @@ -512,10 +519,10 @@ func (a *Allocator) loopTreeDFS(entry Block) { } } - var siblingAddedView []*vrState + var siblingAddedView []*vrState[I, B] cn := loop.LoopNestingForestChildren() for i := 0; i < cn; i++ { - child := loop.LoopNestingForestChild(i) + child := f.LoopNestingForestChild(loop, i) childID := child.ID() childInfo := a.getOrAllocateBlockState(childID) @@ -559,26 +566,27 @@ func (a *Allocator) loopTreeDFS(entry Block) { // the spill happens in the block that is the lowest common ancestor of all the blocks that reloads the value. // // All of these logics are almost the same as Go's compiler which has a dedicated description in the source file ^^. -func (a *Allocator) alloc(f Function) { +func (a *Allocator[I, B]) alloc(f Function[I, B]) { // First we allocate each block in the reverse postorder (at least one predecessor should be allocated for each block). - for blk := f.ReversePostOrderBlockIteratorBegin(); blk != nil; blk = f.ReversePostOrderBlockIteratorNext() { + var nilBlk B + for blk := f.ReversePostOrderBlockIteratorBegin(); blk != nilBlk; blk = f.ReversePostOrderBlockIteratorNext() { if wazevoapi.RegAllocLoggingEnabled { fmt.Printf("========== allocating blk%d ========\n", blk.ID()) } if blk.Entry() { - a.finalizeStartReg(blk) + a.finalizeStartReg(f, blk) } a.allocBlock(f, blk) } // After the allocation, we all know the start and end state of each block. So we can fix the merge states. - for blk := f.ReversePostOrderBlockIteratorBegin(); blk != nil; blk = f.ReversePostOrderBlockIteratorNext() { + for blk := f.ReversePostOrderBlockIteratorBegin(); blk != nilBlk; blk = f.ReversePostOrderBlockIteratorNext() { a.fixMergeState(f, blk) } // Finally, we insert the spill instructions as we know all the places where the reloads happen. a.scheduleSpills(f) } -func (a *Allocator) updateLiveInVRState(liveness *blockState) { +func (a *Allocator[I, B]) updateLiveInVRState(liveness *blockState[I, B]) { currentBlockID := a.state.currentBlockID for _, vs := range liveness.liveIns { vs.lastUse = programCounterLiveIn @@ -586,7 +594,7 @@ func (a *Allocator) updateLiveInVRState(liveness *blockState) { } } -func (a *Allocator) finalizeStartReg(blk Block) { +func (a *Allocator[I, B]) finalizeStartReg(f Function[I, B], blk B) { bID := blk.ID() s := &a.state currentBlkState := a.getOrAllocateBlockState(bID) @@ -598,17 +606,17 @@ func (a *Allocator) finalizeStartReg(blk Block) { a.updateLiveInVRState(currentBlkState) preds := blk.Preds() - var predState *blockState + var predState *blockState[I, B] switch preds { case 0: // This is the entry block. case 1: - predID := blk.Pred(0).ID() + predID := f.Pred(blk, 0).ID() predState = a.getOrAllocateBlockState(predID) currentBlkState.startFromPredIndex = 0 default: // TODO: there should be some better heuristic to choose the predecessor. for i := 0; i < preds; i++ { - predID := blk.Pred(i).ID() + predID := f.Pred(blk, i).ID() if _predState := a.getOrAllocateBlockState(predID); _predState.visited { predState = _predState currentBlkState.startFromPredIndex = i @@ -627,12 +635,12 @@ func (a *Allocator) finalizeStartReg(blk Block) { } else { if wazevoapi.RegAllocLoggingEnabled { fmt.Printf("allocating blk%d starting from blk%d (on index=%d) \n", - bID, blk.Pred(currentBlkState.startFromPredIndex).ID(), currentBlkState.startFromPredIndex) + bID, f.Pred(blk, currentBlkState.startFromPredIndex).ID(), currentBlkState.startFromPredIndex) } s.resetAt(predState) } - s.regsInUse.range_(func(allocated RealReg, v *vrState) { + s.regsInUse.range_(func(allocated RealReg, v *vrState[I, B]) { currentBlkState.startRegs.add(allocated, v) }) if wazevoapi.RegAllocLoggingEnabled { @@ -640,7 +648,7 @@ func (a *Allocator) finalizeStartReg(blk Block) { } } -func (a *Allocator) allocBlock(f Function, blk Block) { +func (a *Allocator[I, B]) allocBlock(f Function[I, B], blk B) { bID := blk.ID() s := &a.state currentBlkState := a.getOrAllocateBlockState(bID) @@ -651,18 +659,19 @@ func (a *Allocator) allocBlock(f Function, blk Block) { } // Clears the previous state. - s.regsInUse.range_(func(allocatedRealReg RealReg, vr *vrState) { vr.r = RealRegInvalid }) + s.regsInUse.range_(func(allocatedRealReg RealReg, vr *vrState[I, B]) { vr.r = RealRegInvalid }) s.regsInUse.reset() // Then set the start state. - currentBlkState.startRegs.range_(func(allocatedRealReg RealReg, vr *vrState) { s.useRealReg(allocatedRealReg, vr) }) + currentBlkState.startRegs.range_(func(allocatedRealReg RealReg, vr *vrState[I, B]) { s.useRealReg(allocatedRealReg, vr) }) desiredUpdated := a.ss[:0] // Update the last use of each VReg. a.copies = a.copies[:0] // Stores the copy instructions. var pc programCounter - for instr := blk.InstrIteratorBegin(); instr != nil; instr = blk.InstrIteratorNext() { - var useState *vrState + var nilInstr I + for instr := blk.InstrIteratorBegin(); instr != nilInstr; instr = blk.InstrIteratorNext() { + var useState *vrState[I, B] for _, use := range instr.Uses(&a.vs) { useState = s.getVRegState(use.ID()) if !use.IsRealReg() { @@ -672,7 +681,7 @@ func (a *Allocator) allocBlock(f Function, blk Block) { if instr.IsCopy() { def := instr.Defs(&a.vs)[0] - a.copies = append(a.copies, _copy{src: useState, dstID: def.ID()}) + a.copies = append(a.copies, _copy[I, B]{src: useState, dstID: def.ID()}) r := def.RealReg() if r != RealRegInvalid { if !useState.isPhi { // TODO: no idea why do we need this. @@ -686,10 +695,11 @@ func (a *Allocator) allocBlock(f Function, blk Block) { // Mark all live-out values by checking live-in of the successors. // While doing so, we also update the desired register values. - var succ Block + var succ B + var nilBlk B for i, ns := 0, blk.Succs(); i < ns; i++ { - succ = blk.Succ(i) - if succ == nil { + succ = f.Succ(blk, i) + if succ == nilBlk { continue } @@ -705,11 +715,11 @@ func (a *Allocator) allocBlock(f Function, blk Block) { if wazevoapi.RegAllocLoggingEnabled { fmt.Printf("blk%d -> blk%d: start_regs: %s\n", bID, succID, succState.startRegs.format(a.regInfo)) } - succState.startRegs.range_(func(allocatedRealReg RealReg, vs *vrState) { + succState.startRegs.range_(func(allocatedRealReg RealReg, vs *vrState[I, B]) { vs.desiredLoc = newDesiredLocReg(allocatedRealReg) desiredUpdated = append(desiredUpdated, vs) }) - for _, p := range succ.BlockParams(&a.vs) { + for _, p := range f.BlockParams(succ, &a.vs) { vs := s.getVRegState(p.ID()) if vs.desiredLoc.realReg() == RealRegInvalid { vs.desiredLoc = desiredLocStack @@ -731,7 +741,7 @@ func (a *Allocator) allocBlock(f Function, blk Block) { } pc = 0 - for instr := blk.InstrIteratorBegin(); instr != nil; instr = blk.InstrIteratorNext() { + for instr := blk.InstrIteratorBegin(); instr != nilInstr; instr = blk.InstrIteratorNext() { if wazevoapi.RegAllocLoggingEnabled { fmt.Println(instr) } @@ -763,7 +773,7 @@ func (a *Allocator) allocBlock(f Function, blk Block) { r := vs.r if r == RealRegInvalid { - r = s.findOrSpillAllocatable(a, a.regInfo.AllocatableRegisters[use.RegType()], currentUsedSet, + r = a.findOrSpillAllocatable(s, a.regInfo.AllocatableRegisters[use.RegType()], currentUsedSet, // Prefer the desired register if it's available. vs.desiredLoc.realReg()) vs.recordReload(f, blk) @@ -866,7 +876,7 @@ func (a *Allocator) allocBlock(f Function, blk Block) { } if r == RealRegInvalid { typ := def.RegType() - r = s.findOrSpillAllocatable(a, a.regInfo.AllocatableRegisters[typ], RegSet(0), RealRegInvalid) + r = a.findOrSpillAllocatable(s, a.regInfo.AllocatableRegisters[typ], RegSet(0), RealRegInvalid) } s.useRealReg(r, vState) } @@ -901,7 +911,7 @@ func (a *Allocator) allocBlock(f Function, blk Block) { pc++ } - s.regsInUse.range_(func(allocated RealReg, v *vrState) { currentBlkState.endRegs.add(allocated, v) }) + s.regsInUse.range_(func(allocated RealReg, v *vrState[I, B]) { currentBlkState.endRegs.add(allocated, v) }) currentBlkState.visited = true if wazevoapi.RegAllocLoggingEnabled { @@ -915,16 +925,16 @@ func (a *Allocator) allocBlock(f Function, blk Block) { a.ss = desiredUpdated[:0] for i := 0; i < blk.Succs(); i++ { - succ := blk.Succ(i) - if succ == nil { + succ := f.Succ(blk, i) + if succ == nilBlk { continue } // If the successor is not visited yet, finalize the start state. - a.finalizeStartReg(succ) + a.finalizeStartReg(f, succ) } } -func (a *Allocator) releaseCallerSavedRegs(addrReg RealReg) { +func (a *Allocator[I, B]) releaseCallerSavedRegs(addrReg RealReg) { s := &a.state for allocated := RealReg(0); allocated < 64; allocated++ { @@ -944,7 +954,7 @@ func (a *Allocator) releaseCallerSavedRegs(addrReg RealReg) { } } -func (a *Allocator) fixMergeState(f Function, blk Block) { +func (a *Allocator[I, B]) fixMergeState(f Function[I, B], blk B) { preds := blk.Preds() if preds <= 1 { return @@ -975,7 +985,7 @@ func (a *Allocator) fixMergeState(f Function, blk Block) { continue } - pred := blk.Pred(i) + pred := f.Pred(blk, i) predSt := a.getOrAllocateBlockState(pred.ID()) s.resetAt(predSt) @@ -1022,10 +1032,10 @@ func (a *Allocator) fixMergeState(f Function, blk Block) { // - desiredVReg is the desired VReg value that should be on the register `r`. // - freeReg is the temporary register that can be used to swap the values, which may or may not be used. // - typ is the register type of the `r`. -func (a *Allocator) reconcileEdge(f Function, +func (a *Allocator[I, B]) reconcileEdge(f Function[I, B], r RealReg, - pred Block, - currentState, desiredState *vrState, + pred B, + currentState, desiredState *vrState[I, B], freeReg VReg, typ RegType, ) { @@ -1107,7 +1117,7 @@ func (a *Allocator) reconcileEdge(f Function, } } -func (a *Allocator) scheduleSpills(f Function) { +func (a *Allocator[I, B]) scheduleSpills(f Function[I, B]) { states := a.state.vrStates for i := 0; i <= states.MaxIDEncountered(); i++ { vs := states.Get(i) @@ -1120,7 +1130,7 @@ func (a *Allocator) scheduleSpills(f Function) { } } -func (a *Allocator) scheduleSpill(f Function, vs *vrState) { +func (a *Allocator[I, B]) scheduleSpill(f Function[I, B], vs *vrState[I, B]) { v := vs.v // If the value is the phi value, we need to insert a spill after each phi definition. if vs.isPhi { @@ -1133,10 +1143,11 @@ func (a *Allocator) scheduleSpill(f Function, vs *vrState) { pos := vs.lca definingBlk := vs.defBlk r := RealRegInvalid - if definingBlk == nil { + var nilBlk B + if definingBlk == nilBlk { panic(fmt.Sprintf("BUG: definingBlk should not be nil for %s. This is likley a bug in backend lowering logic", vs.v.String())) } - if pos == nil { + if pos == nilBlk { panic(fmt.Sprintf("BUG: pos should not be nil for %s. This is likley a bug in backend lowering logic", vs.v.String())) } @@ -1179,7 +1190,7 @@ func (a *Allocator) scheduleSpill(f Function, vs *vrState) { } // Reset resets the allocator's internal state so that it can be reused. -func (a *Allocator) Reset() { +func (a *Allocator[I, B]) Reset() { a.state.reset() a.blockStates.Reset() a.phiDefInstListPool.Reset() diff --git a/internal/engine/wazevo/backend/regalloc/regalloc_test.go b/internal/engine/wazevo/backend/regalloc/regalloc_test.go index 818b59b953..04433a8822 100644 --- a/internal/engine/wazevo/backend/regalloc/regalloc_test.go +++ b/internal/engine/wazevo/backend/regalloc/regalloc_test.go @@ -10,6 +10,8 @@ import ( ) func TestAllocator_livenessAnalysis(t *testing.T) { + type _function = Function[*mockInstr, *mockBlock] + const realRegID, realRegID2 = 50, 100 realReg, realReg2 := FromRealReg(realRegID, RegTypeInt), FromRealReg(realRegID2, RegTypeInt) phiVReg := VReg(12345).SetRegType(RegTypeInt) @@ -20,12 +22,12 @@ func TestAllocator_livenessAnalysis(t *testing.T) { for _, tc := range []struct { name string - setup func() Function + setup func() _function exps map[int]*exp }{ { name: "single block", - setup: func() Function { + setup: func() _function { return newMockFunction( newMockBlock(0, newMockInstr().def(1), @@ -39,7 +41,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { }, { name: "single block with real reg", - setup: func() Function { + setup: func() _function { realVReg := FromRealReg(10, RegTypeInt) param := VReg(1) ret := VReg(2) @@ -58,7 +60,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { { name: "straight", // b0 -> b1 -> b2 - setup: func() Function { + setup: func() _function { b0 := newMockBlock(0, newMockInstr().def(1000, 1, 2), newMockInstr().use(1000), @@ -91,7 +93,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { // 1 2 // \ / // 3 - setup: func() Function { + setup: func() _function { b0 := newMockBlock(0, newMockInstr().def(1000), newMockInstr().def(1), @@ -138,7 +140,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { // 2 3 // \ / // 4 use v5 (phi node) defined at both 1 and 3. - setup: func() Function { + setup: func() _function { b0 := newMockBlock(0, newMockInstr().def(1000, 2000, 3000), ).entry() @@ -182,7 +184,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { // ^ | // | v // 4 <- 3 -> 5 - setup: func() Function { + setup: func() _function { b0 := newMockBlock(0, newMockInstr().def(1), newMockInstr().def(phiVReg).use(1), @@ -238,7 +240,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { }, { name: "multiple pass alive", - setup: func() Function { + setup: func() _function { v := VReg(9999) b0 := newMockBlock(0, newMockInstr().def(v)).entry() @@ -285,7 +287,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { // ^ | // +----+ name: "Fig. 9.2 in paper", - setup: func() Function { + setup: func() _function { b0 := newMockBlock(0, newMockInstr().def(99999), newMockInstr().def(phiVReg).use(111).asCopy(), @@ -328,7 +330,7 @@ func TestAllocator_livenessAnalysis(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { f := tc.setup() - a := NewAllocator(&RegisterInfo{ + a := NewAllocator[*mockInstr, *mockBlock](&RegisterInfo{ RealRegName: func(r RealReg) string { return fmt.Sprintf("r%d", r) }, @@ -369,26 +371,29 @@ func TestAllocator_livenessAnalysis_copy(t *testing.T) { newMockInstr().use(1).def(2).asCopy(), ).entry(), ) - a := NewAllocator(&RegisterInfo{}) + a := NewAllocator[*mockInstr, *mockBlock](&RegisterInfo{}) a.livenessAnalysis(f) } func Test_findOrSpillAllocatable_prefersSpill(t *testing.T) { t.Run("ok", func(t *testing.T) { - s := &state{regsInUse: newRegInUseSet()} - s.regsInUse.add(RealReg(1), &vrState{v: VReg(2222222)}) - got := s.findOrSpillAllocatable(&Allocator{}, []RealReg{3}, 0, 3) + a := &Allocator[*mockInstr, *mockBlock]{} + s := &state[*mockInstr, *mockBlock]{regsInUse: newRegInUseSet[*mockInstr, *mockBlock]()} + s.regsInUse.add(RealReg(1), &vrState[*mockInstr, *mockBlock]{v: VReg(2222222)}) + got := a.findOrSpillAllocatable(s, []RealReg{3}, 0, 3) require.Equal(t, RealReg(3), got) }) t.Run("preferred but in use", func(t *testing.T) { - s := &state{vrStates: wazevoapi.NewIDedPool[vrState](resetVrState)} - s.regsInUse.add(RealReg(3), &vrState{v: VReg(1).SetRealReg(3)}) - got := s.findOrSpillAllocatable(&Allocator{}, []RealReg{3, 4}, 0, 3) + a := &Allocator[*mockInstr, *mockBlock]{} + s := &state[*mockInstr, *mockBlock]{vrStates: wazevoapi.NewIDedPool[vrState[*mockInstr, *mockBlock]](resetVrState[*mockInstr, *mockBlock])} + s.regsInUse.add(RealReg(3), &vrState[*mockInstr, *mockBlock]{v: VReg(1).SetRealReg(3)}) + got := a.findOrSpillAllocatable(s, []RealReg{3, 4}, 0, 3) require.Equal(t, RealReg(4), got) }) t.Run("preferred but forbidden", func(t *testing.T) { - s := &state{vrStates: wazevoapi.NewIDedPool[vrState](resetVrState)} - got := s.findOrSpillAllocatable(&Allocator{}, []RealReg{3, 4}, RegSet(0).add(3), 3) + a := &Allocator[*mockInstr, *mockBlock]{} + s := &state[*mockInstr, *mockBlock]{vrStates: wazevoapi.NewIDedPool[vrState[*mockInstr, *mockBlock]](resetVrState[*mockInstr, *mockBlock])} + got := a.findOrSpillAllocatable(s, []RealReg{3, 4}, RegSet(0).add(3), 3) require.Equal(t, RealReg(4), got) }) } diff --git a/internal/engine/wazevo/backend/regalloc/regset.go b/internal/engine/wazevo/backend/regalloc/regset.go index 1624fc96f9..5453a61c04 100644 --- a/internal/engine/wazevo/backend/regalloc/regset.go +++ b/internal/engine/wazevo/backend/regalloc/regset.go @@ -46,21 +46,21 @@ func (rs RegSet) Range(f func(allocatedRealReg RealReg)) { } } -type regInUseSet [64]*vrState +type regInUseSet[I Instr, B Block[I]] [64]*vrState[I, B] -func newRegInUseSet() regInUseSet { - var ret regInUseSet +func newRegInUseSet[I Instr, B Block[I]]() regInUseSet[I, B] { + var ret regInUseSet[I, B] ret.reset() return ret } -func (rs *regInUseSet) reset() { +func (rs *regInUseSet[I, B]) reset() { for i := range rs { rs[i] = nil } } -func (rs *regInUseSet) format(info *RegisterInfo) string { //nolint:unused +func (rs *regInUseSet[I, B]) format(info *RegisterInfo) string { //nolint:unused var ret []string for i, vr := range rs { if vr != nil { @@ -70,26 +70,26 @@ func (rs *regInUseSet) format(info *RegisterInfo) string { //nolint:unused return strings.Join(ret, ", ") } -func (rs *regInUseSet) has(r RealReg) bool { +func (rs *regInUseSet[I, B]) has(r RealReg) bool { return r < 64 && rs[r] != nil } -func (rs *regInUseSet) get(r RealReg) *vrState { +func (rs *regInUseSet[I, B]) get(r RealReg) *vrState[I, B] { return rs[r] } -func (rs *regInUseSet) remove(r RealReg) { +func (rs *regInUseSet[I, B]) remove(r RealReg) { rs[r] = nil } -func (rs *regInUseSet) add(r RealReg, vr *vrState) { +func (rs *regInUseSet[I, B]) add(r RealReg, vr *vrState[I, B]) { if r >= 64 { return } rs[r] = vr } -func (rs *regInUseSet) range_(f func(allocatedRealReg RealReg, vr *vrState)) { +func (rs *regInUseSet[I, B]) range_(f func(allocatedRealReg RealReg, vr *vrState[I, B])) { for i, vr := range rs { if vr != nil { f(RealReg(i), vr)