Skip to content

Commit

Permalink
Fix middleware being applied in the reverse order.
Browse files Browse the repository at this point in the history
  • Loading branch information
bojanz committed Mar 22, 2024
1 parent e7d8c38 commit b7c1769
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 4 deletions.
8 changes: 4 additions & 4 deletions nanoq.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ func (p *Processor) Use(m Middleware) {
// Handle registers the handler for a task type.
func (p *Processor) Handle(taskType string, h Handler, ms ...Middleware) {
// Wrap the handler with the passed middleware.
for _, m := range ms {
h = m(h)
for i := len(ms) - 1; i >= 0; i-- {
h = ms[i](h)
}
p.handlers[taskType] = h
}
Expand Down Expand Up @@ -331,8 +331,8 @@ func (p *Processor) process(ctx context.Context) error {
}
}
// Apply global middleware.
for _, m := range p.middleware {
h = m(h)
for i := len(p.middleware) - 1; i >= 0; i-- {
h = p.middleware[i](h)
}

if err = h(ctx, t); err != nil {
Expand Down
90 changes: 90 additions & 0 deletions nanoq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,93 @@ func TestProcessor_Run_SkipRetry(t *testing.T) {
t.Errorf("erorr handler called %v times instead of %v", errorHandlerCalled, 1)
}
}

func TestProcessor_Run_Middleware(t *testing.T) {
// Used to store and retrieve the context value.
type contextKey string

db, mock, _ := sqlmock.New()
defer db.Close()
client := nanoq.NewClient(sqlx.NewDb(db, "sqlmock"))
processor := nanoq.NewProcessor(client, zerolog.Nop())
processor.Use(func(next nanoq.Handler) nanoq.Handler {
return func(ctx context.Context, t nanoq.Task) error {
middlewareValue := ctx.Value(contextKey("middleware"))
if middlewareValue == nil {
middlewareValue = make([]string, 0, 10)
}
middleware := append(middlewareValue.([]string), "first_global")
ctx = context.WithValue(ctx, contextKey("middleware"), middleware)

return next(ctx, t)
}
})
processor.Use(func(next nanoq.Handler) nanoq.Handler {
return func(ctx context.Context, t nanoq.Task) error {
middlewareValue := ctx.Value(contextKey("middleware"))
if middlewareValue == nil {
middlewareValue = make([]string, 0, 10)
}
middleware := append(middlewareValue.([]string), "second_global")
ctx = context.WithValue(ctx, contextKey("middleware"), middleware)

return next(ctx, t)
}
})

firstHandlerMiddleware := func(next nanoq.Handler) nanoq.Handler {
return func(ctx context.Context, t nanoq.Task) error {
middlewareValue := ctx.Value(contextKey("middleware"))
if middlewareValue == nil {
middlewareValue = make([]string, 0, 10)
}
middleware := append(middlewareValue.([]string), "first")
ctx = context.WithValue(ctx, contextKey("middleware"), middleware)

return next(ctx, t)
}
}
secondHandlerMiddleware := func(next nanoq.Handler) nanoq.Handler {
return func(ctx context.Context, t nanoq.Task) error {
middlewareValue := ctx.Value(contextKey("middleware"))
if middlewareValue == nil {
middlewareValue = make([]string, 0, 10)
}
middleware := append(middlewareValue.([]string), "second")
ctx = context.WithValue(ctx, contextKey("middleware"), middleware)

return next(ctx, t)
}
}
handler := func(ctx context.Context, task nanoq.Task) error {
middlewareValue := ctx.Value(contextKey("middleware"))
middleware := middlewareValue.([]string)
wantMiddleware := []string{"first_global", "second_global", "first", "second"}
if !slices.Equal(middleware, wantMiddleware) {
t.Errorf("got %v, want %v", middleware, wantMiddleware)
}

return nil
}
processor.Handle("my-type", handler, firstHandlerMiddleware, secondHandlerMiddleware)

// Task claim and deletion.
mock.ExpectBegin()
rows := sqlmock.NewRows([]string{"id", "fingerprint", "type", "payload", "retries", "max_retries", "created_at", "scheduled_at"}).
AddRow("01HQJHTZCAT5WDCGVTWJ640VMM", "25c084d0", "my-type", "{}", "0", "1", time.Now(), time.Now())
mock.ExpectQuery(`SELECT \* FROM tasks WHERE(.+)`).WillReturnRows(rows)

mock.ExpectExec("DELETE FROM tasks WHERE id = (.+)").WithArgs("01HQJHTZCAT5WDCGVTWJ640VMM").
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()

ctx, cancel := context.WithCancel(context.Background())
go processor.Run(ctx, 1, 1*time.Second)
time.Sleep(1 * time.Second)
cancel()

err := mock.ExpectationsWereMet()
if err != nil {
t.Error(err)
}
}

0 comments on commit b7c1769

Please sign in to comment.