Skip to content

Commit

Permalink
Change secrets methods to return empty dict if env var missing
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanjdillon committed Feb 12, 2020
1 parent 1f1bbc5 commit e7807a0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 54 deletions.
55 changes: 29 additions & 26 deletions gordo/reporters/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
import tempfile
from typing import List
from typing import List, Optional
from uuid import uuid4

from azureml.core import Workspace
Expand Down Expand Up @@ -284,37 +284,28 @@ def get_batch_kwargs(machine: Machine) -> dict:
return {"metrics": metric_list, "params": param_list}


def get_kwargs_from_secret(name: str, keys: List[str]) -> dict:
def get_kwargs_from_secret(secret_str: str, keys: List[str]) -> dict:
"""
Get keyword arguments dictionary from secrets environment variable
Parameters
----------
name: str
Name of the environment variable whose content is a colon separated
list of secrets.
secret_str: str
String of a colon separated list of secrets.
keys: List[str]
List of keys associated with each element in secrets string.
Returns
-------
kwargs: dict
Dictionary of keyword arguments parsed from environment variable.
"""
secret_str = os.getenv(name)

if secret_str is None:
raise MlflowLoggingError(f"The value for env var '{name}' must not be `None`.")

if secret_str:
elements = secret_str.split(":")
if len(elements) != len(keys):
raise MlflowLoggingError(
"`keys` len {len(keys)} must be of equal length with env var {name} elements {len(elements)}."
)
kwargs = {key: elements[i] for i, key in enumerate(keys)}
else:
kwargs = {}

return kwargs
elements = secret_str.split(":")
if len(elements) != len(keys):
raise MlflowLoggingError(
f"`keys` len {len(keys)} must be of equal number of elements {len(elements)} parsed from secrets str."
)
return {key: elements[i] for i, key in enumerate(keys)}


def get_workspace_kwargs() -> dict:
Expand All @@ -330,8 +321,13 @@ def get_workspace_kwargs() -> dict:
AzureML Workspace configuration to use for remote MLFlow tracking. See
:func:`gordo.builder.mlflow_utils.get_mlflow_client`.
"""
return get_kwargs_from_secret(
"AZUREML_WORKSPACE_STR", ["subscription_id", "resource_group", "workspace_name"]
secret_str = os.getenv("AZUREML_WORKSPACE_STR")
return (
get_kwargs_from_secret(
secret_str, ["subscription_id", "resource_group", "workspace_name"]
)
if secret_str
else dict()
)


Expand All @@ -348,9 +344,15 @@ def get_spauth_kwargs() -> dict:
AzureML ServicePrincipalAuthentication keyword arguments. See
:func:`gordo.builder.mlflow_utils.get_mlflow_client`
"""
return get_kwargs_from_secret(
"DL_SERVICE_AUTH_STR",
["tenant_id", "service_principal_id", "service_principal_password"],

secret_str = os.getenv("DL_SERVICE_AUTH_STR")
return (
get_kwargs_from_secret(
secret_str,
["tenant_id", "service_principal_id", "service_principal_password"],
)
if secret_str
else dict()
)


Expand Down Expand Up @@ -433,6 +435,7 @@ def report(self, machine: Machine):

workspace_kwargs = get_workspace_kwargs()
service_principal_kwargs = get_spauth_kwargs()

cache_key = ModelBuilder.calculate_cache_key(machine)

with mlflow_context(
Expand Down
57 changes: 29 additions & 28 deletions tests/gordo/reporters/test_mlflow_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,41 +196,42 @@ def _test_mlflow_batch_arg_types(metadata):
_test_mlflow_batch_arg_types(metadata)


@pytest.mark.parametrize(
"secret_str,keys,keys_valid",
[
("dummy1:dummy2:dummy3", ["key1", "key2", "key3"], True),
("dummy1:dummy2:dummy3", ["key1", "key2"], False),
],
)
def test_get_kwargs_from_secret(monkeypatch, secret_str, keys, keys_valid):
def test_get_kwargs_from_secret_invalid():
"""
Test that service principal kwargs are generated correctly if env var present
Test that method fails with number of secret elements mismatch number of keys
"""
env_var_name = "TEST_SECRET"

# TEST_SECRET doesn't exist as env var
with pytest.raises(ReporterException):
mlu.get_kwargs_from_secret(env_var_name, keys)
mlu.get_kwargs_from_secret("dummy1:dummy2:dummy3", ["key1", "key2"])

# TEST_SECRET exists as env var
monkeypatch.setenv(name=env_var_name, value=secret_str)
if keys_valid:
kwargs = mlu.get_kwargs_from_secret(env_var_name, keys)
for key, value in zip(keys, secret_str.split(":")):
assert kwargs[key] == value
else:
with pytest.raises(ReporterException):
mlu.get_kwargs_from_secret(env_var_name, keys)
with pytest.raises(AttributeError):
mlu.get_kwargs_from_secret(None, ["key1", "key2"])

def test_workspace_kwargs(monkeypatch):
"""
Test that appropriate kwargs dict is returned
"""
assert mlu.get_workspace_kwargs() == {}

def test_workspace_spauth_kwargs():
"""Make sure an error is thrown when env vars not set"""
with pytest.raises(ReporterException):
mlu.get_workspace_kwargs()
monkeypatch.setenv("AZUREML_WORKSPACE_STR", "test:test:test")
assert mlu.get_workspace_kwargs() == {
"subscription_id": "test",
"resource_group": "test",
"workspace_name": "test",
}

with pytest.raises(ReporterException):
mlu.get_spauth_kwargs()
def test_spauth_kwargs(monkeypatch):
"""
Test that appropriate kwargs dict is returned
"""

assert mlu.get_spauth_kwargs() == {}

monkeypatch.setenv("DL_SERVICE_AUTH_STR", "test:test:test")
assert mlu.get_spauth_kwargs() == {
"tenant_id": "test",
"service_principal_id": "test",
"service_principal_password": "test",
}


def test_MachineEncoder(metadata):
Expand Down

0 comments on commit e7807a0

Please sign in to comment.