Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk/backend): enable parameterization of container images #11404

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package driver

import (
Expand Down Expand Up @@ -448,19 +449,28 @@ func initPodSpecPatch(
accelerator := container.GetResources().GetAccelerator()
if accelerator != nil {
if accelerator.GetType() != "" && accelerator.GetCount() > 0 {
acceleratorType, err := resolvePodSpecInputRuntimeParameter(accelerator.GetType(), executorInput)
if err != nil {
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err)
}
q, err := k8sres.ParseQuantity(fmt.Sprintf("%v", accelerator.GetCount()))
if err != nil {
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err)
}
res.Limits[k8score.ResourceName(accelerator.GetType())] = q
res.Limits[k8score.ResourceName(acceleratorType)] = q
}
}

containerImage, err := resolvePodSpecInputRuntimeParameter(container.Image, executorInput)
if err != nil {
return nil, fmt.Errorf("failed to init podSpecPatch: %w", err)
}
podSpec := &k8score.PodSpec{
Containers: []k8score.Container{{
Name: "main", // argo task user container is always called "main"
Command: launcherCmd,
Args: userCmdArgs,
Image: container.Image,
Image: containerImage,
Resources: res,
Env: userEnvVar,
}},
Expand Down
78 changes: 78 additions & 0 deletions backend/src/v2/driver/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright 2021-2024 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package driver

import (
"fmt"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"regexp"
)

// inputPipelineChannelPattern define a regex pattern to match the content within single quotes
// example input channel looks like "{{$.inputs.parameters['pipelinechannel--val']}}"
const inputPipelineChannelPattern = `\$.inputs.parameters\['(.+?)'\]`

func isInputParameterChannel(inputChannel string) bool {
re := regexp.MustCompile(inputPipelineChannelPattern)
match := re.FindStringSubmatch(inputChannel)
if len(match) == 2 {
return true
} else {
// if len(match) > 2, then this is still incorrect because
// inputChannel should contain only one parameter channel input
return false
}
}

// extractInputParameterFromChannel takes an inputChannel that adheres to
// inputPipelineChannelPattern and extracts the channel parameter name.
// For example given an input channel of the form "{{$.inputs.parameters['pipelinechannel--val']}}"
// the channel parameter name "pipelinechannel--val" is returned.
func extractInputParameterFromChannel(inputChannel string) (string, error) {
re := regexp.MustCompile(inputPipelineChannelPattern)
match := re.FindStringSubmatch(inputChannel)
if len(match) > 1 {
extractedValue := match[1]
return extractedValue, nil
} else {
return "", fmt.Errorf("failed to extract input parameter from channel: %s", inputChannel)
}
}

// resolvePodSpecInputRuntimeParameter resolves runtime value that is intended to be
// utilized within the Pod Spec. parameterValue takes the form of:
// "{{$.inputs.parameters['pipelinechannel--someParameterName']}}"
//
// parameterValue is a runtime parameter value that has been resolved and included within
// the executor input. Since the pod spec patch cannot dynamically update the underlying
// container template's inputs in an Argo Workflow, this is a workaround for resolving
// such parameters.
//
// If parameter value is not a parameter channel, then a constant value is assumed and
// returned as is.
func resolvePodSpecInputRuntimeParameter(parameterValue string, executorInput *pipelinespec.ExecutorInput) (string, error) {
if isInputParameterChannel(parameterValue) {
inputImage, err := extractInputParameterFromChannel(parameterValue)
if err != nil {
return "", err
}
if val, ok := executorInput.Inputs.ParameterValues[inputImage]; ok {
return val.GetStringValue(), nil
} else {
return "", fmt.Errorf("executorInput did not contain container Image input parameter")
HumairAK marked this conversation as resolved.
Show resolved Hide resolved
}
}
return parameterValue, nil
}
159 changes: 159 additions & 0 deletions backend/src/v2/driver/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright 2021-2024 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package driver

import (
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/stretchr/testify/assert"
structpb "google.golang.org/protobuf/types/known/structpb"
"testing"
)

func Test_isInputParameterChannel(t *testing.T) {
tests := []struct {
name string
input string
isValid bool
}{
{
name: "wellformed pipeline channel should produce no errors",
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
isValid: true,
},
{
name: "pipeline channel index should have quotes",
input: "{{$.inputs.parameters[pipelinechannel--someParameterName]}}",
isValid: false,
},
{
name: "plain text as pipelinechannel of parameter type is invalid",
input: "randomtext",
isValid: false,
},
{
name: "inputs should be prefixed with $.",
input: "{{inputs.parameters['pipelinechannel--someParameterName']}}",
isValid: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, isInputParameterChannel(test.input), test.isValid)
})
}
}

func Test_extractInputParameterFromChannel(t *testing.T) {
tests := []struct {
name string
input string
expected string
wantErr bool
}{
{
name: "standard parameter pipeline channel input",
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
expected: "pipelinechannel--someParameterName",
wantErr: false,
},
{
name: "a more complex parameter pipeline channel input",
input: "{{$.inputs.parameters['pipelinechannel--somePara-me_terName']}}",
expected: "pipelinechannel--somePara-me_terName",
wantErr: false,
},
{
name: "invalid input should return err",
input: "invalidvalue",
wantErr: true,
},
{
name: "invalid input should return err 2",
input: "pipelinechannel--somePara-me_terName",
wantErr: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := extractInputParameterFromChannel(test.input)
if test.wantErr {
assert.NotNil(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, actual, test.expected)
}
})
}
}

func Test_resolvePodSpecRuntimeParameter(t *testing.T) {
tests := []struct {
name string
input string
expected string
executorInput *pipelinespec.ExecutorInput
wantErr bool
}{
{
name: "should retrieve correct parameter value",
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
expected: "test2",
executorInput: &pipelinespec.ExecutorInput{
Inputs: &pipelinespec.ExecutorInput_Inputs{
ParameterValues: map[string]*structpb.Value{
"pipelinechannel--": structpb.NewStringValue("test1"),
"pipelinechannel--someParameterName": structpb.NewStringValue("test2"),
"someParameterName": structpb.NewStringValue("test3"),
},
},
},
wantErr: false,
},
{
name: "return err when no match is found",
input: "{{$.inputs.parameters['pipelinechannel--someParameterName']}}",
expected: "test1",
executorInput: &pipelinespec.ExecutorInput{
Inputs: &pipelinespec.ExecutorInput_Inputs{
ParameterValues: map[string]*structpb.Value{
"doesNotMatch": structpb.NewStringValue("test2"),
},
},
},
wantErr: true,
},
{
name: "return const val when input is not a pipeline channel",
input: "not-pipeline-channel",
expected: "not-pipeline-channel",
executorInput: &pipelinespec.ExecutorInput{},
wantErr: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual, err := resolvePodSpecInputRuntimeParameter(test.input, test.executorInput)
if test.wantErr {
assert.NotNil(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, actual, test.expected)
}
})
}
}
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features
* Expose `--existing-token` flag in `kfp` CLI to allow users to provide an existing token for authentication. [\#11400](https://github.com/kubeflow/pipelines/pull/11400)
* Add the ability to parameterize container images for tasks within pipelines

## Breaking changes

Expand Down
64 changes: 64 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,70 @@ def my_pipeline() -> NamedTuple('Outputs', [
]):
task = print_and_return(text='Hello')

def test_pipeline_with_parameterized_container_image(self):
with tempfile.TemporaryDirectory() as tmpdir:
HumairAK marked this conversation as resolved.
Show resolved Hide resolved

@dsl.component(base_image='docker.io/python:3.9.17')
def empty_component():
pass

@dsl.pipeline()
def simple_pipeline(img: str):
task = empty_component()
# overwrite base_image="docker.io/python:3.9.17"
task.set_container_image(img)

output_yaml = os.path.join(tmpdir, 'result.yaml')
compiler.Compiler().compile(
pipeline_func=simple_pipeline,
package_path=output_yaml,
pipeline_parameters={'img': 'someimage'})
self.assertTrue(os.path.exists(output_yaml))

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)
container = pipeline_spec['deploymentSpec']['executors'][
'exec-empty-component']['container']
self.assertEqual(
container['image'],
"{{$.inputs.parameters['pipelinechannel--img']}}")
# A parameter value should result in 2 input parameters
# One for storing pipeline channel template to be resolved during runtime.
# Two for holding the key to the resolved input.
input_parameters = pipeline_spec['root']['dag']['tasks'][
'empty-component']['inputs']['parameters']
self.assertTrue('base_image' in input_parameters)
self.assertTrue('pipelinechannel--img' in input_parameters)

def test_pipeline_with_constant_container_image(self):
with tempfile.TemporaryDirectory() as tmpdir:

@dsl.component(base_image='docker.io/python:3.9.17')
def empty_component():
pass

@dsl.pipeline()
def simple_pipeline():
task = empty_component()
# overwrite base_image="docker.io/python:3.9.17"
task.set_container_image('constant-value')

output_yaml = os.path.join(tmpdir, 'result.yaml')
compiler.Compiler().compile(
pipeline_func=simple_pipeline, package_path=output_yaml)

self.assertTrue(os.path.exists(output_yaml))

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)
container = pipeline_spec['deploymentSpec']['executors'][
'exec-empty-component']['container']
self.assertEqual(container['image'], 'constant-value')
# A constant value should yield no parameters
dag_task = pipeline_spec['root']['dag']['tasks'][
'empty-component']
self.assertTrue('inputs' not in dag_task)


class TestCompilePipelineCaching(unittest.TestCase):

Expand Down
7 changes: 6 additions & 1 deletion sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ def build_task_spec_for_task(
if val and pipeline_channel.extract_pipeline_channels_from_any(val):
task.inputs[key] = val

if task.container_spec and task.container_spec.image:
val = task.container_spec.image
if val and pipeline_channel.extract_pipeline_channels_from_any(val):
task.inputs['base_image'] = val

for input_name, input_value in task.inputs.items():
# Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower
# types than PipelineParameterChannel, start with them.
Expand Down Expand Up @@ -634,7 +639,7 @@ def convert_to_placeholder(input_value: str) -> str:

container_spec = (
pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec(
image=task.container_spec.image,
image=convert_to_placeholder(task.container_spec.image),
command=task.container_spec.command,
args=task.container_spec.args,
env=[
Expand Down
Loading
Loading