diff --git a/internal/watch/watch.go b/internal/watch/watch.go index 3642ecc..5e61a2a 100644 --- a/internal/watch/watch.go +++ b/internal/watch/watch.go @@ -8,12 +8,10 @@ import "sync" // // The zero value of a Value is valid and stores the zero value of T. type Value[T any] struct { - // Invariant: Every Watch must receive one update call for every value of the - // Value from the time it is added to the watchers set to the time it is - // removed. - // - // mu protects this invariant, and prevents data races on value. - mu sync.RWMutex + // mu prevents data races on the value, and protects the invariant that every + // Watch receives one update call for every value of the Value from the time + // it's added to the watchers set to the time it's removed. + mu sync.Mutex value T watchers map[*watch[T]]struct{} } @@ -25,8 +23,8 @@ func NewValue[T any](x T) *Value[T] { // Get returns the current value stored in v. func (v *Value[T]) Get() T { - v.mu.RLock() - defer v.mu.RUnlock() + v.mu.Lock() + defer v.mu.Unlock() return v.value } @@ -46,31 +44,30 @@ func (v *Value[T]) Set(x T) { // // Each active watch executes up to one instance of handler at a time in a new // goroutine, first with the value stored in v upon creation of the watch, then -// with subsequent values stored in v by calls to Set. If the value stored in v -// changes while a handler execution is in flight, handler will be called once -// more with the latest value stored in v following its current execution. -// Intermediate updates preceding the latest value will be dropped. +// with subsequent values stored in v by calls to [Value.Set]. If the value +// stored in v changes while a handler execution is in flight, handler will be +// called once more with the latest value stored in v following its current +// execution. Intermediate updates preceding the latest value will be dropped. // // Values are not recovered by the garbage collector until all of their // associated watches have terminated. A watch is terminated after it has been -// canceled by a call to Watch.Cancel, and any pending or in-flight handler +// canceled by a call to [Watch.Cancel], and any pending or in-flight handler // execution has finished. func (v *Value[T]) Watch(handler func(x T)) Watch { w := newWatch(handler, v.unregisterWatch) - v.updateAndRegisterWatch(w) + v.registerAndUpdateWatch(w) return w } -func (v *Value[T]) updateAndRegisterWatch(w *watch[T]) { +func (v *Value[T]) registerAndUpdateWatch(w *watch[T]) { v.mu.Lock() defer v.mu.Unlock() - w.update(v.value) - if v.watchers == nil { v.watchers = make(map[*watch[T]]struct{}) } v.watchers[w] = struct{}{} + w.update(v.value) } func (v *Value[T]) unregisterWatch(w *watch[T]) { @@ -80,7 +77,7 @@ func (v *Value[T]) unregisterWatch(w *watch[T]) { delete(v.watchers, w) } -// Watch represents a single watch on a Value. See Value.Watch for details. +// Watch represents a single watch on a Value. See [Value.Watch] for details. type Watch interface { // Cancel requests that this watch be terminated as soon as possible, // potentially after a pending or in-flight handler execution has finished. @@ -98,61 +95,75 @@ type Watch interface { type watch[T any] struct { handler func(T) unregister func(*watch[T]) - pending chan T - done chan struct{} + + mu sync.Mutex + wg sync.WaitGroup + next T + ok bool // There is a valid value in next. + running bool // There is (or will be) a goroutine responsible for handling values. + cancel bool // The WaitGroup must be canceled as soon as running == false. } func newWatch[T any](handler func(T), unregister func(*watch[T])) *watch[T] { w := &watch[T]{ handler: handler, unregister: unregister, - pending: make(chan T, 1), - done: make(chan struct{}), } - go w.run() + w.wg.Add(1) return w } -func (w *watch[T]) run() { - var wg sync.WaitGroup - defer close(w.done) - - for next := range w.pending { - x := next - wg.Add(1) - // Insulate the handler from the main loop, e.g. if it calls runtime.Goexit - // it should not terminate this loop and break the processing of new values. - go func() { - defer wg.Done() - w.handler(x) - }() - wg.Wait() +func (w *watch[T]) update(x T) { + w.mu.Lock() + start := !w.running + w.next, w.ok, w.running = x, true, true + w.mu.Unlock() + if start { + go w.run() } } -func (w *watch[T]) update(x T) { - // It's important that this call not block, so we assume w.pending is buffered - // and drop a pending update to free space if necessary. - select { - case <-w.pending: - w.pending <- x - case w.pending <- x: +func (w *watch[T]) run() { + var unwind bool + defer func() { + if unwind { + // Only possible if w.running == true, so we must maintain the invariant. + go w.run() + } + }() + + for { + w.mu.Lock() + next, cancel := w.next, w.cancel + stop := !w.ok || cancel + w.running = !stop + w.next, w.ok = *new(T), false + w.mu.Unlock() + + if cancel { + w.wg.Done() + } + if stop { + return + } + + unwind = true + w.handler(next) // May panic or call runtime.Goexit. + unwind = false } } func (w *watch[T]) Cancel() { - w.unregister(w) - w.clearPending() - close(w.pending) -} - -func (w *watch[T]) clearPending() { - select { - case <-w.pending: - default: + w.unregister(w) // After this, we are guaranteed no new w.update calls. + w.mu.Lock() + finish := !w.running && !w.cancel + w.cancel = true + w.mu.Unlock() + if finish { + w.wg.Done() } } func (w *watch[T]) Wait() { - <-w.done + w.wg.Wait() } diff --git a/internal/watch/watch_test.go b/internal/watch/watch_test.go index fbf3856..de53fd5 100644 --- a/internal/watch/watch_test.go +++ b/internal/watch/watch_test.go @@ -1,6 +1,7 @@ package watch import ( + "math/rand/v2" "runtime" "sync" "testing" @@ -9,7 +10,7 @@ import ( const timeout = 2 * time.Second -func TestValue(t *testing.T) { +func TestValueStress(t *testing.T) { // A stress test meant to be run with the race detector enabled. This test // ensures that all access to a Value is synchronized, that handlers run // serially, and that handlers are properly notified of the most recent state. @@ -24,11 +25,9 @@ func TestValue(t *testing.T) { var handlerGroup sync.WaitGroup handlerGroup.Add(nWatchers) - for i := 0; i < nWatchers; i++ { - var ( - sum int - sawFinal bool - ) + for i := range nWatchers { + var sum int + var sawFinal bool watches[i] = v.Watch(func(x int) { // This will quickly make the race detector complain if more than one // instance of a handler runs at once. @@ -50,10 +49,7 @@ func TestValue(t *testing.T) { for i := 1; i <= nWrites-1; i++ { // This will quickly make the race detector complain if Set is not properly // synchronized. - go func(i int) { - defer setGroup.Done() - v.Set(i) - }(i) + go func() { defer setGroup.Done(); v.Set(i) }() } setGroup.Wait() @@ -68,7 +64,7 @@ func TestValue(t *testing.T) { select { case <-done: case <-time.After(timeout): - t.Fatalf("reached %v timeout before all watchers saw final state", timeout) + t.Fatalf("not all watchers saw final state within %v", timeout) } for _, w := range watches { @@ -97,7 +93,7 @@ func TestWatchZeroValue(t *testing.T) { t.Errorf("watch on zero value of Value got %v; want nil", x) } case <-time.After(timeout): - t.Fatalf("reached %v timeout before watcher was notified", timeout) + t.Fatalf("watcher not notified within %v", timeout) } w.Cancel() @@ -207,6 +203,40 @@ func TestGoexitFromHandler(t *testing.T) { assertWatchTerminates(t, w) } +func TestCancelInactiveHandler(t *testing.T) { + // The usual case of canceling a watch, where no handler is active at the time + // of cancellation. Once we cancel, no further handler calls should be made. + + v := NewValue("alice") + notify := make(chan string, 1) + w := v.Watch(func(x string) { + select { + case notify <- x: + default: + } + }) + + assertNextReceive(t, notify, "alice") + forceRuntimeProgress() // Try to ensure the handler has fully terminated. + + w.Cancel() + v.Set("bob") + assertBlocked(t, notify) +} + +func TestDoubleCancelInactiveHandler(t *testing.T) { + // A specific test for calling Cancel twice on an inactive handler, and + // ensuring we don't panic. + + v := NewValue("alice") + w := v.Watch(func(x string) {}) + forceRuntimeProgress() // Try to ensure the initial handler has fully terminated. + + w.Cancel() + w.Cancel() + assertWatchTerminates(t, w) +} + func TestCancelBlockedWatcher(t *testing.T) { // A specific test for canceling a watch while it is handling a notification. @@ -240,9 +270,10 @@ func TestCancelBlockedWatcher(t *testing.T) { assertWatchTerminates(t, w) } -func TestCancelFromHandler(t *testing.T) { +func TestDoubleCancelFromHandler(t *testing.T) { // This is a special case of Cancel being called while a handler is blocked, - // as the caller of Cancel is the handler itself. + // as the caller of Cancel is the handler itself. We also call Cancel twice, + // to make sure multi-cancellation works in the active handler case. v := NewValue("alice") @@ -257,6 +288,7 @@ func TestCancelFromHandler(t *testing.T) { v.Set("bob") w := <-watchCh w.Cancel() + w.Cancel() canceled = true }) @@ -309,7 +341,7 @@ func assertNextReceive[T comparable](t *testing.T, ch chan T, want T) { t.Fatalf("got %v from channel, want %v", got, want) } case <-time.After(timeout): - t.Fatalf("reached %v timeout before watcher was notified", timeout) + t.Fatalf("watcher not notified within %v", timeout) } } @@ -325,25 +357,77 @@ func assertWatchTerminates(t *testing.T, w Watch) { select { case <-done: case <-time.After(timeout): - t.Fatalf("watch not terminated after %v", timeout) + t.Fatalf("watch still active after %v", timeout) } } -func assertBlocked(t *testing.T, ch <-chan struct{}) { +func assertBlocked[T any](t *testing.T, ch <-chan T) { t.Helper() - // If any background routines are going to close ch when they should not, - // let's make a best effort to help them along. + forceRuntimeProgress() + select { + case <-ch: + t.Fatal("progress was not blocked") + default: + } +} + +// forceRuntimeProgress makes a best-effort attempt to force the Go runtime to +// make progress on all other goroutines in the system, ideally to the point at +// which they will next block if not preempted. It works best if no other +// goroutines are CPU-intensive or change GOMAXPROCS. +func forceRuntimeProgress() { gomaxprocs := runtime.GOMAXPROCS(1) defer runtime.GOMAXPROCS(gomaxprocs) - n := runtime.NumGoroutine() - for i := 0; i < n; i++ { + for range runtime.NumGoroutine() { runtime.Gosched() } +} - select { - case <-ch: - t.Fatal("progress was not blocked") - default: +func BenchmarkSet1Watcher(b *testing.B) { + benchmarkSetWithWatchers(b, 1) +} + +func BenchmarkSet10Watchers(b *testing.B) { + benchmarkSetWithWatchers(b, 10) +} + +func BenchmarkSet100Watchers(b *testing.B) { + benchmarkSetWithWatchers(b, 100) +} + +func BenchmarkSet1000Watchers(b *testing.B) { + benchmarkSetWithWatchers(b, 1000) +} + +func benchmarkSetWithWatchers(b *testing.B, nWatchers int) { + v := NewValue(uint64(0)) + watchers := make([]Watch, nWatchers) + for i := range watchers { + var sum uint64 + watchers[i] = v.Watch(func(x uint64) { sum += x }) } + + b.Cleanup(func() { + for _, w := range watchers { + w.Cancel() + } + for _, w := range watchers { + w.Wait() + } + }) + + b.RunParallel(func(pb *testing.PB) { + // The choice to set random values is somewhat arbitrary. In practice, the + // cost of lock contention probably outweighs any strategy for generating + // these values--even setting a constant every time (unless there were ever + // an optimization to not trigger watches when the value doesn't change). + // Having the setters do work that the handlers can't predict feels vaguely + // more realistic, though, and it's not a huge difference either way since + // the goal is to compare different watcher implementations (that is, the + // work just needs to be the same on both sides of the comparison). + for pb.Next() { + v.Set(rand.Uint64()) + } + }) }