diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index b4a6b7a438..317a16c5f4 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -122,9 +122,12 @@ def upgrade_channel_to_proxy_authenticated(cfg: PlatformConfig, in_channel: grpc :param in_channel: grpc.Channel Precreated channel :return: grpc.Channel. New composite channel """ + + def authenticator_factory(): + return get_proxy_authenticator(cfg) + if cfg.proxy_command: - proxy_authenticator = get_proxy_authenticator(cfg) - return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(proxy_authenticator)) + return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator_factory)) else: return in_channel @@ -137,8 +140,11 @@ def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Chann :param in_channel: grpc.Channel Precreated channel :return: grpc.Channel. New composite channel """ - authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel)) - return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator)) + + def authenticator_factory(): + return get_authenticator(cfg, RemoteClientConfigStore(in_channel)) + + return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator_factory)) def get_authenticated_channel(cfg: PlatformConfig) -> grpc.Channel: diff --git a/flytekit/clients/grpc_utils/auth_interceptor.py b/flytekit/clients/grpc_utils/auth_interceptor.py index 6a73e0764e..05a5cb53fa 100644 --- a/flytekit/clients/grpc_utils/auth_interceptor.py +++ b/flytekit/clients/grpc_utils/auth_interceptor.py @@ -25,15 +25,22 @@ class AuthUnaryInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamCli is needed. """ - def __init__(self, authenticator: Authenticator): - self._authenticator = authenticator + def __init__(self, get_authenticator: typing.Callable[[], Authenticator]): + self._get_authenticator = get_authenticator + self._authenticator = None + + @property + def authenticator(self) -> Authenticator: + if self._authenticator is None: + self._authenticator = self._get_authenticator() + return self._authenticator def _call_details_with_auth_metadata(self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails: """ Returns new ClientCallDetails with metadata added. """ metadata = client_call_details.metadata - auth_metadata = self._authenticator.fetch_grpc_call_auth_metadata() + auth_metadata = self.authenticator.fetch_grpc_call_auth_metadata() if auth_metadata: metadata = [] if client_call_details.metadata: @@ -65,6 +72,7 @@ def intercept_unary_unary( raise e if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN: self._authenticator.refresh_credentials() + self.authenticator.refresh_credentials() updated_call_details = self._call_details_with_auth_metadata(client_call_details) return continuation(updated_call_details, request) return fut @@ -77,6 +85,7 @@ def intercept_unary_stream(self, continuation, client_call_details, request): c: grpc.Call = continuation(updated_call_details, request) if c.code() == grpc.StatusCode.UNAUTHENTICATED: self._authenticator.refresh_credentials() + self.authenticator.refresh_credentials() updated_call_details = self._call_details_with_auth_metadata(client_call_details) return continuation(updated_call_details, request) return c diff --git a/tests/flytekit/unit/clients/auth/test_keyring_store.py b/tests/flytekit/unit/clients/auth/test_keyring_store.py index d068a1f451..fc6a3b98df 100644 --- a/tests/flytekit/unit/clients/auth/test_keyring_store.py +++ b/tests/flytekit/unit/clients/auth/test_keyring_store.py @@ -4,6 +4,15 @@ from flytekit.clients.auth.keyring import Credentials, KeyringStore +from flytekit.clients.auth_helper import upgrade_channel_to_authenticated, upgrade_channel_to_proxy_authenticated + +from flytekit.configuration import PlatformConfig + +import pytest + +from flytekit.clients.auth.authenticator import CommandAuthenticator + +from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor @patch("keyring.get_password") def test_keyring_store_get(kr_get_password: MagicMock): @@ -30,3 +39,34 @@ def test_keyring_store_set(kr_set_password: MagicMock): kr_set_password.side_effect = NoKeyringError() assert KeyringStore.retrieve("example2.com") is None + +@patch("flytekit.clients.auth.authenticator.KeyringStore") +def test_upgrade_channel_to_authenticated_with_keyring_exception(mock_keyring_store): + mock_keyring_store.retrieve.side_effect = Exception("mock exception") + + mock_channel = MagicMock() + + platform_config = PlatformConfig() + + try: + out_ch = upgrade_channel_to_authenticated(platform_config, mock_channel) + except Exception as e: + pytest.fail(f"upgrade_channel_to_authenticated Exception: {e}") + + assert isinstance(out_ch._interceptor, AuthUnaryInterceptor) + +@patch("flytekit.clients.auth.authenticator.KeyringStore") +def test_upgrade_channel_to_proxy_authenticated_with_keyring_exception(mock_keyring_store): + mock_keyring_store.retrieve.side_effect = Exception("mock exception") + + mock_channel = MagicMock() + + platform_config = PlatformConfig(auth_mode="Pkce", proxy_command=["echo", "foo-bar"]) + + try: + out_ch = upgrade_channel_to_proxy_authenticated(platform_config, mock_channel) + except Exception as e: + pytest.fail(f"upgrade_channel_to_proxy_authenticated Exception: {e}") + + assert isinstance(out_ch._interceptor, AuthUnaryInterceptor) + assert isinstance(out_ch._interceptor.authenticator, CommandAuthenticator) diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py index 4baac2ebc5..d1a1851e7e 100644 --- a/tests/flytekit/unit/clients/test_auth_helper.py +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -172,7 +172,7 @@ def test_upgrade_channel_to_proxy_auth(): ch, ) assert isinstance(out_ch._interceptor, AuthUnaryInterceptor) - assert isinstance(out_ch._interceptor._authenticator, CommandAuthenticator) + assert isinstance(out_ch._interceptor.authenticator, CommandAuthenticator) def test_get_proxy_authenticated_session(): diff --git a/tests/flytekit/unit/clients/test_friendly.py b/tests/flytekit/unit/clients/test_friendly.py index b553ae78a0..59211c3cd2 100644 --- a/tests/flytekit/unit/clients/test_friendly.py +++ b/tests/flytekit/unit/clients/test_friendly.py @@ -9,7 +9,6 @@ from flytekit.configuration import PlatformConfig from flytekit.models.project import Project as _Project - @mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.update_project") def test_update_project(mock_raw_update_project): client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True)) diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index ee4e516354..00ca38b807 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -5,7 +5,6 @@ from flytekit.clients.raw import RawSynchronousFlyteClient from flytekit.configuration import PlatformConfig - @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") def test_update_project(mock_channel, mock_admin): @@ -14,7 +13,6 @@ def test_update_project(mock_channel, mock_admin): client.update_project(project) mock_admin.AdminServiceStub().UpdateProject.assert_called_with(project, metadata=None) - @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") def test_list_projects_paginated(mock_channel, mock_admin):