Skip to content

Commit

Permalink
Merge pull request #9361 from starius/optimize-context-guard
Browse files Browse the repository at this point in the history
fn: optimize context guard
  • Loading branch information
guggero authored Jan 10, 2025
2 parents 70e7b56 + 07c4668 commit dd25e6e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 26 deletions.
49 changes: 23 additions & 26 deletions fn/context_guard.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ func (g *ContextGuard) Quit() {
cancel()
}

// Clear cancelFns. It is safe to use nil, because no write
// operations to it can happen after g.quit is closed.
g.cancelFns = nil

close(g.quit)
})
}
Expand Down Expand Up @@ -149,7 +153,7 @@ func (g *ContextGuard) Create(ctx context.Context,
}

if opts.blocking {
g.ctxBlocking(ctx, cancel)
g.ctxBlocking(ctx)

return ctx, cancel
}
Expand All @@ -169,9 +173,10 @@ func (g *ContextGuard) Create(ctx context.Context,
return ctx, cancel
}

// ctxQuitUnsafe spins off a goroutine that will block until the passed context
// is cancelled or until the quit channel has been signaled after which it will
// call the passed cancel function and decrement the wait group.
// ctxQuitUnsafe increases the wait group counter, waits until the context is
// cancelled and decreases the wait group counter. It stores the passed cancel
// function and returns a wrapped version, which removed the stored one and
// calls it. The Quit method calls all the stored cancel functions.
//
// NOTE: the caller must hold the ContextGuard's mutex before calling this
// function.
Expand All @@ -181,35 +186,27 @@ func (g *ContextGuard) ctxQuitUnsafe(ctx context.Context,
cancel = g.addCancelFnUnsafe(cancel)

g.wg.Add(1)
go func() {
defer cancel()
defer g.wg.Done()

select {
case <-g.quit:

case <-ctx.Done():
}
}()
// We don't have to wait on g.quit here: g.quit can be closed only in
// the Quit method, which also closes the context we are waiting for.
context.AfterFunc(ctx, func() {
g.wg.Done()
})

return cancel
}

// ctxBlocking spins off a goroutine that will block until the passed context
// is cancelled after which it will call the passed cancel function and
// decrement the wait group.
func (g *ContextGuard) ctxBlocking(ctx context.Context,
cancel context.CancelFunc) {

// ctxBlocking increases the wait group counter, waits until the context is
// cancelled and decreases the wait group counter.
//
// NOTE: the caller must hold the ContextGuard's mutex before calling this
// function.
func (g *ContextGuard) ctxBlocking(ctx context.Context) {
g.wg.Add(1)
go func() {
defer cancel()
defer g.wg.Done()

select {
case <-ctx.Done():
}
}()
context.AfterFunc(ctx, func() {
g.wg.Done()
})
}

// addCancelFnUnsafe adds a context cancel function to the manager and returns a
Expand Down
42 changes: 42 additions & 0 deletions fn/context_guard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package fn

import (
"context"
"runtime"
"testing"
"time"

"github.com/stretchr/testify/require"
)

// TestContextGuard tests the behaviour of the ContextGuard.
Expand Down Expand Up @@ -298,6 +301,12 @@ func TestContextGuard(t *testing.T) {
case <-time.After(time.Second):
t.Fatalf("timeout")
}

// Cancel the context.
cancel()

// Make sure wg's counter gets to 0 eventually.
g.WgWait()
})

// Test that if we add the CustomTimeoutCGOpt option, then the context
Expand Down Expand Up @@ -433,3 +442,36 @@ func TestContextGuard(t *testing.T) {
}
})
}

// TestContextGuardCountGoroutines makes sure that ContextGuard doesn't create
// any goroutines while waiting for contexts.
func TestContextGuardCountGoroutines(t *testing.T) {
// NOTE: t.Parallel() is not called in this test because it relies on an
// accurate count of active goroutines. Running other tests in parallel
// would introduce additional goroutines, leading to unreliable results.

g := NewContextGuard()

ctx, cancel := context.WithCancel(context.Background())

// Count goroutines before contexts are created.
count1 := runtime.NumGoroutine()

// Create 1000 contexts of each type.
for i := 0; i < 1000; i++ {
_, _ = g.Create(ctx)
_, _ = g.Create(ctx, WithBlockingCG())
_, _ = g.Create(ctx, WithTimeoutCG())
_, _ = g.Create(ctx, WithBlockingCG(), WithTimeoutCG())
}

// Make sure no new goroutine was launched.
count2 := runtime.NumGoroutine()
require.LessOrEqual(t, count2, count1)

// Cancel root context.
cancel()

// Make sure wg's counter gets to 0 eventually.
g.WgWait()
}

0 comments on commit dd25e6e

Please sign in to comment.