From 7239523a23e3b87950b06be70e00626a6a1a81ec Mon Sep 17 00:00:00 2001 From: ddalvi Date: Thu, 14 Nov 2024 17:33:38 -0500 Subject: [PATCH] Add test to verify setting of SemaphoreKey and MutexName fields in KFP DSL Signed-off-by: ddalvi --- sdk/python/kfp/compiler/compiler_test.py | 36 ++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 2433f09bc6d1..a0d50a3cc2c9 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -3847,6 +3847,42 @@ def outer(): foo_platform_set_bar_feature(task, 12) +class TestPipelineSemaphoreMutex(unittest.TestCase): + + def test_pipeline_with_semaphore_and_mutex(self): + from kfp import compiler + from kfp import dsl + from kfp.dsl.pipeline_config import PipelineConfig + + config = PipelineConfig() + config.set_semaphore_key("semaphore") + config.set_mutex_name("mutex") + + @dsl.pipeline(pipeline_config=config) + def my_pipeline(): + task = comp() + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_docs = list(yaml.safe_load_all(f)) + + pipeline_spec = None + for doc in pipeline_docs: + if 'platforms' in doc: + pipeline_spec = doc + break + + if pipeline_spec: + kubernetes_spec = pipeline_spec['platforms']['kubernetes'][ + 'pipelineConfig'] + assert kubernetes_spec['semaphoreKey'] == "semaphore" + assert kubernetes_spec['mutexName'] == "mutex" + + class ExtractInputOutputDescription(unittest.TestCase): def test_no_descriptions(self):