diff --git a/fn/context_guard.go b/fn/context_guard.go index cfe5a32998..ed3b0ba70f 100644 --- a/fn/context_guard.go +++ b/fn/context_guard.go @@ -173,9 +173,10 @@ func (g *ContextGuard) Create(ctx context.Context, return ctx, cancel } -// ctxQuitUnsafe spins off a goroutine that will block until the passed context -// is cancelled or until the quit channel has been signaled after which it will -// call the passed cancel function and decrement the wait group. +// ctxQuitUnsafe increases the wait group counter, waits until the context is +// cancelled and decreases the wait group counter. It stores the passed cancel +// function and returns a wrapped version, which removed the stored one and +// calls it. Quit method calls all the stored cancel functions. // // NOTE: the caller must hold the ContextGuard's mutex before calling this // function. @@ -185,31 +186,27 @@ func (g *ContextGuard) ctxQuitUnsafe(ctx context.Context, cancel = g.addCancelFnUnsafe(cancel) g.wg.Add(1) - go func() { - defer cancel() - defer g.wg.Done() - select { - case <-g.quit: - - case <-ctx.Done(): - } - }() + // We don't have to wait on g.quit here: g.quit can be closed only in + // Quit method, which also closes the context we are waiting for. + context.AfterFunc(ctx, func() { + g.wg.Done() + }) return cancel } -// ctxBlocking spins off a goroutine that will block until the passed context -// is cancelled after which it will decrement the wait group. +// ctxBlocking increases the wait group counter, waits until the context is +// cancelled and decreases the wait group counter. +// +// NOTE: the caller must hold the ContextGuard's mutex before calling this +// function. func (g *ContextGuard) ctxBlocking(ctx context.Context) { g.wg.Add(1) - go func() { - defer g.wg.Done() - select { - case <-ctx.Done(): - } - }() + context.AfterFunc(ctx, func() { + g.wg.Done() + }) } // addCancelFnUnsafe adds a context cancel function to the manager and returns a diff --git a/fn/context_guard_test.go b/fn/context_guard_test.go index 76d63a50f5..87a5cf0cb0 100644 --- a/fn/context_guard_test.go +++ b/fn/context_guard_test.go @@ -2,8 +2,11 @@ package fn import ( "context" + "runtime" "testing" "time" + + "github.com/stretchr/testify/require" ) // TestContextGuard tests the behaviour of the ContextGuard. @@ -439,3 +442,32 @@ func TestContextGuard(t *testing.T) { } }) } + +// TestContextGuardCountGoroutines makes sure that ContextGuard doesn't create +// any goroutines while waiting for contexts. +func TestContextGuardCountGoroutines(t *testing.T) { + g := NewContextGuard() + + ctx, cancel := context.WithCancel(context.Background()) + + // Count goroutines before contexts are created. + count1 := runtime.NumGoroutine() + + // Create 1000 contexts of each type. + for i := 0; i < 1000; i++ { + _, _ = g.Create(ctx) + _, _ = g.Create(ctx, WithBlockingCG()) + _, _ = g.Create(ctx, WithTimeoutCG()) + _, _ = g.Create(ctx, WithBlockingCG(), WithTimeoutCG()) + } + + // Make sure no new goroutine was launched. + count2 := runtime.NumGoroutine() + require.LessOrEqual(t, count2, count1) + + // Cancel root context. + cancel() + + // Make sure wg's counter gets to 0 eventually. + g.WgWait() +}