Skip to content

Commit

Permalink
wazevo(frontend): fast pass for static local searches (#2116)
Browse files Browse the repository at this point in the history
Signed-off-by: Takeshi Yoneda <[email protected]>
  • Loading branch information
mathetake authored Mar 5, 2024
1 parent 027e68b commit cb0c399
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 35 deletions.
66 changes: 33 additions & 33 deletions internal/engine/wazevo/frontend/frontend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,10 @@ blk3: (v5:i32) <-- (blk1,blk2)
exp: `
blk0: (exec_ctx:i64, module_ctx:i64, v2:i32)
v3:i32 = Iconst_32 0x0
Jump blk1, v2
Jump blk1
blk1: (v4:i32) <-- (blk0)
Return v4
blk1: () <-- (blk0)
Return v2
blk2: ()
`,
Expand All @@ -464,15 +464,15 @@ blk1: () <-- (blk0)
m: testcases.ReferenceValueFromUnsealedBlock2.Module,
exp: `
blk0: (exec_ctx:i64, module_ctx:i64, v2:i32)
Jump blk1, v2
Jump blk1
blk1: (v3:i32) <-- (blk0,blk1)
Brnz v3, blk1, v3
blk1: () <-- (blk0,blk1)
Brnz v2, blk1
Jump blk4
blk2: () <-- (blk3)
v4:i32 = Iconst_32 0x0
Jump blk_ret, v4
v3:i32 = Iconst_32 0x0
Jump blk_ret, v3
blk3: () <-- (blk4)
Jump blk2
Expand All @@ -498,8 +498,8 @@ blk3: () <-- (blk4)
Jump fallthrough
blk2: () <-- (blk3)
v4:i32 = Iconst_32 0x0
Jump blk_ret, v4
v3:i32 = Iconst_32 0x0
Jump blk_ret, v3
`,
},
{
Expand Down Expand Up @@ -1060,22 +1060,22 @@ blk1: () <-- (blk0)
Call f1:sig1, exec_ctx, module_ctx
v5:i64 = Load module_ctx, 0x8
v6:i64 = Uload32 module_ctx, 0x10
Jump blk3, v2
Jump blk3
blk2: () <-- (blk0)
Jump blk3, v2
Jump blk3
blk3: (v7:i32) <-- (blk1,blk2)
v8:i64 = Iconst_64 0x4
v9:i64 = UExtend v7, 32->64
v10:i64 = Uload32 module_ctx, 0x10
v11:i64 = Iadd v9, v8
v12:i32 = Icmp lt_u, v10, v11
ExitIfTrue v12, exec_ctx, memory_out_of_bounds
v13:i64 = Load module_ctx, 0x8
v14:i64 = Iadd v13, v9
v15:i32 = Load v14, 0x0
Jump blk_ret, v15
blk3: () <-- (blk1,blk2)
v7:i64 = Iconst_64 0x4
v8:i64 = UExtend v2, 32->64
v9:i64 = Uload32 module_ctx, 0x10
v10:i64 = Iadd v8, v7
v11:i32 = Icmp lt_u, v9, v10
ExitIfTrue v11, exec_ctx, memory_out_of_bounds
v12:i64 = Load module_ctx, 0x8
v13:i64 = Iadd v12, v8
v14:i32 = Load v13, 0x0
Jump blk_ret, v14
`,
expAfterPasses: `
signatures:
Expand All @@ -1095,16 +1095,16 @@ blk2: () <-- (blk0)
Jump fallthrough
blk3: () <-- (blk1,blk2)
v8:i64 = Iconst_64 0x4
v9:i64 = UExtend v2, 32->64
v10:i64 = Uload32 module_ctx, 0x10
v11:i64 = Iadd v9, v8
v12:i32 = Icmp lt_u, v10, v11
ExitIfTrue v12, exec_ctx, memory_out_of_bounds
v13:i64 = Load module_ctx, 0x8
v14:i64 = Iadd v13, v9
v15:i32 = Load v14, 0x0
Jump blk_ret, v15
v7:i64 = Iconst_64 0x4
v8:i64 = UExtend v2, 32->64
v9:i64 = Uload32 module_ctx, 0x10
v10:i64 = Iadd v8, v7
v11:i32 = Icmp lt_u, v9, v10
ExitIfTrue v11, exec_ctx, memory_out_of_bounds
v12:i64 = Load module_ctx, 0x8
v13:i64 = Iadd v12, v8
v14:i32 = Load v13, 0x0
Jump blk_ret, v14
`,
},
{
Expand Down
12 changes: 10 additions & 2 deletions internal/engine/wazevo/frontend/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -1078,8 +1078,16 @@ func (c *Compiler) lowerCurrentOpcode() {
break
}
variable := c.localVariable(index)
v := builder.MustFindValue(variable)
state.push(v)
if _, ok := c.m.NonStaticLocals[c.wasmLocalFunctionIndex][index]; ok {
state.push(builder.MustFindValue(variable))
} else {
// If a local is static, we can simply find it in the entry block which is either a function param
// or a zero value. This fast pass helps to avoid the overhead of searching the entire function plus
// avoid adding unnecessary block arguments.
// TODO: I think this optimization should be done in a SSA pass like passRedundantPhiEliminationOpt,
// but somehow there's some corner cases that it fails to optimize.
state.push(builder.MustFindValueInBlk(variable, c.ssaBuilder.EntryBlock()))
}
case wasm.OpcodeLocalSet:
index := c.readI32u()
if state.unreachable {
Expand Down
8 changes: 8 additions & 0 deletions internal/engine/wazevo/ssa/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ type Builder interface {
// MustFindValue searches the latest definition of the given Variable and returns the result.
MustFindValue(variable Variable) Value

// MustFindValueInBlk is the same as MustFindValue except it searches the latest definition from the given BasicBlock.
MustFindValueInBlk(variable Variable, blk BasicBlock) Value

// FindValueInLinearPath tries to find the latest definition of the given Variable in the linear path to the current BasicBlock.
// If it cannot find the definition, or it's not sealed yet, it returns ValueInvalid.
FindValueInLinearPath(variable Variable) Value
Expand Down Expand Up @@ -445,6 +448,11 @@ func (b *builder) findValueInLinearPath(variable Variable, blk *basicBlock) Valu
return ValueInvalid
}

func (b *builder) MustFindValueInBlk(variable Variable, blk BasicBlock) Value {
typ := b.definedVariableType(variable)
return b.findValue(typ, variable, blk.(*basicBlock))
}

// MustFindValue implements Builder.MustFindValue.
func (b *builder) MustFindValue(variable Variable) Value {
typ := b.definedVariableType(variable)
Expand Down
7 changes: 7 additions & 0 deletions internal/wasm/func_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ func (m *Module) validateFunctionWithMaxStackValues(
declaredFunctionIndexes map[Index]struct{},
br *bytes.Reader,
) error {
nonStaticLocals := make(map[Index]struct{})
if len(m.NonStaticLocals) > 0 {
m.NonStaticLocals[idx] = nonStaticLocals
}

functionType := &m.TypeSection[m.FunctionSection[idx]]
code := &m.CodeSection[idx]
body := code.Body
Expand Down Expand Up @@ -352,6 +357,7 @@ func (m *Module) validateFunctionWithMaxStackValues(
return fmt.Errorf("invalid local index for %s %d >= %d(=len(locals)+len(parameters))",
OpcodeLocalSetName, index, l)
}
nonStaticLocals[index] = struct{}{}
var expType ValueType
if index < inputLen {
expType = functionType.Params[index]
Expand All @@ -367,6 +373,7 @@ func (m *Module) validateFunctionWithMaxStackValues(
return fmt.Errorf("invalid local index for %s %d >= %d(=len(locals)+len(parameters))",
OpcodeLocalTeeName, index, l)
}
nonStaticLocals[index] = struct{}{}
var expType ValueType
if index < inputLen {
expType = functionType.Params[index]
Expand Down
5 changes: 5 additions & 0 deletions internal/wasm/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ type Module struct {
// as described in https://yurydelendik.github.io/webassembly-dwarf/, though it is not specified in the Wasm
// specification: https://github.com/WebAssembly/debugging/issues/1
DWARFLines *wasmdebug.DWARFLines

// NonStaticLocals collects the local indexes that will change its value through either local.get or local.tee.
NonStaticLocals []map[Index]struct{}
}

// ModuleID represents sha256 hash value uniquely assigned to Module.
Expand Down Expand Up @@ -363,6 +366,8 @@ func (m *Module) validateFunctions(enabledFeatures api.CoreFeatures, functions [
br := bytes.NewReader(nil)
// Also, we reuse the stacks across multiple function validations to reduce allocations.
vs := &stacks{}
// Non-static locals are gathered during validation and used in the down-stream compilation.
m.NonStaticLocals = make([]map[Index]struct{}, len(m.FunctionSection))
for idx, typeIndex := range m.FunctionSection {
if typeIndex >= typeCount {
return fmt.Errorf("invalid %s: type section index %d out of range", m.funcDesc(SectionIDFunction, Index(idx)), typeIndex)
Expand Down

0 comments on commit cb0c399

Please sign in to comment.