diff --git a/internal/engine/wazevo/frontend/frontend.go b/internal/engine/wazevo/frontend/frontend.go index 4e466d5fe8..6e2d40b1b6 100644 --- a/internal/engine/wazevo/frontend/frontend.go +++ b/internal/engine/wazevo/frontend/frontend.go @@ -53,29 +53,43 @@ type Compiler struct { br *bytes.Reader loweringState loweringState - knownSafeBounds []knownSafeBound + knownSafeBounds [] /* ssa.ValueID to */ knownSafeBound knownSafeBoundsSet []ssa.ValueID + knownSafeBoundsAtTheEndOfBlocks [] /* ssa.BlockID to */ knownSafeBoundsAtTheEndOfBlock + varLengthKnownSafeBoundWithIDPool wazevoapi.VarLengthPool[knownSafeBoundWithID] + execCtxPtrValue, moduleCtxPtrValue ssa.Value } -// knownSafeBound represents a known safe bound for a value. -type knownSafeBound struct { - // bound is a constant upper bound for the value. - bound uint64 - // absoluteAddr is the absolute address of the value. - absoluteAddr ssa.Value -} +type ( + // knownSafeBound represents a known safe bound for a value. + knownSafeBound struct { + // bound is a constant upper bound for the value. + bound uint64 + // absoluteAddr is the absolute address of the value. + absoluteAddr ssa.Value + } + // knownSafeBoundWithID is a knownSafeBound with the ID of the value. + knownSafeBoundWithID struct { + knownSafeBound + id ssa.ValueID + } + knownSafeBoundsAtTheEndOfBlock = wazevoapi.VarLength[knownSafeBoundWithID] +) + +var knownSafeBoundsAtTheEndOfBlockNil = wazevoapi.NewNilVarLength[knownSafeBoundWithID]() // NewFrontendCompiler returns a frontend Compiler. func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData, ensureTermination bool, listenerOn bool, sourceInfo bool) *Compiler { c := &Compiler{ - m: m, - ssaBuilder: ssaBuilder, - br: bytes.NewReader(nil), - offset: offset, - ensureTermination: ensureTermination, - needSourceOffsetInfo: sourceInfo, + m: m, + ssaBuilder: ssaBuilder, + br: bytes.NewReader(nil), + offset: offset, + ensureTermination: ensureTermination, + needSourceOffsetInfo: sourceInfo, + varLengthKnownSafeBoundWithIDPool: wazevoapi.NewVarLengthPool[knownSafeBoundWithID](), } c.declareSignatures(listenerOn) return c @@ -207,6 +221,8 @@ func (c *Compiler) Init(idx, typIndex wasm.Index, typ *wasm.FunctionType, localT c.wasmFunctionBodyOffsetInCodeSection = bodyOffsetInCodeSection c.needListener = needListener c.clearSafeBounds() + c.varLengthKnownSafeBoundWithIDPool.Reset() + c.knownSafeBoundsAtTheEndOfBlocks = c.knownSafeBoundsAtTheEndOfBlocks[:0] } // Note: this assumes 64-bit platform (I believe we won't have 32-bit backend ;)). @@ -445,6 +461,7 @@ func (c *Compiler) clearSafeBounds() { for _, v := range c.knownSafeBoundsSet { ptr := &c.knownSafeBounds[v] ptr.bound = 0 + ptr.absoluteAddr = ssa.ValueInvalid } c.knownSafeBoundsSet = c.knownSafeBoundsSet[:0] } @@ -470,3 +487,84 @@ func (c *Compiler) allocateVarLengthValues(vs ...ssa.Value) ssa.Values { } return args } + +func (c *Compiler) finalizeKnownSafeBoundsAtTheEndOfBlock(bID ssa.BasicBlockID) { + _bID := int(bID) + if l := len(c.knownSafeBoundsAtTheEndOfBlocks); _bID >= l { + c.knownSafeBoundsAtTheEndOfBlocks = append(c.knownSafeBoundsAtTheEndOfBlocks, + make([]knownSafeBoundsAtTheEndOfBlock, _bID+1-len(c.knownSafeBoundsAtTheEndOfBlocks))...) + for i := l; i < len(c.knownSafeBoundsAtTheEndOfBlocks); i++ { + c.knownSafeBoundsAtTheEndOfBlocks[i] = knownSafeBoundsAtTheEndOfBlockNil + } + } + p := &c.varLengthKnownSafeBoundWithIDPool + size := len(c.knownSafeBoundsSet) + allocated := c.varLengthKnownSafeBoundWithIDPool.Allocate(size) + for _, vID := range c.knownSafeBoundsSet { + kb := c.knownSafeBounds[vID] + allocated = allocated.Append(p, knownSafeBoundWithID{ + knownSafeBound: kb, + id: vID, + }) + } + c.knownSafeBoundsAtTheEndOfBlocks[bID] = allocated + c.clearSafeBounds() +} + +func (c *Compiler) initializeCurrentBlockKnownBounds() { + currentBlk := c.ssaBuilder.CurrentBlock() + switch preds := currentBlk.Preds(); preds { + case 0: + case 1: + pred := currentBlk.Pred(0).ID() + for _, kb := range c.getKnownSafeBoundsAtTheEndOfBlocks(pred).View() { + c.recordKnownSafeBound(kb.id, kb.bound, kb.absoluteAddr) + } + default: + primary := currentBlk.Pred(0).ID() + type mapVal struct { + kb knownSafeBoundWithID + count int + } + set := map[ssa.ValueID]mapVal{} + for _, kb := range c.getKnownSafeBoundsAtTheEndOfBlocks(primary).View() { + if kb.valid() { + set[kb.id] = mapVal{kb, 1} + } + } + + // If there are more than one predecessor, we need to find the intersection of the known safe bounds. + for i := 1; i < preds; i++ { + pred := currentBlk.Pred(i).ID() + for _, kb := range c.getKnownSafeBoundsAtTheEndOfBlocks(pred).View() { + if !kb.valid() { + continue + } + mv, ok := set[kb.id] + if !ok { + continue + } + mv.count++ + // Choose the lower bound. + if kb.bound < mv.kb.bound { + mv.kb = kb + } + set[kb.id] = mv + } + } + for _, mv := range set { + if mv.count == preds { + kb := mv.kb + // Absolute address cannot be used in the intersection since the value might be only defined in one of the predecessors. + c.recordKnownSafeBound(kb.id, kb.bound, ssa.ValueInvalid) + } + } + } +} + +func (c *Compiler) getKnownSafeBoundsAtTheEndOfBlocks(id ssa.BasicBlockID) knownSafeBoundsAtTheEndOfBlock { + if int(id) >= len(c.knownSafeBoundsAtTheEndOfBlocks) { + return knownSafeBoundsAtTheEndOfBlockNil + } + return c.knownSafeBoundsAtTheEndOfBlocks[id] +} diff --git a/internal/engine/wazevo/frontend/frontend_test.go b/internal/engine/wazevo/frontend/frontend_test.go index 51fbc422f7..39d0be5ba4 100644 --- a/internal/engine/wazevo/frontend/frontend_test.go +++ b/internal/engine/wazevo/frontend/frontend_test.go @@ -1,6 +1,7 @@ package frontend import ( + "fmt" "testing" "github.com/tetratelabs/wazero/api" @@ -2810,6 +2811,86 @@ blk0: (exec_ctx:i64, module_ctx:i64) exp: ` blk0: (exec_ctx:i64, module_ctx:i64) Jump blk_ret +`, + }, + { + name: "bounds check in if-else", + m: &wasm.Module{ + TypeSection: []wasm.FunctionType{{Params: []wasm.ValueType{wasm.ValueTypeI32}}}, + FunctionSection: []wasm.Index{0}, + CodeSection: []wasm.Code{{ + Body: []byte{ + wasm.OpcodeLocalGet, 0, + wasm.OpcodeI32Load, 0x2, 0x20, // alignment=2 (natural alignment) staticOffset=0x20 + wasm.OpcodeDrop, + + wasm.OpcodeLocalGet, 0, + wasm.OpcodeIf, 0x40, // blockSignature:vv. + /* */ wasm.OpcodeLocalGet, 0, + /* */ // This bound check should be removed since it's known to be in bounds. + /* */ wasm.OpcodeI32Load, 0x2, 0x10, // alignment=2 (natural alignment) staticOffset=0x20 + /* */ wasm.OpcodeDrop, + wasm.OpcodeElse, + /* */ wasm.OpcodeLocalGet, 0, + /* */ // This bound check should be removed since it's known to be in bounds. + /* */ wasm.OpcodeI32Load, 0x2, 0x10, // alignment=2 (natural alignment) staticOffset=0x20 + /* */ wasm.OpcodeDrop, + /* */ // This shouldn't be removed since it's not known to be in bounds. + /* */ wasm.OpcodeLocalGet, 0, + /* */ wasm.OpcodeI32Load, 0x2, 0x30, // alignment=2 (natural alignment) staticOffset=0x20 + /* */ wasm.OpcodeDrop, + /* */ // But this is known to be in bounds. + /* */ wasm.OpcodeLocalGet, 0, + /* */ wasm.OpcodeI32Load, 0x2, 0x25, // alignment=2 (natural alignment) staticOffset=0x20 + /* */ wasm.OpcodeDrop, + wasm.OpcodeEnd, + + // At this point, the known bound 0x20 should be retained. + wasm.OpcodeLocalGet, 0, + wasm.OpcodeI32Load, 0x2, 0x15, // alignment=2 (natural alignment) staticOffset=0x20 + wasm.OpcodeDrop, + + wasm.OpcodeEnd, + }, + }}, + MemorySection: &wasm.Memory{Min: 1}, + }, + features: api.CoreFeaturesV2, + exp: ` +blk0: (exec_ctx:i64, module_ctx:i64, v2:i32) + v3:i64 = Iconst_64 0x24 + v4:i64 = UExtend v2, 32->64 + v5:i64 = Uload32 module_ctx, 0x10 + v6:i64 = Iadd v4, v3 + v7:i32 = Icmp lt_u, v5, v6 + ExitIfTrue v7, exec_ctx, memory_out_of_bounds + v8:i64 = Load module_ctx, 0x8 + v9:i64 = Iadd v8, v4 + v10:i32 = Load v9, 0x20 + Brz v2, blk2 + Jump blk1 + +blk1: () <-- (blk0) + v11:i32 = Load v9, 0x10 + Jump blk3 + +blk2: () <-- (blk0) + v12:i32 = Load v9, 0x10 + v13:i64 = Iconst_64 0x34 + v14:i64 = UExtend v2, 32->64 + v15:i64 = Iadd v14, v13 + v16:i32 = Icmp lt_u, v5, v15 + ExitIfTrue v16, exec_ctx, memory_out_of_bounds + v17:i32 = Load v9, 0x30 + v18:i32 = Load v9, 0x25 + Jump blk3 + +blk3: () <-- (blk1,blk2) + v19:i64 = Load module_ctx, 0x8 + v20:i64 = UExtend v2, 32->64 + v21:i64 = Iadd v19, v20 + v22:i32 = Load v21, 0x15 + Jump blk_ret `, }, } { @@ -3001,7 +3082,7 @@ func TestCompiler_clearSafeBounds(t *testing.T) { c.knownSafeBoundsSet = []ssa.ValueID{0, 2, 5} c.clearSafeBounds() require.Equal(t, 0, len(c.knownSafeBoundsSet)) - require.Equal(t, []knownSafeBound{{}, {}, {}, {}, {}, {}}, c.knownSafeBounds) + require.Equal(t, []knownSafeBound{{absoluteAddr: ssa.ValueInvalid}, {}, {absoluteAddr: ssa.ValueInvalid}, {}, {}, {absoluteAddr: ssa.ValueInvalid}}, c.knownSafeBounds) } func TestCompiler_resetAbsoluteAddressInSafeBounds(t *testing.T) { @@ -3033,3 +3114,120 @@ func TestKnownSafeBound_valid(t *testing.T) { k.bound = 0 require.False(t, k.valid()) } + +func TestCompiler_finalizeKnownSafeBoundsAtTheEndOoBlock(t *testing.T) { + c := NewFrontendCompiler(&wasm.Module{}, ssa.NewBuilder(), nil, false, false, false) + blk := c.ssaBuilder.AllocateBasicBlock() + require.True(t, len(c.getKnownSafeBoundsAtTheEndOfBlocks(blk.ID()).View()) == 0) + c.ssaBuilder.SetCurrentBlock(blk) + c.knownSafeBoundsSet = []ssa.ValueID{0, 2, 5} + c.knownSafeBounds = []knownSafeBound{ + {bound: 1, absoluteAddr: ssa.Value(1)}, + {}, + {bound: 2, absoluteAddr: ssa.Value(2)}, + {}, + {}, + {bound: 3, absoluteAddr: ssa.Value(3)}, + } + c.finalizeKnownSafeBoundsAtTheEndOfBlock(blk.ID()) + require.True(t, len(c.knownSafeBoundsSet) == 0) + finalized := c.getKnownSafeBoundsAtTheEndOfBlocks(blk.ID()) + require.Equal(t, 3, len(finalized.View())) +} + +func TestCompiler_initializeCurrentBlockKnownBounds(t *testing.T) { + t.Run("single", func(t *testing.T) { + c := NewFrontendCompiler(&wasm.Module{}, ssa.NewBuilder(), nil, false, false, false) + builder := c.ssaBuilder + child := builder.AllocateBasicBlock() + { + parent := builder.AllocateBasicBlock() + builder.SetCurrentBlock(parent) + c.recordKnownSafeBound(1, 99, 9999) + c.recordKnownSafeBound(2, 150, 9999) + c.recordKnownSafeBound(5, 666, 54321) + builder.AllocateInstruction().AsJump(ssa.ValuesNil, child).Insert(builder) + c.finalizeKnownSafeBoundsAtTheEndOfBlock(parent.ID()) + } + + builder.SetCurrentBlock(child) + c.initializeCurrentBlockKnownBounds() + kb := c.getKnownSafeBound(1) + require.True(t, kb.valid()) + require.Equal(t, uint64(99), kb.bound) + require.Equal(t, ssa.Value(9999), kb.absoluteAddr) + kb = c.getKnownSafeBound(2) + require.True(t, kb.valid()) + require.Equal(t, uint64(150), kb.bound) + require.Equal(t, ssa.Value(9999), kb.absoluteAddr) + kb = c.getKnownSafeBound(5) + require.True(t, kb.valid()) + require.Equal(t, uint64(666), kb.bound) + require.Equal(t, ssa.Value(54321), kb.absoluteAddr) + }) + t.Run("multiple predecessors", func(t *testing.T) { + c := NewFrontendCompiler(&wasm.Module{}, ssa.NewBuilder(), nil, false, false, false) + builder := c.ssaBuilder + child := builder.AllocateBasicBlock() + { + p1 := builder.AllocateBasicBlock() + builder.SetCurrentBlock(p1) + c.recordKnownSafeBound(1, 99, 9999) + c.recordKnownSafeBound(2, 150, 9999) + c.recordKnownSafeBound(5, 666, 54321) + c.recordKnownSafeBound(592131, 666, 54321) + builder.AllocateInstruction().AsJump(ssa.ValuesNil, child).Insert(builder) + c.finalizeKnownSafeBoundsAtTheEndOfBlock(p1.ID()) + } + { + p2 := builder.AllocateBasicBlock() + builder.SetCurrentBlock(p2) + c.recordKnownSafeBound(1, 100, 999419) + c.recordKnownSafeBound(2, 4, 9991239) + c.recordKnownSafeBound(5, 555, 54341221) + c.recordKnownSafeBound(6, 666, 54321) + c.recordKnownSafeBound(7, 666, 54321) + builder.AllocateInstruction().AsJump(ssa.ValuesNil, child).Insert(builder) + c.finalizeKnownSafeBoundsAtTheEndOfBlock(p2.ID()) + } + { + p3 := builder.AllocateBasicBlock() + builder.SetCurrentBlock(p3) + c.recordKnownSafeBound(1, 1, 999419) + c.recordKnownSafeBound(2, 11111, 9991239) + c.recordKnownSafeBound(5, 5551231, 54341221) + c.recordKnownSafeBound(7, 666, 54321) + c.recordKnownSafeBound(60, 666, 54321) + builder.AllocateInstruction().AsJump(ssa.ValuesNil, child).Insert(builder) + c.finalizeKnownSafeBoundsAtTheEndOfBlock(p3.ID()) + } + + builder.SetCurrentBlock(child) + c.initializeCurrentBlockKnownBounds() + for _, tc := range []struct { + id ssa.ValueID + valid bool + bound uint64 + }{ + {id: 0, valid: false}, + {id: 1, valid: true, bound: 1}, + {id: 2, valid: true, bound: 4}, + {id: 3, valid: false}, + {id: 4, valid: false}, + {id: 5, valid: true, bound: 555}, + {id: 6, valid: false}, + {id: 7, valid: false}, + {id: 60, valid: false}, + {id: 592131, valid: false}, + } { + t.Run(fmt.Sprintf("id=%d", tc.id), func(t *testing.T) { + kb := c.getKnownSafeBound(tc.id) + require.Equal(t, tc.valid, kb.valid()) + if kb.valid() { + require.Equal(t, tc.bound, kb.bound) + require.Equal(t, ssa.ValueInvalid, kb.absoluteAddr) + } + }) + } + }) +} diff --git a/internal/engine/wazevo/frontend/lower.go b/internal/engine/wazevo/frontend/lower.go index 2e8ce6ae30..0c3a02b13f 100644 --- a/internal/engine/wazevo/frontend/lower.go +++ b/internal/engine/wazevo/frontend/lower.go @@ -161,7 +161,16 @@ func (c *Compiler) lowerBody(entryBlk ssa.BasicBlock) { }) for c.loweringState.pc < len(c.wasmFunctionBody) { + blkBeforeLowering := c.ssaBuilder.CurrentBlock() c.lowerCurrentOpcode() + blkAfterLowering := c.ssaBuilder.CurrentBlock() + if blkBeforeLowering != blkAfterLowering { + // In Wasm, once a block exits, that means we've done compiling the block. + // Therefore, we finalize the known bounds at the end of the block for the exiting block. + c.finalizeKnownSafeBoundsAtTheEndOfBlock(blkBeforeLowering.ID()) + // After that, we initialize the known bounds for the new compilation target block. + c.initializeCurrentBlockKnownBounds() + } } } @@ -1415,11 +1424,6 @@ func (c *Compiler) lowerCurrentOpcode() { builder.Seal(thenBlk) builder.Seal(elseBlk) case wasm.OpcodeElse: - // Reset the safe bounds since we are entering the Else block. - // TODO: we should be able to inherit the safe bounds from the parent block. So, right now, this means that - // else block is a little bit more slow than the then block. - c.clearSafeBounds() - ifctrl := state.ctrlPeekAt(0) if unreachable := state.unreachable; unreachable && state.unreachableDepth > 0 { // If it is currently in unreachable and is a nested if, @@ -1482,13 +1486,6 @@ func (c *Compiler) lowerCurrentOpcode() { builder.Seal(followingBlk) - if unreachable || followingBlk.Preds() != 1 { - // If we can reach this block without being unreachable, and it has only one predecessor, - // this means that we get here from the unique block contiguously. Therefore, we can - // keep using the same safe bounds information. Otherwise, we need to reset it. - c.clearSafeBounds() - } - // Ready to start translating the following block. c.switchTo(ctrl.originalStackLenWithoutParam, followingBlk) diff --git a/internal/engine/wazevo/ssa/instructions.go b/internal/engine/wazevo/ssa/instructions.go index 0c81429d5f..5d2ddad0c5 100644 --- a/internal/engine/wazevo/ssa/instructions.go +++ b/internal/engine/wazevo/ssa/instructions.go @@ -2119,10 +2119,11 @@ func (i *Instruction) BrTableData() (index Value, targets []BasicBlock) { } // AsJump initializes this instruction as a jump instruction with OpcodeJump. -func (i *Instruction) AsJump(vs Values, target BasicBlock) { +func (i *Instruction) AsJump(vs Values, target BasicBlock) *Instruction { i.opcode = OpcodeJump i.vs = vs i.blk = target + return i } // IsFallthroughJump returns true if this instruction is a fallthrough jump.