Skip to content

Commit

Permalink
stopper: Add method for task-tracking
Browse files Browse the repository at this point in the history
This change adds a Call() method to stopper.Context. This allows the stopper to
track tasks that were not started using the Go() method.

Additionally, a Len() method is added to return the number of tasks being
tracked within a stopper hierarchy.
  • Loading branch information
bobvawter committed Aug 29, 2024
1 parent 35aa047 commit c680a70
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
29 changes: 29 additions & 0 deletions stopper/stopper.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,27 @@ func WithContext(ctx context.Context) *Context {
return s
}

// Call executes the given function within the current goroutine and
// monitors its lifecycle. That is, both Call and Wait will block until
// the function has returned.
//
// Call returns any error from the function with no other side effects.
// Unlike the Go method, Call does not stop the Context if the function
// returns an error. If the Context has already been stopped,
// [ErrStopped] will be returned.
//
// The function passed to Call should prefer the [Context.Stopping]
// channel to return instead of depending on [Context.Done]. This allows
// a soft-stop, rather than waiting for the grace period to expire when
// [Context.Stop] is called.
func (c *Context) Call(fn func(ctx *Context) error) error {
if !c.apply(1) {
return ErrStopped
}
defer c.apply(-1)
return fn(c)
}

// Deadline implements [context.Context].
func (c *Context) Deadline() (deadline time.Time, ok bool) { return c.delegate.Deadline() }

Expand Down Expand Up @@ -163,6 +184,14 @@ func (c *Context) Done() <-chan struct{} { return c.delegate.Done() }
// the context cancellation resulted from a call to Stop.
func (c *Context) Err() error { return c.delegate.Err() }

// Len returns the number of tasks being tracked by the Context. This
// includes tasks started by derived Contexts.
func (c *Context) Len() int {
c.mu.RLock()
defer c.mu.RUnlock()
return c.mu.count
}

// Go spawns a new goroutine to execute the given function and monitors
// its lifecycle.
//
Expand Down
35 changes: 35 additions & 0 deletions stopper/stopper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@ func TestCancelOuter(t *testing.T) {
a.Nil(s.Wait())
}

func TestCall(t *testing.T) {
a := assert.New(t)

s := WithContext(context.Background())

// Verify that returning an error from the callback does not stop
// the Context.
err := errors.New("BOOM")
a.ErrorIs(s.Call(func(ctx *Context) error {
// The call should increment the wait value.
a.Equal(1, s.Len())
return err
}), err)

a.False(s.IsStopping())

s.Stop(0)
a.ErrorIs(
s.Call(func(ctx *Context) error { return nil }),
ErrStopped)
}

func TestCallbackErrorStops(t *testing.T) {
a := assert.New(t)

Expand All @@ -81,10 +103,16 @@ func TestChainStopper(t *testing.T) {
mid := context.WithValue(parent, parent, parent) // Demonstrate unwrapping.
child := WithContext(mid)
a.Same(parent, child.parent)
a.Zero(parent.Len())
a.Zero(child.Len())

waitFor := make(chan struct{})
child.Go(func(*Context) error { <-waitFor; return nil })

// Task tracking chains.
a.Equal(1, parent.Len())
a.Equal(1, child.Len())

// Verify that stopping the parent propagates to the child.
parent.Stop(0)
select {
Expand All @@ -98,6 +126,10 @@ func TestChainStopper(t *testing.T) {
a.Nil(parent.Err())
a.Nil(child.Err())

// There are still pending tasks.
a.Equal(1, parent.Len())
a.Equal(1, child.Len())

// Allow the work to finish, and verify cancellation.
close(waitFor)

Expand Down Expand Up @@ -132,6 +164,9 @@ func TestChainStopper(t *testing.T) {
a.ErrorIs(parent.Err(), context.Canceled)
a.ErrorIs(context.Cause(parent), ErrStopped)
a.Nil(child.Wait())

a.Zero(parent.Len())
a.Zero(child.Len())
}

func TestDefer(t *testing.T) {
Expand Down

0 comments on commit c680a70

Please sign in to comment.