diff --git a/gordo/reporters/mlflow.py b/gordo/reporters/mlflow.py index c31cc3ddc..d8e01014c 100644 --- a/gordo/reporters/mlflow.py +++ b/gordo/reporters/mlflow.py @@ -4,7 +4,7 @@ import logging import os import tempfile -from typing import List +from typing import List, Optional from uuid import uuid4 from azureml.core import Workspace @@ -284,37 +284,28 @@ def get_batch_kwargs(machine: Machine) -> dict: return {"metrics": metric_list, "params": param_list} -def get_kwargs_from_secret(name: str, keys: List[str]) -> dict: +def get_kwargs_from_secret(secret_str: str, keys: List[str]) -> dict: """ Get keyword arguments dictionary from secrets environment variable Parameters ---------- - name: str - Name of the environment variable whose content is a colon separated - list of secrets. + secret_str: str + String of a colon separated list of secrets. + keys: List[str] + List of keys associated with each element in secrets string. Returns ------- kwargs: dict Dictionary of keyword arguments parsed from environment variable. """ - secret_str = os.getenv(name) - - if secret_str is None: - raise MlflowLoggingError(f"The value for env var '{name}' must not be `None`.") - - if secret_str: - elements = secret_str.split(":") - if len(elements) != len(keys): - raise MlflowLoggingError( - "`keys` len {len(keys)} must be of equal length with env var {name} elements {len(elements)}." - ) - kwargs = {key: elements[i] for i, key in enumerate(keys)} - else: - kwargs = {} - - return kwargs + elements = secret_str.split(":") + if len(elements) != len(keys): + raise MlflowLoggingError( + f"`keys` len {len(keys)} must be of equal number of elements {len(elements)} parsed from secrets str." + ) + return {key: elements[i] for i, key in enumerate(keys)} def get_workspace_kwargs() -> dict: @@ -330,8 +321,13 @@ def get_workspace_kwargs() -> dict: AzureML Workspace configuration to use for remote MLFlow tracking. See :func:`gordo.builder.mlflow_utils.get_mlflow_client`. """ - return get_kwargs_from_secret( - "AZUREML_WORKSPACE_STR", ["subscription_id", "resource_group", "workspace_name"] + secret_str = os.getenv("AZUREML_WORKSPACE_STR") + return ( + get_kwargs_from_secret( + secret_str, ["subscription_id", "resource_group", "workspace_name"] + ) + if secret_str + else dict() ) @@ -348,9 +344,15 @@ def get_spauth_kwargs() -> dict: AzureML ServicePrincipalAuthentication keyword arguments. See :func:`gordo.builder.mlflow_utils.get_mlflow_client` """ - return get_kwargs_from_secret( - "DL_SERVICE_AUTH_STR", - ["tenant_id", "service_principal_id", "service_principal_password"], + + secret_str = os.getenv("DL_SERVICE_AUTH_STR") + return ( + get_kwargs_from_secret( + secret_str, + ["tenant_id", "service_principal_id", "service_principal_password"], + ) + if secret_str + else dict() ) @@ -433,6 +435,7 @@ def report(self, machine: Machine): workspace_kwargs = get_workspace_kwargs() service_principal_kwargs = get_spauth_kwargs() + cache_key = ModelBuilder.calculate_cache_key(machine) with mlflow_context( diff --git a/tests/gordo/reporters/test_mlflow_reporter.py b/tests/gordo/reporters/test_mlflow_reporter.py index d497ae192..f55c368b3 100644 --- a/tests/gordo/reporters/test_mlflow_reporter.py +++ b/tests/gordo/reporters/test_mlflow_reporter.py @@ -196,41 +196,42 @@ def _test_mlflow_batch_arg_types(metadata): _test_mlflow_batch_arg_types(metadata) -@pytest.mark.parametrize( - "secret_str,keys,keys_valid", - [ - ("dummy1:dummy2:dummy3", ["key1", "key2", "key3"], True), - ("dummy1:dummy2:dummy3", ["key1", "key2"], False), - ], -) -def test_get_kwargs_from_secret(monkeypatch, secret_str, keys, keys_valid): +def test_get_kwargs_from_secret_invalid(): """ - Test that service principal kwargs are generated correctly if env var present + Test that method fails with number of secret elements mismatch number of keys """ - env_var_name = "TEST_SECRET" - - # TEST_SECRET doesn't exist as env var with pytest.raises(ReporterException): - mlu.get_kwargs_from_secret(env_var_name, keys) + mlu.get_kwargs_from_secret("dummy1:dummy2:dummy3", ["key1", "key2"]) - # TEST_SECRET exists as env var - monkeypatch.setenv(name=env_var_name, value=secret_str) - if keys_valid: - kwargs = mlu.get_kwargs_from_secret(env_var_name, keys) - for key, value in zip(keys, secret_str.split(":")): - assert kwargs[key] == value - else: - with pytest.raises(ReporterException): - mlu.get_kwargs_from_secret(env_var_name, keys) + with pytest.raises(AttributeError): + mlu.get_kwargs_from_secret(None, ["key1", "key2"]) +def test_workspace_kwargs(monkeypatch): + """ + Test that appropriate kwargs dict is returned + """ + assert mlu.get_workspace_kwargs() == {} -def test_workspace_spauth_kwargs(): - """Make sure an error is thrown when env vars not set""" - with pytest.raises(ReporterException): - mlu.get_workspace_kwargs() + monkeypatch.setenv("AZUREML_WORKSPACE_STR", "test:test:test") + assert mlu.get_workspace_kwargs() == { + "subscription_id": "test", + "resource_group": "test", + "workspace_name": "test", + } - with pytest.raises(ReporterException): - mlu.get_spauth_kwargs() +def test_spauth_kwargs(monkeypatch): + """ + Test that appropriate kwargs dict is returned + """ + + assert mlu.get_spauth_kwargs() == {} + + monkeypatch.setenv("DL_SERVICE_AUTH_STR", "test:test:test") + assert mlu.get_spauth_kwargs() == { + "tenant_id": "test", + "service_principal_id": "test", + "service_principal_password": "test", + } def test_MachineEncoder(metadata):