Skip to content

Commit

Permalink
save registration process results into result_dir
Browse files Browse the repository at this point in the history
Signed-off-by: Artem Petrov <[email protected]>
  • Loading branch information
wckdman committed Oct 19, 2023
1 parent 85b8d29 commit 3029470
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 0 deletions.
19 changes: 19 additions & 0 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@
is_flag=True,
help="Activate newly registered Launchplans. This operation deactivates previous versions of Launchplans.",
)
@click.option(
"--result_dir",
required=False,
type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True),
default=None,
help="Directory to write the registration process results",
)
@click.option(
"-f",
"--format",
required=False,
type=click.Choice(["json", "yaml"], case_sensitive=False),
default="json",
help="Results file format",
)
@click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1)
@click.pass_context
def register(
Expand All @@ -124,6 +139,8 @@ def register(
package_or_module: typing.Tuple[str],
dry_run: bool,
activate_launchplans: bool,
result_dir: str,
format: str,
):
"""
see help
Expand Down Expand Up @@ -175,6 +192,8 @@ def register(
remote=remote,
dry_run=dry_run,
activate_launchplans=activate_launchplans,
result_dir=result_dir,
format=format,
)
except Exception as e:
raise e
1 change: 1 addition & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
FUTURES_FILE_NAME = "futures.pb"
ERROR_FILE_NAME = "error.pb"
REQUIREMENTS_FILE_NAME = "requirements.txt"
REGISTRATION_RESULT_FILENAME = "output.{}"


class SdkTaskType(object):
Expand Down
24 changes: 24 additions & 0 deletions flytekit/tools/repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json

Check warning on line 1 in flytekit/tools/repo.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/repo.py#L1

Added line #L1 was not covered by tests
import os
import tarfile
import tempfile
import typing
from pathlib import Path

import click
import yaml

Check warning on line 9 in flytekit/tools/repo.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/repo.py#L9

Added line #L9 was not covered by tests

from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.core.constants import REGISTRATION_RESULT_FILENAME

Check warning on line 12 in flytekit/tools/repo.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/repo.py#L12

Added line #L12 was not covered by tests
from flytekit.core.context_manager import FlyteContextManager
from flytekit.loggers import logger
from flytekit.models import launch_plan
Expand Down Expand Up @@ -217,6 +220,8 @@ def register(
fast: bool,
package_or_module: typing.Tuple[str],
remote: FlyteRemote,
result_dir: str,
format: str,
dry_run: bool = False,
activate_launchplans: bool = False,
):
Expand Down Expand Up @@ -262,6 +267,7 @@ def register(
click.secho("No Flyte entities were detected. Aborting!", fg="red")
return

registration_results = []

Check warning on line 270 in flytekit/tools/repo.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/repo.py#L270

Added line #L270 was not covered by tests
for cp_entity in registrable_entities:
is_lp = False
if isinstance(cp_entity, launch_plan.LaunchPlan):
Expand All @@ -282,6 +288,24 @@ def register(
secho(i, reason="activated", op="Activation")
else:
secho(og_id, reason="Dry run Mode!")
status = "SUCCESS"

Check warning on line 291 in flytekit/tools/repo.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/repo.py#L291

Added line #L291 was not covered by tests
except RegistrationSkipped:
secho(og_id, "failed")
status = "FAILED"

Check warning on line 294 in flytekit/tools/repo.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/repo.py#L294

Added line #L294 was not covered by tests

registration_results.append(

Check warning on line 296 in flytekit/tools/repo.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/repo.py#L296

Added line #L296 was not covered by tests
{
"id": og_id.name,
"type": og_id.resource_type_name(),
"version": og_id.version,
"status": status,
}
)
if result_dir:
with open(os.path.join(result_dir, REGISTRATION_RESULT_FILENAME.format(format)), "w") as f:
if format == "yaml":
txt = yaml.dump(registration_results)

Check warning on line 307 in flytekit/tools/repo.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/repo.py#L307

Added line #L307 was not covered by tests
else:
txt = json.dumps(registration_results)
f.write(txt)

Check warning on line 310 in flytekit/tools/repo.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/repo.py#L309-L310

Added lines #L309 - L310 were not covered by tests
click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green")
35 changes: 35 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess

import mock
import pytest
from click.testing import CliRunner

from flytekit.clients.friendly import SynchronousFlyteClient
Expand Down Expand Up @@ -38,6 +39,16 @@ def my_workflow(x: int, y: int) -> int:
"""


@pytest.fixture
def expected_result_file_content():
expected_result_json = '[{"id": "core_json.sample.sum", "type": "TASK", "version": "dummy_version_from_hash", "status": "SUCCESS"}, {"id": "core_json.sample.square", "type": "TASK", "version": "dummy_version_from_hash", "status": "SUCCESS"}, {"id": "core_json.sample.my_workflow", "type": "WORKFLOW", "version": "dummy_version_from_hash", "status": "SUCCESS"}, {"id": "core_json.sample.my_workflow", "type": "LAUNCH_PLAN", "version": "dummy_version_from_hash", "status": "SUCCESS"}]'
expected_result_yaml = "- id: core_yaml.sample.sum\n status: SUCCESS\n type: TASK\n version: dummy_version_from_hash\n- id: core_yaml.sample.square\n status: SUCCESS\n type: TASK\n version: dummy_version_from_hash\n- id: core_yaml.sample.my_workflow\n status: SUCCESS\n type: WORKFLOW\n version: dummy_version_from_hash\n- id: core_yaml.sample.my_workflow\n status: SUCCESS\n type: LAUNCH_PLAN\n version: dummy_version_from_hash\n"
return {
"json": expected_result_json,
"yaml": expected_result_yaml,
}


@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote")
def test_get_remote(mock_remote):
r = get_remote(None, "p", "d")
Expand Down Expand Up @@ -142,3 +153,27 @@ def test_non_fast_register_require_version(mock_client, mock_remote):
assert result.exit_code == 1
assert str(result.exception) == "Version is a required parameter in case --non-fast is specified."
shutil.rmtree("core3")


@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
@pytest.mark.parametrize("format", ["json", "yaml"])
def test_register_result_output(mock_client, mock_remote, format, expected_result_file_content):
mock_remote._client = mock_client
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
context_manager.FlyteEntities.entities.clear()
dirname = f"core_{format}"
with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs(dirname, exist_ok=True)
with open(os.path.join(dirname, "sample.py"), "w") as f:
f.write(sample_file_contents)
result = runner.invoke(pyflyte.main, ["register", f"--result_dir={dirname}", f"--format={format}", dirname])
assert "Successfully registered 4 entities" in result.output
with open(os.path.join(dirname, f"output.{format}"), "r") as f:
omg = f.read()
assert expected_result_file_content[format] == omg
shutil.rmtree(dirname)

0 comments on commit 3029470

Please sign in to comment.