Skip to content

Commit

Permalink
Use function name instead of base name and address edge cases
Browse files Browse the repository at this point in the history
Signed-off-by: droctothorpe <[email protected]>
Co-authored-by: zazulam <[email protected]>
  • Loading branch information
droctothorpe and zazulam committed Aug 2, 2024
1 parent 1162bf9 commit ee1b13a
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 284 deletions.
57 changes: 39 additions & 18 deletions backend/src/v2/compiler/argocompiler/argo.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ package argocompiler

import (
"fmt"
"strconv"
"strings"

wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/v2/compiler"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
k8score "k8s.io/api/core/v1"
Expand Down Expand Up @@ -191,25 +191,29 @@ const (
)

func (c *workflowCompiler) saveComponentSpec(name string, spec *pipelinespec.ComponentSpec) error {
baseComponentName := ExtractBaseComponentName(argumentsComponents + name)
return c.saveProtoToArguments(baseComponentName, spec)
functionName := c.extractFunctionName(name)

return c.saveProtoToArguments(argumentsComponents+functionName, spec)
}

// useComponentSpec returns a placeholder we can refer to the component spec
// in argo workflow fields.
func (c *workflowCompiler) useComponentSpec(name string) (string, error) {
baseComponentName := ExtractBaseComponentName(argumentsComponents + name)
return c.argumentsPlaceholder(baseComponentName)
functionName := c.extractFunctionName(name)

return c.argumentsPlaceholder(argumentsComponents + functionName)
}

func (c *workflowCompiler) saveComponentImpl(name string, msg proto.Message) error {
baseComponentName := ExtractBaseComponentName(argumentsContainers + name)
return c.saveProtoToArguments(baseComponentName, msg)
functionName := c.extractFunctionName(name)

return c.saveProtoToArguments(argumentsContainers+functionName, msg)
}

func (c *workflowCompiler) useComponentImpl(name string) (string, error) {
baseComponentName := ExtractBaseComponentName(argumentsContainers + name)
return c.argumentsPlaceholder(baseComponentName)
functionName := c.extractFunctionName(name)

return c.argumentsPlaceholder(argumentsContainers + functionName)
}

func (c *workflowCompiler) saveKubernetesSpec(name string, spec *structpb.Struct) error {
Expand Down Expand Up @@ -262,17 +266,34 @@ func (c *workflowCompiler) argumentsPlaceholder(componentName string) (string, e
return workflowParameter(componentName), nil
}

// ExtractBaseComponentName removes the iteration suffix that the IR compiler
// adds to the component name.
func ExtractBaseComponentName(componentName string) string {
baseComponentName := componentName
componentNameArray := strings.Split(componentName, "-")

if _, err := strconv.Atoi(componentNameArray[len(componentNameArray)-1]); err == nil {
baseComponentName = strings.Join(componentNameArray[:len(componentNameArray)-1], "-")
// extractFunctionName extracts the function name of a component by looking it
// up in the pipeline spec.
func (c *workflowCompiler) extractFunctionName(componentName string) string {
log.Debug("componentName: ", componentName)
// The root component is a DAG and therefore doesn't have a corresponding
// executor or function name. The final return statement in this function
// would cover this edge case, but this saves us some unecessary iteration.
if componentName == "root" {
return componentName
}
executorLabel := c.spec.Components[componentName].GetExecutorLabel()
log.Debug("executorLabel: ", executorLabel)
for executorName, executorValue := range c.executors {
log.Debug("executorName: ", executorName)
if executorName == executorLabel {
args := executorValue.GetContainer().Args
componentFunctionName := args[len(args)-1]
log.Debug("componentFunctionName: ", componentFunctionName)

return componentFunctionName
}
}

return baseComponentName
log.Debug("No corresponding executor for component: ", componentName)
// We could theoretically return an error here, but since the only
// consequence of not finding a matching executor is reduced deduplication,
// this doesn't result in application failure and we therefore continue.
return componentName
}

const (
Expand Down
42 changes: 0 additions & 42 deletions backend/src/v2/compiler/argocompiler/argo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,45 +137,3 @@ func load(t *testing.T, path string, platformSpecPath string) (*pipelinespec.Pip
}
return job, nil
}

func Test_extractBaseComponentName(t *testing.T) {
tests := []struct {
name string
componentName string
expectedBaseName string
}{
{
name: "With dash and int",
componentName: "component-2",
expectedBaseName: "component",
},
{
name: "Without dash and int",
componentName: "component",
expectedBaseName: "component",
},
{
name: "Last char is int",
componentName: "component-v2",
expectedBaseName: "component-v2",
},
{
name: "Multiple dashes, ends with int",
componentName: "service-api-v2",
expectedBaseName: "service-api-v2",
},
{
name: "Multiple dashes and ints",
componentName: "module-1-2-3",
expectedBaseName: "module-1-2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := argocompiler.ExtractBaseComponentName(tt.componentName)
if result != tt.expectedBaseName {
t.Errorf("Expected: %s, Got: %s", tt.expectedBaseName, result)
}
})
}
}
28 changes: 14 additions & 14 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def comp():


@dsl.component
def return_one() -> int:
def return_1() -> int:
return 1


Expand Down Expand Up @@ -3369,43 +3369,43 @@ def test_cpu_memory_optional(self):

@dsl.pipeline
def simple_pipeline():
return_one()
return_one().set_cpu_limit('5')
return_one().set_memory_limit('50G')
return_one().set_cpu_request('2').set_cpu_limit(
return_1()
return_1().set_cpu_limit('5')
return_1().set_memory_limit('50G')
return_1().set_cpu_request('2').set_cpu_limit(
'5').set_memory_request('4G').set_memory_limit('50G')

dict_format = json_format.MessageToDict(simple_pipeline.pipeline_spec)

self.assertNotIn(
'resources', dict_format['deploymentSpec']['executors']
['exec-return-one']['container'])
['exec-return-1']['container'])

self.assertEqual(
5, dict_format['deploymentSpec']['executors']['exec-return-one-2']
5, dict_format['deploymentSpec']['executors']['exec-return-1-2']
['container']['resources']['cpuLimit'])
self.assertNotIn(
'memoryLimit', dict_format['deploymentSpec']['executors']
['exec-return-one-2']['container']['resources'])
['exec-return-1-2']['container']['resources'])

self.assertEqual(
50, dict_format['deploymentSpec']['executors']['exec-return-one-3']
50, dict_format['deploymentSpec']['executors']['exec-return-1-3']
['container']['resources']['memoryLimit'])
self.assertNotIn(
'cpuLimit', dict_format['deploymentSpec']['executors']
['exec-return-one-3']['container']['resources'])
['exec-return-1-3']['container']['resources'])

self.assertEqual(
2, dict_format['deploymentSpec']['executors']['exec-return-one-4']
2, dict_format['deploymentSpec']['executors']['exec-return-1-4']
['container']['resources']['cpuRequest'])
self.assertEqual(
5, dict_format['deploymentSpec']['executors']['exec-return-one-4']
5, dict_format['deploymentSpec']['executors']['exec-return-1-4']
['container']['resources']['cpuLimit'])
self.assertEqual(
4, dict_format['deploymentSpec']['executors']['exec-return-one-4']
4, dict_format['deploymentSpec']['executors']['exec-return-1-4']
['container']['resources']['memoryRequest'])
self.assertEqual(
50, dict_format['deploymentSpec']['executors']['exec-return-one-4']
50, dict_format['deploymentSpec']['executors']['exec-return-1-4']
['container']['resources']['memoryLimit'])


Expand Down
7 changes: 0 additions & 7 deletions sdk/python/kfp/dsl/component_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,6 @@ class ComponentInfo():

def _python_function_name_to_component_name(name):
name_with_spaces = re.sub(' +', ' ', name.replace('_', ' ')).strip(' ')
name_list = name_with_spaces.split(' ')

if name_list[-1].isdigit():
raise ValueError(
f'Invalid function name "{name}". The function name must not end in `_<int>`.'
)

return name_with_spaces[0].upper() + name_with_spaces[1:]


Expand Down
24 changes: 0 additions & 24 deletions sdk/python/kfp/dsl/component_factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,30 +174,6 @@ def comp(Output: OutputPath(str), text: str) -> str:
pass


class TestPythonFunctionName(unittest.TestCase):

def test_invalid_function_name(self):

with self.assertRaisesRegex(
ValueError,
r'Invalid function name "comp_2". The function name must not end in `_<int>`.'
):

@component
def comp_2(text: str) -> str:
pass

def test_valid_function_name(self):

@component
def comp_v2(text: str) -> str:
pass

@component
def comp_(text: str) -> str:
pass


class TestExtractComponentInterfaceListofArtifacts(unittest.TestCase):

def test_python_component_input(self):
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/test_data/pipelines/if_elif_else_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@dsl.component
def int_0_to_9999_func() -> int:
def int_0_to_9999() -> int:
import random
return random.randint(0, 9999)

Expand Down Expand Up @@ -49,7 +49,7 @@ def lucky_number_pipeline(add_drumroll: bool = True,
repeat_if_lucky_number: bool = True,
trials: List[int] = [1, 2, 3]):
with dsl.ParallelFor(trials) as trial:
int_task = int_0_to_9999_func().set_caching_options(False)
int_task = int_0_to_9999().set_caching_options(False)
with dsl.If(add_drumroll == True):
with dsl.If(trial == 3):
print_and_return(text='Adding drumroll on last trial!')
Expand Down
Loading

0 comments on commit ee1b13a

Please sign in to comment.