diff --git a/internal/engine/compiler/engine_cache.go b/internal/engine/compiler/engine_cache.go index 9540dc2773..d971975120 100644 --- a/internal/engine/compiler/engine_cache.go +++ b/internal/engine/compiler/engine_cache.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "io" + "runtime" "github.com/tetratelabs/wazero/internal/platform" "github.com/tetratelabs/wazero/internal/u32" @@ -163,12 +164,27 @@ func deserializeCodes(wazeroVersion string, reader io.ReadCloser) (codes []*code break } - if c.codeSegment, err = platform.MmapCodeSegment(reader, int(nativeCodeLen)); err != nil { + if c.codeSegment, err = platform.MmapCodeSegment(int(nativeCodeLen)); err != nil { err = fmt.Errorf("compilationcache: error mmapping func[%d] code (len=%d): %v", i, nativeCodeLen, err) break } codes = append(codes, c) + + _, err = io.ReadFull(reader, c.codeSegment) + if err != nil { + err = fmt.Errorf("compilationcache: error reading func[%d] code (len=%d): %v", i, nativeCodeLen, err) + break + } + + if runtime.GOARCH == "arm64" { + // On arm64, we cannot give all of rwx at the same time, so we change it to exec. + err = platform.MprotectRX(c.codeSegment) + if err != nil { + break + } + } + } if err != nil { diff --git a/internal/engine/compiler/engine_cache_test.go b/internal/engine/compiler/engine_cache_test.go index 44e0956220..513badab6a 100644 --- a/internal/engine/compiler/engine_cache_test.go +++ b/internal/engine/compiler/engine_cache_test.go @@ -192,7 +192,7 @@ func TestDeserializeCodes(t *testing.T) { u64.LeBytes(5), // length of code. // Lack of code here. ), - expErr: "compilationcache: error mmapping func[1] code (len=5): EOF", + expErr: "compilationcache: error reading func[1] code (len=5): EOF", }, } diff --git a/internal/engine/compiler/impl_amd64.go b/internal/engine/compiler/impl_amd64.go index 9926072ac7..e3fb5a3650 100644 --- a/internal/engine/compiler/impl_amd64.go +++ b/internal/engine/compiler/impl_amd64.go @@ -5,7 +5,6 @@ package compiler // if unfamiliar with amd64 instructions used here. import ( - "bytes" "fmt" "math" @@ -96,7 +95,6 @@ type amd64Compiler struct { assignStackPointerCeilNeeded asm.Node withListener bool typ *wasm.FunctionType - br *bytes.Reader // locationStackForEntrypoint is the initial location stack for all functions. To reuse the allocated stack, // we cache it here, and reset and set to .locationStack in the Init method. locationStackForEntrypoint runtimeValueLocationStack @@ -110,7 +108,6 @@ func newAmd64Compiler() compiler { assembler: amd64.NewAssembler(), locationStackForEntrypoint: newRuntimeValueLocationStack(), cpuFeatures: platform.CpuFeatures, - br: bytes.NewReader(nil), } return c } @@ -127,7 +124,6 @@ func (c *amd64Compiler) Init(typ *wasm.FunctionType, ir *wazeroir.CompilationRes assembler: c.assembler, cpuFeatures: c.cpuFeatures, labels: c.labels, - br: c.br, locationStackForEntrypoint: c.locationStackForEntrypoint, brTableTmp: c.brTableTmp, } @@ -283,13 +279,17 @@ func (c *amd64Compiler) compile() (code []byte, stackPointerCeil uint64, err err // Note this MUST be called before Assemble() below. c.assignStackPointerCeil(stackPointerCeil) - code, err = c.assembler.Assemble() + var original []byte + original, err = c.assembler.Assemble() if err != nil { return } - c.br.Reset(code) - code, err = platform.MmapCodeSegment(c.br, len(code)) + code, err = platform.MmapCodeSegment(len(original)) + if err != nil { + return + } + copy(code, original) return } diff --git a/internal/engine/compiler/impl_arm64.go b/internal/engine/compiler/impl_arm64.go index 7e4782b427..72546a3722 100644 --- a/internal/engine/compiler/impl_arm64.go +++ b/internal/engine/compiler/impl_arm64.go @@ -160,8 +160,13 @@ func (c *arm64Compiler) compile() (code []byte, stackPointerCeil uint64, err err return } - c.br.Reset(original) - code, err = platform.MmapCodeSegment(c.br, len(original)) + code, err = platform.MmapCodeSegment(len(original)) + if err != nil { + return + } + copy(code, original) + // On arm64, we cannot give all of rwx at the same time, so we change it to exec. + err = platform.MprotectRX(code) return } diff --git a/internal/platform/buf_writer.go b/internal/platform/buf_writer.go deleted file mode 100644 index 5d53d16788..0000000000 --- a/internal/platform/buf_writer.go +++ /dev/null @@ -1,20 +0,0 @@ -package platform - -// bufWriter implements io.Writer. -// -// This is implemented because bytes.Buffer cannot write from the beginning of the underlying buffer -// without changing the memory location. In this case, the underlying buffer is memory-mapped region, -// and we have to write into that region via io.Copy since sometimes the original native code exists -// as a file for external-cached cases. -type bufWriter struct { - underlying []byte - pos int -} - -// Write implements io.Writer Write. -func (b *bufWriter) Write(p []byte) (n int, err error) { - copy(b.underlying[b.pos:], p) - n = len(p) - b.pos += n - return -} diff --git a/internal/platform/mmap.go b/internal/platform/mmap.go index ce53dfc58e..870528dbd7 100644 --- a/internal/platform/mmap.go +++ b/internal/platform/mmap.go @@ -4,7 +4,6 @@ package platform import ( - "io" "syscall" "unsafe" ) @@ -15,8 +14,8 @@ func munmapCodeSegment(code []byte) error { // mmapCodeSegmentAMD64 gives all read-write-exec permission to the mmap region // to enter the function. Otherwise, segmentation fault exception is raised. -func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) { - mmapFunc, err := syscall.Mmap( +func mmapCodeSegmentAMD64(size int) ([]byte, error) { + buf, err := syscall.Mmap( -1, 0, size, @@ -29,18 +28,15 @@ func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) { if err != nil { return nil, err } - - w := &bufWriter{underlying: mmapFunc} - _, err = io.CopyN(w, code, int64(size)) - return mmapFunc, err + return buf, err } // mmapCodeSegmentARM64 cannot give all read-write-exec permission to the mmap region. // Otherwise, the mmap systemcall would raise an error. Here we give read-write -// to the region at first, write the native code and then change the perm to -// read-exec, so we can execute the native code. -func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) { - mmapFunc, err := syscall.Mmap( +// to the region so that we can write contents at call-sites. Callers are responsible to +// execute MprotectRX on the returned buffer. +func mmapCodeSegmentARM64(size int) ([]byte, error) { + buf, err := syscall.Mmap( -1, 0, size, @@ -53,26 +49,18 @@ func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) { if err != nil { return nil, err } - - w := &bufWriter{underlying: mmapFunc} - _, err = io.CopyN(w, code, int64(size)) - if err != nil { - return nil, err - } - - // Then we're done with writing code, change the permission to RX. - err = mprotect(mmapFunc, syscall.PROT_READ|syscall.PROT_EXEC) - return mmapFunc, err + return buf, err } -// mprotect is like syscall.Mprotect, defined locally so that freebsd compiles. -func mprotect(b []byte, prot int) (err error) { +// MprotectRX is like syscall.Mprotect, defined locally so that freebsd compiles. +func MprotectRX(b []byte) (err error) { var _p0 unsafe.Pointer if len(b) > 0 { _p0 = unsafe.Pointer(&b[0]) } else { _p0 = unsafe.Pointer(&_zero) } + const prot = syscall.PROT_READ | syscall.PROT_EXEC _, _, e1 := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(_p0), uintptr(len(b)), uintptr(prot)) if e1 != 0 { err = syscall.Errno(e1) diff --git a/internal/platform/mmap_test.go b/internal/platform/mmap_test.go index 1747748ad1..74496083ae 100644 --- a/internal/platform/mmap_test.go +++ b/internal/platform/mmap_test.go @@ -1,31 +1,21 @@ package platform import ( - "bytes" - "crypto/rand" - "io" "testing" "github.com/tetratelabs/wazero/internal/testing/require" ) -var testCodeBuf, _ = io.ReadAll(io.LimitReader(rand.Reader, 8*1024)) - func Test_MmapCodeSegment(t *testing.T) { if !CompilerSupported() { t.Skip() } - testCodeReader := bytes.NewReader(testCodeBuf) - newCode, err := MmapCodeSegment(testCodeReader, testCodeReader.Len()) + _, err := MmapCodeSegment(1234) require.NoError(t, err) - // Verify that the mmap is the same as the original. - require.Equal(t, testCodeBuf, newCode) - // TODO: test newCode can executed. - t.Run("panic on zero length", func(t *testing.T) { captured := require.CapturePanic(func() { - _, _ = MmapCodeSegment(bytes.NewBuffer(make([]byte, 0)), 0) + _, _ = MmapCodeSegment(0) }) require.EqualError(t, captured, "BUG: MmapCodeSegment with zero length") }) @@ -37,10 +27,9 @@ func Test_MunmapCodeSegment(t *testing.T) { } // Errors if never mapped - require.Error(t, MunmapCodeSegment(testCodeBuf)) + require.Error(t, MunmapCodeSegment([]byte{1, 2, 3, 5})) - testCodeReader := bytes.NewReader(testCodeBuf) - newCode, err := MmapCodeSegment(testCodeReader, testCodeReader.Len()) + newCode, err := MmapCodeSegment(100) require.NoError(t, err) // First munmap should succeed. require.NoError(t, MunmapCodeSegment(newCode)) diff --git a/internal/platform/mmap_unsupported.go b/internal/platform/mmap_unsupported.go index 2db2baf0d7..b9de9a1f74 100644 --- a/internal/platform/mmap_unsupported.go +++ b/internal/platform/mmap_unsupported.go @@ -4,7 +4,6 @@ package platform import ( "fmt" - "io" "runtime" ) @@ -14,10 +13,14 @@ func munmapCodeSegment(code []byte) error { panic(errUnsupported) } -func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) { +func mmapCodeSegmentAMD64(size int) ([]byte, error) { panic(errUnsupported) } -func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) { +func mmapCodeSegmentARM64(size int) ([]byte, error) { + panic(errUnsupported) +} + +func MprotectRX(b []byte) (err error) { panic(errUnsupported) } diff --git a/internal/platform/mmap_windows.go b/internal/platform/mmap_windows.go index cf6d419b9e..efa98a85bc 100644 --- a/internal/platform/mmap_windows.go +++ b/internal/platform/mmap_windows.go @@ -2,7 +2,6 @@ package platform import ( "fmt" - "io" "reflect" "syscall" "unsafe" @@ -58,7 +57,7 @@ func virtualProtect(address, size, newprotect uintptr, oldprotect *uint32) error return nil } -func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) { +func mmapCodeSegmentAMD64(size int) ([]byte, error) { p, err := allocateMemory(uintptr(size), windows_PAGE_EXECUTE_READWRITE) if err != nil { return nil, err @@ -69,13 +68,10 @@ func mmapCodeSegmentAMD64(code io.Reader, size int) ([]byte, error) { sh.Data = p sh.Len = size sh.Cap = size - - w := &bufWriter{underlying: mem} - _, err = io.CopyN(w, code, int64(size)) return mem, err } -func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) { +func mmapCodeSegmentARM64(size int) ([]byte, error) { p, err := allocateMemory(uintptr(size), windows_PAGE_READWRITE) if err != nil { return nil, err @@ -86,20 +82,16 @@ func mmapCodeSegmentARM64(code io.Reader, size int) ([]byte, error) { sh.Data = p sh.Len = size sh.Cap = size - w := &bufWriter{underlying: mem} - _, err = io.CopyN(w, code, int64(size)) - if err != nil { - return nil, err - } - - old := uint32(windows_PAGE_READWRITE) - err = virtualProtect(p, uintptr(size), windows_PAGE_EXECUTE_READ, &old) - if err != nil { - return nil, err - } return mem, nil } +var old = uint32(windows_PAGE_READWRITE) + +func MprotectRX(b []byte) (err error) { + err = virtualProtect(uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), windows_PAGE_EXECUTE_READ, &old) + return +} + // ensureErr returns syscall.EINVAL when the input error is nil. // // We are supposed to use "GetLastError" which is more precise, but it is not safe to execute in goroutines. While diff --git a/internal/platform/platform.go b/internal/platform/platform.go index 7cf6230cd8..518e11e413 100644 --- a/internal/platform/platform.go +++ b/internal/platform/platform.go @@ -6,7 +6,6 @@ package platform import ( "errors" - "io" "runtime" "strings" ) @@ -33,14 +32,14 @@ func CompilerSupported() bool { // MmapCodeSegment copies the code into the executable region and returns the byte slice of the region. // // See https://man7.org/linux/man-pages/man2/mmap.2.html for mmap API and flags. -func MmapCodeSegment(code io.Reader, size int) ([]byte, error) { +func MmapCodeSegment(size int) ([]byte, error) { if size == 0 { panic(errors.New("BUG: MmapCodeSegment with zero length")) } if runtime.GOARCH == "amd64" { - return mmapCodeSegmentAMD64(code, size) + return mmapCodeSegmentAMD64(size) } else { - return mmapCodeSegmentARM64(code, size) + return mmapCodeSegmentARM64(size) } }