Skip to content

Commit

Permalink
Merge pull request #46 from castai/fix-shutdown
Browse files Browse the repository at this point in the history
Fix Shutdown method of log receiver
  • Loading branch information
apasyniuk authored Aug 25, 2023
2 parents 577062a + 71f0076 commit 5d46afa
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 13 deletions.
25 changes: 13 additions & 12 deletions auditlogsreceiver/audit_logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,21 @@ type auditLogsReceiver struct {
logger *zap.Logger
pollInterval time.Duration
pageLimit int
wg *sync.WaitGroup
doneChan chan bool
storage storage.Storage
rest *resty.Client
consumer consumer.Logs

wg *sync.WaitGroup
stopPolling context.CancelFunc

storage storage.Storage
rest *resty.Client
consumer consumer.Logs
}

func (a *auditLogsReceiver) Start(ctx context.Context, _ component.Host) error {
func (a *auditLogsReceiver) Start(_ context.Context, _ component.Host) error {
a.logger.Debug("starting audit logs receiver")

// According to Component interface, Start function should not reuse context for background tasks.
ctx, cancel := context.WithCancel(context.Background())
a.stopPolling = cancel
a.wg.Add(1)
go a.startPolling(ctx)

Expand All @@ -49,7 +54,7 @@ func (a *auditLogsReceiver) Start(ctx context.Context, _ component.Host) error {

func (a *auditLogsReceiver) Shutdown(_ context.Context) error {
a.logger.Debug("shutting down audit logs receiver")
close(a.doneChan)
a.stopPolling()
a.wg.Wait()

return nil
Expand All @@ -58,15 +63,13 @@ func (a *auditLogsReceiver) Shutdown(_ context.Context) error {
func (a *auditLogsReceiver) startPolling(ctx context.Context) {
defer a.wg.Done()

ctx, cancel := context.WithCancel(ctx)

t := time.NewTicker(a.pollInterval)
defer t.Stop()

for {
err := a.poll(ctx, func() {
// Stop function is called in case of critical errors (error that cannot be restored from).
cancel()
a.stopPolling()

// TODO: reconsider this approach based on Open Telemetry practices.
err := syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
Expand All @@ -81,8 +84,6 @@ func (a *auditLogsReceiver) startPolling(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-a.doneChan:
return
case <-t.C:
continue
}
Expand Down
65 changes: 65 additions & 0 deletions auditlogsreceiver/audit_logs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/http"
"strconv"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -308,4 +309,68 @@ func TestPoll(t *testing.T) {
err := receiver.poll(ctx, nil)
r.NoError(err)
})

t.Run("should cancel work immediately after shutdown is called", func(t *testing.T) {
r := require.New(t)
ctx := context.Background()

mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

storageMock := mock_storage.NewMockStorage(mockCtrl)
storageMock.EXPECT().
Get().
Return(storage.PollData{
CheckPoint: time.Now(),
}).AnyTimes()
storageMock.EXPECT().
Save(gomock.Any()).AnyTimes()

consumerMock := logsConsumerMock{
ConsumeLogsFunc: func(logs plog.Logs) error {
return nil
},
}

restConfig := Config{
API: API{
Url: "https://api.cast.ai",
Key: uuid.NewString(),
},
PageLimit: 2,
}
rest := newRestyClient(&restConfig)
httpmock.ActivateNonDefault(rest.GetClient())
defer httpmock.Reset()

reqStarted := make(chan struct{})
reqStoped := make(chan struct{})
httpmock.RegisterResponder(
http.MethodGet,
`=~^https:\/\/api\.cast\.ai/v1/audit.?`,
func(req *http.Request) (*http.Response, error) {
close(reqStarted)
<-req.Context().Done()
close(reqStoped)
return httpmock.NewStringResponse(200, "none"), nil
})

receiver := auditLogsReceiver{
logger: logger,
pageLimit: restConfig.PageLimit,
pollInterval: 1 * time.Millisecond,
wg: &sync.WaitGroup{},
storage: storageMock,
rest: rest,
consumer: consumerMock,
}
err := receiver.Start(ctx, nil)
<-reqStarted
go func() {
err := receiver.Shutdown(ctx)
r.NoError(err)
}()
<-reqStoped
r.NoError(err)
})
}
2 changes: 1 addition & 1 deletion auditlogsreceiver/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewAuditLogsReceiver(
pollInterval: time.Second * time.Duration(cfg.PollIntervalSec),
pageLimit: cfg.PageLimit,
wg: &sync.WaitGroup{},
doneChan: make(chan bool),
stopPolling: func() {},
storage: st,
rest: newRestyClient(cfg),
consumer: consumer,
Expand Down

0 comments on commit 5d46afa

Please sign in to comment.