Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wazevo(frontend): stricter BCE part2 #2119

Merged
merged 2 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 112 additions & 14 deletions internal/engine/wazevo/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,43 @@ type Compiler struct {
br *bytes.Reader
loweringState loweringState

knownSafeBounds []knownSafeBound
knownSafeBounds [] /* ssa.ValueID to */ knownSafeBound
knownSafeBoundsSet []ssa.ValueID

knownSafeBoundsAtTheEndOfBlocks [] /* ssa.BlockID to */ knownSafeBoundsAtTheEndOfBlock
varLengthKnownSafeBoundWithIDPool wazevoapi.VarLengthPool[knownSafeBoundWithID]

execCtxPtrValue, moduleCtxPtrValue ssa.Value
}

// knownSafeBound represents a known safe bound for a value.
type knownSafeBound struct {
// bound is a constant upper bound for the value.
bound uint64
// absoluteAddr is the absolute address of the value.
absoluteAddr ssa.Value
}
type (
// knownSafeBound represents a known safe bound for a value.
knownSafeBound struct {
// bound is a constant upper bound for the value.
bound uint64
// absoluteAddr is the absolute address of the value.
absoluteAddr ssa.Value
}
// knownSafeBoundWithID is a knownSafeBound with the ID of the value.
knownSafeBoundWithID struct {
knownSafeBound
id ssa.ValueID
}
knownSafeBoundsAtTheEndOfBlock = wazevoapi.VarLength[knownSafeBoundWithID]
)

var knownSafeBoundsAtTheEndOfBlockNil = wazevoapi.NewNilVarLength[knownSafeBoundWithID]()

// NewFrontendCompiler returns a frontend Compiler.
func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData, ensureTermination bool, listenerOn bool, sourceInfo bool) *Compiler {
c := &Compiler{
m: m,
ssaBuilder: ssaBuilder,
br: bytes.NewReader(nil),
offset: offset,
ensureTermination: ensureTermination,
needSourceOffsetInfo: sourceInfo,
m: m,
ssaBuilder: ssaBuilder,
br: bytes.NewReader(nil),
offset: offset,
ensureTermination: ensureTermination,
needSourceOffsetInfo: sourceInfo,
varLengthKnownSafeBoundWithIDPool: wazevoapi.NewVarLengthPool[knownSafeBoundWithID](),
}
c.declareSignatures(listenerOn)
return c
Expand Down Expand Up @@ -207,6 +221,8 @@ func (c *Compiler) Init(idx, typIndex wasm.Index, typ *wasm.FunctionType, localT
c.wasmFunctionBodyOffsetInCodeSection = bodyOffsetInCodeSection
c.needListener = needListener
c.clearSafeBounds()
c.varLengthKnownSafeBoundWithIDPool.Reset()
c.knownSafeBoundsAtTheEndOfBlocks = c.knownSafeBoundsAtTheEndOfBlocks[:0]
}

// Note: this assumes 64-bit platform (I believe we won't have 32-bit backend ;)).
Expand Down Expand Up @@ -445,6 +461,7 @@ func (c *Compiler) clearSafeBounds() {
for _, v := range c.knownSafeBoundsSet {
ptr := &c.knownSafeBounds[v]
ptr.bound = 0
ptr.absoluteAddr = ssa.ValueInvalid
}
c.knownSafeBoundsSet = c.knownSafeBoundsSet[:0]
}
Expand All @@ -470,3 +487,84 @@ func (c *Compiler) allocateVarLengthValues(vs ...ssa.Value) ssa.Values {
}
return args
}

func (c *Compiler) finalizeKnownSafeBoundsAtTheEndOfBlock(bID ssa.BasicBlockID) {
_bID := int(bID)
if l := len(c.knownSafeBoundsAtTheEndOfBlocks); _bID >= l {
c.knownSafeBoundsAtTheEndOfBlocks = append(c.knownSafeBoundsAtTheEndOfBlocks,
make([]knownSafeBoundsAtTheEndOfBlock, _bID+1-len(c.knownSafeBoundsAtTheEndOfBlocks))...)
for i := l; i < len(c.knownSafeBoundsAtTheEndOfBlocks); i++ {
c.knownSafeBoundsAtTheEndOfBlocks[i] = knownSafeBoundsAtTheEndOfBlockNil
}
}
p := &c.varLengthKnownSafeBoundWithIDPool
size := len(c.knownSafeBoundsSet)
allocated := c.varLengthKnownSafeBoundWithIDPool.Allocate(size)
for _, vID := range c.knownSafeBoundsSet {
kb := c.knownSafeBounds[vID]
allocated = allocated.Append(p, knownSafeBoundWithID{
knownSafeBound: kb,
id: vID,
})
}
c.knownSafeBoundsAtTheEndOfBlocks[bID] = allocated
c.clearSafeBounds()
}

func (c *Compiler) initializeCurrentBlockKnownBounds() {
currentBlk := c.ssaBuilder.CurrentBlock()
switch preds := currentBlk.Preds(); preds {
case 0:
case 1:
pred := currentBlk.Pred(0).ID()
for _, kb := range c.getKnownSafeBoundsAtTheEndOfBlocks(pred).View() {
c.recordKnownSafeBound(kb.id, kb.bound, kb.absoluteAddr)
}
default:
primary := currentBlk.Pred(0).ID()
type mapVal struct {
kb knownSafeBoundWithID
count int
}
set := map[ssa.ValueID]mapVal{}
for _, kb := range c.getKnownSafeBoundsAtTheEndOfBlocks(primary).View() {
if kb.valid() {
set[kb.id] = mapVal{kb, 1}
}
}

// If there are more than one predecessor, we need to find the intersection of the known safe bounds.
for i := 1; i < preds; i++ {
pred := currentBlk.Pred(i).ID()
for _, kb := range c.getKnownSafeBoundsAtTheEndOfBlocks(pred).View() {
if !kb.valid() {
continue
}
mv, ok := set[kb.id]
if !ok {
continue
}
mv.count++
// Choose the lower bound.
if kb.bound < mv.kb.bound {
mv.kb = kb
}
set[kb.id] = mv
}
}
for _, mv := range set {
if mv.count == preds {
kb := mv.kb
// Absolute address cannot be used in the intersection since the value might be only defined in one of the predecessors.
c.recordKnownSafeBound(kb.id, kb.bound, ssa.ValueInvalid)
}
}
}
}

func (c *Compiler) getKnownSafeBoundsAtTheEndOfBlocks(id ssa.BasicBlockID) knownSafeBoundsAtTheEndOfBlock {
if int(id) >= len(c.knownSafeBoundsAtTheEndOfBlocks) {
return knownSafeBoundsAtTheEndOfBlockNil
}
return c.knownSafeBoundsAtTheEndOfBlocks[id]
}
200 changes: 199 additions & 1 deletion internal/engine/wazevo/frontend/frontend_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package frontend

import (
"fmt"
"testing"

"github.com/tetratelabs/wazero/api"
Expand Down Expand Up @@ -2810,6 +2811,86 @@ blk0: (exec_ctx:i64, module_ctx:i64)
exp: `
blk0: (exec_ctx:i64, module_ctx:i64)
Jump blk_ret
`,
},
{
name: "bounds check in if-else",
m: &wasm.Module{
TypeSection: []wasm.FunctionType{{Params: []wasm.ValueType{wasm.ValueTypeI32}}},
FunctionSection: []wasm.Index{0},
CodeSection: []wasm.Code{{
Body: []byte{
wasm.OpcodeLocalGet, 0,
wasm.OpcodeI32Load, 0x2, 0x20, // alignment=2 (natural alignment) staticOffset=0x20
wasm.OpcodeDrop,

wasm.OpcodeLocalGet, 0,
wasm.OpcodeIf, 0x40, // blockSignature:vv.
/* */ wasm.OpcodeLocalGet, 0,
/* */ // This bound check should be removed since it's known to be in bounds.
/* */ wasm.OpcodeI32Load, 0x2, 0x10, // alignment=2 (natural alignment) staticOffset=0x20
/* */ wasm.OpcodeDrop,
wasm.OpcodeElse,
/* */ wasm.OpcodeLocalGet, 0,
/* */ // This bound check should be removed since it's known to be in bounds.
/* */ wasm.OpcodeI32Load, 0x2, 0x10, // alignment=2 (natural alignment) staticOffset=0x20
/* */ wasm.OpcodeDrop,
/* */ // This shouldn't be removed since it's not known to be in bounds.
/* */ wasm.OpcodeLocalGet, 0,
/* */ wasm.OpcodeI32Load, 0x2, 0x30, // alignment=2 (natural alignment) staticOffset=0x20
/* */ wasm.OpcodeDrop,
/* */ // But this is known to be in bounds.
/* */ wasm.OpcodeLocalGet, 0,
/* */ wasm.OpcodeI32Load, 0x2, 0x25, // alignment=2 (natural alignment) staticOffset=0x20
/* */ wasm.OpcodeDrop,
wasm.OpcodeEnd,

// At this point, the known bound 0x20 should be retained.
wasm.OpcodeLocalGet, 0,
wasm.OpcodeI32Load, 0x2, 0x15, // alignment=2 (natural alignment) staticOffset=0x20
wasm.OpcodeDrop,

wasm.OpcodeEnd,
},
}},
MemorySection: &wasm.Memory{Min: 1},
},
features: api.CoreFeaturesV2,
exp: `
blk0: (exec_ctx:i64, module_ctx:i64, v2:i32)
v3:i64 = Iconst_64 0x24
v4:i64 = UExtend v2, 32->64
v5:i64 = Uload32 module_ctx, 0x10
v6:i64 = Iadd v4, v3
v7:i32 = Icmp lt_u, v5, v6
ExitIfTrue v7, exec_ctx, memory_out_of_bounds
v8:i64 = Load module_ctx, 0x8
v9:i64 = Iadd v8, v4
v10:i32 = Load v9, 0x20
Brz v2, blk2
Jump blk1

blk1: () <-- (blk0)
v11:i32 = Load v9, 0x10
Jump blk3

blk2: () <-- (blk0)
v12:i32 = Load v9, 0x10
v13:i64 = Iconst_64 0x34
v14:i64 = UExtend v2, 32->64
v15:i64 = Iadd v14, v13
v16:i32 = Icmp lt_u, v5, v15
ExitIfTrue v16, exec_ctx, memory_out_of_bounds
v17:i32 = Load v9, 0x30
v18:i32 = Load v9, 0x25
Jump blk3

blk3: () <-- (blk1,blk2)
v19:i64 = Load module_ctx, 0x8
v20:i64 = UExtend v2, 32->64
v21:i64 = Iadd v19, v20
v22:i32 = Load v21, 0x15
Jump blk_ret
`,
},
} {
Expand Down Expand Up @@ -3001,7 +3082,7 @@ func TestCompiler_clearSafeBounds(t *testing.T) {
c.knownSafeBoundsSet = []ssa.ValueID{0, 2, 5}
c.clearSafeBounds()
require.Equal(t, 0, len(c.knownSafeBoundsSet))
require.Equal(t, []knownSafeBound{{}, {}, {}, {}, {}, {}}, c.knownSafeBounds)
require.Equal(t, []knownSafeBound{{absoluteAddr: ssa.ValueInvalid}, {}, {absoluteAddr: ssa.ValueInvalid}, {}, {}, {absoluteAddr: ssa.ValueInvalid}}, c.knownSafeBounds)
}

func TestCompiler_resetAbsoluteAddressInSafeBounds(t *testing.T) {
Expand Down Expand Up @@ -3033,3 +3114,120 @@ func TestKnownSafeBound_valid(t *testing.T) {
k.bound = 0
require.False(t, k.valid())
}

func TestCompiler_finalizeKnownSafeBoundsAtTheEndOoBlock(t *testing.T) {
c := NewFrontendCompiler(&wasm.Module{}, ssa.NewBuilder(), nil, false, false, false)
blk := c.ssaBuilder.AllocateBasicBlock()
require.True(t, len(c.getKnownSafeBoundsAtTheEndOfBlocks(blk.ID()).View()) == 0)
c.ssaBuilder.SetCurrentBlock(blk)
c.knownSafeBoundsSet = []ssa.ValueID{0, 2, 5}
c.knownSafeBounds = []knownSafeBound{
{bound: 1, absoluteAddr: ssa.Value(1)},
{},
{bound: 2, absoluteAddr: ssa.Value(2)},
{},
{},
{bound: 3, absoluteAddr: ssa.Value(3)},
}
c.finalizeKnownSafeBoundsAtTheEndOfBlock(blk.ID())
require.True(t, len(c.knownSafeBoundsSet) == 0)
finalized := c.getKnownSafeBoundsAtTheEndOfBlocks(blk.ID())
require.Equal(t, 3, len(finalized.View()))
}

func TestCompiler_initializeCurrentBlockKnownBounds(t *testing.T) {
t.Run("single", func(t *testing.T) {
c := NewFrontendCompiler(&wasm.Module{}, ssa.NewBuilder(), nil, false, false, false)
builder := c.ssaBuilder
child := builder.AllocateBasicBlock()
{
parent := builder.AllocateBasicBlock()
builder.SetCurrentBlock(parent)
c.recordKnownSafeBound(1, 99, 9999)
c.recordKnownSafeBound(2, 150, 9999)
c.recordKnownSafeBound(5, 666, 54321)
builder.AllocateInstruction().AsJump(ssa.ValuesNil, child).Insert(builder)
c.finalizeKnownSafeBoundsAtTheEndOfBlock(parent.ID())
}

builder.SetCurrentBlock(child)
c.initializeCurrentBlockKnownBounds()
kb := c.getKnownSafeBound(1)
require.True(t, kb.valid())
require.Equal(t, uint64(99), kb.bound)
require.Equal(t, ssa.Value(9999), kb.absoluteAddr)
kb = c.getKnownSafeBound(2)
require.True(t, kb.valid())
require.Equal(t, uint64(150), kb.bound)
require.Equal(t, ssa.Value(9999), kb.absoluteAddr)
kb = c.getKnownSafeBound(5)
require.True(t, kb.valid())
require.Equal(t, uint64(666), kb.bound)
require.Equal(t, ssa.Value(54321), kb.absoluteAddr)
})
t.Run("multiple predecessors", func(t *testing.T) {
c := NewFrontendCompiler(&wasm.Module{}, ssa.NewBuilder(), nil, false, false, false)
builder := c.ssaBuilder
child := builder.AllocateBasicBlock()
{
p1 := builder.AllocateBasicBlock()
builder.SetCurrentBlock(p1)
c.recordKnownSafeBound(1, 99, 9999)
c.recordKnownSafeBound(2, 150, 9999)
c.recordKnownSafeBound(5, 666, 54321)
c.recordKnownSafeBound(592131, 666, 54321)
builder.AllocateInstruction().AsJump(ssa.ValuesNil, child).Insert(builder)
c.finalizeKnownSafeBoundsAtTheEndOfBlock(p1.ID())
}
{
p2 := builder.AllocateBasicBlock()
builder.SetCurrentBlock(p2)
c.recordKnownSafeBound(1, 100, 999419)
c.recordKnownSafeBound(2, 4, 9991239)
c.recordKnownSafeBound(5, 555, 54341221)
c.recordKnownSafeBound(6, 666, 54321)
c.recordKnownSafeBound(7, 666, 54321)
builder.AllocateInstruction().AsJump(ssa.ValuesNil, child).Insert(builder)
c.finalizeKnownSafeBoundsAtTheEndOfBlock(p2.ID())
}
{
p3 := builder.AllocateBasicBlock()
builder.SetCurrentBlock(p3)
c.recordKnownSafeBound(1, 1, 999419)
c.recordKnownSafeBound(2, 11111, 9991239)
c.recordKnownSafeBound(5, 5551231, 54341221)
c.recordKnownSafeBound(7, 666, 54321)
c.recordKnownSafeBound(60, 666, 54321)
builder.AllocateInstruction().AsJump(ssa.ValuesNil, child).Insert(builder)
c.finalizeKnownSafeBoundsAtTheEndOfBlock(p3.ID())
}

builder.SetCurrentBlock(child)
c.initializeCurrentBlockKnownBounds()
for _, tc := range []struct {
id ssa.ValueID
valid bool
bound uint64
}{
{id: 0, valid: false},
{id: 1, valid: true, bound: 1},
{id: 2, valid: true, bound: 4},
{id: 3, valid: false},
{id: 4, valid: false},
{id: 5, valid: true, bound: 555},
{id: 6, valid: false},
{id: 7, valid: false},
{id: 60, valid: false},
{id: 592131, valid: false},
} {
t.Run(fmt.Sprintf("id=%d", tc.id), func(t *testing.T) {
kb := c.getKnownSafeBound(tc.id)
require.Equal(t, tc.valid, kb.valid())
if kb.valid() {
require.Equal(t, tc.bound, kb.bound)
require.Equal(t, ssa.ValueInvalid, kb.absoluteAddr)
}
})
}
})
}
Loading
Loading