From 688360a65fc897f03f9ccaec3a79d1b9dc1cd375 Mon Sep 17 00:00:00 2001 From: Bojan Zivanovic Date: Wed, 15 May 2024 11:07:45 +0200 Subject: [PATCH] Fix unwrapping of DeadlineExceeded errors into ErrTaskTimeout. context.Cause() was called on the parent context, not the task context. This resulted in the error handler being called with a nil err. --- nanoq.go | 18 ++++++++++------- nanoq_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/nanoq.go b/nanoq.go index 07d16ed..24dc22e 100644 --- a/nanoq.go +++ b/nanoq.go @@ -412,14 +412,9 @@ func (p *Processor) processTask(ctx context.Context, t Task) error { return fmt.Errorf("task %v canceled: %v", t.ID, context.Cause(ctx)) } - if errors.Is(err, context.DeadlineExceeded) { - // Extract a more specific timeout error, if any. - err = context.Cause(ctx) - } if p.errorHandler != nil { p.errorHandler(ctx, t, err) } - if t.Retries < t.MaxRetries && !errors.Is(err, ErrSkipRetry) { retryIn := p.retryPolicy(t) if err := p.client.RetryTask(ctx, t, retryIn); err != nil { @@ -455,8 +450,17 @@ func callHandler(ctx context.Context, h Handler, t Task) (err error) { err = fmt.Errorf("panic [%s:%d]: %v: %w", file, line, r, ErrSkipRetry) } }() - ctx, cancel := context.WithTimeoutCause(ctx, t.Timeout(), ErrTaskTimeout) + taskCtx, cancel := context.WithTimeoutCause(ctx, t.Timeout(), ErrTaskTimeout) defer cancel() - return h(ctx, t) + err = h(taskCtx, t) + if err != nil && errors.Is(err, context.DeadlineExceeded) { + // Extract a more specific timeout error, if any. + // context.Cause returns nil if the canceled context is a child of taskCtx. + if cerr := context.Cause(taskCtx); cerr != nil { + err = cerr + } + } + + return err } diff --git a/nanoq_test.go b/nanoq_test.go index af08438..6c7582e 100644 --- a/nanoq_test.go +++ b/nanoq_test.go @@ -561,3 +561,58 @@ func TestProcessor_Run_Cancel(t *testing.T) { t.Error(err) } } + +func TestProcessor_Run_Timeout(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + client := nanoq.NewClient(sqlx.NewDb(db, "sqlmock")) + processor := nanoq.NewProcessor(client, zerolog.Nop()) + processor.Handle("my-type", func(ctx context.Context, task nanoq.Task) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + continue + } + } + }) + errorHandlerCalled := 0 + processor.OnError(func(ctx context.Context, task nanoq.Task, err error) { + if !errors.Is(err, nanoq.ErrTaskTimeout) { + t.Errorf("error handler called with unexpected error: %v", err) + } + errorHandlerCalled++ + }) + + // Task claim, timeout_seconds=1. + mock.ExpectBegin() + rows := sqlmock.NewRows([]string{"id", "fingerprint", "type", "payload", "retries", "max_retries", "timeout_seconds", "created_at", "scheduled_at"}). + AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "0", "0", "1", time.Now(), time.Now()) + mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows) + + mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() + mock.ExpectExec("DELETE FROM tasks WHERE id = (.+)").WithArgs("01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + ctx, cancel := context.WithCancel(context.Background()) + go processor.Run(ctx, 1, 1*time.Millisecond) + time.Sleep(2 * time.Second) + cancel() + // Wait for the processor to shut down. + time.Sleep(2 * time.Millisecond) + + err := mock.ExpectationsWereMet() + if err != nil { + t.Error(err) + } + + if errorHandlerCalled != 1 { + t.Errorf("erorr handler called %v times instead of %v", errorHandlerCalled, 1) + } +}