Skip to content

Commit

Permalink
Adds AccessMode Write and Pwrite to platform.File (#1438)
Browse files Browse the repository at this point in the history
Signed-off-by: Adrian Cole <[email protected]>
  • Loading branch information
codefromthecrypt authored May 5, 2023
1 parent e46e803 commit b5198a4
Show file tree
Hide file tree
Showing 17 changed files with 238 additions and 302 deletions.
33 changes: 24 additions & 9 deletions imports/assemblyscript/assemblyscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ import (
"io"
"strconv"
"strings"
"syscall"
"unicode/utf16"

"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
. "github.com/tetratelabs/wazero/internal/assemblyscript"
"github.com/tetratelabs/wazero/internal/platform"
internalsys "github.com/tetratelabs/wazero/internal/sys"
"github.com/tetratelabs/wazero/internal/wasm"
"github.com/tetratelabs/wazero/sys"
Expand Down Expand Up @@ -163,10 +165,13 @@ func abortWithMessage(ctx context.Context, mod api.Module, stack []uint64) {
columnNumber := uint32(stack[3])

// Don't panic if there was a problem reading the message
stderr := internalsys.WriterForFile(fsc, internalsys.FdStderr)
if msg, msgOk := readAssemblyScriptString(mem, message); msgOk && stderr != nil {
if fn, fnOk := readAssemblyScriptString(mem, fileName); fnOk {
_, _ = fmt.Fprintf(stderr, "%s at %s:%d:%d\n", msg, fn, lineNumber, columnNumber)
stderr := stdioFile(fsc, internalsys.FdStderr)
if stderr != nil {
if msg, msgOk := readAssemblyScriptString(mem, message); msgOk && stderr != nil {
if fn, fnOk := readAssemblyScriptString(mem, fileName); fnOk {
s := fmt.Sprintf("%s at %s:%d:%d\n", msg, fn, lineNumber, columnNumber)
_, _ = stderr.Write([]byte(s))
}
}
}
abort(ctx, mod, stack)
Expand Down Expand Up @@ -197,15 +202,15 @@ var traceStdout = &wasm.HostFunc{
Code: wasm.Code{
GoFunc: api.GoModuleFunc(func(_ context.Context, mod api.Module, stack []uint64) {
fsc := mod.(*wasm.ModuleInstance).Sys.FS()
traceTo(mod, stack, internalsys.WriterForFile(fsc, internalsys.FdStdout))
traceTo(mod, stack, stdioFile(fsc, internalsys.FdStdout))
}),
},
}

// traceStderr implements trace to the configured Stderr.
var traceStderr = traceStdout.WithGoModuleFunc(func(_ context.Context, mod api.Module, stack []uint64) {
fsc := mod.(*wasm.ModuleInstance).Sys.FS()
traceTo(mod, stack, internalsys.WriterForFile(fsc, internalsys.FdStderr))
traceTo(mod, stack, stdioFile(fsc, internalsys.FdStderr))
})

// traceTo implements the function "trace" in AssemblyScript. e.g.
Expand All @@ -218,8 +223,8 @@ var traceStderr = traceStdout.WithGoModuleFunc(func(_ context.Context, mod api.M
// (import "env" "trace" (func $~lib/builtins/trace (param i32 i32 f64 f64 f64 f64 f64)))
//
// See https://github.com/AssemblyScript/assemblyscript/blob/fa14b3b03bd4607efa52aaff3132bea0c03a7989/std/assembly/wasi/index.ts#L61
func traceTo(mod api.Module, params []uint64, writer io.Writer) {
if writer == nil {
func traceTo(mod api.Module, params []uint64, file platform.File) {
if file == nil {
return // closed
}
message := uint32(params[0])
Expand Down Expand Up @@ -258,7 +263,7 @@ func traceTo(mod api.Module, params []uint64, writer io.Writer) {
ret.WriteString(formatFloat(arg4))
}
ret.WriteByte('\n')
_, _ = writer.Write([]byte(ret.String())) // don't crash if trace logging fails
_, _ = file.Write([]byte(ret.String())) // don't crash if trace logging fails
}

func formatFloat(f float64) string {
Expand Down Expand Up @@ -318,3 +323,13 @@ func decodeUTF16(b []byte) string {

return string(utf16.Decode(u16s))
}

func stdioFile(fsc *internalsys.FSContext, fd int32) platform.File {
if f, ok := fsc.LookupFile(fd); !ok {
return nil
} else if f.File.AccessMode() == syscall.O_RDONLY {
return nil
} else {
return f.File
}
}
49 changes: 30 additions & 19 deletions imports/wasi_snapshot_preview1/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ func fdFdstatGetFn(_ context.Context, mod api.Module, params []uint64) syscall.E
return syscall.EBADF
} else if st, errno = f.Stat(); errno != 0 {
return errno
} else if _, ok := f.File.File().(io.Writer); ok {
// TODO: maybe cache flags to open instead
} else if f.File.AccessMode() != syscall.O_RDONLY {
fdflags = wasip1.FD_APPEND
}

Expand Down Expand Up @@ -1287,6 +1286,23 @@ func fdWriteFn(_ context.Context, mod api.Module, params []uint64) syscall.Errno
return fdWriteOrPwrite(mod, params, false)
}

// pwriter tracks an offset across multiple writes.
type pwriter struct {
f platform.File
offset int64
}

// Write implements the same function as documented on platform.File.
func (w *pwriter) Write(p []byte) (n int, errno syscall.Errno) {
if len(p) == 0 {
return 0, 0 // less overhead on zero-length writes.
}

n, err := w.f.Pwrite(p, w.offset)
w.offset += int64(n)
return n, err
}

func fdWriteOrPwrite(mod api.Module, params []uint64, isPwrite bool) syscall.Errno {
mem := mod.Memory()
fsc := mod.(*wasm.ModuleInstance).Sys.FS()
Expand All @@ -1296,20 +1312,20 @@ func fdWriteOrPwrite(mod api.Module, params []uint64, isPwrite bool) syscall.Err
iovsCount := uint32(params[2])

var resultNwritten uint32
var writer io.Writer
var writer func(p []byte) (n int, errno syscall.Errno)
if f, ok := fsc.LookupFile(fd); !ok {
return syscall.EBADF
} else if f.File.AccessMode() == syscall.O_RDONLY {
return syscall.EBADF
} else if isPwrite {
offset := int64(params[3])
writer = sysfs.WriterAtOffset(f.File.File(), offset)
writer = (&pwriter{f: f.File, offset: offset}).Write
resultNwritten = uint32(params[4])
} else if writer, ok = f.File.File().(io.Writer); !ok {
return syscall.EBADF
} else {
writer = f.File.Write
resultNwritten = uint32(params[3])
}

var err error
var nwritten uint32
iovsStop := iovsCount << 3 // iovsCount * 8
iovsBuf, ok := mem.Read(iovs, iovsStop)
Expand All @@ -1321,20 +1337,15 @@ func fdWriteOrPwrite(mod api.Module, params []uint64, isPwrite bool) syscall.Err
offset := le.Uint32(iovsBuf[iovsPos:])
l := le.Uint32(iovsBuf[iovsPos+4:])

var n int
if writer == io.Discard { // special-case default
n = int(l)
} else {
b, ok := mem.Read(offset, l)
if !ok {
return syscall.EFAULT
}
n, err = writer.Write(b)
if err != nil {
return platform.UnwrapOSError(err)
}
b, ok := mem.Read(offset, l)
if !ok {
return syscall.EFAULT
}
n, errno := writer(b)
nwritten += uint32(n)
if errno != 0 {
return errno
}
}

if !mod.Memory().WriteUint32Le(resultNwritten, nwritten) {
Expand Down
38 changes: 19 additions & 19 deletions imports/wasi_snapshot_preview1/fs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2884,7 +2884,7 @@ func Test_fdWrite_discard(t *testing.T) {
func Test_fdWrite_Errors(t *testing.T) {
tmpDir := t.TempDir() // open before loop to ensure no locking problems.
pathName := "test_path"
mod, fd, log, r := requireOpenFile(t, tmpDir, pathName, nil, false)
mod, fd, log, r := requireOpenFile(t, tmpDir, pathName, []byte{1, 2, 3, 4}, false)
defer r.Close(testCtx)

// Setup valid test memory
Expand Down Expand Up @@ -3789,10 +3789,7 @@ func Test_pathOpen(t *testing.T) {
path: func(t *testing.T) (file string) { return appendName },
fdflags: wasip1.FD_APPEND,
expected: func(t *testing.T, fsc *sys.FSContext) {
contents := []byte("hello")
_, err := sys.WriterForFile(fsc, expectedOpenedFd).Write(contents)
require.NoError(t, err)
require.Zero(t, fsc.CloseFile(expectedOpenedFd))
contents := writeAndCloseFile(t, fsc, expectedOpenedFd)

// verify the contents were appended
b := readFile(t, dir, appendName)
Expand Down Expand Up @@ -3821,10 +3818,7 @@ func Test_pathOpen(t *testing.T) {
oflags: wasip1.O_CREAT,
expected: func(t *testing.T, fsc *sys.FSContext) {
// expect to create a new file
contents := []byte("hello")
_, err := sys.WriterForFile(fsc, expectedOpenedFd).Write(contents)
require.NoError(t, err)
require.Zero(t, fsc.CloseFile(expectedOpenedFd))
contents := writeAndCloseFile(t, fsc, expectedOpenedFd)

// verify the contents were written
b := readFile(t, dir, "creat")
Expand Down Expand Up @@ -3853,10 +3847,7 @@ func Test_pathOpen(t *testing.T) {
oflags: wasip1.O_CREAT | wasip1.O_TRUNC,
expected: func(t *testing.T, fsc *sys.FSContext) {
// expect to create a new file
contents := []byte("hello")
_, err := sys.WriterForFile(fsc, expectedOpenedFd).Write(contents)
require.NoError(t, err)
require.Zero(t, fsc.CloseFile(expectedOpenedFd))
contents := writeAndCloseFile(t, fsc, expectedOpenedFd)

// verify the contents were written
b := readFile(t, dir, joinPath(dirName, "O_CREAT-O_TRUNC"))
Expand Down Expand Up @@ -3918,10 +3909,7 @@ func Test_pathOpen(t *testing.T) {
path: func(t *testing.T) (file string) { return "trunc" },
oflags: wasip1.O_TRUNC,
expected: func(t *testing.T, fsc *sys.FSContext) {
contents := []byte("hello")
_, err := sys.WriterForFile(fsc, expectedOpenedFd).Write(contents)
require.NoError(t, err)
require.Zero(t, fsc.CloseFile(expectedOpenedFd))
contents := writeAndCloseFile(t, fsc, expectedOpenedFd)

// verify the contents were truncated
b := readFile(t, dir, "trunc")
Expand Down Expand Up @@ -3985,6 +3973,16 @@ func Test_pathOpen(t *testing.T) {
}
}

func writeAndCloseFile(t *testing.T, fsc *sys.FSContext, fd int32) []byte {
contents := []byte("hello")
f, ok := fsc.LookupFile(fd)
require.True(t, ok)
_, errno := f.File.Write([]byte("hello"))
require.EqualErrno(t, 0, errno)
require.EqualErrno(t, 0, fsc.CloseFile(fd))
return contents
}

func requireOpenFD(t *testing.T, mod api.Module, path string) int32 {
fsc := mod.(*wasm.ModuleInstance).Sys.FS()
preopen := fsc.RootFS()
Expand Down Expand Up @@ -4911,6 +4909,9 @@ func Test_pathUnlinkFile_Errors(t *testing.T) {

func requireOpenFile(t *testing.T, tmpDir string, pathName string, data []byte, readOnly bool) (api.Module, int32, *bytes.Buffer, api.Closer) {
oflags := os.O_RDWR
if readOnly {
oflags = os.O_RDONLY
}

realPath := joinPath(tmpDir, pathName)
if data == nil {
Expand All @@ -4923,7 +4924,6 @@ func requireOpenFile(t *testing.T, tmpDir string, pathName string, data []byte,
fsConfig := wazero.NewFSConfig()

if readOnly {
oflags = os.O_RDONLY
fsConfig = fsConfig.WithReadOnlyDirMount(tmpDir, "/")
} else {
fsConfig = fsConfig.WithDirMount(tmpDir, "/")
Expand Down Expand Up @@ -5076,5 +5076,5 @@ func joinPath(dirName, baseName string) string {
func openFsFile(t *testing.T, path string, flag int, perm fs.FileMode) platform.File {
f, errno := platform.OpenFile(path, flag, perm)
require.EqualErrno(t, 0, errno)
return platform.NewFsFile(path, f)
return platform.NewFsFile(path, flag, f)
}
8 changes: 5 additions & 3 deletions imports/wasi_snapshot_preview1/poll.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,12 @@ func processFDEventRead(fsc *internalsys.FSContext, fd int32) wasip1.Errno {

// processFDEventWrite returns ErrnoNotsup if the file exists and ErrnoBadf otherwise.
func processFDEventWrite(fsc *internalsys.FSContext, fd int32) wasip1.Errno {
if internalsys.WriterForFile(fsc, fd) == nil {
return wasip1.ErrnoBadf
if f, ok := fsc.LookupFile(fd); ok {
if f.File.AccessMode() != syscall.O_RDONLY {
return wasip1.ErrnoNotsup
}
}
return wasip1.ErrnoNotsup
return wasip1.ErrnoBadf
}

// writeEvent writes the event corresponding to the processed subscription.
Expand Down
33 changes: 13 additions & 20 deletions internal/gojs/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,35 +281,28 @@ func (jsfsWrite) invoke(ctx context.Context, mod api.Module, args ...interface{}
callback := args[5].(funcWrapper)

if byteCount > 0 { // empty is possible on EOF
n, err := syscallWrite(mod, fd, fOffset, buf.Unwrap()[offset:offset+byteCount])
return callback.invoke(ctx, mod, goos.RefJsfs, err, n) // note: error first
n, errno := syscallWrite(mod, fd, fOffset, buf.Unwrap()[offset:offset+byteCount])
var err error
if errno != 0 {
err = errno
}
// It is safe to cast to uint32 because n <= uint32(byteCount).
return callback.invoke(ctx, mod, goos.RefJsfs, err, uint32(n)) // note: error first
}
return callback.invoke(ctx, mod, goos.RefJsfs, nil, goos.RefValueZero)
}

// syscallWrite is like syscall.Write
func syscallWrite(mod api.Module, fd int32, offset interface{}, p []byte) (n uint32, err error) {
func syscallWrite(mod api.Module, fd int32, offset interface{}, p []byte) (n int, errno syscall.Errno) {
fsc := mod.(*wasm.ModuleInstance).Sys.FS()

var writer io.Writer
if f, ok := fsc.LookupFile(fd); !ok {
err = syscall.EBADF
errno = syscall.EBADF
} else if f.File.AccessMode() == syscall.O_RDONLY {
errno = syscall.EBADF
} else if offset != nil {
writer = sysfs.WriterAtOffset(f.File.File(), toInt64(offset))
} else if writer, ok = f.File.File().(io.Writer); !ok {
err = syscall.EBADF
}

if err != nil {
return
}

if nWritten, e := writer.Write(p); e == nil || e == io.EOF {
// fs_js.go cannot parse io.EOF so coerce it to nil.
// See https://github.com/golang/go/issues/43913
n = uint32(nWritten)
n, errno = f.File.Pwrite(p, toInt64(offset))
} else {
err = e
n, errno = f.File.Write(p)
}
return
}
Expand Down
10 changes: 6 additions & 4 deletions internal/gojs/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package gojs
import (
"context"
"fmt"
"syscall"
"time"

"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/internal/gojs/custom"
"github.com/tetratelabs/wazero/internal/gojs/goarch"
internalsys "github.com/tetratelabs/wazero/internal/sys"
"github.com/tetratelabs/wazero/internal/wasm"
)

Expand Down Expand Up @@ -42,10 +42,12 @@ func wasmWrite(_ context.Context, mod api.Module, stack goarch.Stack) {
p := stack.ParamBytes(mod.Memory(), 1 /*, 2 */)

fsc := mod.(*wasm.ModuleInstance).Sys.FS()
if writer := internalsys.WriterForFile(fsc, fd); writer == nil {
if f, ok := fsc.LookupFile(fd); ok && f.File.AccessMode() != syscall.O_RDONLY {
if _, err := f.File.Write(p); err != 0 {
panic(fmt.Errorf("error writing p: %w", err))
}
} else {
panic(fmt.Errorf("fd %d invalid", fd))
} else if _, err := writer.Write(p); err != nil {
panic(fmt.Errorf("error writing p: %w", err))
}
}

Expand Down
Loading

0 comments on commit b5198a4

Please sign in to comment.