diff --git a/internal/engine/wazevo/frontend/frontend.go b/internal/engine/wazevo/frontend/frontend.go index 1c54c8b457..ac92a37607 100644 --- a/internal/engine/wazevo/frontend/frontend.go +++ b/internal/engine/wazevo/frontend/frontend.go @@ -50,9 +50,17 @@ type Compiler struct { br *bytes.Reader loweringState loweringState + knownSafeBounds []knownSafeBound + knownSafeBoundsSet []ssa.ValueID + execCtxPtrValue, moduleCtxPtrValue ssa.Value } +type knownSafeBound struct { + bound uint64 + absoluteAddr ssa.Value +} + // NewFrontendCompiler returns a frontend Compiler. func NewFrontendCompiler(m *wasm.Module, ssaBuilder ssa.Builder, offset *wazevoapi.ModuleContextOffsetData, ensureTermination bool, listenerOn bool, sourceInfo bool) *Compiler { c := &Compiler{ @@ -354,3 +362,42 @@ func SignatureForListener(wasmSig *wasm.FunctionType) (*ssa.Signature, *ssa.Sign } return beforeSig, afterSig } + +// isBoundSafe returns true if the given value is known to be safe to access up to the given bound. +func (c *Compiler) getKnownSafeBound(v ssa.ValueID) *knownSafeBound { + if int(v) >= len(c.knownSafeBounds) { + return nil + } + return &c.knownSafeBounds[v] +} + +// recordKnownSafeBound records the given safe bound for the given value. +func (c *Compiler) recordKnownSafeBound(v ssa.ValueID, safeBound uint64, absoluteAddr ssa.Value) { + if int(v) >= len(c.knownSafeBounds) { + c.knownSafeBounds = append(c.knownSafeBounds, make([]knownSafeBound, v+1)...) + } + + if exiting := c.knownSafeBounds[v]; exiting.bound == 0 { + c.knownSafeBounds[v] = knownSafeBound{ + bound: safeBound, + absoluteAddr: absoluteAddr, + } + c.knownSafeBoundsSet = append(c.knownSafeBoundsSet, v) + } else if safeBound > exiting.bound { + c.knownSafeBounds[v].bound = safeBound + } +} + +// clearSafeBounds clears the known safe bounds. This must be called +// after the compilation of each block. +func (c *Compiler) clearSafeBounds() { + for _, v := range c.knownSafeBoundsSet { + ptr := &c.knownSafeBounds[v] + ptr.bound = 0 + } + c.knownSafeBoundsSet = c.knownSafeBoundsSet[:0] +} + +func (k *knownSafeBound) valid() bool { + return k != nil && k.bound > 0 +} diff --git a/internal/engine/wazevo/frontend/frontend_test.go b/internal/engine/wazevo/frontend/frontend_test.go index f50d3218cc..34d17ec400 100644 --- a/internal/engine/wazevo/frontend/frontend_test.go +++ b/internal/engine/wazevo/frontend/frontend_test.go @@ -1019,14 +1019,8 @@ blk0: (exec_ctx:i64, module_ctx:i64, v2:i32, v3:i32) v9:i64 = Load module_ctx, 0x8 v10:i64 = Iadd v9, v5 Store v3, v10, 0x0 - v11:i64 = Iconst_64 0x4 - v12:i64 = UExtend v2, 32->64 - v13:i64 = Iadd v12, v11 - v14:i32 = Icmp lt_u, v6, v13 - ExitIfTrue v14, exec_ctx, memory_out_of_bounds - v15:i64 = Iadd v9, v12 - v16:i32 = Load v15, 0x0 - Jump blk_ret, v16 + v11:i32 = Load v10, 0x0 + Jump blk_ret, v11 `, }, { @@ -1142,191 +1136,44 @@ blk0: (exec_ctx:i64, module_ctx:i64, v2:i32) v13:i64 = Iadd v12, v11 v14:i32 = Icmp lt_u, v5, v13 ExitIfTrue v14, exec_ctx, memory_out_of_bounds - v15:i64 = Iadd v8, v12 - v16:i64 = Load v15, 0x0 - v17:i64 = Iconst_64 0x4 - v18:i64 = UExtend v2, 32->64 - v19:i64 = Iadd v18, v17 - v20:i32 = Icmp lt_u, v5, v19 - ExitIfTrue v20, exec_ctx, memory_out_of_bounds - v21:i64 = Iadd v8, v18 - v22:f32 = Load v21, 0x0 - v23:i64 = Iconst_64 0x8 + v15:i64 = Load v9, 0x0 + v16:f32 = Load v9, 0x0 + v17:f64 = Load v9, 0x0 + v18:i64 = Iconst_64 0x13 + v19:i64 = UExtend v2, 32->64 + v20:i64 = Iadd v19, v18 + v21:i32 = Icmp lt_u, v5, v20 + ExitIfTrue v21, exec_ctx, memory_out_of_bounds + v22:i32 = Load v9, 0xf + v23:i64 = Iconst_64 0x17 v24:i64 = UExtend v2, 32->64 v25:i64 = Iadd v24, v23 v26:i32 = Icmp lt_u, v5, v25 ExitIfTrue v26, exec_ctx, memory_out_of_bounds - v27:i64 = Iadd v8, v24 - v28:f64 = Load v27, 0x0 - v29:i64 = Iconst_64 0x13 - v30:i64 = UExtend v2, 32->64 - v31:i64 = Iadd v30, v29 - v32:i32 = Icmp lt_u, v5, v31 - ExitIfTrue v32, exec_ctx, memory_out_of_bounds - v33:i64 = Iadd v8, v30 - v34:i32 = Load v33, 0xf - v35:i64 = Iconst_64 0x17 - v36:i64 = UExtend v2, 32->64 - v37:i64 = Iadd v36, v35 - v38:i32 = Icmp lt_u, v5, v37 - ExitIfTrue v38, exec_ctx, memory_out_of_bounds - v39:i64 = Iadd v8, v36 - v40:i64 = Load v39, 0xf - v41:i64 = Iconst_64 0x13 - v42:i64 = UExtend v2, 32->64 - v43:i64 = Iadd v42, v41 - v44:i32 = Icmp lt_u, v5, v43 - ExitIfTrue v44, exec_ctx, memory_out_of_bounds - v45:i64 = Iadd v8, v42 - v46:f32 = Load v45, 0xf - v47:i64 = Iconst_64 0x17 - v48:i64 = UExtend v2, 32->64 - v49:i64 = Iadd v48, v47 - v50:i32 = Icmp lt_u, v5, v49 - ExitIfTrue v50, exec_ctx, memory_out_of_bounds - v51:i64 = Iadd v8, v48 - v52:f64 = Load v51, 0xf - v53:i64 = Iconst_64 0x1 - v54:i64 = UExtend v2, 32->64 - v55:i64 = Iadd v54, v53 - v56:i32 = Icmp lt_u, v5, v55 - ExitIfTrue v56, exec_ctx, memory_out_of_bounds - v57:i64 = Iadd v8, v54 - v58:i32 = Sload8 v57, 0x0 - v59:i64 = Iconst_64 0x10 - v60:i64 = UExtend v2, 32->64 - v61:i64 = Iadd v60, v59 - v62:i32 = Icmp lt_u, v5, v61 - ExitIfTrue v62, exec_ctx, memory_out_of_bounds - v63:i64 = Iadd v8, v60 - v64:i32 = Sload8 v63, 0xf - v65:i64 = Iconst_64 0x1 - v66:i64 = UExtend v2, 32->64 - v67:i64 = Iadd v66, v65 - v68:i32 = Icmp lt_u, v5, v67 - ExitIfTrue v68, exec_ctx, memory_out_of_bounds - v69:i64 = Iadd v8, v66 - v70:i32 = Uload8 v69, 0x0 - v71:i64 = Iconst_64 0x10 - v72:i64 = UExtend v2, 32->64 - v73:i64 = Iadd v72, v71 - v74:i32 = Icmp lt_u, v5, v73 - ExitIfTrue v74, exec_ctx, memory_out_of_bounds - v75:i64 = Iadd v8, v72 - v76:i32 = Uload8 v75, 0xf - v77:i64 = Iconst_64 0x2 - v78:i64 = UExtend v2, 32->64 - v79:i64 = Iadd v78, v77 - v80:i32 = Icmp lt_u, v5, v79 - ExitIfTrue v80, exec_ctx, memory_out_of_bounds - v81:i64 = Iadd v8, v78 - v82:i32 = Sload16 v81, 0x0 - v83:i64 = Iconst_64 0x11 - v84:i64 = UExtend v2, 32->64 - v85:i64 = Iadd v84, v83 - v86:i32 = Icmp lt_u, v5, v85 - ExitIfTrue v86, exec_ctx, memory_out_of_bounds - v87:i64 = Iadd v8, v84 - v88:i32 = Sload16 v87, 0xf - v89:i64 = Iconst_64 0x2 - v90:i64 = UExtend v2, 32->64 - v91:i64 = Iadd v90, v89 - v92:i32 = Icmp lt_u, v5, v91 - ExitIfTrue v92, exec_ctx, memory_out_of_bounds - v93:i64 = Iadd v8, v90 - v94:i32 = Uload16 v93, 0x0 - v95:i64 = Iconst_64 0x11 - v96:i64 = UExtend v2, 32->64 - v97:i64 = Iadd v96, v95 - v98:i32 = Icmp lt_u, v5, v97 - ExitIfTrue v98, exec_ctx, memory_out_of_bounds - v99:i64 = Iadd v8, v96 - v100:i32 = Uload16 v99, 0xf - v101:i64 = Iconst_64 0x1 - v102:i64 = UExtend v2, 32->64 - v103:i64 = Iadd v102, v101 - v104:i32 = Icmp lt_u, v5, v103 - ExitIfTrue v104, exec_ctx, memory_out_of_bounds - v105:i64 = Iadd v8, v102 - v106:i64 = Sload8 v105, 0x0 - v107:i64 = Iconst_64 0x10 - v108:i64 = UExtend v2, 32->64 - v109:i64 = Iadd v108, v107 - v110:i32 = Icmp lt_u, v5, v109 - ExitIfTrue v110, exec_ctx, memory_out_of_bounds - v111:i64 = Iadd v8, v108 - v112:i64 = Sload8 v111, 0xf - v113:i64 = Iconst_64 0x1 - v114:i64 = UExtend v2, 32->64 - v115:i64 = Iadd v114, v113 - v116:i32 = Icmp lt_u, v5, v115 - ExitIfTrue v116, exec_ctx, memory_out_of_bounds - v117:i64 = Iadd v8, v114 - v118:i64 = Uload8 v117, 0x0 - v119:i64 = Iconst_64 0x10 - v120:i64 = UExtend v2, 32->64 - v121:i64 = Iadd v120, v119 - v122:i32 = Icmp lt_u, v5, v121 - ExitIfTrue v122, exec_ctx, memory_out_of_bounds - v123:i64 = Iadd v8, v120 - v124:i64 = Uload8 v123, 0xf - v125:i64 = Iconst_64 0x2 - v126:i64 = UExtend v2, 32->64 - v127:i64 = Iadd v126, v125 - v128:i32 = Icmp lt_u, v5, v127 - ExitIfTrue v128, exec_ctx, memory_out_of_bounds - v129:i64 = Iadd v8, v126 - v130:i64 = Sload16 v129, 0x0 - v131:i64 = Iconst_64 0x11 - v132:i64 = UExtend v2, 32->64 - v133:i64 = Iadd v132, v131 - v134:i32 = Icmp lt_u, v5, v133 - ExitIfTrue v134, exec_ctx, memory_out_of_bounds - v135:i64 = Iadd v8, v132 - v136:i64 = Sload16 v135, 0xf - v137:i64 = Iconst_64 0x2 - v138:i64 = UExtend v2, 32->64 - v139:i64 = Iadd v138, v137 - v140:i32 = Icmp lt_u, v5, v139 - ExitIfTrue v140, exec_ctx, memory_out_of_bounds - v141:i64 = Iadd v8, v138 - v142:i64 = Uload16 v141, 0x0 - v143:i64 = Iconst_64 0x11 - v144:i64 = UExtend v2, 32->64 - v145:i64 = Iadd v144, v143 - v146:i32 = Icmp lt_u, v5, v145 - ExitIfTrue v146, exec_ctx, memory_out_of_bounds - v147:i64 = Iadd v8, v144 - v148:i64 = Uload16 v147, 0xf - v149:i64 = Iconst_64 0x4 - v150:i64 = UExtend v2, 32->64 - v151:i64 = Iadd v150, v149 - v152:i32 = Icmp lt_u, v5, v151 - ExitIfTrue v152, exec_ctx, memory_out_of_bounds - v153:i64 = Iadd v8, v150 - v154:i64 = Sload32 v153, 0x0 - v155:i64 = Iconst_64 0x13 - v156:i64 = UExtend v2, 32->64 - v157:i64 = Iadd v156, v155 - v158:i32 = Icmp lt_u, v5, v157 - ExitIfTrue v158, exec_ctx, memory_out_of_bounds - v159:i64 = Iadd v8, v156 - v160:i64 = Sload32 v159, 0xf - v161:i64 = Iconst_64 0x4 - v162:i64 = UExtend v2, 32->64 - v163:i64 = Iadd v162, v161 - v164:i32 = Icmp lt_u, v5, v163 - ExitIfTrue v164, exec_ctx, memory_out_of_bounds - v165:i64 = Iadd v8, v162 - v166:i64 = Uload32 v165, 0x0 - v167:i64 = Iconst_64 0x13 - v168:i64 = UExtend v2, 32->64 - v169:i64 = Iadd v168, v167 - v170:i32 = Icmp lt_u, v5, v169 - ExitIfTrue v170, exec_ctx, memory_out_of_bounds - v171:i64 = Iadd v8, v168 - v172:i64 = Uload32 v171, 0xf - Jump blk_ret, v10, v16, v22, v28, v34, v40, v46, v52, v58, v64, v70, v76, v82, v88, v94, v100, v106, v112, v118, v124, v130, v136, v142, v148, v154, v160, v166, v172 + v27:i64 = Load v9, 0xf + v28:f32 = Load v9, 0xf + v29:f64 = Load v9, 0xf + v30:i32 = Sload8 v9, 0x0 + v31:i32 = Sload8 v9, 0xf + v32:i32 = Uload8 v9, 0x0 + v33:i32 = Uload8 v9, 0xf + v34:i32 = Sload16 v9, 0x0 + v35:i32 = Sload16 v9, 0xf + v36:i32 = Uload16 v9, 0x0 + v37:i32 = Uload16 v9, 0xf + v38:i64 = Sload8 v9, 0x0 + v39:i64 = Sload8 v9, 0xf + v40:i64 = Uload8 v9, 0x0 + v41:i64 = Uload8 v9, 0xf + v42:i64 = Sload16 v9, 0x0 + v43:i64 = Sload16 v9, 0xf + v44:i64 = Uload16 v9, 0x0 + v45:i64 = Uload16 v9, 0xf + v46:i64 = Sload32 v9, 0x0 + v47:i64 = Sload32 v9, 0xf + v48:i64 = Uload32 v9, 0x0 + v49:i64 = Uload32 v9, 0xf + Jump blk_ret, v10, v15, v16, v17, v22, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49 `, }, { @@ -1934,3 +1781,49 @@ func TestCompiler_declareSignatures(t *testing.T) { } }) } + +func TestCompiler_recordKnownSafeBound(t *testing.T) { + c := &Compiler{} + c.recordKnownSafeBound(1, 99, 9999) + require.Equal(t, 1, len(c.knownSafeBoundsSet)) + require.True(t, c.getKnownSafeBound(1).valid()) + require.Equal(t, uint64(99), c.getKnownSafeBound(1).bound) + require.Equal(t, ssa.Value(9999), c.getKnownSafeBound(1).absoluteAddr) + + c.recordKnownSafeBound(1, 150, 9999) + require.Equal(t, 1, len(c.knownSafeBoundsSet)) + require.Equal(t, uint64(150), c.getKnownSafeBound(1).bound) + + c.recordKnownSafeBound(5, 666, 54321) + require.Equal(t, 2, len(c.knownSafeBoundsSet)) + require.Equal(t, uint64(666), c.getKnownSafeBound(5).bound) + require.Equal(t, ssa.Value(54321), c.getKnownSafeBound(5).absoluteAddr) +} + +func TestCompiler_getKnownSafeBound(t *testing.T) { + c := &Compiler{ + knownSafeBounds: []knownSafeBound{ + {}, {bound: 2134}, + }, + } + require.Nil(t, c.getKnownSafeBound(5)) + require.Nil(t, c.getKnownSafeBound(12345)) + require.False(t, c.getKnownSafeBound(0).valid()) + require.True(t, c.getKnownSafeBound(1).valid()) +} + +func TestCompiler_clearSafeBounds(t *testing.T) { + c := &Compiler{} + c.knownSafeBounds = []knownSafeBound{{bound: 1}, {}, {bound: 2}, {}, {}, {bound: 3}} + c.knownSafeBoundsSet = []ssa.ValueID{0, 2, 5} + c.clearSafeBounds() + require.Equal(t, 0, len(c.knownSafeBoundsSet)) + require.Equal(t, []knownSafeBound{{}, {}, {}, {}, {}, {}}, c.knownSafeBounds) +} + +func TestKnownSafeBound_valid(t *testing.T) { + k := &knownSafeBound{bound: 10, absoluteAddr: 12345} + require.True(t, k.valid()) + k.bound = 0 + require.False(t, k.valid()) +} diff --git a/internal/engine/wazevo/frontend/lower.go b/internal/engine/wazevo/frontend/lower.go index 0816910e5c..2b2de518eb 100644 --- a/internal/engine/wazevo/frontend/lower.go +++ b/internal/engine/wazevo/frontend/lower.go @@ -1416,6 +1416,8 @@ func (c *Compiler) lowerCurrentOpcode() { builder.Seal(thenBlk) builder.Seal(elseBlk) case wasm.OpcodeElse: + c.clearSafeBounds() // Reset the safe bounds since we are entering the Else block. + ifctrl := state.ctrlPeekAt(0) if unreachable := state.unreachable; unreachable && state.unreachableDepth > 0 { // If it is currently in unreachable and is a nested if, @@ -1443,6 +1445,8 @@ func (c *Compiler) lowerCurrentOpcode() { builder.SetCurrentBlock(elseBlk) case wasm.OpcodeEnd: + c.clearSafeBounds() // Reset the safe bounds since we are exiting the block. + if state.unreachableDepth > 0 { state.unreachableDepth-- break @@ -3368,24 +3372,35 @@ func (c *Compiler) lowerCallIndirect(typeIndex, tableIndex uint32) { // memOpSetup inserts the bounds check and calculates the address of the memory operation (loads/stores). func (c *Compiler) memOpSetup(baseAddr ssa.Value, constOffset, operationSizeInBytes uint64) (address ssa.Value) { + address = ssa.ValueInvalid builder := c.ssaBuilder + baseAddrID := baseAddr.ID() ceil := constOffset + operationSizeInBytes + if known := c.getKnownSafeBound(baseAddrID); known.valid() { + // We reuse the calculated absolute address even if the bound is not known to be safe. + address = known.absoluteAddr + if ceil <= known.bound { + return + } + } + ceilConst := builder.AllocateInstruction() ceilConst.AsIconst64(ceil) builder.InsertInstruction(ceilConst) // We calculate the offset in 64-bit space. - extBaseAddr := builder.AllocateInstruction() - extBaseAddr.AsUExtend(baseAddr, 32, 64) - builder.InsertInstruction(extBaseAddr) + extBaseAddr := builder.AllocateInstruction(). + AsUExtend(baseAddr, 32, 64). + Insert(builder). + Return() // Note: memLen is already zero extended to 64-bit space at the load time. memLen := c.getMemoryLenValue(false) // baseAddrPlusCeil = baseAddr + ceil baseAddrPlusCeil := builder.AllocateInstruction() - baseAddrPlusCeil.AsIadd(extBaseAddr.Return(), ceilConst.Return()) + baseAddrPlusCeil.AsIadd(extBaseAddr, ceilConst.Return()) builder.InsertInstruction(baseAddrPlusCeil) // Check for out of bounds memory access: `memLen >= baseAddrPlusCeil`. @@ -3397,11 +3412,15 @@ func (c *Compiler) memOpSetup(baseAddr ssa.Value, constOffset, operationSizeInBy builder.InsertInstruction(exitIfNZ) // Load the value from memBase + extBaseAddr. - memBase := c.getMemoryBaseValue(false) - addrCalc := builder.AllocateInstruction() - addrCalc.AsIadd(memBase, extBaseAddr.Return()) - builder.InsertInstruction(addrCalc) - return addrCalc.Return() + if address == ssa.ValueInvalid { // Reuse the value if the memBase is already calculated at this point. + memBase := c.getMemoryBaseValue(false) + address = builder.AllocateInstruction(). + AsIadd(memBase, extBaseAddr).Insert(builder).Return() + } + + // Record the bound ceil for this baseAddr is known to be safe for the subsequent memory access in the same block. + c.recordKnownSafeBound(baseAddrID, ceil, address) + return } func (c *Compiler) callMemmove(dst, src, size ssa.Value) { @@ -3434,6 +3453,10 @@ func (c *Compiler) reloadAfterCall() { func (c *Compiler) reloadMemoryBaseLen() { _ = c.getMemoryBaseValue(true) _ = c.getMemoryLenValue(true) + + // This function being called means that the memory base might have changed. + // Therefore, we need to clear the known safe bounds because we cache the absolute address of the memory access per each base offset. + c.clearSafeBounds() } // globalInstanceValueOffset is the offsetOf .Value field of wasm.GlobalInstance.