-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enable parameterization of container images
This change allows component base images to be parameterized using runtime pipeline parameters. The container images can be specified within an @pipeline decorated function, and takes precedence over the @component(base_image=..) argument. This change also adds logic to resolve these runtime parameters in the argo driver logic. It also includes resolution steps for resolving the accelerator type which functions the same way but was missing the resolution logic. The resolution logic is a generic workaround solution for any run time pod spec input parameters that cannot be resolved because they cannot be added dynamically in the argo pod spec container template. Signed-off-by: Humair Khan <[email protected]>
- Loading branch information
Showing
6 changed files
with
342 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Copyright 2021-2024 The Kubeflow Authors | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package driver | ||
|
||
import ( | ||
"fmt" | ||
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" | ||
"regexp" | ||
) | ||
|
||
// InputPipelineChannelPattern define a regex pattern to match the content within single quotes | ||
// example input channel looks like "{{$.inputs.parameters['pipelinechannel--val']}}" | ||
const InputPipelineChannelPattern = `\$.inputs.parameters\['(.+?)'\]` | ||
|
||
func isInputParameterChannel(inputChannel string) bool { | ||
re := regexp.MustCompile(InputPipelineChannelPattern) | ||
match := re.FindStringSubmatch(inputChannel) | ||
if len(match) == 2 { | ||
return true | ||
} else { | ||
// if len(match) > 2, then this is still incorrect because | ||
// inputChannel should contain only one parameter channel input | ||
return false | ||
} | ||
} | ||
|
||
// extractInputParameterFromChannel takes an inputChannel that adheres to | ||
// InputPipelineChannelPattern and extracts the channel parameter name. | ||
// For example given an input channel of the form "{{$.inputs.parameters['pipelinechannel--val']}}" | ||
// the channel parameter name "pipelinechannel--val" is returned. | ||
func extractInputParameterFromChannel(inputChannel string) (string, error) { | ||
re := regexp.MustCompile(InputPipelineChannelPattern) | ||
match := re.FindStringSubmatch(inputChannel) | ||
if len(match) > 1 { | ||
extractedValue := match[1] | ||
return extractedValue, nil | ||
} else { | ||
return "", fmt.Errorf("failed to extract input parameter from channel: %s", inputChannel) | ||
} | ||
} | ||
|
||
// resolvePodSpecInputRuntimeParameter resolves runtime value that is intended to be | ||
// utilized within the Pod Spec. parameterValue takes the form of: | ||
// "{{$.inputs.parameters['pipelinechannel--someParameterName']}}" | ||
// | ||
// parameterValue is a runtime parameter value that has been resolved and included within | ||
// the executor input. Since the pod spec patch cannot dynamically update the underlying | ||
// container template's inputs in an Argo Workflow, this is a workaround for resolving | ||
// such parameters. | ||
// | ||
// If parameter value is not a parameter channel, then a constant value is assumed and | ||
// returned as is. | ||
func resolvePodSpecInputRuntimeParameter(parameterValue string, executorInput *pipelinespec.ExecutorInput) (string, error) { | ||
if isInputParameterChannel(parameterValue) { | ||
inputImage, err1 := extractInputParameterFromChannel(parameterValue) | ||
if err1 != nil { | ||
return "", err1 | ||
} | ||
if val, ok := executorInput.Inputs.ParameterValues[inputImage]; ok { | ||
return val.GetStringValue(), nil | ||
} else { | ||
return "", fmt.Errorf("executorInput did not contain container Image input parameter") | ||
} | ||
} | ||
return parameterValue, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
// Copyright 2021-2024 The Kubeflow Authors | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// https://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package driver | ||
|
||
import ( | ||
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" | ||
"github.com/stretchr/testify/assert" | ||
structpb "google.golang.org/protobuf/types/known/structpb" | ||
"testing" | ||
) | ||
|
||
func Test_isInputParameterChannel(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input string | ||
isValid bool | ||
}{ | ||
{ | ||
name: "wellformed pipeline channel should produce no errors", | ||
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", | ||
isValid: true, | ||
}, | ||
{ | ||
name: "pipeline channel index should have quotes", | ||
input: "{{$.inputs.parameters[pipelinechannel--someParameterName]}}", | ||
isValid: false, | ||
}, | ||
{ | ||
name: "plain text as pipelinechannel of parameter type is invalid", | ||
input: "randomtext", | ||
isValid: false, | ||
}, | ||
{ | ||
name: "inputs should be prefixed with $.", | ||
input: "{{inputs.parameters['pipelinechannel--someParameterName']}}", | ||
isValid: false, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
t.Run(test.name, func(t *testing.T) { | ||
assert.Equal(t, isInputParameterChannel(test.input), test.isValid) | ||
}) | ||
} | ||
} | ||
|
||
func Test_extractInputParameterFromChannel(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input string | ||
expected string | ||
wantErr bool | ||
}{ | ||
{ | ||
name: "standard parameter pipeline channel input", | ||
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", | ||
expected: "pipelinechannel--someParameterName", | ||
wantErr: false, | ||
}, | ||
{ | ||
name: "a more complex parameter pipeline channel input", | ||
input: "{{$.inputs.parameters['pipelinechannel--somePara-me_terName']}}", | ||
expected: "pipelinechannel--somePara-me_terName", | ||
wantErr: false, | ||
}, | ||
{ | ||
name: "invalid input should return err", | ||
input: "invalidvalue", | ||
wantErr: true, | ||
}, | ||
{ | ||
name: "invalid input should return err 2", | ||
input: "pipelinechannel--somePara-me_terName", | ||
wantErr: true, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
t.Run(test.name, func(t *testing.T) { | ||
actual, err := extractInputParameterFromChannel(test.input) | ||
if test.wantErr { | ||
assert.NotNil(t, err) | ||
} else { | ||
assert.NoError(t, err) | ||
assert.Equal(t, actual, test.expected) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func Test_resolvePodSpecRuntimeParameter(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input string | ||
expected string | ||
executorInput *pipelinespec.ExecutorInput | ||
wantErr bool | ||
}{ | ||
{ | ||
name: "should retrieve correct parameter value", | ||
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", | ||
expected: "test2", | ||
executorInput: &pipelinespec.ExecutorInput{ | ||
Inputs: &pipelinespec.ExecutorInput_Inputs{ | ||
ParameterValues: map[string]*structpb.Value{ | ||
"pipelinechannel--": structpb.NewStringValue("test1"), | ||
"pipelinechannel--someParameterName": structpb.NewStringValue("test2"), | ||
"someParameterName": structpb.NewStringValue("test3"), | ||
}, | ||
}, | ||
}, | ||
wantErr: false, | ||
}, | ||
{ | ||
name: "return err when no match is found", | ||
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}", | ||
expected: "test1", | ||
executorInput: &pipelinespec.ExecutorInput{ | ||
Inputs: &pipelinespec.ExecutorInput_Inputs{ | ||
ParameterValues: map[string]*structpb.Value{ | ||
"doesNotMatch": structpb.NewStringValue("test2"), | ||
}, | ||
}, | ||
}, | ||
wantErr: true, | ||
}, | ||
{ | ||
name: "return const val when input is not a pipeline channel", | ||
input: "not-pipeline-channel", | ||
expected: "not-pipeline-channel", | ||
executorInput: &pipelinespec.ExecutorInput{}, | ||
wantErr: false, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
t.Run(test.name, func(t *testing.T) { | ||
actual, err := resolvePodSpecInputRuntimeParameter(test.input, test.executorInput) | ||
if test.wantErr { | ||
assert.NotNil(t, err) | ||
} else { | ||
assert.NoError(t, err) | ||
assert.Equal(t, actual, test.expected) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters