Skip to content

Commit

Permalink
fix(input): use cancel io to cancel windows read events (#360)
Browse files Browse the repository at this point in the history
This gets rid of the overlapping I/O and replace it with Windows
Cancel I/O.

Related: charmbracelet/bubbletea#1167
Related: charmbracelet/bubbletea@920d07b
  • Loading branch information
aymanbagabas authored Feb 4, 2025
1 parent 9ed0fca commit fe292ba
Showing 1 changed file with 9 additions and 100 deletions.
109 changes: 9 additions & 100 deletions input/cancelreader_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"os"
"sync"
"time"

xwindows "github.com/charmbracelet/x/windows"
"github.com/muesli/cancelreader"
Expand All @@ -17,14 +16,8 @@ import (

type conInputReader struct {
cancelMixin

conin windows.Handle
cancelEvent windows.Handle

conin windows.Handle
originalMode uint32

// blockingReadSignal is used to signal that a blocking read is in progress.
blockingReadSignal chan struct{}
}

var _ cancelreader.CancelReader = &conInputReader{}
Expand Down Expand Up @@ -62,47 +55,21 @@ func newCancelreader(r io.Reader) (cancelreader.CancelReader, error) {
return nil, fmt.Errorf("failed to prepare console input: %w", err)
}

cancelEvent, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, fmt.Errorf("create stop event: %w", err)
}

return &conInputReader{
conin: conin,
cancelEvent: cancelEvent,
originalMode: originalMode,
blockingReadSignal: make(chan struct{}, 1),
conin: conin,
originalMode: originalMode,
}, nil
}

// Cancel implements cancelreader.CancelReader.
func (r *conInputReader) Cancel() bool {
r.setCanceled()

select {
case r.blockingReadSignal <- struct{}{}:
err := windows.SetEvent(r.cancelEvent)
if err != nil {
return false
}
<-r.blockingReadSignal
case <-time.After(100 * time.Millisecond):
// Read() hangs in a GetOverlappedResult which is likely due to
// WaitForMultipleObjects returning without input being available
// so we cannot cancel this ongoing read.
return false
}

return true
return windows.CancelIoEx(r.conin, nil) == nil || windows.CancelIo(r.conin) == nil
}

// Close implements cancelreader.CancelReader.
func (r *conInputReader) Close() error {
err := windows.CloseHandle(r.cancelEvent)
if err != nil {
return fmt.Errorf("closing cancel event handle: %w", err)
}

if r.originalMode != 0 {
err := windows.SetConsoleMode(r.conin, r.originalMode)
if err != nil {
Expand All @@ -114,25 +81,17 @@ func (r *conInputReader) Close() error {
}

// Read implements cancelreader.CancelReader.
func (r *conInputReader) Read(data []byte) (n int, err error) {
func (r *conInputReader) Read(data []byte) (int, error) {
if r.isCanceled() {
return 0, cancelreader.ErrCanceled
}

err = waitForInput(r.conin, r.cancelEvent)
if err != nil {
return 0, err
}

if r.isCanceled() {
return 0, cancelreader.ErrCanceled
var n uint32
if err := windows.ReadFile(r.conin, data, &n, nil); err != nil {
return int(n), fmt.Errorf("read console input: %w", err)
}

r.blockingReadSignal <- struct{}{}
n, err = overlappedReader(r.conin).Read(data)
<-r.blockingReadSignal

return
return int(n), nil
}

func prepareConsole(input windows.Handle, modes ...uint32) (originalMode uint32, err error) {
Expand All @@ -154,30 +113,6 @@ func prepareConsole(input windows.Handle, modes ...uint32) (originalMode uint32,
return originalMode, nil
}

func waitForInput(conin, cancel windows.Handle) error {
event, err := windows.WaitForMultipleObjects([]windows.Handle{conin, cancel}, false, windows.INFINITE)
switch {
case windows.WAIT_OBJECT_0 <= event && event < windows.WAIT_OBJECT_0+2:
if event == windows.WAIT_OBJECT_0+1 {
return cancelreader.ErrCanceled
}

if event == windows.WAIT_OBJECT_0 {
return nil
}

return fmt.Errorf("unexpected wait object is ready: %d", event-windows.WAIT_OBJECT_0)
case windows.WAIT_ABANDONED <= event && event < windows.WAIT_ABANDONED+2:
return fmt.Errorf("abandoned")
case event == uint32(windows.WAIT_TIMEOUT):
return fmt.Errorf("timeout")
case event == windows.WAIT_FAILED:
return fmt.Errorf("failed")
default:
return fmt.Errorf("unexpected error: %w", err)
}
}

// cancelMixin represents a goroutine-safe cancelation status.
type cancelMixin struct {
unsafeCanceled bool
Expand All @@ -197,29 +132,3 @@ func (c *cancelMixin) isCanceled() bool {

return c.unsafeCanceled
}

type overlappedReader windows.Handle

// Read performs an overlapping read fom a windows.Handle.
func (r overlappedReader) Read(data []byte) (int, error) {
hevent, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return 0, fmt.Errorf("create event: %w", err)
}

overlapped := windows.Overlapped{HEvent: hevent}

var n uint32

err = windows.ReadFile(windows.Handle(r), data, &n, &overlapped)
if err != nil && err != windows.ERROR_IO_PENDING {
return int(n), err
}

err = windows.GetOverlappedResult(windows.Handle(r), &overlapped, &n, true)
if err != nil {
return int(n), err
}

return int(n), nil
}

0 comments on commit fe292ba

Please sign in to comment.