From cb0c399445686391f51fbf16bea2709c2dd9bd6d Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Tue, 5 Mar 2024 10:02:34 +0900 Subject: [PATCH] wazevo(frontend): fast pass for static local searches (#2116) Signed-off-by: Takeshi Yoneda --- .../engine/wazevo/frontend/frontend_test.go | 66 +++++++++---------- internal/engine/wazevo/frontend/lower.go | 12 +++- internal/engine/wazevo/ssa/builder.go | 8 +++ internal/wasm/func_validation.go | 7 ++ internal/wasm/module.go | 5 ++ 5 files changed, 63 insertions(+), 35 deletions(-) diff --git a/internal/engine/wazevo/frontend/frontend_test.go b/internal/engine/wazevo/frontend/frontend_test.go index 6f3dc309dc..9341bd9b45 100644 --- a/internal/engine/wazevo/frontend/frontend_test.go +++ b/internal/engine/wazevo/frontend/frontend_test.go @@ -444,10 +444,10 @@ blk3: (v5:i32) <-- (blk1,blk2) exp: ` blk0: (exec_ctx:i64, module_ctx:i64, v2:i32) v3:i32 = Iconst_32 0x0 - Jump blk1, v2 + Jump blk1 -blk1: (v4:i32) <-- (blk0) - Return v4 +blk1: () <-- (blk0) + Return v2 blk2: () `, @@ -464,15 +464,15 @@ blk1: () <-- (blk0) m: testcases.ReferenceValueFromUnsealedBlock2.Module, exp: ` blk0: (exec_ctx:i64, module_ctx:i64, v2:i32) - Jump blk1, v2 + Jump blk1 -blk1: (v3:i32) <-- (blk0,blk1) - Brnz v3, blk1, v3 +blk1: () <-- (blk0,blk1) + Brnz v2, blk1 Jump blk4 blk2: () <-- (blk3) - v4:i32 = Iconst_32 0x0 - Jump blk_ret, v4 + v3:i32 = Iconst_32 0x0 + Jump blk_ret, v3 blk3: () <-- (blk4) Jump blk2 @@ -498,8 +498,8 @@ blk3: () <-- (blk4) Jump fallthrough blk2: () <-- (blk3) - v4:i32 = Iconst_32 0x0 - Jump blk_ret, v4 + v3:i32 = Iconst_32 0x0 + Jump blk_ret, v3 `, }, { @@ -1060,22 +1060,22 @@ blk1: () <-- (blk0) Call f1:sig1, exec_ctx, module_ctx v5:i64 = Load module_ctx, 0x8 v6:i64 = Uload32 module_ctx, 0x10 - Jump blk3, v2 + Jump blk3 blk2: () <-- (blk0) - Jump blk3, v2 + Jump blk3 -blk3: (v7:i32) <-- (blk1,blk2) - v8:i64 = Iconst_64 0x4 - v9:i64 = UExtend v7, 32->64 - v10:i64 = Uload32 module_ctx, 0x10 - v11:i64 = Iadd v9, v8 - v12:i32 = Icmp lt_u, v10, v11 - ExitIfTrue v12, exec_ctx, memory_out_of_bounds - v13:i64 = Load module_ctx, 0x8 - v14:i64 = Iadd v13, v9 - v15:i32 = Load v14, 0x0 - Jump blk_ret, v15 +blk3: () <-- (blk1,blk2) + v7:i64 = Iconst_64 0x4 + v8:i64 = UExtend v2, 32->64 + v9:i64 = Uload32 module_ctx, 0x10 + v10:i64 = Iadd v8, v7 + v11:i32 = Icmp lt_u, v9, v10 + ExitIfTrue v11, exec_ctx, memory_out_of_bounds + v12:i64 = Load module_ctx, 0x8 + v13:i64 = Iadd v12, v8 + v14:i32 = Load v13, 0x0 + Jump blk_ret, v14 `, expAfterPasses: ` signatures: @@ -1095,16 +1095,16 @@ blk2: () <-- (blk0) Jump fallthrough blk3: () <-- (blk1,blk2) - v8:i64 = Iconst_64 0x4 - v9:i64 = UExtend v2, 32->64 - v10:i64 = Uload32 module_ctx, 0x10 - v11:i64 = Iadd v9, v8 - v12:i32 = Icmp lt_u, v10, v11 - ExitIfTrue v12, exec_ctx, memory_out_of_bounds - v13:i64 = Load module_ctx, 0x8 - v14:i64 = Iadd v13, v9 - v15:i32 = Load v14, 0x0 - Jump blk_ret, v15 + v7:i64 = Iconst_64 0x4 + v8:i64 = UExtend v2, 32->64 + v9:i64 = Uload32 module_ctx, 0x10 + v10:i64 = Iadd v8, v7 + v11:i32 = Icmp lt_u, v9, v10 + ExitIfTrue v11, exec_ctx, memory_out_of_bounds + v12:i64 = Load module_ctx, 0x8 + v13:i64 = Iadd v12, v8 + v14:i32 = Load v13, 0x0 + Jump blk_ret, v14 `, }, { diff --git a/internal/engine/wazevo/frontend/lower.go b/internal/engine/wazevo/frontend/lower.go index 963c2643cf..0674ff9017 100644 --- a/internal/engine/wazevo/frontend/lower.go +++ b/internal/engine/wazevo/frontend/lower.go @@ -1078,8 +1078,16 @@ func (c *Compiler) lowerCurrentOpcode() { break } variable := c.localVariable(index) - v := builder.MustFindValue(variable) - state.push(v) + if _, ok := c.m.NonStaticLocals[c.wasmLocalFunctionIndex][index]; ok { + state.push(builder.MustFindValue(variable)) + } else { + // If a local is static, we can simply find it in the entry block which is either a function param + // or a zero value. This fast pass helps to avoid the overhead of searching the entire function plus + // avoid adding unnecessary block arguments. + // TODO: I think this optimization should be done in a SSA pass like passRedundantPhiEliminationOpt, + // but somehow there's some corner cases that it fails to optimize. + state.push(builder.MustFindValueInBlk(variable, c.ssaBuilder.EntryBlock())) + } case wasm.OpcodeLocalSet: index := c.readI32u() if state.unreachable { diff --git a/internal/engine/wazevo/ssa/builder.go b/internal/engine/wazevo/ssa/builder.go index 34368ac60d..1fc84d2eaf 100644 --- a/internal/engine/wazevo/ssa/builder.go +++ b/internal/engine/wazevo/ssa/builder.go @@ -54,6 +54,9 @@ type Builder interface { // MustFindValue searches the latest definition of the given Variable and returns the result. MustFindValue(variable Variable) Value + // MustFindValueInBlk is the same as MustFindValue except it searches the latest definition from the given BasicBlock. + MustFindValueInBlk(variable Variable, blk BasicBlock) Value + // FindValueInLinearPath tries to find the latest definition of the given Variable in the linear path to the current BasicBlock. // If it cannot find the definition, or it's not sealed yet, it returns ValueInvalid. FindValueInLinearPath(variable Variable) Value @@ -445,6 +448,11 @@ func (b *builder) findValueInLinearPath(variable Variable, blk *basicBlock) Valu return ValueInvalid } +func (b *builder) MustFindValueInBlk(variable Variable, blk BasicBlock) Value { + typ := b.definedVariableType(variable) + return b.findValue(typ, variable, blk.(*basicBlock)) +} + // MustFindValue implements Builder.MustFindValue. func (b *builder) MustFindValue(variable Variable) Value { typ := b.definedVariableType(variable) diff --git a/internal/wasm/func_validation.go b/internal/wasm/func_validation.go index 673d576393..df811bc5aa 100644 --- a/internal/wasm/func_validation.go +++ b/internal/wasm/func_validation.go @@ -67,6 +67,11 @@ func (m *Module) validateFunctionWithMaxStackValues( declaredFunctionIndexes map[Index]struct{}, br *bytes.Reader, ) error { + nonStaticLocals := make(map[Index]struct{}) + if len(m.NonStaticLocals) > 0 { + m.NonStaticLocals[idx] = nonStaticLocals + } + functionType := &m.TypeSection[m.FunctionSection[idx]] code := &m.CodeSection[idx] body := code.Body @@ -352,6 +357,7 @@ func (m *Module) validateFunctionWithMaxStackValues( return fmt.Errorf("invalid local index for %s %d >= %d(=len(locals)+len(parameters))", OpcodeLocalSetName, index, l) } + nonStaticLocals[index] = struct{}{} var expType ValueType if index < inputLen { expType = functionType.Params[index] @@ -367,6 +373,7 @@ func (m *Module) validateFunctionWithMaxStackValues( return fmt.Errorf("invalid local index for %s %d >= %d(=len(locals)+len(parameters))", OpcodeLocalTeeName, index, l) } + nonStaticLocals[index] = struct{}{} var expType ValueType if index < inputLen { expType = functionType.Params[index] diff --git a/internal/wasm/module.go b/internal/wasm/module.go index c019730e04..6ef68d508b 100644 --- a/internal/wasm/module.go +++ b/internal/wasm/module.go @@ -185,6 +185,9 @@ type Module struct { // as described in https://yurydelendik.github.io/webassembly-dwarf/, though it is not specified in the Wasm // specification: https://github.com/WebAssembly/debugging/issues/1 DWARFLines *wasmdebug.DWARFLines + + // NonStaticLocals collects the local indexes that will change its value through either local.get or local.tee. + NonStaticLocals []map[Index]struct{} } // ModuleID represents sha256 hash value uniquely assigned to Module. @@ -363,6 +366,8 @@ func (m *Module) validateFunctions(enabledFeatures api.CoreFeatures, functions [ br := bytes.NewReader(nil) // Also, we reuse the stacks across multiple function validations to reduce allocations. vs := &stacks{} + // Non-static locals are gathered during validation and used in the down-stream compilation. + m.NonStaticLocals = make([]map[Index]struct{}, len(m.FunctionSection)) for idx, typeIndex := range m.FunctionSection { if typeIndex >= typeCount { return fmt.Errorf("invalid %s: type section index %d out of range", m.funcDesc(SectionIDFunction, Index(idx)), typeIndex)