diff --git a/flytepropeller/pkg/controller/nodes/dynamic/handler.go b/flytepropeller/pkg/controller/nodes/dynamic/handler.go index fc2a4cb57..3324c539b 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/handler.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/handler.go @@ -165,7 +165,10 @@ func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, n return trns.WithInfo(handler.PhaseInfoFailureErr(ee.ExecutionError, trns.Info().GetInfo())), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: ee.ExecutionError.String()}, nil } taskNodeInfoMetadata := &event.TaskNodeMetadata{CacheStatus: status.GetCacheStatus(), CatalogKey: status.GetMetadata()} - trns.WithInfo(trns.Info().WithInfo(&handler.ExecutionInfo{TaskNodeInfo: &handler.TaskNodeInfo{TaskNodeMetadata: taskNodeInfoMetadata}})) + trns = trns.WithInfo(trns.Info().WithInfo(&handler.ExecutionInfo{ + OutputInfo: trns.Info().GetInfo().OutputInfo, + TaskNodeInfo: &handler.TaskNodeInfo{TaskNodeMetadata: taskNodeInfoMetadata}, + })) } return trns, newState, nil diff --git a/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go b/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go index 64c2eb336..27bc2d935 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go @@ -25,6 +25,7 @@ import ( lpMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/executors" @@ -512,17 +513,51 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { return nCtx } + validCachePopulatedStatus := catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, &core.CatalogMetadata{ + DatasetId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "project", + Domain: "domain", + Name: "name", + Version: "version", + }, + ArtifactTag: &core.CatalogArtifactTag{Name: "name", ArtifactId: "id"}, + }) + execInfoOutputOnly := &handler.ExecutionInfo{ + OutputInfo: &handler.OutputInfo{ + OutputURI: "output-dir/outputs.pb", + }, + } + execInfoWithTaskNodeMeta := &handler.ExecutionInfo{ + OutputInfo: &handler.OutputInfo{ + OutputURI: "output-dir/outputs.pb", + }, + TaskNodeInfo: &handler.TaskNodeInfo{ + TaskNodeMetadata: &event.TaskNodeMetadata{ + CacheStatus: validCachePopulatedStatus.GetCacheStatus(), + CatalogKey: &core.CatalogMetadata{ + DatasetId: validCachePopulatedStatus.GetMetadata().DatasetId, + ArtifactTag: validCachePopulatedStatus.GetMetadata().ArtifactTag, + SourceExecution: validCachePopulatedStatus.GetMetadata().SourceExecution, + }, + ReservationStatus: core.CatalogReservation_RESERVATION_DISABLED, + }, + }, + } + type args struct { - s executors.NodeStatus - isErr bool - dj *core.DynamicJobSpec - validErr *io.ExecutionError - generateOutputs bool + s executors.NodeStatus + isErr bool + dj *core.DynamicJobSpec + validErr *io.ExecutionError + validCacheStatus *catalog.Status + generateOutputs bool } type want struct { p handler.EPhase isErr bool phase v1alpha1.DynamicNodePhase + info *handler.ExecutionInfo } tests := []struct { name string @@ -531,10 +566,10 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { }{ {"error", args{isErr: true, dj: createDynamicJobSpec()}, want{isErr: true}}, {"success", args{s: executors.NodeStatusSuccess, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, - {"complete", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting}}, + {"complete", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: true, validCacheStatus: &validCachePopulatedStatus}, want{p: handler.EPhaseSuccess, phase: v1alpha1.DynamicNodePhaseExecuting, info: execInfoWithTaskNodeMeta}}, {"complete-no-outputs", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), generateOutputs: false}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"complete-valid-error-retryable", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing}}, - {"complete-valid-error", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing}}, + {"complete-valid-error-retryable", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{IsRecoverable: true}, generateOutputs: true}, want{p: handler.EPhaseRetryableFailure, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, + {"complete-valid-error", args{s: executors.NodeStatusComplete, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}, generateOutputs: true}, want{p: handler.EPhaseFailed, phase: v1alpha1.DynamicNodePhaseFailing, info: execInfoOutputOnly}}, {"failed", args{s: executors.NodeStatusFailed(&core.ExecutionError{}), dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseFailing}}, {"running", args{s: executors.NodeStatusRunning, dj: createDynamicJobSpec()}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, {"running-valid-err", args{s: executors.NodeStatusRunning, dj: createDynamicJobSpec(), validErr: &io.ExecutionError{}}, want{p: handler.EPhaseDynamicRunning, phase: v1alpha1.DynamicNodePhaseExecuting}}, @@ -556,7 +591,15 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { if tt.args.validErr != nil { h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_DISABLED, nil), tt.args.validErr, nil) } else { - h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, &core.CatalogMetadata{ArtifactTag: &core.CatalogArtifactTag{Name: "name", ArtifactId: "id"}}), nil, nil) + var validCacheStatus catalog.Status + if tt.args.validCacheStatus == nil { + validCacheStatus = catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, &core.CatalogMetadata{ + ArtifactTag: &core.CatalogArtifactTag{Name: "name", ArtifactId: "id"}, + }) + } else { + validCacheStatus = *tt.args.validCacheStatus + } + h.OnValidateOutputAndCacheAddMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(validCacheStatus, nil, nil) } n := &executorMocks.Node{} if tt.args.isErr { @@ -586,6 +629,7 @@ func Test_dynamicNodeHandler_Handle_SubTaskV1(t *testing.T) { if err == nil { assert.Equal(t, tt.want.p.String(), got.Info().GetPhase().String()) assert.Equal(t, tt.want.phase, s.s.Phase) + assert.Equal(t, tt.want.info, got.Info().GetInfo()) } }) }