Skip to content

Commit

Permalink
Replay cancel runs (#273) (#276)
Browse files Browse the repository at this point in the history
* fix: emit replay lifecycle events

* fix: refactor optimus replay template
  • Loading branch information
Mryashbhardwaj authored Oct 9, 2024
1 parent a7bb2c5 commit b88fa7f
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 14 deletions.
39 changes: 38 additions & 1 deletion core/scheduler/service/replay_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type SchedulerRunGetter interface {
type ReplayRepository interface {
RegisterReplay(ctx context.Context, replay *scheduler.Replay, runs []*scheduler.JobRunStatus) (uuid.UUID, error)
UpdateReplay(ctx context.Context, replayID uuid.UUID, state scheduler.ReplayState, runs []*scheduler.JobRunStatus, message string) error
UpdateReplayRuns(ctx context.Context, replayID uuid.UUID, runs []*scheduler.JobRunStatus) error
UpdateReplayStatus(ctx context.Context, replayID uuid.UUID, state scheduler.ReplayState, message string) error

GetReplayByFilters(ctx context.Context, projectName tenant.ProjectName, filters ...filter.FilterOpt) ([]*scheduler.ReplayWithRun, error)
Expand All @@ -52,6 +53,8 @@ type ReplayValidator interface {

type ReplayExecutor interface {
Execute(replayID uuid.UUID, jobTenant tenant.Tenant, jobName scheduler.JobName)
SyncStatus(ctx context.Context, replayWithRun *scheduler.ReplayWithRun, jobCron *cron.ScheduleSpec) (scheduler.JobRunStatusList, error)
CancelReplayRunsOnScheduler(ctx context.Context, replay *scheduler.Replay, jobCron *cron.ScheduleSpec, runs []*scheduler.JobRunStatus) []*scheduler.JobRunStatus
}

type ReplayService struct {
Expand Down Expand Up @@ -216,13 +219,47 @@ func (r *ReplayService) GetRunsStatus(ctx context.Context, tenant tenant.Tenant,
return runs, nil
}

func (r *ReplayService) cancelReplayRuns(ctx context.Context, replayWithRun *scheduler.ReplayWithRun) error {
// get list of in progress runs
// stop them on the scheduler
replay := replayWithRun.Replay
jobName := replay.JobName()
jobCron, err := getJobCron(ctx, r.logger, r.jobRepo, replay.Tenant(), jobName)
if err != nil {
r.logger.Error("unable to get cron value for job [%s]: %s", jobName.String(), err.Error())
return err
}

syncedRunStatus, err := r.executor.SyncStatus(ctx, replayWithRun, jobCron)
if err != nil {
r.logger.Error("unable to sync replay runs status for job [%s]: %s", jobName.String(), err.Error())
return err
}

statesForCanceling := []scheduler.State{scheduler.StateRunning, scheduler.StateInProgress, scheduler.StateQueued}
toBeCanceledRuns := syncedRunStatus.GetSortedRunsByStates(statesForCanceling)
if len(toBeCanceledRuns) == 0 {
return nil
}

canceledRuns := r.executor.CancelReplayRunsOnScheduler(ctx, replay, jobCron, toBeCanceledRuns)

// update the status of these runs as failed in DB
return r.replayRepo.UpdateReplayRuns(ctx, replay.ID(), canceledRuns)
}

func (r *ReplayService) CancelReplay(ctx context.Context, replayWithRun *scheduler.ReplayWithRun) error {
if replayWithRun.Replay.IsTerminated() {
return errors.InvalidArgument(scheduler.EntityReplay, fmt.Sprintf("replay has already been terminated with status %s", replayWithRun.Replay.State().String()))
}
statusSummary := scheduler.JobRunStatusList(replayWithRun.Runs).GetJobRunStatusSummary()
cancelMessage := fmt.Sprintf("replay cancelled with run status %s", statusSummary)
return r.replayRepo.UpdateReplayStatus(ctx, replayWithRun.Replay.ID(), scheduler.ReplayStateCancelled, cancelMessage)

err := r.replayRepo.UpdateReplayStatus(ctx, replayWithRun.Replay.ID(), scheduler.ReplayStateCancelled, cancelMessage)
if err != nil {
return err
}
return r.cancelReplayRuns(ctx, replayWithRun)
}

func NewReplayService(
Expand Down
31 changes: 26 additions & 5 deletions core/scheduler/service/replay_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,6 @@ func TestReplayService(t *testing.T) {
assert.ErrorContains(t, err, errorMsg)
})
t.Run("returns no error if replay has been successfully cancelled", func(t *testing.T) {
replayRepository := new(ReplayRepository)
defer replayRepository.AssertExpectations(t)

replay := scheduler.NewReplay(replayID, jobName, tnnt, replayConfig, scheduler.ReplayStateInProgress, startTime, message)
replayWithRun := &scheduler.ReplayWithRun{
Replay: replay,
Expand All @@ -423,10 +420,19 @@ func TestReplayService(t *testing.T) {
},
},
}

replayRepository := new(ReplayRepository)
defer replayRepository.AssertExpectations(t)
replayRepository.On("UpdateReplayStatus", mock.Anything, replay.ID(), scheduler.ReplayStateCancelled, mock.Anything).Return(nil).Once()

replayService := service.NewReplayService(replayRepository, nil, nil, nil, nil, nil, logger, nil, nil)
jobRepository := new(JobRepository)
defer jobRepository.AssertExpectations(t)
jobRepository.On("GetJobDetails", mock.Anything, projName, jobName).Return(jobWithDetails, nil)

replayWorker := new(ReplayExecutor)
defer replayWorker.AssertExpectations(t)
replayWorker.On("SyncStatus", ctx, replayWithRun, jobCron).Return(scheduler.JobRunStatusList{}, nil)

replayService := service.NewReplayService(replayRepository, jobRepository, nil, nil, replayWorker, nil, logger, nil, nil)
err := replayService.CancelReplay(ctx, replayWithRun)
assert.NoError(t, err)
})
Expand Down Expand Up @@ -564,6 +570,11 @@ func (_m *ReplayRepository) GetReplayByID(ctx context.Context, replayID uuid.UUI
return r0, r1
}

func (_m *ReplayRepository) UpdateReplayRuns(ctx context.Context, replayID uuid.UUID, runs []*scheduler.JobRunStatus) error {
args := _m.Called(ctx, replayID, runs)
return args.Error(0)
}

// GetReplayRequestsByStatus provides a mock function with given fields: ctx, statusList
func (_m *ReplayRepository) GetReplayRequestsByStatus(ctx context.Context, statusList []scheduler.ReplayState) ([]*scheduler.Replay, error) {
ret := _m.Called(ctx, statusList)
Expand Down Expand Up @@ -715,6 +726,16 @@ func (_m *ReplayExecutor) Execute(replayID uuid.UUID, jobTenant tenant.Tenant, j
_m.Called(replayID, jobTenant, jobName)
}

func (_m *ReplayExecutor) SyncStatus(ctx context.Context, replayWithRun *scheduler.ReplayWithRun, jobCron *cron.ScheduleSpec) (scheduler.JobRunStatusList, error) {
args := _m.Called(ctx, replayWithRun, jobCron)
return args.Get(0).(scheduler.JobRunStatusList), args.Error(1)
}

func (_m *ReplayExecutor) CancelReplayRunsOnScheduler(ctx context.Context, replay *scheduler.Replay, jobCron *cron.ScheduleSpec, runs []*scheduler.JobRunStatus) []*scheduler.JobRunStatus {
args := _m.Called(ctx, replay, jobCron, runs)
return args.Get(0).([]*scheduler.JobRunStatus)
}

// TenantGetter is an autogenerated mock type for the TenantGetter type
type TenantGetter struct {
mock.Mock
Expand Down
35 changes: 31 additions & 4 deletions core/scheduler/service/replay_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type ReplayScheduler interface {
Clear(ctx context.Context, t tenant.Tenant, jobName scheduler.JobName, scheduledAt time.Time) error
ClearBatch(ctx context.Context, t tenant.Tenant, jobName scheduler.JobName, startTime, endTime time.Time) error

CancelRun(ctx context.Context, tnnt tenant.Tenant, jobName scheduler.JobName, executionTime time.Time, dagRunIDPrefix string) error
CreateRun(ctx context.Context, tnnt tenant.Tenant, jobName scheduler.JobName, executionTime time.Time, dagRunIDPrefix string) error
GetJobRuns(ctx context.Context, t tenant.Tenant, criteria *scheduler.JobRunsCriteria, jobCron *cron.ScheduleSpec) ([]*scheduler.JobRunStatus, error)
}
Expand Down Expand Up @@ -98,6 +99,15 @@ func (w *ReplayWorker) Execute(replayID uuid.UUID, jobTenant tenant.Tenant, jobN
}
}

func (w *ReplayWorker) SyncStatus(ctx context.Context, replayWithRun *scheduler.ReplayWithRun, jobCron *cron.ScheduleSpec) (scheduler.JobRunStatusList, error) {
incomingRuns, err := w.fetchRuns(ctx, replayWithRun, jobCron)
if err != nil {
w.logger.Error("[ReplayID: %s] unable to get incoming runs: %s", replayWithRun.Replay.ID().String(), err)
return scheduler.JobRunStatusList{}, err
}
return syncStatus(replayWithRun.Runs, incomingRuns), nil
}

func (w *ReplayWorker) startExecutionLoop(ctx context.Context, replayID uuid.UUID, jobCron *cron.ScheduleSpec) error {
executionLoopCount := 0
for {
Expand Down Expand Up @@ -136,14 +146,13 @@ func (w *ReplayWorker) startExecutionLoop(ctx context.Context, replayID uuid.UUI
return nil
}

incomingRuns, err := w.fetchRuns(ctx, replayWithRun, jobCron)
syncedRunStatus, err := w.SyncStatus(ctx, replayWithRun, jobCron)
if err != nil {
// todo: lets not kill watchers on such errors
w.logger.Error("[ReplayID: %s] unable to get incoming runs: %s", replayWithRun.Replay.ID().String(), err)
return err
}
existingRuns := replayWithRun.Runs
syncedRunStatus := w.syncStatus(existingRuns, incomingRuns)

if err := w.replayRepo.UpdateReplay(ctx, replayWithRun.Replay.ID(), scheduler.ReplayStateInProgress, syncedRunStatus, ""); err != nil {
w.logger.Error("[ReplayID: %s] unable to update replay state to failed: %s", replayWithRun.Replay.ID(), err)
return err
Expand Down Expand Up @@ -224,6 +233,24 @@ func (w *ReplayWorker) fetchRuns(ctx context.Context, replayReq *scheduler.Repla
return w.scheduler.GetJobRuns(ctx, replayReq.Replay.Tenant(), jobRunCriteria, jobCron)
}

func (w *ReplayWorker) CancelReplayRunsOnScheduler(ctx context.Context, replay *scheduler.Replay, jobCron *cron.ScheduleSpec, runs []*scheduler.JobRunStatus) []*scheduler.JobRunStatus {
var canceledRuns []*scheduler.JobRunStatus
for _, run := range runs {
logicalTime := run.GetLogicalTime(jobCron)

w.logger.Info("[ReplayID: %s] Canceling run with logical time: %s", replay.ID(), logicalTime)
if err := w.scheduler.CancelRun(ctx, replay.Tenant(), replay.JobName(), logicalTime, prefixReplayed); err != nil {
w.logger.Error("[ReplayID: %s] unable to cancel job run for job: %s, Schedule Time: %s, err: %s", replay.ID(), replay.JobName(), run.ScheduledAt, err.Error())
continue
}
canceledRuns = append(canceledRuns, &scheduler.JobRunStatus{
ScheduledAt: run.ScheduledAt,
State: scheduler.StateCanceled,
})
}
return canceledRuns
}

func (w *ReplayWorker) replayRunOnScheduler(ctx context.Context, jobCron *cron.ScheduleSpec, replayReq *scheduler.Replay, runs ...*scheduler.JobRunStatus) error {
// clear runs
pendingRuns := scheduler.JobRunStatusList(runs).GetSortedRunsByStates([]scheduler.State{scheduler.StatePending})
Expand Down Expand Up @@ -255,7 +282,7 @@ func (w *ReplayWorker) replayRunOnScheduler(ctx context.Context, jobCron *cron.S
// syncStatus syncs existing and incoming runs
// replay status: created -> in_progress -> [success, failed]
// replay runs: [missing, pending] -> in_progress -> [success, failed]
func (*ReplayWorker) syncStatus(existingJobRuns, incomingJobRuns []*scheduler.JobRunStatus) scheduler.JobRunStatusList {
func syncStatus(existingJobRuns, incomingJobRuns []*scheduler.JobRunStatus) scheduler.JobRunStatusList {
incomingRunStatusMap := scheduler.JobRunStatusList(incomingJobRuns).ToRunStatusMap()
existingRunStatusMap := scheduler.JobRunStatusList(existingJobRuns).ToRunStatusMap()

Expand Down
5 changes: 5 additions & 0 deletions core/scheduler/service/replay_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,11 @@ func (_m *mockReplayScheduler) CreateRun(ctx context.Context, tnnt tenant.Tenant
return r0
}

func (_m *mockReplayScheduler) CancelRun(ctx context.Context, tnnt tenant.Tenant, jobName scheduler.JobName, executionTime time.Time, dagRunIDPrefix string) error {
args := _m.Called(ctx, tnnt, jobName, executionTime, dagRunIDPrefix)
return args.Error(0)
}

// GetJobRuns provides a mock function with given fields: ctx, t, criteria, jobCron
func (_m *mockReplayScheduler) GetJobRuns(ctx context.Context, t tenant.Tenant, criteria *scheduler.JobRunsCriteria, jobCron *cron.ScheduleSpec) ([]*scheduler.JobRunStatus, error) {
ret := _m.Called(ctx, t, criteria, jobCron)
Expand Down
3 changes: 1 addition & 2 deletions core/scheduler/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const (
StateAccepted State = "accepted"
StateRunning State = "running"
StateQueued State = "queued"
StateCanceled State = "canceled"

StateRetry State = "retried"

Expand All @@ -29,8 +30,6 @@ const (
StateMissing State = "missing"
)

var TaskEndStates = []State{StateSuccess, StateFailed, StateRetry}

type State string

func StateFromString(state string) (State, error) {
Expand Down
22 changes: 22 additions & 0 deletions ext/scheduler/airflow/airflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const (
dagURL = "api/v1/dags/%s"
dagRunClearURL = "api/v1/dags/%s/clearTaskInstances"
dagRunCreateURL = "api/v1/dags/%s/dagRuns"
dagRunModifyURL = "api/v1/dags/%s/dagRuns/%s"
airflowDateFormat = "2006-01-02T15:04:05+00:00"

schedulerHostKey = "SCHEDULER_HOST"
Expand Down Expand Up @@ -418,6 +419,27 @@ func (s *Scheduler) ClearBatch(ctx context.Context, tnnt tenant.Tenant, jobName
return nil
}

func (s *Scheduler) CancelRun(ctx context.Context, tnnt tenant.Tenant, jobName scheduler.JobName, executionTime time.Time, dagRunIDPrefix string) error {
spanCtx, span := startChildSpan(ctx, "CancelRun")
defer span.End()
dagRunID := fmt.Sprintf("%s__%s", dagRunIDPrefix, executionTime.UTC().Format(airflowDateFormat))
data := []byte(`{"state": "failed"}`)
req := airflowRequest{
path: fmt.Sprintf(dagRunModifyURL, jobName.String(), dagRunID),
method: http.MethodPatch,
body: data,
}
schdAuth, err := s.getSchedulerAuth(ctx, tnnt)
if err != nil {
return err
}
_, err = s.client.Invoke(spanCtx, req, schdAuth)
if err != nil {
return errors.Wrap(EntityAirflow, "failure while canceling airflow dag run", err)
}
return nil
}

func (s *Scheduler) CreateRun(ctx context.Context, tnnt tenant.Tenant, jobName scheduler.JobName, executionTime time.Time, dagRunIDPrefix string) error {
spanCtx, span := startChildSpan(ctx, "CreateRun")
defer span.End()
Expand Down
4 changes: 2 additions & 2 deletions internal/store/postgres/scheduler/replay_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (r ReplayRepository) UpdateReplayStatus(ctx context.Context, id uuid.UUID,
}

func (r ReplayRepository) UpdateReplay(ctx context.Context, id uuid.UUID, replayStatus scheduler.ReplayState, runs []*scheduler.JobRunStatus, message string) error {
if err := r.updateReplayRuns(ctx, id, runs); err != nil {
if err := r.UpdateReplayRuns(ctx, id, runs); err != nil {
return err
}
return r.updateReplayRequest(ctx, id, replayStatus, message)
Expand Down Expand Up @@ -329,7 +329,7 @@ func (r ReplayRepository) updateReplayRequest(ctx context.Context, id uuid.UUID,
return nil
}

func (r ReplayRepository) updateReplayRuns(ctx context.Context, id uuid.UUID, runs []*scheduler.JobRunStatus) error {
func (r ReplayRepository) UpdateReplayRuns(ctx context.Context, id uuid.UUID, runs []*scheduler.JobRunStatus) error {
tx, err := r.db.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return err
Expand Down

0 comments on commit b88fa7f

Please sign in to comment.