diff --git a/fn/goroutine_manager.go b/fn/goroutine_manager.go index 81c538ea01..4b7843f02a 100644 --- a/fn/goroutine_manager.go +++ b/fn/goroutine_manager.go @@ -3,51 +3,123 @@ package fn import ( "context" "sync" + "sync/atomic" ) // GoroutineManager is used to launch goroutines until context expires or the // manager is stopped. The Stop method blocks until all started goroutines stop. type GoroutineManager struct { - wg sync.WaitGroup - mu sync.Mutex - ctx context.Context - cancel func() + // id is used to generate unique ids for each goroutine. + id atomic.Uint32 + + // cancelFns is a map of cancel functions that can be used to cancel the + // context of a goroutine. The mutex must be held when accessing this + // map. The key is the id of the goroutine. + cancelFns map[uint32]context.CancelFunc + + mu sync.Mutex + + stopped sync.Once + quit chan struct{} + wg sync.WaitGroup } // NewGoroutineManager constructs and returns a new instance of // GoroutineManager. -func NewGoroutineManager(ctx context.Context) *GoroutineManager { - ctx, cancel := context.WithCancel(ctx) - +func NewGoroutineManager() *GoroutineManager { return &GoroutineManager{ - ctx: ctx, - cancel: cancel, + cancelFns: make(map[uint32]context.CancelFunc), + quit: make(chan struct{}), + } +} + +// addCancelFn adds a context cancel function to the manager and returns an id +// that can can be used to cancel the context later on when the goroutine is +// done. +func (g *GoroutineManager) addCancelFn(cancel context.CancelFunc) uint32 { + g.mu.Lock() + defer g.mu.Unlock() + + id := g.id.Add(1) + g.cancelFns[id] = cancel + + return id +} + +// cancel cancels the context associated with the passed id. +func (g *GoroutineManager) cancel(id uint32) { + g.mu.Lock() + defer g.mu.Unlock() + + g.cancelUnsafe(id) +} + +// cancelUnsafe cancels the context associated with the passed id without +// acquiring the mutex. +func (g *GoroutineManager) cancelUnsafe(id uint32) { + fn, ok := g.cancelFns[id] + if !ok { + return } + + fn() + + delete(g.cancelFns, id) } // Go tries to start a new goroutine and returns a boolean indicating its -// success. It fails iff the goroutine manager is stopping or its context passed -// to NewGoroutineManager has expired. -func (g *GoroutineManager) Go(f func(ctx context.Context)) bool { - // Calling wg.Add(1) and wg.Wait() when wg's counter is 0 is a race - // condition, since it is not clear should Wait() block or not. This +// success. It returns true if the goroutine was successfully created and false +// otherwise. A goroutine will fail to be created iff the goroutine manager is +// stopping or the passed context has already expired. The passed call-back +// function must exit if the passed context expires. +func (g *GoroutineManager) Go(ctx context.Context, + f func(ctx context.Context)) bool { + + // Derive a cancellable context from the passed context and store its + // cancel function in the manager. The context will be cancelled when + // either the parent context is cancelled or the quit channel is closed + // which will call the stored cancel function. + ctx, cancel := context.WithCancel(ctx) + id := g.addCancelFn(cancel) + + // Calling wg.Add(1) and wg.Wait() when the wg's counter is 0 is a race + // condition, since it is not clear if Wait() should block or not. This // kind of race condition is detected by Go runtime and results in a - // crash if running with `-race`. To prevent this, whole Go method is - // protected with a mutex. The call to wg.Wait() inside Stop() can still - // run in parallel with Go, but in that case g.ctx is in expired state, - // because cancel() was called in Stop, so Go returns before wg.Add(1) - // call. + // crash if running with `-race`. To prevent this, we protect the calls + // to wg.Add(1) and wg.Wait() with a mutex. If we block here because + // Stop is running first, then Stop will cancel the quit channel which + // will cause the context to be cancelled, and we will exit before + // calling wg.Add(1). If we grab the mutex here before Stop does, then + // Stop will block until after we call wg.Add(1). g.mu.Lock() defer g.mu.Unlock() - if g.ctx.Err() != nil { + // Before continuing to start the goroutine, we need to check if the + // context has already expired. This could be the case if the parent + // context has already expired or if Stop has been called. + if ctx.Err() != nil { + g.cancelUnsafe(id) + + return false + } + + // Ensure that the goroutine is not started if the manager has stopped. + select { + case <-g.quit: + g.cancelUnsafe(id) + return false + default: } g.wg.Add(1) go func() { - defer g.wg.Done() - f(g.ctx) + defer func() { + g.cancel(id) + g.wg.Done() + }() + + f(ctx) }() return true @@ -56,20 +128,28 @@ func (g *GoroutineManager) Go(f func(ctx context.Context)) bool { // Stop prevents new goroutines from being added and waits for all running // goroutines to finish. func (g *GoroutineManager) Stop() { - g.mu.Lock() - g.cancel() - g.mu.Unlock() - - // Wait for all goroutines to finish. Note that this wg.Wait() call is - // safe, since it can't run in parallel with wg.Add(1) call in Go, since - // we just cancelled the context and even if Go call starts running here - // after acquiring the mutex, it would see that the context has expired - // and return false instead of calling wg.Add(1). - g.wg.Wait() + g.stopped.Do(func() { + // Closing the quit channel will prevent any new goroutines from + // starting. + g.mu.Lock() + close(g.quit) + for _, cancel := range g.cancelFns { + cancel() + } + g.mu.Unlock() + + // Wait for all goroutines to finish. Note that this wg.Wait() + // call is safe, since it can't run in parallel with wg.Add(1) + // call in Go, since we just cancelled the context and even if + // Go call starts running here after acquiring the mutex, it + // would see that the context has expired and return false + // instead of calling wg.Add(1). + g.wg.Wait() + }) } // Done returns a channel which is closed when either the context passed to // NewGoroutineManager expires or when Stop is called. func (g *GoroutineManager) Done() <-chan struct{} { - return g.ctx.Done() + return g.quit } diff --git a/fn/goroutine_manager_test.go b/fn/goroutine_manager_test.go index 1fc945b97b..e39dce995c 100644 --- a/fn/goroutine_manager_test.go +++ b/fn/goroutine_manager_test.go @@ -2,156 +2,145 @@ package fn import ( "context" - "sync" "testing" "time" "github.com/stretchr/testify/require" ) -// TestGoroutineManager tests that the GoroutineManager starts goroutines until -// ctx expires. It also makes sure it fails to start new goroutines after the -// context expired and the GoroutineManager is in the process of waiting for -// already started goroutines in the Stop method. +// TestGoroutineManager tests the behaviour of the GoroutineManager. func TestGoroutineManager(t *testing.T) { t.Parallel() - m := NewGoroutineManager(context.Background()) - - taskChan := make(chan struct{}) - - require.True(t, m.Go(func(ctx context.Context) { - <-taskChan - })) - - t1 := time.Now() - - // Close taskChan in 1s, causing the goroutine to stop. - time.AfterFunc(time.Second, func() { - close(taskChan) - }) - - m.Stop() - stopDelay := time.Since(t1) - - // Make sure Stop was waiting for the goroutine to stop. - require.Greater(t, stopDelay, time.Second) - - // Make sure new goroutines do not start after Stop. - require.False(t, m.Go(func(ctx context.Context) {})) - - // When Stop() is called, the internal context expires and m.Done() is - // closed. Test this. - select { - case <-m.Done(): - default: - t.Errorf("Done() channel must be closed at this point") - } -} - -// TestGoroutineManagerContextExpires tests the effect of context expiry. -func TestGoroutineManagerContextExpires(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - - m := NewGoroutineManager(ctx) - - require.True(t, m.Go(func(ctx context.Context) { - <-ctx.Done() - })) - - // The Done channel of the manager should not be closed, so the - // following call must block. - select { - case <-m.Done(): - t.Errorf("Done() channel must not be closed at this point") - default: - } - - cancel() - - // The Done channel of the manager should be closed, so the following - // call must not block. - select { - case <-m.Done(): - default: - t.Errorf("Done() channel must be closed at this point") - } + // Here we test that the GoroutineManager starts goroutines until it has + // been stopped. + t.Run("GM is stopped", func(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + m = NewGoroutineManager() + taskChan = make(chan struct{}) + ) + + // The gm has not stopped yet and the passed in context has not + // expired, so we expect the goroutine to start. The taskChan is + // blocking, so this goroutine will be live for a while. + require.True(t, m.Go(ctx, func(ctx context.Context) { + <-taskChan + })) - // Make sure new goroutines do not start after context expiry. - require.False(t, m.Go(func(ctx context.Context) {})) + t1 := time.Now() - // Stop will wait for all goroutines to stop. - m.Stop() -} + // Close taskChan in 1s, causing the goroutine to stop. + time.AfterFunc(time.Second, func() { + close(taskChan) + }) -// TestGoroutineManagerStress starts many goroutines while calling Stop. It -// is needed to make sure the GoroutineManager does not crash if this happen. -// If the mutex was not used, it would crash because of a race condition between -// wg.Add(1) and wg.Wait(). -func TestGoroutineManagerStress(t *testing.T) { - t.Parallel() + m.Stop() + stopDelay := time.Since(t1) - m := NewGoroutineManager(context.Background()) + // Make sure Stop was waiting for the goroutine to stop. + require.Greater(t, stopDelay, time.Second) - stopChan := make(chan struct{}) + // Make sure new goroutines do not start after Stop. + require.False(t, m.Go(ctx, func(ctx context.Context) {})) - time.AfterFunc(1*time.Millisecond, func() { - m.Stop() - close(stopChan) + // When Stop() is called, gm quit channel has been closed and so + // Done() should return. + select { + case <-m.Done(): + default: + t.Errorf("Done() channel must be closed at this point") + } }) - // Starts 100 goroutines sequentially. Sequential order is needed to - // keep wg.counter low (0 or 1) to increase probability of race - // condition to be caught if it exists. If mutex is removed in the - // implementation, this test crashes under `-race`. - for i := 0; i < 100; i++ { - taskChan := make(chan struct{}) - ok := m.Go(func(ctx context.Context) { - close(taskChan) - }) - // If goroutine was started, wait for its completion. - if ok { - <-taskChan + // Test that the GoroutineManager fails to start a goroutine or exits a + // goroutine if the caller context has expired. + t.Run("Caller context expires", func(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + m = NewGoroutineManager() + taskChan = make(chan struct{}) + ) + + // Derive a child context with a cancel function. + ctxc, cancel := context.WithCancel(ctx) + + // The gm has not stopped yet and the passed in context has not + // expired, so we expect the goroutine to start. + require.True(t, m.Go(ctxc, func(ctx context.Context) { + select { + case <-ctx.Done(): + case <-taskChan: + t.Fatalf("The task was performed when it " + + "should not have") + } + })) + + // Give the GM a little bit of time to start the goroutine so + // that we can be sure that it is already listening on the + // ctx and taskChan before calling cancel. + time.Sleep(time.Millisecond * 500) + + // Cancel the context so that the goroutine exits. + cancel() + + // Attempt to send a signal on the task channel, nothing should + // happen since the goroutine has already exited. + select { + case taskChan <- struct{}{}: + case <-time.After(time.Millisecond * 200): } - } - // Wait for Stop to complete. - <-stopChan -} + // Again attempt to add a goroutine with the same cancelled + // context. This should fail since the context has already + // expired. + require.False(t, m.Go(ctxc, func(ctx context.Context) { + t.Fatalf("The goroutine should not have started") + })) -// TestGoroutineManagerStopsStress launches many Stop() calls in parallel with a -// task exiting. It attempts to catch a race condition between wg.Done() and -// wg.Wait() calls. According to documentation of wg.Wait() this is acceptable, -// therefore this test passes even with -race. -func TestGoroutineManagerStopsStress(t *testing.T) { - t.Parallel() + // Stop the goroutine manager. + m.Stop() + }) - m := NewGoroutineManager(context.Background()) + // Start many goroutines while calling Stop. We do this to make sure + // that the GoroutineManager does not crash when these calls are done in + // parallel because of the potential race between wg.Add() and + // wg.Done() when the wg counter is 0. + t.Run("Stress test", func(t *testing.T) { + t.Parallel() - // jobChan is used to make the task to finish. - jobChan := make(chan struct{}) + var ( + ctx = context.Background() + m = NewGoroutineManager() + stopChan = make(chan struct{}) + ) - // Start a task and wait inside it until we start calling Stop() method. - ok := m.Go(func(ctx context.Context) { - <-jobChan - }) - require.True(t, ok) - - // Now launch many gorotines calling Stop() method in parallel. - var wg sync.WaitGroup - for i := 0; i < 100; i++ { - wg.Add(1) - go func() { - defer wg.Done() + time.AfterFunc(1*time.Millisecond, func() { m.Stop() - }() - } + close(stopChan) + }) - // Exit the task in parallel with Stop() calls. - close(jobChan) + // Start 100 goroutines sequentially. Sequential order is + // needed to keep wg.counter low (0 or 1) to increase + // probability of the race condition to triggered if it exists. + // If mutex is removed in the implementation, this test crashes + // under `-race`. + for i := 0; i < 100; i++ { + taskChan := make(chan struct{}) + ok := m.Go(ctx, func(ctx context.Context) { + close(taskChan) + }) + // If goroutine was started, wait for its completion. + if ok { + <-taskChan + } + } - // Wait until all the Stop() calls complete. - wg.Wait() + // Wait for Stop to complete. + <-stopChan + }) }