From b88fa7f74a45084fd6be8646d094414c7cf792e5 Mon Sep 17 00:00:00 2001 From: Yash Bhardwaj Date: Wed, 9 Oct 2024 14:43:56 +0530 Subject: [PATCH] Replay cancel runs (#273) (#276) * fix: emit replay lifecycle events * fix: refactor optimus replay template --- core/scheduler/service/replay_service.go | 39 ++++++++++++++++++- core/scheduler/service/replay_service_test.go | 31 ++++++++++++--- core/scheduler/service/replay_worker.go | 35 +++++++++++++++-- core/scheduler/service/replay_worker_test.go | 5 +++ core/scheduler/status.go | 3 +- ext/scheduler/airflow/airflow.go | 22 +++++++++++ .../postgres/scheduler/replay_repository.go | 4 +- 7 files changed, 125 insertions(+), 14 deletions(-) diff --git a/core/scheduler/service/replay_service.go b/core/scheduler/service/replay_service.go index e7770db56a..931c2d4381 100644 --- a/core/scheduler/service/replay_service.go +++ b/core/scheduler/service/replay_service.go @@ -34,6 +34,7 @@ type SchedulerRunGetter interface { type ReplayRepository interface { RegisterReplay(ctx context.Context, replay *scheduler.Replay, runs []*scheduler.JobRunStatus) (uuid.UUID, error) UpdateReplay(ctx context.Context, replayID uuid.UUID, state scheduler.ReplayState, runs []*scheduler.JobRunStatus, message string) error + UpdateReplayRuns(ctx context.Context, replayID uuid.UUID, runs []*scheduler.JobRunStatus) error UpdateReplayStatus(ctx context.Context, replayID uuid.UUID, state scheduler.ReplayState, message string) error GetReplayByFilters(ctx context.Context, projectName tenant.ProjectName, filters ...filter.FilterOpt) ([]*scheduler.ReplayWithRun, error) @@ -52,6 +53,8 @@ type ReplayValidator interface { type ReplayExecutor interface { Execute(replayID uuid.UUID, jobTenant tenant.Tenant, jobName scheduler.JobName) + SyncStatus(ctx context.Context, replayWithRun *scheduler.ReplayWithRun, jobCron *cron.ScheduleSpec) (scheduler.JobRunStatusList, error) + CancelReplayRunsOnScheduler(ctx context.Context, replay *scheduler.Replay, jobCron *cron.ScheduleSpec, runs []*scheduler.JobRunStatus) []*scheduler.JobRunStatus } type ReplayService struct { @@ -216,13 +219,47 @@ func (r *ReplayService) GetRunsStatus(ctx context.Context, tenant tenant.Tenant, return runs, nil } +func (r *ReplayService) cancelReplayRuns(ctx context.Context, replayWithRun *scheduler.ReplayWithRun) error { + // get list of in progress runs + // stop them on the scheduler + replay := replayWithRun.Replay + jobName := replay.JobName() + jobCron, err := getJobCron(ctx, r.logger, r.jobRepo, replay.Tenant(), jobName) + if err != nil { + r.logger.Error("unable to get cron value for job [%s]: %s", jobName.String(), err.Error()) + return err + } + + syncedRunStatus, err := r.executor.SyncStatus(ctx, replayWithRun, jobCron) + if err != nil { + r.logger.Error("unable to sync replay runs status for job [%s]: %s", jobName.String(), err.Error()) + return err + } + + statesForCanceling := []scheduler.State{scheduler.StateRunning, scheduler.StateInProgress, scheduler.StateQueued} + toBeCanceledRuns := syncedRunStatus.GetSortedRunsByStates(statesForCanceling) + if len(toBeCanceledRuns) == 0 { + return nil + } + + canceledRuns := r.executor.CancelReplayRunsOnScheduler(ctx, replay, jobCron, toBeCanceledRuns) + + // update the status of these runs as failed in DB + return r.replayRepo.UpdateReplayRuns(ctx, replay.ID(), canceledRuns) +} + func (r *ReplayService) CancelReplay(ctx context.Context, replayWithRun *scheduler.ReplayWithRun) error { if replayWithRun.Replay.IsTerminated() { return errors.InvalidArgument(scheduler.EntityReplay, fmt.Sprintf("replay has already been terminated with status %s", replayWithRun.Replay.State().String())) } statusSummary := scheduler.JobRunStatusList(replayWithRun.Runs).GetJobRunStatusSummary() cancelMessage := fmt.Sprintf("replay cancelled with run status %s", statusSummary) - return r.replayRepo.UpdateReplayStatus(ctx, replayWithRun.Replay.ID(), scheduler.ReplayStateCancelled, cancelMessage) + + err := r.replayRepo.UpdateReplayStatus(ctx, replayWithRun.Replay.ID(), scheduler.ReplayStateCancelled, cancelMessage) + if err != nil { + return err + } + return r.cancelReplayRuns(ctx, replayWithRun) } func NewReplayService( diff --git a/core/scheduler/service/replay_service_test.go b/core/scheduler/service/replay_service_test.go index 970ee39957..1aaeda6b45 100644 --- a/core/scheduler/service/replay_service_test.go +++ b/core/scheduler/service/replay_service_test.go @@ -410,9 +410,6 @@ func TestReplayService(t *testing.T) { assert.ErrorContains(t, err, errorMsg) }) t.Run("returns no error if replay has been successfully cancelled", func(t *testing.T) { - replayRepository := new(ReplayRepository) - defer replayRepository.AssertExpectations(t) - replay := scheduler.NewReplay(replayID, jobName, tnnt, replayConfig, scheduler.ReplayStateInProgress, startTime, message) replayWithRun := &scheduler.ReplayWithRun{ Replay: replay, @@ -423,10 +420,19 @@ func TestReplayService(t *testing.T) { }, }, } - + replayRepository := new(ReplayRepository) + defer replayRepository.AssertExpectations(t) replayRepository.On("UpdateReplayStatus", mock.Anything, replay.ID(), scheduler.ReplayStateCancelled, mock.Anything).Return(nil).Once() - replayService := service.NewReplayService(replayRepository, nil, nil, nil, nil, nil, logger, nil, nil) + jobRepository := new(JobRepository) + defer jobRepository.AssertExpectations(t) + jobRepository.On("GetJobDetails", mock.Anything, projName, jobName).Return(jobWithDetails, nil) + + replayWorker := new(ReplayExecutor) + defer replayWorker.AssertExpectations(t) + replayWorker.On("SyncStatus", ctx, replayWithRun, jobCron).Return(scheduler.JobRunStatusList{}, nil) + + replayService := service.NewReplayService(replayRepository, jobRepository, nil, nil, replayWorker, nil, logger, nil, nil) err := replayService.CancelReplay(ctx, replayWithRun) assert.NoError(t, err) }) @@ -564,6 +570,11 @@ func (_m *ReplayRepository) GetReplayByID(ctx context.Context, replayID uuid.UUI return r0, r1 } +func (_m *ReplayRepository) UpdateReplayRuns(ctx context.Context, replayID uuid.UUID, runs []*scheduler.JobRunStatus) error { + args := _m.Called(ctx, replayID, runs) + return args.Error(0) +} + // GetReplayRequestsByStatus provides a mock function with given fields: ctx, statusList func (_m *ReplayRepository) GetReplayRequestsByStatus(ctx context.Context, statusList []scheduler.ReplayState) ([]*scheduler.Replay, error) { ret := _m.Called(ctx, statusList) @@ -715,6 +726,16 @@ func (_m *ReplayExecutor) Execute(replayID uuid.UUID, jobTenant tenant.Tenant, j _m.Called(replayID, jobTenant, jobName) } +func (_m *ReplayExecutor) SyncStatus(ctx context.Context, replayWithRun *scheduler.ReplayWithRun, jobCron *cron.ScheduleSpec) (scheduler.JobRunStatusList, error) { + args := _m.Called(ctx, replayWithRun, jobCron) + return args.Get(0).(scheduler.JobRunStatusList), args.Error(1) +} + +func (_m *ReplayExecutor) CancelReplayRunsOnScheduler(ctx context.Context, replay *scheduler.Replay, jobCron *cron.ScheduleSpec, runs []*scheduler.JobRunStatus) []*scheduler.JobRunStatus { + args := _m.Called(ctx, replay, jobCron, runs) + return args.Get(0).([]*scheduler.JobRunStatus) +} + // TenantGetter is an autogenerated mock type for the TenantGetter type type TenantGetter struct { mock.Mock diff --git a/core/scheduler/service/replay_worker.go b/core/scheduler/service/replay_worker.go index 508c939d0c..4bea4a6f81 100644 --- a/core/scheduler/service/replay_worker.go +++ b/core/scheduler/service/replay_worker.go @@ -33,6 +33,7 @@ type ReplayScheduler interface { Clear(ctx context.Context, t tenant.Tenant, jobName scheduler.JobName, scheduledAt time.Time) error ClearBatch(ctx context.Context, t tenant.Tenant, jobName scheduler.JobName, startTime, endTime time.Time) error + CancelRun(ctx context.Context, tnnt tenant.Tenant, jobName scheduler.JobName, executionTime time.Time, dagRunIDPrefix string) error CreateRun(ctx context.Context, tnnt tenant.Tenant, jobName scheduler.JobName, executionTime time.Time, dagRunIDPrefix string) error GetJobRuns(ctx context.Context, t tenant.Tenant, criteria *scheduler.JobRunsCriteria, jobCron *cron.ScheduleSpec) ([]*scheduler.JobRunStatus, error) } @@ -98,6 +99,15 @@ func (w *ReplayWorker) Execute(replayID uuid.UUID, jobTenant tenant.Tenant, jobN } } +func (w *ReplayWorker) SyncStatus(ctx context.Context, replayWithRun *scheduler.ReplayWithRun, jobCron *cron.ScheduleSpec) (scheduler.JobRunStatusList, error) { + incomingRuns, err := w.fetchRuns(ctx, replayWithRun, jobCron) + if err != nil { + w.logger.Error("[ReplayID: %s] unable to get incoming runs: %s", replayWithRun.Replay.ID().String(), err) + return scheduler.JobRunStatusList{}, err + } + return syncStatus(replayWithRun.Runs, incomingRuns), nil +} + func (w *ReplayWorker) startExecutionLoop(ctx context.Context, replayID uuid.UUID, jobCron *cron.ScheduleSpec) error { executionLoopCount := 0 for { @@ -136,14 +146,13 @@ func (w *ReplayWorker) startExecutionLoop(ctx context.Context, replayID uuid.UUI return nil } - incomingRuns, err := w.fetchRuns(ctx, replayWithRun, jobCron) + syncedRunStatus, err := w.SyncStatus(ctx, replayWithRun, jobCron) if err != nil { // todo: lets not kill watchers on such errors w.logger.Error("[ReplayID: %s] unable to get incoming runs: %s", replayWithRun.Replay.ID().String(), err) return err } - existingRuns := replayWithRun.Runs - syncedRunStatus := w.syncStatus(existingRuns, incomingRuns) + if err := w.replayRepo.UpdateReplay(ctx, replayWithRun.Replay.ID(), scheduler.ReplayStateInProgress, syncedRunStatus, ""); err != nil { w.logger.Error("[ReplayID: %s] unable to update replay state to failed: %s", replayWithRun.Replay.ID(), err) return err @@ -224,6 +233,24 @@ func (w *ReplayWorker) fetchRuns(ctx context.Context, replayReq *scheduler.Repla return w.scheduler.GetJobRuns(ctx, replayReq.Replay.Tenant(), jobRunCriteria, jobCron) } +func (w *ReplayWorker) CancelReplayRunsOnScheduler(ctx context.Context, replay *scheduler.Replay, jobCron *cron.ScheduleSpec, runs []*scheduler.JobRunStatus) []*scheduler.JobRunStatus { + var canceledRuns []*scheduler.JobRunStatus + for _, run := range runs { + logicalTime := run.GetLogicalTime(jobCron) + + w.logger.Info("[ReplayID: %s] Canceling run with logical time: %s", replay.ID(), logicalTime) + if err := w.scheduler.CancelRun(ctx, replay.Tenant(), replay.JobName(), logicalTime, prefixReplayed); err != nil { + w.logger.Error("[ReplayID: %s] unable to cancel job run for job: %s, Schedule Time: %s, err: %s", replay.ID(), replay.JobName(), run.ScheduledAt, err.Error()) + continue + } + canceledRuns = append(canceledRuns, &scheduler.JobRunStatus{ + ScheduledAt: run.ScheduledAt, + State: scheduler.StateCanceled, + }) + } + return canceledRuns +} + func (w *ReplayWorker) replayRunOnScheduler(ctx context.Context, jobCron *cron.ScheduleSpec, replayReq *scheduler.Replay, runs ...*scheduler.JobRunStatus) error { // clear runs pendingRuns := scheduler.JobRunStatusList(runs).GetSortedRunsByStates([]scheduler.State{scheduler.StatePending}) @@ -255,7 +282,7 @@ func (w *ReplayWorker) replayRunOnScheduler(ctx context.Context, jobCron *cron.S // syncStatus syncs existing and incoming runs // replay status: created -> in_progress -> [success, failed] // replay runs: [missing, pending] -> in_progress -> [success, failed] -func (*ReplayWorker) syncStatus(existingJobRuns, incomingJobRuns []*scheduler.JobRunStatus) scheduler.JobRunStatusList { +func syncStatus(existingJobRuns, incomingJobRuns []*scheduler.JobRunStatus) scheduler.JobRunStatusList { incomingRunStatusMap := scheduler.JobRunStatusList(incomingJobRuns).ToRunStatusMap() existingRunStatusMap := scheduler.JobRunStatusList(existingJobRuns).ToRunStatusMap() diff --git a/core/scheduler/service/replay_worker_test.go b/core/scheduler/service/replay_worker_test.go index 1359f6d40e..de139862bb 100644 --- a/core/scheduler/service/replay_worker_test.go +++ b/core/scheduler/service/replay_worker_test.go @@ -796,6 +796,11 @@ func (_m *mockReplayScheduler) CreateRun(ctx context.Context, tnnt tenant.Tenant return r0 } +func (_m *mockReplayScheduler) CancelRun(ctx context.Context, tnnt tenant.Tenant, jobName scheduler.JobName, executionTime time.Time, dagRunIDPrefix string) error { + args := _m.Called(ctx, tnnt, jobName, executionTime, dagRunIDPrefix) + return args.Error(0) +} + // GetJobRuns provides a mock function with given fields: ctx, t, criteria, jobCron func (_m *mockReplayScheduler) GetJobRuns(ctx context.Context, t tenant.Tenant, criteria *scheduler.JobRunsCriteria, jobCron *cron.ScheduleSpec) ([]*scheduler.JobRunStatus, error) { ret := _m.Called(ctx, t, criteria, jobCron) diff --git a/core/scheduler/status.go b/core/scheduler/status.go index 4ffa8a5508..43e0b017ce 100644 --- a/core/scheduler/status.go +++ b/core/scheduler/status.go @@ -16,6 +16,7 @@ const ( StateAccepted State = "accepted" StateRunning State = "running" StateQueued State = "queued" + StateCanceled State = "canceled" StateRetry State = "retried" @@ -29,8 +30,6 @@ const ( StateMissing State = "missing" ) -var TaskEndStates = []State{StateSuccess, StateFailed, StateRetry} - type State string func StateFromString(state string) (State, error) { diff --git a/ext/scheduler/airflow/airflow.go b/ext/scheduler/airflow/airflow.go index f7e7f3113d..6e86a8183c 100644 --- a/ext/scheduler/airflow/airflow.go +++ b/ext/scheduler/airflow/airflow.go @@ -36,6 +36,7 @@ const ( dagURL = "api/v1/dags/%s" dagRunClearURL = "api/v1/dags/%s/clearTaskInstances" dagRunCreateURL = "api/v1/dags/%s/dagRuns" + dagRunModifyURL = "api/v1/dags/%s/dagRuns/%s" airflowDateFormat = "2006-01-02T15:04:05+00:00" schedulerHostKey = "SCHEDULER_HOST" @@ -418,6 +419,27 @@ func (s *Scheduler) ClearBatch(ctx context.Context, tnnt tenant.Tenant, jobName return nil } +func (s *Scheduler) CancelRun(ctx context.Context, tnnt tenant.Tenant, jobName scheduler.JobName, executionTime time.Time, dagRunIDPrefix string) error { + spanCtx, span := startChildSpan(ctx, "CancelRun") + defer span.End() + dagRunID := fmt.Sprintf("%s__%s", dagRunIDPrefix, executionTime.UTC().Format(airflowDateFormat)) + data := []byte(`{"state": "failed"}`) + req := airflowRequest{ + path: fmt.Sprintf(dagRunModifyURL, jobName.String(), dagRunID), + method: http.MethodPatch, + body: data, + } + schdAuth, err := s.getSchedulerAuth(ctx, tnnt) + if err != nil { + return err + } + _, err = s.client.Invoke(spanCtx, req, schdAuth) + if err != nil { + return errors.Wrap(EntityAirflow, "failure while canceling airflow dag run", err) + } + return nil +} + func (s *Scheduler) CreateRun(ctx context.Context, tnnt tenant.Tenant, jobName scheduler.JobName, executionTime time.Time, dagRunIDPrefix string) error { spanCtx, span := startChildSpan(ctx, "CreateRun") defer span.End() diff --git a/internal/store/postgres/scheduler/replay_repository.go b/internal/store/postgres/scheduler/replay_repository.go index 41f25b6447..aec5102390 100644 --- a/internal/store/postgres/scheduler/replay_repository.go +++ b/internal/store/postgres/scheduler/replay_repository.go @@ -210,7 +210,7 @@ func (r ReplayRepository) UpdateReplayStatus(ctx context.Context, id uuid.UUID, } func (r ReplayRepository) UpdateReplay(ctx context.Context, id uuid.UUID, replayStatus scheduler.ReplayState, runs []*scheduler.JobRunStatus, message string) error { - if err := r.updateReplayRuns(ctx, id, runs); err != nil { + if err := r.UpdateReplayRuns(ctx, id, runs); err != nil { return err } return r.updateReplayRequest(ctx, id, replayStatus, message) @@ -329,7 +329,7 @@ func (r ReplayRepository) updateReplayRequest(ctx context.Context, id uuid.UUID, return nil } -func (r ReplayRepository) updateReplayRuns(ctx context.Context, id uuid.UUID, runs []*scheduler.JobRunStatus) error { +func (r ReplayRepository) UpdateReplayRuns(ctx context.Context, id uuid.UUID, runs []*scheduler.JobRunStatus) error { tx, err := r.db.BeginTx(ctx, pgx.TxOptions{}) if err != nil { return err