diff --git a/nanoq.go b/nanoq.go index 6324d1b..bc71e60 100644 --- a/nanoq.go +++ b/nanoq.go @@ -12,7 +12,6 @@ import ( "os/signal" "runtime" "strings" - "sync" "sync/atomic" "syscall" "time" @@ -40,15 +39,16 @@ var ( // Task represents a task. type Task struct { - ID string `db:"id"` - Fingerprint string `db:"fingerprint"` - Type string `db:"type"` - Payload []byte `db:"payload"` - Retries uint8 `db:"retries"` - MaxRetries uint8 `db:"max_retries"` - TimeoutSeconds int32 `db:"timeout_seconds"` - CreatedAt time.Time `db:"created_at"` - ScheduledAt time.Time `db:"scheduled_at"` + ID string `db:"id"` + Fingerprint string `db:"fingerprint"` + Type string `db:"type"` + Payload []byte `db:"payload"` + Retries uint8 `db:"retries"` + MaxRetries uint8 `db:"max_retries"` + TimeoutSeconds int32 `db:"timeout_seconds"` + CreatedAt time.Time `db:"created_at"` + ScheduledAt time.Time `db:"scheduled_at"` + ClaimedAt *time.Time `db:"claimed_at"` } // NewTask creates a new task. @@ -147,6 +147,7 @@ func NewClient(db *sqlx.DB) *Client { // CreateTask creates the given task. // +// Expected to run in an existing transaction. // Returns ErrDuplicateTask if a task with the same fingerprint already exists. func (c *Client) CreateTask(ctx context.Context, tx *sqlx.Tx, t Task) error { _, err := tx.NamedExecContext(ctx, ` @@ -166,41 +167,61 @@ func (c *Client) CreateTask(ctx context.Context, tx *sqlx.Tx, t Task) error { // ClaimTask claims a task for processing. // -// The claim is valid until the transaction is committed or rolled back. -func (c *Client) ClaimTask(ctx context.Context, tx *sqlx.Tx) (Task, error) { +// Returns ErrNoTasks if no tasks are available. +func (c *Client) ClaimTask(ctx context.Context) (Task, error) { t := Task{} - err := tx.GetContext(ctx, &t, `SELECT * FROM tasks WHERE scheduled_at <= UTC_TIMESTAMP() ORDER BY scheduled_at ASC LIMIT 1 FOR UPDATE SKIP LOCKED`) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return t, ErrNoTasks + err := c.RunTransaction(ctx, func(tx *sqlx.Tx) error { + err := tx.GetContext(ctx, &t, ` + SELECT * FROM tasks + WHERE scheduled_at <= UTC_TIMESTAMP() + AND (claimed_at IS NULL OR DATE_ADD(claimed_at, INTERVAL timeout_seconds*1.1 SECOND) < UTC_TIMESTAMP()) + ORDER BY scheduled_at ASC LIMIT 1 FOR UPDATE SKIP LOCKED`) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return ErrNoTasks + } + return fmt.Errorf("get task: %w", err) } - return t, err - } + now := time.Now().UTC() + t.ClaimedAt = &now - return t, nil -} - -// RetryTask schedules a retry of the given task. -func (c *Client) RetryTask(ctx context.Context, tx *sqlx.Tx, t Task, retryIn time.Duration) error { - t.Retries++ - t.ScheduledAt = time.Now().UTC().Add(retryIn) + _, err = tx.NamedExecContext(ctx, `UPDATE tasks SET claimed_at = :claimed_at WHERE id = :id`, t) + if err != nil { + return fmt.Errorf("update task: %w", err) + } - _, err := tx.NamedExecContext(ctx, `UPDATE tasks SET retries = :retries, scheduled_at = :scheduled_at WHERE id = :id`, t) - if err != nil { - return err - } + return nil + }) - return nil + return t, err } // DeleteTask deletes the given task. -func (c *Client) DeleteTask(ctx context.Context, tx *sqlx.Tx, t Task) error { - _, err := tx.NamedExecContext(ctx, `DELETE FROM tasks WHERE id = :id`, t) - if err != nil { +func (c *Client) DeleteTask(ctx context.Context, t Task) error { + return c.RunTransaction(ctx, func(tx *sqlx.Tx) error { + _, err := tx.NamedExecContext(ctx, `DELETE FROM tasks WHERE id = :id`, t) return err - } + }) +} - return nil +// ReleaseTask releases the given task, allowing it to be claimed again. +func (c *Client) ReleaseTask(ctx context.Context, t Task) error { + return c.RunTransaction(ctx, func(tx *sqlx.Tx) error { + _, err := tx.NamedExecContext(ctx, `UPDATE tasks SET claimed_at = NULL WHERE id = :id`, t) + return err + }) +} + +// RetryTask schedules a retry of the given task. +func (c *Client) RetryTask(ctx context.Context, t Task, retryIn time.Duration) error { + return c.RunTransaction(ctx, func(tx *sqlx.Tx) error { + t.Retries++ + t.ScheduledAt = time.Now().UTC().Add(retryIn) + t.ClaimedAt = nil + + _, err := tx.NamedExecContext(ctx, `UPDATE tasks SET retries = :retries, scheduled_at = :scheduled_at, claimed_at = :claimed_at WHERE id = :id`, t) + return err + }) } // RunTransaction runs the given function in a transaction. @@ -274,7 +295,8 @@ type Processor struct { middleware []Middleware retryPolicy RetryPolicy - done atomic.Bool + workers chan struct{} + done atomic.Bool } // NewProcessor creates a new processor. @@ -336,76 +358,80 @@ func (p *Processor) Run(ctx context.Context, concurrency int, shutdownTimeout ti }() p.logger.Info().Int("concurrency", concurrency).Msg("Starting processor") - var wg sync.WaitGroup - for range concurrency { - wg.Add(1) + p.workers = make(chan struct{}, concurrency) + for !p.done.Load() { + // Acquire a worker before claiming a task, to avoid holding claimed tasks while all workers are busy. + p.workers <- struct{}{} + + t, err := p.client.ClaimTask(processorCtx) + if err != nil { + if !errors.Is(err, ErrNoTasks) && !errors.Is(err, context.Canceled) { + p.logger.Error().Err(err).Msg("Could not claim task") + } + <-p.workers + time.Sleep(1 * time.Second) + continue + } go func() { - for !p.done.Load() { - err := p.process(processorCtx) - if err != nil { - if errors.Is(err, ErrNoTasks) { - // The queue is empty. Wait a second before trying again. - time.Sleep(1 * time.Second) - continue - } - p.logger.Error().Err(err).Msg("Could not process task") - } + if err = p.processTask(processorCtx, t); err != nil { + p.logger.Error().Err(err).Msg("Could not process task") } - wg.Done() + <-p.workers }() } - wg.Wait() + // Wait for workers to finish. + for range cap(p.workers) { + p.workers <- struct{}{} + } } -// process claims a single task and processes it. -func (p *Processor) process(ctx context.Context) error { - return p.client.RunTransaction(ctx, func(tx *sqlx.Tx) error { - t, err := p.client.ClaimTask(ctx, tx) - if err != nil { - return fmt.Errorf("claim task: %w", err) +// processTask processes a single task. +func (p *Processor) processTask(ctx context.Context, t Task) error { + h, ok := p.handlers[t.Type] + if !ok { + h = func(ctx context.Context, t Task) error { + return fmt.Errorf("no handler found for task type %v: %w", t.Type, ErrSkipRetry) } + } + // Apply global middleware. + for i := len(p.middleware) - 1; i >= 0; i-- { + h = p.middleware[i](h) + } - h, ok := p.handlers[t.Type] - if !ok { - h = func(ctx context.Context, t Task) error { - return fmt.Errorf("no handler found for task type %v: %w", t.Type, ErrSkipRetry) + if err := callHandler(ctx, h, t); err != nil { + if errors.Is(err, context.Canceled) { + // The processor is shutting down. Release the task and exit. + if err = p.client.ReleaseTask(context.Background(), t); err != nil { + return fmt.Errorf("release task %v: %w", t.ID, err) } - } - // Apply global middleware. - for i := len(p.middleware) - 1; i >= 0; i-- { - h = p.middleware[i](h) + return fmt.Errorf("task %v canceled: %v", t.ID, context.Cause(ctx)) } - if err = callHandler(ctx, h, t); err != nil { - if errors.Is(err, context.Canceled) { - return fmt.Errorf("task %v canceled: %v", t.ID, context.Cause(ctx)) - } + if errors.Is(err, context.DeadlineExceeded) { // Extract a more specific timeout error, if any. - if errors.Is(err, context.DeadlineExceeded) { - err = context.Cause(ctx) - } + err = context.Cause(ctx) + } + if p.errorHandler != nil { + p.errorHandler(ctx, t, err) + } - if p.errorHandler != nil { - p.errorHandler(ctx, t, err) + if t.Retries < t.MaxRetries && !errors.Is(err, ErrSkipRetry) { + retryIn := p.retryPolicy(t) + if err := p.client.RetryTask(ctx, t, retryIn); err != nil { + return fmt.Errorf("retry task %v: %w", t.ID, err) } - if t.Retries < t.MaxRetries && !errors.Is(err, ErrSkipRetry) { - retryIn := p.retryPolicy(t) - if err := p.client.RetryTask(ctx, tx, t, retryIn); err != nil { - return fmt.Errorf("update task %v: %w", t.ID, err) - } - return nil - } + return nil } + } - if err := p.client.DeleteTask(ctx, tx, t); err != nil { - return fmt.Errorf("delete task %v: %w", t.ID, err) - } + if err := p.client.DeleteTask(ctx, t); err != nil { + return fmt.Errorf("delete task %v: %w", t.ID, err) + } - return nil - }) + return nil } // callHandler calls the given handler, converting panics into errors. diff --git a/nanoq_test.go b/nanoq_test.go index e4a21c5..af08438 100644 --- a/nanoq_test.go +++ b/nanoq_test.go @@ -175,7 +175,12 @@ func TestProcessor_Run(t *testing.T) { AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "0", "1", "60", time.Now(), time.Now()) mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows) - mock.ExpectExec("UPDATE tasks SET retries = (.+), scheduled_at = (.+) WHERE id = (.+)").WithArgs(1, sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() + mock.ExpectExec("UPDATE tasks SET retries = (.+), scheduled_at = (.+), claimed_at = (.+) WHERE id = (.+)").WithArgs(1, sqlmock.AnyArg(), nil, "01HQJHTZCAT5WDCGVTWJ640VMM"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -185,6 +190,11 @@ func TestProcessor_Run(t *testing.T) { AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "1", "1", "60", time.Now(), time.Now()) mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows) + mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() mock.ExpectExec("DELETE FROM tasks WHERE id = (.+)").WithArgs("01HQJHTZCAT5WDCGVTWJ640VMM"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -229,7 +239,12 @@ func TestProcessor_Run_RetriesExhausted(t *testing.T) { AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "0", "1", "60", time.Now(), time.Now()) mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows) - mock.ExpectExec("UPDATE tasks SET retries = (.+), scheduled_at = (.+) WHERE id = (.+)").WithArgs(1, sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() + mock.ExpectExec("UPDATE tasks SET retries = (.+), scheduled_at = (.+), claimed_at = (.+) WHERE id = (.+)").WithArgs(1, sqlmock.AnyArg(), nil, "01HQJHTZCAT5WDCGVTWJ640VMM"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -239,6 +254,11 @@ func TestProcessor_Run_RetriesExhausted(t *testing.T) { AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "1", "1", "60", time.Now(), time.Now()) mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows) + mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() mock.ExpectExec("DELETE FROM tasks WHERE id = (.+)").WithArgs("01HQJHTZCAT5WDCGVTWJ640VMM"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -282,6 +302,11 @@ func TestProcessor_Run_SkipRetry(t *testing.T) { AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "0", "1", "60", time.Now(), time.Now()) mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows) + mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() mock.ExpectExec("DELETE FROM tasks WHERE id = (.+)").WithArgs("01HQJHTZCAT5WDCGVTWJ640VMM"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -325,6 +350,11 @@ func TestProcessor_Run_Panic(t *testing.T) { AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "0", "1", "60", time.Now(), time.Now()) mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows) + mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() mock.ExpectExec("DELETE FROM tasks WHERE id = (.+)").WithArgs("01HQJHTZCAT5WDCGVTWJ640VMM"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -365,6 +395,11 @@ func TestProcessor_Run_NoHandler(t *testing.T) { AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "0", "1", "60", time.Now(), time.Now()) mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows) + mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() mock.ExpectExec("DELETE FROM tasks WHERE id = (.+)").WithArgs("01HQJHTZCAT5WDCGVTWJ640VMM"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -461,6 +496,11 @@ func TestProcessor_Run_Middleware(t *testing.T) { AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "0", "1", "60", time.Now(), time.Now()) mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows) + mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() mock.ExpectExec("DELETE FROM tasks WHERE id = (.+)").WithArgs("01HQJHTZCAT5WDCGVTWJ640VMM"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -477,3 +517,47 @@ func TestProcessor_Run_Middleware(t *testing.T) { t.Error(err) } } + +func TestProcessor_Run_Cancel(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + client := nanoq.NewClient(sqlx.NewDb(db, "sqlmock")) + processor := nanoq.NewProcessor(client, zerolog.Nop()) + processor.Handle("my-type", func(ctx context.Context, task nanoq.Task) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + continue + } + } + }) + + // Task claim and release. + mock.ExpectBegin() + rows := sqlmock.NewRows([]string{"id", "fingerprint", "type", "payload", "retries", "max_retries", "timeout_seconds", "created_at", "scheduled_at"}). + AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "0", "1", "60", time.Now(), time.Now()) + mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows) + + mock.ExpectExec("UPDATE tasks SET claimed_at = (.+) WHERE id = (.+)").WithArgs(sqlmock.AnyArg(), "01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectBegin() + mock.ExpectExec("UPDATE tasks SET claimed_at = NULL WHERE id = (.+)").WithArgs("01HQJHTZCAT5WDCGVTWJ640VMM"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + ctx, cancel := context.WithCancel(context.Background()) + go processor.Run(ctx, 1, 1*time.Millisecond) + time.Sleep(1 * time.Second) + cancel() + // Wait for the processor to shut down. + time.Sleep(2 * time.Millisecond) + + err := mock.ExpectationsWereMet() + if err != nil { + t.Error(err) + } +} diff --git a/table.sql b/table.sql index a949e44..449dead 100644 --- a/table.sql +++ b/table.sql @@ -11,6 +11,7 @@ CREATE TABLE `tasks` ( `timeout_seconds` int(11) NOT NULL DEFAULT 60, `created_at` datetime(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `scheduled_at` datetime(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + `claimed_at` datetime(6), PRIMARY KEY (`id`), UNIQUE KEY `fingerprint` (`fingerprint`), KEY `scheduled_at` (`scheduled_at`)