Skip to content

Commit

Permalink
Avoids unnecessary allocations during mmap executables (#1366)
Browse files Browse the repository at this point in the history
Signed-off-by: Takeshi Yoneda <[email protected]>
  • Loading branch information
mathetake authored Apr 17, 2023
1 parent 00d9d88 commit d9e5d6b
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 93 deletions.
18 changes: 17 additions & 1 deletion internal/engine/compiler/engine_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/binary"
"fmt"
"io"
"runtime"

"github.com/tetratelabs/wazero/internal/platform"
"github.com/tetratelabs/wazero/internal/u32"
Expand Down Expand Up @@ -163,12 +164,27 @@ func deserializeCodes(wazeroVersion string, reader io.ReadCloser) (codes []*code
break
}

if c.codeSegment, err = platform.MmapCodeSegment(reader, int(nativeCodeLen)); err != nil {
if c.codeSegment, err = platform.MmapCodeSegment(int(nativeCodeLen)); err != nil {
err = fmt.Errorf("compilationcache: error mmapping func[%d] code (len=%d): %v", i, nativeCodeLen, err)
break
}

codes = append(codes, c)

_, err = io.ReadFull(reader, c.codeSegment)
if err != nil {
err = fmt.Errorf("compilationcache: error reading func[%d] code (len=%d): %v", i, nativeCodeLen, err)
break
}

if runtime.GOARCH == "arm64" {
// On arm64, we cannot give all of rwx at the same time, so we change it to exec.
err = platform.MprotectRX(c.codeSegment)
if err != nil {
break
}
}

}

if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/engine/compiler/engine_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func TestDeserializeCodes(t *testing.T) {
u64.LeBytes(5), // length of code.
// Lack of code here.
),
expErr: "compilationcache: error mmapping func[1] code (len=5): EOF",
expErr: "compilationcache: error reading func[1] code (len=5): EOF",
},
}

Expand Down
14 changes: 7 additions & 7 deletions internal/engine/compiler/impl_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package compiler
// if unfamiliar with amd64 instructions used here.

import (
"bytes"
"fmt"
"math"

Expand Down Expand Up @@ -96,7 +95,6 @@ type amd64Compiler struct {
assignStackPointerCeilNeeded asm.Node
withListener bool
typ *wasm.FunctionType
br *bytes.Reader
// locationStackForEntrypoint is the initial location stack for all functions. To reuse the allocated stack,
// we cache it here, and reset and set to .locationStack in the Init method.
locationStackForEntrypoint runtimeValueLocationStack
Expand All @@ -110,7 +108,6 @@ func newAmd64Compiler() compiler {
assembler: amd64.NewAssembler(),
locationStackForEntrypoint: newRuntimeValueLocationStack(),
cpuFeatures: platform.CpuFeatures,
br: bytes.NewReader(nil),
}
return c
}
Expand All @@ -127,7 +124,6 @@ func (c *amd64Compiler) Init(typ *wasm.FunctionType, ir *wazeroir.CompilationRes
assembler: c.assembler,
cpuFeatures: c.cpuFeatures,
labels: c.labels,
br: c.br,
locationStackForEntrypoint: c.locationStackForEntrypoint,
brTableTmp: c.brTableTmp,
}
Expand Down Expand Up @@ -283,13 +279,17 @@ func (c *amd64Compiler) compile() (code []byte, stackPointerCeil uint64, err err
// Note this MUST be called before Assemble() below.
c.assignStackPointerCeil(stackPointerCeil)

code, err = c.assembler.Assemble()
var original []byte
original, err = c.assembler.Assemble()
if err != nil {
return
}

c.br.Reset(code)
code, err = platform.MmapCodeSegment(c.br, len(code))
code, err = platform.MmapCodeSegment(len(original))
if err != nil {
return
}
copy(code, original)
return
}

Expand Down
9 changes: 7 additions & 2 deletions internal/engine/compiler/impl_arm64.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,13 @@ func (c *arm64Compiler) compile() (code []byte, stackPointerCeil uint64, err err
return
}

c.br.Reset(original)
code, err = platform.MmapCodeSegment(c.br, len(original))
code, err = platform.MmapCodeSegment(len(original))
if err != nil {
return
}
copy(code, original)
// On arm64, we cannot give all of rwx at the same time, so we change it to exec.
err = platform.MprotectRX(code)
return
}

Expand Down
20 changes: 0 additions & 20 deletions internal/platform/buf_writer.go

This file was deleted.

34 changes: 11 additions & 23 deletions internal/platform/mmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package platform

import (
"io"
"syscall"
"unsafe"
)
Expand All @@ -15,8 +14,8 @@ func munmapCodeSegment(code []byte) error {

// mmapCodeSegmentAMD64 gives all read-write-exec permission to the mmap region
// to enter the function. Otherwise, segmentation fault exception is raised.
func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) {
mmapFunc, err := syscall.Mmap(
func mmapCodeSegmentAMD64(size int) ([]byte, error) {
buf, err := syscall.Mmap(
-1,
0,
size,
Expand All @@ -29,18 +28,15 @@ func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) {
if err != nil {
return nil, err
}

w := &bufWriter{underlying: mmapFunc}
_, err = io.CopyN(w, code, int64(size))
return mmapFunc, err
return buf, err
}

// mmapCodeSegmentARM64 cannot give all read-write-exec permission to the mmap region.
// Otherwise, the mmap systemcall would raise an error. Here we give read-write
// to the region at first, write the native code and then change the perm to
// read-exec, so we can execute the native code.
func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) {
mmapFunc, err := syscall.Mmap(
// to the region so that we can write contents at call-sites. Callers are responsible to
// execute MprotectRX on the returned buffer.
func mmapCodeSegmentARM64(size int) ([]byte, error) {
buf, err := syscall.Mmap(
-1,
0,
size,
Expand All @@ -53,26 +49,18 @@ func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) {
if err != nil {
return nil, err
}

w := &bufWriter{underlying: mmapFunc}
_, err = io.CopyN(w, code, int64(size))
if err != nil {
return nil, err
}

// Then we're done with writing code, change the permission to RX.
err = mprotect(mmapFunc, syscall.PROT_READ|syscall.PROT_EXEC)
return mmapFunc, err
return buf, err
}

// mprotect is like syscall.Mprotect, defined locally so that freebsd compiles.
func mprotect(b []byte, prot int) (err error) {
// MprotectRX is like syscall.Mprotect, defined locally so that freebsd compiles.
func MprotectRX(b []byte) (err error) {
var _p0 unsafe.Pointer
if len(b) > 0 {
_p0 = unsafe.Pointer(&b[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
const prot = syscall.PROT_READ | syscall.PROT_EXEC
_, _, e1 := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(_p0), uintptr(len(b)), uintptr(prot))
if e1 != 0 {
err = syscall.Errno(e1)
Expand Down
19 changes: 4 additions & 15 deletions internal/platform/mmap_test.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,21 @@
package platform

import (
"bytes"
"crypto/rand"
"io"
"testing"

"github.com/tetratelabs/wazero/internal/testing/require"
)

var testCodeBuf, _ = io.ReadAll(io.LimitReader(rand.Reader, 8*1024))

func Test_MmapCodeSegment(t *testing.T) {
if !CompilerSupported() {
t.Skip()
}

testCodeReader := bytes.NewReader(testCodeBuf)
newCode, err := MmapCodeSegment(testCodeReader, testCodeReader.Len())
_, err := MmapCodeSegment(1234)
require.NoError(t, err)
// Verify that the mmap is the same as the original.
require.Equal(t, testCodeBuf, newCode)
// TODO: test newCode can executed.

t.Run("panic on zero length", func(t *testing.T) {
captured := require.CapturePanic(func() {
_, _ = MmapCodeSegment(bytes.NewBuffer(make([]byte, 0)), 0)
_, _ = MmapCodeSegment(0)
})
require.EqualError(t, captured, "BUG: MmapCodeSegment with zero length")
})
Expand All @@ -37,10 +27,9 @@ func Test_MunmapCodeSegment(t *testing.T) {
}

// Errors if never mapped
require.Error(t, MunmapCodeSegment(testCodeBuf))
require.Error(t, MunmapCodeSegment([]byte{1, 2, 3, 5}))

testCodeReader := bytes.NewReader(testCodeBuf)
newCode, err := MmapCodeSegment(testCodeReader, testCodeReader.Len())
newCode, err := MmapCodeSegment(100)
require.NoError(t, err)
// First munmap should succeed.
require.NoError(t, MunmapCodeSegment(newCode))
Expand Down
9 changes: 6 additions & 3 deletions internal/platform/mmap_unsupported.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package platform

import (
"fmt"
"io"
"runtime"
)

Expand All @@ -14,10 +13,14 @@ func munmapCodeSegment(code []byte) error {
panic(errUnsupported)
}

func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) {
func mmapCodeSegmentAMD64(size int) ([]byte, error) {
panic(errUnsupported)
}

func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) {
func mmapCodeSegmentARM64(size int) ([]byte, error) {
panic(errUnsupported)
}

func MprotectRX(b []byte) (err error) {
panic(errUnsupported)
}
26 changes: 9 additions & 17 deletions internal/platform/mmap_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package platform

import (
"fmt"
"io"
"reflect"
"syscall"
"unsafe"
Expand Down Expand Up @@ -58,7 +57,7 @@ func virtualProtect(address, size, newprotect uintptr, oldprotect *uint32) error
return nil
}

func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) {
func mmapCodeSegmentAMD64(size int) ([]byte, error) {
p, err := allocateMemory(uintptr(size), windows_PAGE_EXECUTE_READWRITE)
if err != nil {
return nil, err
Expand All @@ -69,13 +68,10 @@ func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) {
sh.Data = p
sh.Len = size
sh.Cap = size

w := &bufWriter{underlying: mem}
_, err = io.CopyN(w, code, int64(size))
return mem, err
}

func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) {
func mmapCodeSegmentARM64(size int) ([]byte, error) {
p, err := allocateMemory(uintptr(size), windows_PAGE_READWRITE)
if err != nil {
return nil, err
Expand All @@ -86,20 +82,16 @@ func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) {
sh.Data = p
sh.Len = size
sh.Cap = size
w := &bufWriter{underlying: mem}
_, err = io.CopyN(w, code, int64(size))
if err != nil {
return nil, err
}

old := uint32(windows_PAGE_READWRITE)
err = virtualProtect(p, uintptr(size), windows_PAGE_EXECUTE_READ, &old)
if err != nil {
return nil, err
}
return mem, nil
}

var old = uint32(windows_PAGE_READWRITE)

func MprotectRX(b []byte) (err error) {
err = virtualProtect(uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), windows_PAGE_EXECUTE_READ, &old)
return
}

// ensureErr returns syscall.EINVAL when the input error is nil.
//
// We are supposed to use "GetLastError" which is more precise, but it is not safe to execute in goroutines. While
Expand Down
7 changes: 3 additions & 4 deletions internal/platform/platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package platform

import (
"errors"
"io"
"runtime"
"strings"
)
Expand All @@ -33,14 +32,14 @@ func CompilerSupported() bool {
// MmapCodeSegment copies the code into the executable region and returns the byte slice of the region.
//
// See https://man7.org/linux/man-pages/man2/mmap.2.html for mmap API and flags.
func MmapCodeSegment(code io.Reader, size int) ([]byte, error) {
func MmapCodeSegment(size int) ([]byte, error) {
if size == 0 {
panic(errors.New("BUG: MmapCodeSegment with zero length"))
}
if runtime.GOARCH == "amd64" {
return mmapCodeSegmentAMD64(code, size)
return mmapCodeSegmentAMD64(size)
} else {
return mmapCodeSegmentARM64(code, size)
return mmapCodeSegmentARM64(size)
}
}

Expand Down

0 comments on commit d9e5d6b

Please sign in to comment.