diff --git a/RATIONALE.md b/RATIONALE.md index d8a1d12e2b..43daeeb7d4 100644 --- a/RATIONALE.md +++ b/RATIONALE.md @@ -1376,9 +1376,7 @@ as socket handles. ### Clock Subscriptions As detailed above in [sys.Nanosleep](#sysnanosleep), `poll_oneoff` handles -relative clock subscriptions. In our implementation we use `sys.Nanosleep()` -for this purpose in most cases, except when polling for interactive input -from `os.Stdin` (see more details below). +relative clock subscriptions. In our implementation we use `sys.Nanosleep()`. ### FdRead and FdWrite Subscriptions @@ -1386,23 +1384,26 @@ When subscribing a file descriptor (except `Stdin`) for reads or writes, the implementation will generally return immediately with success, unless the file descriptor is unknown. The file descriptor is not checked further for new incoming data. Any timeout is cancelled, and the API call is able -to return, unless there are subscriptions to `Stdin`: these are handled -separately. +to return, unless there are subscriptions to blocking file descriptors: +these are handled separately. -### FdRead and FdWrite Subscription to Stdin +### FdRead and FdWrite Subscription to Blocking File Descriptors -Subscribing `Stdin` for reads (writes make no sense and cause an error), -requires extra care: wazero allows to configure a custom reader for `Stdin`. +Subscribing a file descriptor for reads requires extra care: +wazero allows to plug an entire custom virtual file system, +and it also allows to configure custom readers and writers for standard I/O +descriptors. -In general, if a custom reader is found, the behavior will be the same -as for regular file descriptors: data is assumed to be present and -a success is written back to the result buffer. +In general, if the file reports to be in non-blocking mode, +the behavior will be the same as for regular file descriptors: +data is assumed to be present and a success is written back to the result buffer. -However, if the reader is detected to read from `os.Stdin`, -a special code path is followed, invoking `sysfs.poll()`. +However, if the file is reported to be in blocking mode (the default), +the `fsapi.File.Poll()` method is invoked. -`sysfs.poll()` is a wrapper for `poll(2)` on POSIX systems, -and it is emulated on Windows. +For regular files, stdin, pipes and sockets, `fsapi.File.Poll()` +is a wrapper for `poll(2)` on POSIX systems, and it is emulated on Windows. +Virtual file systems may provide their own custom implementation. ### Poll on POSIX @@ -1410,24 +1411,19 @@ On POSIX systems, `poll(2)` allows to wait for incoming data on a file descriptor, and block until either data becomes available or the timeout expires. -Usage of `syfs.poll()` is currently only reserved for standard input, because - -1. it is really only necessary to handle interactive input: otherwise, - there is no way in Go to peek from Standard Input without actually - reading (and thus consuming) from it; - -2. if `Stdin` is connected to a pipe, it is ok in most cases to return - with success immediately; - -3. `syfs.poll()` is currently a blocking call, irrespective of goroutines, - because the underlying syscall is; thus, it is better to limit its usage. +Usage of `syfs.poll()` is reserved to blocking I/O. In particular, +it is used most often with pipes (such as `os.Stdin`) and TCP sockets. So, if the subscription is for `os.Stdin` and the handle is detected to correspond to an interactive session, then `sysfs.poll()` will be -invoked with a the `Stdin` handle *and* the timeout. +invoked with the `Stdin` file descriptor. -This also means that in this specific case, the timeout is uninterruptible, -unless data becomes available on `Stdin` itself. +In order to avoid a blocking call, the underlying `sysfs.poll()` call +is repeatedly invoked with a 0 timeout at given intervals (currently 100 ms, +until the given timeout expires). + +The timeout and the tick both honor the settings for `sys.Nanosleep()`. +This also implies that `sys.Nanosleep()` has to be properly configured. ### Select on Windows @@ -1457,15 +1453,18 @@ which plays nicely with the rest of the Go runtime. ### Impact of blocking -Because this is a blocking syscall, it will also block the carrier thread of -the goroutine, preventing any means to support context cancellation directly. - -There are ways to obviate this issue. We outline here one idea, that is however -not currently implemented. A common approach to support context cancellation is -to add a signal file descriptor to the set, e.g. the read-end of a pipe or an -eventfd on Linux. When the context is canceled, we may unblock a Select call by -writing to the fd, causing it to return immediately. This however requires to -do a bit of housekeeping to hide the "special" FD from the end-user. +Because this is a blocking syscall, invoking it with a nonzero timeout will also +block the carrier thread of the goroutine, preventing any means +to support context cancellation directly. + +We obviate this by invoking `poll` with a 0 timeout repeatedly, +at given intervals (currently, 100 ms). We outline here another idea: +a common approach to support context cancellation is to add a signal +file descriptor to the set, e.g. the read-end of a pipe or an +eventfd on Linux. When the context is canceled, we may unblock a Select +call by writing to the fd, causing it to return immediately. +This however requires to do a bit of housekeeping to hide the "special" FD +from the end-user. [poll_oneoff]: https://github.com/WebAssembly/wasi-poll#why-is-the-function-called-poll_oneoff [async-io-windows]: https://tinyclouds.org/iocp_links diff --git a/imports/wasi_snapshot_preview1/poll.go b/imports/wasi_snapshot_preview1/poll.go index 0119b5410f..4248e81f33 100644 --- a/imports/wasi_snapshot_preview1/poll.go +++ b/imports/wasi_snapshot_preview1/poll.go @@ -28,6 +28,7 @@ import ( // - sys.ENOTSUP: a parameters is valid, but not yet supported. // - sys.EFAULT: there is not enough memory to read the subscriptions or // write results. +// - sys.EINTR: an OS interrupt has occurred while invoking the syscall. // // # Notes // @@ -42,11 +43,15 @@ var pollOneoff = newHostFunc( "in", "out", "nsubscriptions", "result.nevents", ) -type event struct { +type pollEvent struct { eventType byte userData []byte errno wasip1.Errno - outOffset uint32 +} + +type filePollEvent struct { + f *internalsys.FileEntry + e *pollEvent } func pollOneoffFn(_ context.Context, mod api.Module, params []uint64) sys.Errno { @@ -86,36 +91,34 @@ func pollOneoffFn(_ context.Context, mod api.Module, params []uint64) sys.Errno // Extract FS context, used in the body of the for loop for FS access. fsc := mod.(*wasm.ModuleInstance).Sys.FS() - // Slice of events that are processed out of the loop (blocking stdin subscribers). - var blockingStdinSubs []*event + // Slice of events that are processed out of the loop (blocking subscribers). + var blockingSubs []*filePollEvent // The timeout is initialized at max Duration, the loop will find the minimum. var timeout time.Duration = 1<<63 - 1 - // Count of all the clock subscribers that have been already written back to outBuf. - clockEvents := uint32(0) - // Count of all the non-clock subscribers that have been already written back to outBuf. - readySubs := uint32(0) + // Count of all the subscriptions that have been already written back to outBuf. + // nevents*32 returns at all times the offset where the next event should be written: + // this way we ensure that there are no gaps between records. + nevents := uint32(0) // Layout is subscription_u: Union // https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#subscription_u for i := uint32(0); i < nsubscriptions; i++ { inOffset := i * 48 - outOffset := i * 32 + outOffset := nevents * 32 eventType := inBuf[inOffset+8] // +8 past userdata // +8 past userdata +8 contents_offset argBuf := inBuf[inOffset+8+8:] userData := inBuf[inOffset : inOffset+8] - evt := &event{ + evt := &pollEvent{ eventType: eventType, userData: userData, errno: wasip1.ErrnoSuccess, - outOffset: outOffset, } switch eventType { case wasip1.EventTypeClock: // handle later - clockEvents++ newTimeout, err := processClockEvent(argBuf) if err != 0 { return err @@ -125,7 +128,8 @@ func pollOneoffFn(_ context.Context, mod api.Module, params []uint64) sys.Errno timeout = newTimeout } // Ack the clock event to the outBuf. - writeEvent(outBuf, evt) + writeEvent(outBuf[outOffset:], evt) + nevents++ case wasip1.EventTypeFdRead: fd := int32(le.Uint32(argBuf)) if fd < 0 { @@ -133,16 +137,16 @@ func pollOneoffFn(_ context.Context, mod api.Module, params []uint64) sys.Errno } if file, ok := fsc.LookupFile(fd); !ok { evt.errno = wasip1.ErrnoBadf - writeEvent(outBuf, evt) - readySubs++ - continue - } else if fd == internalsys.FdStdin && !file.File.IsNonblock() { - // if the fd is Stdin, and it is in non-blocking mode, - // do not ack yet, append to a slice for delayed evaluation. - blockingStdinSubs = append(blockingStdinSubs, evt) + writeEvent(outBuf[outOffset:], evt) + nevents++ + } else if file.File.IsNonblock() { + writeEvent(outBuf[outOffset:], evt) + nevents++ } else { - writeEvent(outBuf, evt) - readySubs++ + // If the fd is blocking, do not ack yet, + // append to a slice for delayed evaluation. + fe := &filePollEvent{f: file, e: evt} + blockingSubs = append(blockingSubs, fe) } case wasip1.EventTypeFdWrite: fd := int32(le.Uint32(argBuf)) @@ -154,47 +158,46 @@ func pollOneoffFn(_ context.Context, mod api.Module, params []uint64) sys.Errno } else { evt.errno = wasip1.ErrnoBadf } - readySubs++ - writeEvent(outBuf, evt) + nevents++ + writeEvent(outBuf[outOffset:], evt) default: return sys.EINVAL } } - // If there are subscribers with data ready, we have already written them to outBuf, - // and we don't need to wait for the timeout: clear it. - if readySubs != 0 { - timeout = 0 + sysCtx := mod.(*wasm.ModuleInstance).Sys + if nevents == nsubscriptions { + // We already wrote back all the results. We already wrote this number + // earlier to offset `resultNevents`. + // We only need to observe the timeout (nonzero if there are clock subscriptions) + // and return. + if timeout > 0 { + sysCtx.Nanosleep(int64(timeout)) + } + return 0 + } + + // If nevents != nsubscriptions, then there are blocking subscribers. + // We check these fds once using poll. + n, errno := pollFileEventsOnce(blockingSubs, outBuf[nevents*32:]) + if errno != 0 { + return errno } + nevents += n - // If there are blocking stdin subscribers, check for data with given timeout. - if len(blockingStdinSubs) > 0 { - stdin, ok := fsc.LookupFile(internalsys.FdStdin) - if !ok { - return sys.EBADF - } - // Wait for the timeout to expire, or for some data to become available on Stdin. - stdinReady, errno := stdin.File.Poll(sys.POLLIN, int32(timeout.Milliseconds())) + // If the previous poll returned n == 0 (no data) but the timeout is nonzero + // (i.e. there are clock subscriptions), we poll until either the timeout expires + // or any File.Poll() returns true ("ready"); otherwise we are done. + if n == 0 && timeout > 0 { + n, errno = pollFileEventsUntil(sysCtx, timeout, blockingSubs, outBuf[nevents*32:]) if errno != 0 { return errno } - if stdinReady { - // stdin has data ready to for reading, write back all the events - for i := range blockingStdinSubs { - readySubs++ - evt := blockingStdinSubs[i] - evt.errno = 0 - writeEvent(outBuf, evt) - } - } - } else { - // No subscribers, just wait for the given timeout. - sysCtx := mod.(*wasm.ModuleInstance).Sys - sysCtx.Nanosleep(int64(timeout)) + nevents += n } - if readySubs != nsubscriptions { - if !mod.Memory().WriteUint32Le(resultNevents, readySubs+clockEvents) { + if nevents != nsubscriptions { + if !mod.Memory().WriteUint32Le(resultNevents, nevents) { return sys.EFAULT } } @@ -233,10 +236,60 @@ func processClockEvent(inBuf []byte) (time.Duration, sys.Errno) { // writeEvent writes the event corresponding to the processed subscription. // https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-event-struct -func writeEvent(outBuf []byte, evt *event) { - copy(outBuf[evt.outOffset:], evt.userData) // userdata - outBuf[evt.outOffset+8] = byte(evt.errno) // uint16, but safe as < 255 - outBuf[evt.outOffset+9] = 0 - le.PutUint32(outBuf[evt.outOffset+10:], uint32(evt.eventType)) +func writeEvent(outBuf []byte, evt *pollEvent) { + copy(outBuf, evt.userData) // userdata + outBuf[8] = byte(evt.errno) // uint16, but safe as < 255 + outBuf[9] = 0 + le.PutUint32(outBuf[10:], uint32(evt.eventType)) // TODO: When FD events are supported, write outOffset+16 } + +// closeChAfter closes a channel after the given timeout. +// It is similar to time.After but it uses sysCtx.Nanosleep. +func closeChAfter(sysCtx *internalsys.Context, timeout time.Duration, timeoutCh chan struct{}) { + sysCtx.Nanosleep(int64(timeout)) + close(timeoutCh) +} + +// pollFileEventsOnce invokes Poll on each sys.FileEntry in the given slice +// and writes back the result to outBuf for each file reported "ready"; +// i.e., when Poll() returns true, and no error. +func pollFileEventsOnce(evts []*filePollEvent, outBuf []byte) (n uint32, errno sys.Errno) { + // For simplicity, we assume that there are no multiple subscriptions for the same file. + for _, e := range evts { + isReady, errno := e.f.File.Poll(sys.POLLIN, 0) + if errno != 0 { + return 0, errno + } + if isReady { + e.e.errno = 0 + writeEvent(outBuf[n*32:], e.e) + n++ + } + } + return +} + +// pollFileEventsUntil repeatedly invokes pollFileEventsOnce until the given timeout is reached. +// The poll interval is currently fixed at 100 millis. +func pollFileEventsUntil(sysCtx *internalsys.Context, timeout time.Duration, blockingSubs []*filePollEvent, outBuf []byte) (n uint32, errno sys.Errno) { + timeoutCh := make(chan struct{}, 1) + go closeChAfter(sysCtx, timeout, timeoutCh) + + pollInterval := 100 * time.Millisecond + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-timeoutCh: + // Give one last chance before returning. + return pollFileEventsOnce(blockingSubs, outBuf) + case <-ticker.C: + n, errno = pollFileEventsOnce(blockingSubs, outBuf) + if errno != 0 || n > 0 { + return + } + } + } +} diff --git a/imports/wasi_snapshot_preview1/poll_test.go b/imports/wasi_snapshot_preview1/poll_test.go index 6db7d6926d..a2d5291da7 100644 --- a/imports/wasi_snapshot_preview1/poll_test.go +++ b/imports/wasi_snapshot_preview1/poll_test.go @@ -2,12 +2,15 @@ package wasi_snapshot_preview1_test import ( "io/fs" + "net" + "os" "strings" "testing" "time" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/experimental/sock" experimentalsys "github.com/tetratelabs/wazero/experimental/sys" "github.com/tetratelabs/wazero/internal/sys" "github.com/tetratelabs/wazero/internal/testing/require" @@ -401,6 +404,291 @@ func setStdin(t *testing.T, mod api.Module, stdin experimentalsys.File) { f.File = stdin } +func Test_pollOneoff_Mixed(t *testing.T) { + // Test stdin (pipes) mixed with sockets. + + const listenFd = 3 + const acceptFd = 4 + + type addr interface { + Addr() *net.TCPAddr + } + + tests := []struct { + name string + in, out, nsubscriptions, resultNevents uint32 + connected, nonblocking bool + mem []byte // at offset in + files []experimentalsys.File + expectedErrno wasip1.Errno + expectedMem []byte // at offset out + expectedLog string + expectedNevents uint32 + }{ + { + name: "Read from sock (not connected)", + nsubscriptions: 1, + expectedNevents: 0, + mem: fdReadSubFd(listenFd), // assume sock at fd 3 + expectedErrno: wasip1.ErrnoSuccess, + out: 128, // past in + resultNevents: 512, // past out + expectedMem: []byte{ + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + + '?', // stopped after encoding + }, + expectedLog: ` +==> wasi_snapshot_preview1.poll_oneoff(in=0,out=128,nsubscriptions=1) +<== (nevents=0,errno=ESUCCESS) +`, + }, + { + name: "Read from sock (connected)", + connected: true, + nsubscriptions: 2, + expectedNevents: 1, + mem: append(fdReadSubFd(listenFd), fdReadSubFd(acceptFd)...), // assume sock at fd 3 + expectedErrno: wasip1.ErrnoSuccess, + out: 128, // past in + resultNevents: 512, // past out + expectedMem: []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // userdata + byte(wasip1.ErrnoSuccess), 0x0, // errno is 16 bit + wasip1.EventTypeFdRead, 0x0, 0x0, 0x0, // 4 bytes for type enum + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, + + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + + '?', // stopped after encoding + }, + expectedLog: ` +==> wasi_snapshot_preview1.poll_oneoff(in=0,out=128,nsubscriptions=2) +<== (nevents=1,errno=ESUCCESS) +`, + }, + + { + name: "Read from sock (connected+nonblocking)", + connected: true, + nonblocking: true, + nsubscriptions: 2, + expectedNevents: 2, + mem: append(fdReadSubFd(listenFd), fdReadSubFd(acceptFd)...), // assume sock at fd 3 + expectedErrno: wasip1.ErrnoSuccess, + out: 128, // past in + resultNevents: 512, // past out + expectedMem: []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // userdata + byte(wasip1.ErrnoSuccess), 0x0, // errno is 16 bit + wasip1.EventTypeFdRead, 0x0, 0x0, 0x0, // 4 bytes for type enum + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, + + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // userdata + byte(wasip1.ErrnoSuccess), 0x0, // errno is 16 bit + wasip1.EventTypeFdRead, 0x0, 0x0, 0x0, // 4 bytes for type enum + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, + + '?', // stopped after encoding + }, + expectedLog: ` +==> wasi_snapshot_preview1.poll_oneoff(in=0,out=128,nsubscriptions=2) +<== (nevents=2,errno=ESUCCESS) +`, + }, + + { + name: "Read from sock (not connected) and stdin", + nsubscriptions: 2, + expectedNevents: 1, + mem: append(fdReadSubFd(listenFd), fdReadSub...), // assume sock at fd 3 + expectedErrno: wasip1.ErrnoSuccess, + out: 128, // past in + resultNevents: 512, // past out + expectedMem: []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // userdata + byte(wasip1.ErrnoSuccess), 0x0, // errno is 16 bit + wasip1.EventTypeFdRead, 0x0, 0x0, 0x0, // 4 bytes for type enum + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, + + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + + '?', // stopped after encoding + }, + expectedLog: ` +==> wasi_snapshot_preview1.poll_oneoff(in=0,out=128,nsubscriptions=2) +<== (nevents=1,errno=ESUCCESS) +`, + }, + + { + name: "Read from sock (connected) and stdin (ready)", + connected: true, + nsubscriptions: 3, + expectedNevents: 2, + mem: append(append(fdReadSubFd(listenFd), fdReadSubFd(acceptFd)...), fdReadSub...), // assume sock at fd 3 + expectedErrno: wasip1.ErrnoSuccess, + out: 128, // past in + resultNevents: 512, // past out + expectedMem: []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // userdata + byte(wasip1.ErrnoSuccess), 0x0, // errno is 16 bit + wasip1.EventTypeFdRead, 0x0, 0x0, 0x0, // 4 bytes for type enum + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, + + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // userdata + byte(wasip1.ErrnoSuccess), 0x0, // errno is 16 bit + wasip1.EventTypeFdRead, 0x0, 0x0, 0x0, // 4 bytes for type enum + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, + + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + + '?', // stopped after encoding + }, + expectedLog: ` +==> wasi_snapshot_preview1.poll_oneoff(in=0,out=128,nsubscriptions=3) +<== (nevents=2,errno=ESUCCESS) +`, + }, + { + name: "Read from sock (connected+nonblocking) and stdin (ready)", + connected: true, + nonblocking: true, + nsubscriptions: 3, + expectedNevents: 3, + mem: append(append(fdReadSubFd(listenFd), fdReadSubFd(acceptFd)...), fdReadSub...), // assume sock at fd 3 + expectedErrno: wasip1.ErrnoSuccess, + out: 128, // past in + resultNevents: 512, // past out + expectedMem: []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // userdata + byte(wasip1.ErrnoSuccess), 0x0, // errno is 16 bit + wasip1.EventTypeFdRead, 0x0, 0x0, 0x0, // 4 bytes for type enum + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, + + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // userdata + byte(wasip1.ErrnoSuccess), 0x0, // errno is 16 bit + wasip1.EventTypeFdRead, 0x0, 0x0, 0x0, // 4 bytes for type enum + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, + + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // userdata + byte(wasip1.ErrnoSuccess), 0x0, // errno is 16 bit + wasip1.EventTypeFdRead, 0x0, 0x0, 0x0, // 4 bytes for type enum + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, + + '?', // stopped after encoding + }, + expectedLog: ` +==> wasi_snapshot_preview1.poll_oneoff(in=0,out=128,nsubscriptions=3) +<== (nevents=3,errno=ESUCCESS) +`, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + ctx := sock.WithConfig(testCtx, + sock.NewConfig().WithTCPListener("127.0.0.1", 0)) + + stdinReader, stdinWriter, err := os.Pipe() + require.NoError(t, err) + defer stdinReader.Close() + defer stdinWriter.Close() + + mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig().WithStdin(stdinReader)) + _, _ = stdinWriter.Write([]byte("wazero")) + + defer r.Close(ctx) + defer log.Reset() + + maskMemory(t, mod, 1024) + if tc.mem != nil { + mod.Memory().Write(tc.in, tc.mem) + } + + if tc.connected { + fsc := mod.(*wasm.ModuleInstance).Sys.FS() + ch := make(chan struct{}, 1) + file, _ := fsc.LookupFile(listenFd) + if tc.nonblocking { + _ = file.File.SetNonblock(true) + } + + go func() { + for { + _, errno := fsc.SockAccept(listenFd, false) + if errno == experimentalsys.EAGAIN { + continue + } + require.EqualErrno(t, 0, errno) + close(ch) + return + } + }() + + // Wait for the socket to accept. + sleepALittle() + + addr := file.File.(addr) + c, err := net.DialTCP("tcp", nil, addr.Addr()) + + <-ch + + require.NoError(t, err) + _, _ = c.Write([]byte("wazero")) + } + + requireErrnoResult(t, tc.expectedErrno, mod, wasip1.PollOneoffName, uint64(tc.in), uint64(tc.out), + uint64(tc.nsubscriptions), uint64(tc.resultNevents)) + require.Equal(t, tc.expectedLog, "\n"+log.String()) + + out, ok := mod.Memory().Read(tc.out, uint32(len(tc.expectedMem))) + require.True(t, ok) + require.Equal(t, tc.expectedMem, out) + + // Events should be written on success regardless of nested failure. + if tc.expectedErrno == wasip1.ErrnoSuccess { + nevents, ok := mod.Memory().ReadUint32Le(tc.resultNevents) + require.True(t, ok) + require.Equal(t, tc.expectedNevents, nevents) + _ = nevents + } + }) + } +} + func Test_pollOneoff_Zero(t *testing.T) { poller := &pollStdinFile{StdinFile: sys.StdinFile{Reader: strings.NewReader("test")}, ready: true} @@ -522,6 +810,10 @@ func fdReadSubFd(fd byte) []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // userdata wasip1.EventTypeFdRead, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, fd, 0x0, 0x0, 0x0, // valid readable FD + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, } } diff --git a/imports/wasi_snapshot_preview1/sock_test.go b/imports/wasi_snapshot_preview1/sock_test.go index 2bafbd374d..d1b8df9a24 100644 --- a/imports/wasi_snapshot_preview1/sock_test.go +++ b/imports/wasi_snapshot_preview1/sock_test.go @@ -47,7 +47,7 @@ func Test_sockAccept(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := experimentalsock.WithConfig(testCtx, experimentalsock.NewConfig().WithTCPListener("127.0.0.1", 0)) - mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig()) + mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig().WithSysNanosleep()) defer r.Close(testCtx) // Dial the socket so that a call to accept doesn't hang. @@ -100,7 +100,7 @@ func Test_sockShutdown(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := experimentalsock.WithConfig(testCtx, experimentalsock.NewConfig().WithTCPListener("127.0.0.1", 0)) - mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig()) + mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig().WithSysNanosleep()) defer r.Close(testCtx) // Dial the socket so that a call to accept doesn't hang. @@ -326,7 +326,7 @@ func Test_sockSend(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := experimentalsock.WithConfig(testCtx, experimentalsock.NewConfig().WithTCPListener("127.0.0.1", 0)) - mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig()) + mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig().WithSysNanosleep()) defer r.Close(testCtx) // Dial the socket so that a call to accept doesn't hang. diff --git a/imports/wasi_snapshot_preview1/testdata/cargo-wasi/Cargo.toml b/imports/wasi_snapshot_preview1/testdata/cargo-wasi/Cargo.toml index f1e93d7b4e..8282d20797 100644 --- a/imports/wasi_snapshot_preview1/testdata/cargo-wasi/Cargo.toml +++ b/imports/wasi_snapshot_preview1/testdata/cargo-wasi/Cargo.toml @@ -9,3 +9,4 @@ path = "wasi.rs" [dependencies] libc = "0.2" +mio = {version="0.8", features=["os-poll", "net"]} diff --git a/imports/wasi_snapshot_preview1/testdata/cargo-wasi/wasi.rs b/imports/wasi_snapshot_preview1/testdata/cargo-wasi/wasi.rs index 3641e53949..d0345483ca 100644 --- a/imports/wasi_snapshot_preview1/testdata/cargo-wasi/wasi.rs +++ b/imports/wasi_snapshot_preview1/testdata/cargo-wasi/wasi.rs @@ -6,6 +6,10 @@ use std::net::{TcpListener}; use std::os::wasi::io::FromRawFd; use std::process::exit; use std::str::from_utf8; +use std::error::Error; + +use std::collections::HashMap; +use std::time::Duration; // Until NotADirectory is implemented, read the underlying error raised by // wasi-libc. See https://github.com/rust-lang/rust/issues/86442 @@ -23,6 +27,7 @@ fn main() { } "stat" => main_stat(), "sock" => main_sock(), + "mixed" => main_mixed().unwrap(), _ => { writeln!(io::stderr(), "unknown command: {}", args[1]).unwrap(); exit(1); @@ -87,3 +92,103 @@ fn main_sock() { } } } + +fn main_mixed() -> Result<(), Box> { + + use mio::net::{TcpListener, TcpStream}; + use mio::{Events, Interest, Poll, Token}; + + // Some tokens to allow us to identify which event is for which socket. + const SERVER: Token = Token(0); + const STDIN: Token = Token(1); + + // Create a poll instance. + let mut poll = Poll::new()?; + // Create storage for events. + let mut events = Events::with_capacity(128); + + let mut server = unsafe { TcpListener::from_raw_fd(3) }; + let mut stdin = unsafe { TcpStream::from_raw_fd(0) }; + + + // Start listening for incoming connections. + poll.registry() + .register(&mut server, SERVER, Interest::READABLE)?; + + // Keep track of incoming connections. + let mut m: HashMap = HashMap::new(); + + let mut count = 2; + + // Start an event loop. + loop { + // Poll Mio for events, blocking until we get an event. + if let Err(e) = poll.poll(&mut events, Some(Duration::from_nanos(0))) { + // Ignore EINTR. + if e.kind() == std::io::ErrorKind::Interrupted { + continue; + } + return Err(Box::from(e)) + } + + // Process each event. + for event in events.iter() { + // We can use the token we previously provided to `register` to + // determine for which socket the event is. + match event.token() { + SERVER => { + // If this is an event for the server, it means a connection + // is ready to be accepted. + // + // Accept the connection and add it to the map. + match server.accept() { + Ok((mut connection, _addr)) => { + let tok = Token(count); + _ = poll.registry() + .register(&mut connection, tok, Interest::READABLE); + m.insert(tok, connection); + // drop(connection); + count+=1; + }, + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // ignore + }, + Err(err) => panic!("ERROR! {}", err), + } + }, + + STDIN => { + // There is for reading on one of our connections, read it and echo. + let mut buf = [0u8; 32]; + match stdin.read(&mut buf) { + Ok(n) if n>0 => + println!("{}", String::from_utf8_lossy(&buf[0..n])), + _ => {} // ignore error. + } + }, + + + conn_id => { + // There is for reading on one of our connections, read it and echo. + let mut buf = [0u8; 32]; + let mut el = m.get(&conn_id).unwrap(); + match el.read(&mut buf) { + Ok(n) if n>0 => { + let s = String::from_utf8_lossy(&buf[0..n]); + println!("{}", s); + // Quit when the socket contains the string wazero. + if s.contains("wazero") { + return Ok(()); + } + }, + + _ => {} // ignore error. + } + } + } + } + + } +} + + diff --git a/imports/wasi_snapshot_preview1/testdata/cargo-wasi/wasi.wasm b/imports/wasi_snapshot_preview1/testdata/cargo-wasi/wasi.wasm index 2ddde1064d..8cafb21ed2 100644 Binary files a/imports/wasi_snapshot_preview1/testdata/cargo-wasi/wasi.wasm and b/imports/wasi_snapshot_preview1/testdata/cargo-wasi/wasi.wasm differ diff --git a/imports/wasi_snapshot_preview1/testdata/gotip/wasi.go b/imports/wasi_snapshot_preview1/testdata/gotip/wasi.go index 78fd70ce21..993b5ea753 100644 --- a/imports/wasi_snapshot_preview1/testdata/gotip/wasi.go +++ b/imports/wasi_snapshot_preview1/testdata/gotip/wasi.go @@ -38,6 +38,10 @@ func main() { if err := mainSock(); err != nil { panic(err) } + case "mixed": + if err := mainMixed(); err != nil { + panic(err) + } case "nonblock": if err := mainNonblock(os.Args[2], os.Args[3:]); err != nil { panic(err) @@ -142,6 +146,67 @@ func mainSock() error { return nil } +// mainMixed is an explicit test of a blocking socket + stdin pipe. +// It exercises poll_oneoff by setting read deadlines. +func mainMixed() error { + // Get a listener from the pre-opened file descriptor. + // The listener is the first pre-open, with a file-descriptor of 3. + f := os.NewFile(3, "") + l, err := net.FileListener(f) + defer f.Close() + if err != nil { + return err + } + defer l.Close() + + ch1 := make(chan error) + ch2 := make(chan error) + + go func() { + // Accept a connection + conn, err := l.Accept() + if err != nil { + ch1 <- err + return + } + defer conn.Close() + + // Do a blocking read of up to 32 bytes. + // Note: the test should write: "wazero", so that's all we should read. + var buf [32]byte + // Force a deadline to involve netpoll. + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) + n, err := conn.Read(buf[:]) + if err != nil { + ch1 <- err + return + } + fmt.Println(string(buf[:n])) + close(ch1) + }() + + go func() { + // Force a deadline to involve netpoll. + _ = os.Stdin.SetReadDeadline(time.Now().Add(time.Second)) + b, err := io.ReadAll(os.Stdin) + if err != nil { + ch2 <- err + return + } + os.Stdout.Write(b) + close(ch2) + }() + err1 := <-ch1 + err2 := <-ch2 + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + return nil +} + // Adapted from nonblock.go // https://github.com/golang/go/blob/0fcc70ecd56e3b5c214ddaee4065ea1139ae16b5/src/runtime/internal/wasitest/testdata/nonblock.go func mainNonblock(mode string, files []string) error { diff --git a/imports/wasi_snapshot_preview1/testdata/zig-cc/wasi.c b/imports/wasi_snapshot_preview1/testdata/zig-cc/wasi.c index 818b8c330a..f8afbc743c 100644 --- a/imports/wasi_snapshot_preview1/testdata/zig-cc/wasi.c +++ b/imports/wasi_snapshot_preview1/testdata/zig-cc/wasi.c @@ -54,7 +54,7 @@ void main_poll(int timeout, int millis) { tv.tv_usec = millis*1000; ret = select(1, &rfds, NULL, NULL, &tv); if ((ret > 0) && FD_ISSET(0, &rfds)) { - printf("STDIN\n"); + printf("STDIN\n", ret); } else { printf("NOINPUT\n"); } @@ -121,7 +121,7 @@ void main_open_wronly() { unlink(path); } -void main_sock() { +void main_sock_mixed(bool checkStdin) { // Get a listener from the pre-opened file descriptor. // The listener is the first pre-open, with a file-descriptor of 3. int listener_fd = 3; @@ -148,7 +148,13 @@ void main_sock() { struct timeval tv = {1, 0}; fd_set set; FD_ZERO(&set); - FD_SET(nfd, &set); + if (checkStdin) { + FD_SET(0, &set); + FD_SET(nfd, &set); + FD_SET(listener_fd, &set); + } else { + FD_SET(nfd, &set); + } int ret = select(nfd+1, &set, NULL, NULL, &tv); // If some data is available, read it. @@ -162,6 +168,14 @@ void main_sock() { } } +void main_sock() { + main_sock_mixed(false); +} + +void main_mixed() { + main_sock_mixed(true); +} + void main_nonblock(char* fpath) { struct timespec tim, tim2; tim.tv_sec = 0; @@ -212,6 +226,8 @@ int main(int argc, char** argv) { main_open_wronly(); } else if (strcmp(argv[1],"sock")==0) { main_sock(); + } else if (strcmp(argv[1],"mixed")==0) { + main_mixed(); } else if (strcmp(argv[1],"nonblock")==0) { main_nonblock(argv[2]); } else { diff --git a/imports/wasi_snapshot_preview1/testdata/zig-cc/wasi.wasm b/imports/wasi_snapshot_preview1/testdata/zig-cc/wasi.wasm index 255a198f09..e75963d132 100755 Binary files a/imports/wasi_snapshot_preview1/testdata/zig-cc/wasi.wasm and b/imports/wasi_snapshot_preview1/testdata/zig-cc/wasi.wasm differ diff --git a/imports/wasi_snapshot_preview1/wasi_stdlib_test.go b/imports/wasi_snapshot_preview1/wasi_stdlib_test.go index 3e4933ff37..9007ff2434 100644 --- a/imports/wasi_snapshot_preview1/wasi_stdlib_test.go +++ b/imports/wasi_snapshot_preview1/wasi_stdlib_test.go @@ -19,6 +19,8 @@ import ( "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" + experimentalapi "github.com/tetratelabs/wazero/experimental" + "github.com/tetratelabs/wazero/experimental/logging" experimentalsock "github.com/tetratelabs/wazero/experimental/sock" experimentalsys "github.com/tetratelabs/wazero/experimental/sys" "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" @@ -374,12 +376,12 @@ func Test_Poll(t *testing.T) { tc := tt t.Run(tc.name, func(t *testing.T) { start := time.Now() - console := compileAndRunWithPreStart(t, testCtx, wazero.NewModuleConfig().WithArgs(tc.args...), wasmZigCc, + console := compileAndRunWithPreStart(t, testCtx, wazero.NewModuleConfig().WithSysNanosleep().WithArgs(tc.args...), wasmZigCc, func(t *testing.T, mod api.Module) { setStdin(t, mod, tc.stdin) }) elapsed := time.Since(start) - require.True(t, elapsed >= tc.expectedTimeout) + require.True(t, elapsed >= tc.expectedTimeout, "Elapsed %d < expected %d", elapsed, tc.expectedTimeout) require.Equal(t, tc.expectedOutput+"\n", console) }) } @@ -398,7 +400,8 @@ func Test_Sleep(t *testing.T) { moduleConfig := wazero.NewModuleConfig().WithArgs("wasi", "sleepmillis", "100").WithSysNanosleep() start := time.Now() console := compileAndRun(t, testCtx, moduleConfig, wasmZigCc) - require.True(t, time.Since(start) >= 100*time.Millisecond) + elapsed := time.Since(start) + require.True(t, elapsed >= 100*time.Millisecond, "elapsed %d ns < 100 ms", elapsed) require.Equal(t, "OK\n", console) } @@ -455,7 +458,7 @@ func Test_Sock(t *testing.T) { func testSock(t *testing.T, bin []byte) { sockCfg := experimentalsock.NewConfig().WithTCPListener("127.0.0.1", 0) ctx := experimentalsock.WithConfig(testCtx, sockCfg) - moduleConfig := wazero.NewModuleConfig().WithArgs("wasi", "sock") + moduleConfig := wazero.NewModuleConfig().WithArgs("wasi", "sock").WithSysNanosleep() tcpAddrCh := make(chan *net.TCPAddr, 1) ch := make(chan string, 1) go func() { @@ -634,3 +637,72 @@ func testLargeStdout(t *testing.T, tname string, bin []byte) { require.NoError(t, err, string(output)) } } + +func Test_Mixed(t *testing.T) { + toolchains := map[string][]byte{ + "cargo-wasi": wasmCargoWasi, + "zig-cc": wasmZigCc, + } + if wasmGotip != nil { + toolchains["gotip"] = wasmGotip + } + + for toolchain, bin := range toolchains { + toolchain := toolchain + bin := bin + t.Run(toolchain, func(t *testing.T) { + testMixed(t, bin) + }) + } +} + +func testMixed(t *testing.T, bin []byte) { + // This is almost identical to testSock, except we also hook a pipe to stdin + // We expect poll_oneoff to be invoked successfully. + sockCfg := experimentalsock.NewConfig().WithTCPListener("127.0.0.1", 0) + + var logBuf bytes.Buffer + ctx := experimentalsock.WithConfig(testCtx, sockCfg) + ctx = context.WithValue(ctx, experimentalapi.FunctionListenerFactoryKey{}, + logging.NewHostLoggingListenerFactory(&logBuf, logging.LogScopePoll)) + + r, w, err := os.Pipe() + require.NoError(t, err) + defer r.Close() + defer w.Close() + + require.NoError(t, err) + moduleConfig := wazero.NewModuleConfig().WithArgs("wasi", "mixed").WithSysNanosleep().WithStdin(r) + tcpAddrCh := make(chan *net.TCPAddr, 1) + ch := make(chan string, 1) + go func() { + ch <- compileAndRunWithPreStart(t, ctx, moduleConfig, bin, func(t *testing.T, mod api.Module) { + tcpAddrCh <- requireTCPListenerAddr(t, mod) + }) + }() + tcpAddr := <-tcpAddrCh + + // Give a little time for _start to complete + sleepALittle() + + // Now dial to the initial address, which should be now held by wazero. + conn, err := net.Dial("tcp", tcpAddr.String()) + require.NoError(t, err) + defer conn.Close() + + go func() { + _, err = w.Write([]byte("wazero")) + require.NoError(t, err) + err = w.Close() + require.NoError(t, err) + n, err := conn.Write([]byte("wazero")) + require.NotEqual(t, 0, n) + require.NoError(t, err) + }() + console := <-ch + + // The log should print poll_oneoff at least once. + require.Contains(t, logBuf.String(), "poll_oneoff", logBuf.String()) + // Nonblocking connections may contain error logging/whitespace, we ignore those. + require.Equal(t, "wazero", strings.TrimSpace(console[len(console)-7:])) +} diff --git a/internal/sysfs/sock_unix.go b/internal/sysfs/sock_unix.go index ef67952091..1bc2803f1c 100644 --- a/internal/sysfs/sock_unix.go +++ b/internal/sysfs/sock_unix.go @@ -41,8 +41,9 @@ var _ socketapi.TCPSock = (*tcpListenerFile)(nil) type tcpListenerFile struct { baseSockFile - fd uintptr - addr *net.TCPAddr + fd uintptr + addr *net.TCPAddr + nonblock bool } // Accept implements the same method as documented on socketapi.TCPSock @@ -55,11 +56,22 @@ func (f *tcpListenerFile) Accept() (socketapi.TCPConn, sys.Errno) { return &tcpConnFile{fd: uintptr(nfd)}, 0 } +// Poll implements the same method as documented on sys.File +func (f *tcpListenerFile) Poll(flag sys.Pflag, timeoutMillis int32) (ready bool, errno sys.Errno) { + return poll(f.fd, flag, timeoutMillis) +} + // SetNonblock implements the same method as documented on sys.File func (f *tcpListenerFile) SetNonblock(enabled bool) sys.Errno { + f.nonblock = enabled return sys.UnwrapOSError(setNonblock(f.fd, enabled)) } +// IsNonblock implements the same method as documented on sys.File +func (f *tcpListenerFile) IsNonblock() bool { + return f.nonblock +} + // Close implements the same method as documented on sys.File func (f *tcpListenerFile) Close() sys.Errno { return sys.UnwrapOSError(syscall.Close(int(f.fd))) @@ -75,7 +87,8 @@ var _ socketapi.TCPConn = (*tcpConnFile)(nil) type tcpConnFile struct { baseSockFile - fd uintptr + fd uintptr + nonblock bool // closed is true when closed was called. This ensures proper sys.EBADF closed bool @@ -91,9 +104,20 @@ func newTcpConn(tc *net.TCPConn) socketapi.TCPConn { // SetNonblock implements the same method as documented on sys.File func (f *tcpConnFile) SetNonblock(enabled bool) (errno sys.Errno) { + f.nonblock = enabled return sys.UnwrapOSError(setNonblock(f.fd, enabled)) } +// IsNonblock implements the same method as documented on sys.File +func (f *tcpConnFile) IsNonblock() bool { + return f.nonblock +} + +// Poll implements the same method as documented on sys.File +func (f *tcpConnFile) Poll(flag sys.Pflag, timeoutMillis int32) (ready bool, errno sys.Errno) { + return poll(f.fd, flag, timeoutMillis) +} + // Read implements the same method as documented on sys.File func (f *tcpConnFile) Read(buf []byte) (n int, errno sys.Errno) { n, err := syscall.Read(int(f.fd), buf) diff --git a/internal/sysfs/sock_windows.go b/internal/sysfs/sock_windows.go index 325d739f92..3c5a1b9339 100644 --- a/internal/sysfs/sock_windows.go +++ b/internal/sysfs/sock_windows.go @@ -100,7 +100,6 @@ func syscallConnControl(conn syscall.Conn, fn func(fd uintptr) (int, sys.Errno)) // because they are sensibly different from Unix's. func newTCPListenerFile(tl *net.TCPListener) socketapi.TCPSock { w := &winTcpListenerFile{tl: tl} - _ = w.SetNonblock(true) return w } @@ -116,14 +115,11 @@ type winTcpListenerFile struct { // Accept implements the same method as documented on socketapi.TCPSock func (f *winTcpListenerFile) Accept() (socketapi.TCPConn, sys.Errno) { - // Ensure we have an incoming connection using winsock_select. - n, errno := syscallConnControl(f.tl, func(fd uintptr) (int, sys.Errno) { - return _poll([]pollFd{newPollFd(fd, _POLLIN, 0)}, 0) - }) - - // Otherwise return immediately. - if n == 0 || errno != 0 { - return nil, sys.EAGAIN + // Ensure we have an incoming connection using winsock_select, otherwise return immediately. + if f.nonblock { + if ready, errno := f.Poll(sys.POLLIN, 0); !ready || errno != 0 { + return nil, sys.EAGAIN + } } // Accept normally blocks goroutines, but we @@ -136,6 +132,11 @@ func (f *winTcpListenerFile) Accept() (socketapi.TCPConn, sys.Errno) { } } +// Poll implements the same method as documented on sys.File +func (f *winTcpListenerFile) Poll(flag sys.Pflag, timeoutMillis int32) (ready bool, errno sys.Errno) { + return _pollSock(f.tl, flag, timeoutMillis) +} + // IsNonblock implements File.IsNonblock func (f *winTcpListenerFile) IsNonblock() bool { return f.nonblock @@ -197,6 +198,11 @@ func (f *winTcpConnFile) IsNonblock() bool { return f.nonblock } +// Poll implements the same method as documented on sys.File +func (f *winTcpConnFile) Poll(flag sys.Pflag, timeoutMillis int32) (ready bool, errno sys.Errno) { + return _pollSock(f.tc, flag, timeoutMillis) +} + // Read implements the same method as documented on sys.File func (f *winTcpConnFile) Read(buf []byte) (n int, errno sys.Errno) { if len(buf) == 0 { @@ -272,3 +278,13 @@ func (f *winTcpConnFile) close() sys.Errno { f.closed = true return f.Shutdown(syscall.SHUT_RDWR) } + +func _pollSock(conn syscall.Conn, flag sys.Pflag, timeoutMillis int32) (bool, sys.Errno) { + if flag != sys.POLLIN { + return false, sys.ENOTSUP + } + n, errno := syscallConnControl(conn, func(fd uintptr) (int, sys.Errno) { + return _poll([]pollFd{newPollFd(fd, _POLLIN, 0)}, timeoutMillis) + }) + return n > 0, errno +}