Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CM-29444 - Add caching of access token #199

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cycode/cli/commands/auth/auth_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def authorization_check(context: click.Context) -> None:
return

try:
if CycodeTokenBasedClient(client_id, client_secret).api_token:
if CycodeTokenBasedClient(client_id, client_secret).get_access_token():
printer.print_result(passed_auth_check_res)
return
except (NetworkError, HttpUnauthorizedError):
Expand Down
2 changes: 1 addition & 1 deletion cycode/cli/commands/auth/auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_api_token_polling(self, session_id: str, code_verifier: str) -> 'ApiToke
raise AuthProcessError('session expired')

def save_api_token(self, api_token: 'ApiToken') -> None:
self.credentials_manager.update_credentials_file(api_token.client_id, api_token.secret)
self.credentials_manager.update_credentials(api_token.client_id, api_token.secret)

def _build_login_url(self, code_challenge: str, session_id: str) -> str:
app_url = self.configuration_manager.get_cycode_app_url()
Expand Down
2 changes: 1 addition & 1 deletion cycode/cli/commands/configure/configure_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def configure_command() -> None:
credentials_updated = False
if _should_update_value(current_client_id, client_id) or _should_update_value(current_client_secret, client_secret):
credentials_updated = True
_CREDENTIALS_MANAGER.update_credentials_file(client_id, client_secret)
_CREDENTIALS_MANAGER.update_credentials(client_id, client_secret)

if config_updated:
click.echo(_get_urls_update_result_message())
Expand Down
42 changes: 31 additions & 11 deletions cycode/cli/user_settings/credentials_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@

from cycode.cli.config import CYCODE_CLIENT_ID_ENV_VAR_NAME, CYCODE_CLIENT_SECRET_ENV_VAR_NAME
from cycode.cli.user_settings.base_file_manager import BaseFileManager
from cycode.cli.utils.yaml_utils import read_file
from cycode.cli.user_settings.jwt_creator import JwtCreator


class CredentialsManager(BaseFileManager):
HOME_PATH: str = Path.home()
CYCODE_HIDDEN_DIRECTORY: str = '.cycode'
FILE_NAME: str = 'credentials.yaml'

CLIENT_ID_FIELD_NAME: str = 'cycode_client_id'
CLIENT_SECRET_FIELD_NAME: str = 'cycode_client_secret'
ACCESS_TOKEN_FIELD_NAME: str = 'cycode_access_token'
ACCESS_TOKEN_EXPIRES_IN_FIELD_NAME: str = 'cycode_access_token_expires_in'
ACCESS_TOKEN_CREATOR_FIELD_NAME: str = 'cycode_access_token_creator'

def get_credentials(self) -> Tuple[str, str]:
client_id, client_secret = self.get_credentials_from_environment_variables()
Expand All @@ -28,21 +32,37 @@ def get_credentials_from_environment_variables() -> Tuple[str, str]:
return client_id, client_secret

def get_credentials_from_file(self) -> Tuple[Optional[str], Optional[str]]:
credentials_filename = self.get_filename()
try:
file_content = read_file(credentials_filename)
except FileNotFoundError:
return None, None

file_content = self.read_file()
client_id = file_content.get(self.CLIENT_ID_FIELD_NAME)
client_secret = file_content.get(self.CLIENT_SECRET_FIELD_NAME)
return client_id, client_secret

def update_credentials_file(self, client_id: str, client_secret: str) -> None:
credentials = {self.CLIENT_ID_FIELD_NAME: client_id, self.CLIENT_SECRET_FIELD_NAME: client_secret}
def update_credentials(self, client_id: str, client_secret: str) -> None:
file_content_to_update = {self.CLIENT_ID_FIELD_NAME: client_id, self.CLIENT_SECRET_FIELD_NAME: client_secret}
self.write_content_to_file(file_content_to_update)

def get_access_token(self) -> Tuple[Optional[str], Optional[float], Optional[JwtCreator]]:
file_content = self.read_file()

access_token = file_content.get(self.ACCESS_TOKEN_FIELD_NAME)
expires_in = file_content.get(self.ACCESS_TOKEN_EXPIRES_IN_FIELD_NAME)

creator = None
hashed_creator = file_content.get(self.ACCESS_TOKEN_CREATOR_FIELD_NAME)
if hashed_creator:
creator = JwtCreator(hashed_creator)

return access_token, expires_in, creator

self.get_filename()
self.write_content_to_file(credentials)
def update_access_token(
self, access_token: Optional[str], expires_in: Optional[float], creator: Optional[JwtCreator]
) -> None:
file_content_to_update = {
self.ACCESS_TOKEN_FIELD_NAME: access_token,
self.ACCESS_TOKEN_EXPIRES_IN_FIELD_NAME: expires_in,
self.ACCESS_TOKEN_CREATOR_FIELD_NAME: str(creator) if creator else None,
}
self.write_content_to_file(file_content_to_update)

def get_filename(self) -> str:
return os.path.join(self.HOME_PATH, self.CYCODE_HIDDEN_DIRECTORY, self.FILE_NAME)
24 changes: 24 additions & 0 deletions cycode/cli/user_settings/jwt_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from cycode.cli.utils.string_utils import hash_string_to_sha256

_SEPARATOR = '::'


def _get_hashed_creator(client_id: str, client_secret: str) -> str:
return hash_string_to_sha256(_SEPARATOR.join([client_id, client_secret]))


class JwtCreator:
def __init__(self, hashed_creator: str) -> None:
self._hashed_creator = hashed_creator

def __str__(self) -> str:
return self._hashed_creator

@classmethod
def create(cls, client_id: str, client_secret: str) -> 'JwtCreator':
return cls(_get_hashed_creator(client_id, client_secret))

def __eq__(self, other: 'JwtCreator') -> bool:
if not isinstance(other, JwtCreator):
return NotImplemented
return str(self) == str(other)
69 changes: 52 additions & 17 deletions cycode/cyclient/cycode_token_based_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,52 @@
from typing import Optional

import arrow
from requests import Response

from .cycode_client import CycodeClient
from cycode.cli.user_settings.credentials_manager import CredentialsManager
from cycode.cli.user_settings.jwt_creator import JwtCreator
from cycode.cyclient.cycode_client import CycodeClient


class CycodeTokenBasedClient(CycodeClient):
"""Send requests with api token"""
"""Send requests with JWT."""

def __init__(self, client_id: str, client_secret: str) -> None:
super().__init__()
self.client_secret = client_secret
self.client_id = client_id

self._api_token = None
self._expires_in = None
self._credentials_manager = CredentialsManager()
# load cached access token
access_token, expires_in, creator = self._credentials_manager.get_access_token()

self._access_token = self._expires_in = None
if creator == JwtCreator.create(client_id, client_secret):
# we must be sure that cached access token is created using the same client id and client secret.
# because client id and client secret could be passed via command, via env vars or via config file.
# we must not use cached access token if client id or client secret was changed.
self._access_token = access_token
self._expires_in = arrow.get(expires_in) if expires_in else None

self._lock = Lock()

self.lock = Lock()
def get_access_token(self) -> str:
with self._lock:
self.refresh_access_token_if_needed()
return self._access_token

def invalidate_access_token(self, in_storage: bool = False) -> None:
self._access_token = None
self._expires_in = None

@property
def api_token(self) -> str:
# TODO(MarshalX): This property performs HTTP request to refresh the token. This must be the method.
with self.lock:
self.refresh_api_token_if_needed()
return self._api_token
if in_storage:
self._credentials_manager.update_access_token(None, None, None)

def refresh_api_token_if_needed(self) -> None:
if self._api_token is None or self._expires_in is None or arrow.utcnow() >= self._expires_in:
self.refresh_api_token()
def refresh_access_token_if_needed(self) -> None:
if self._access_token is None or self._expires_in is None or arrow.utcnow() >= self._expires_in:
self.refresh_access_token()

def refresh_api_token(self) -> None:
def refresh_access_token(self) -> None:
auth_response = self.post(
url_path='api/v1/auth/api-token',
body={'clientId': self.client_id, 'secret': self.client_secret},
Expand All @@ -39,9 +56,12 @@ def refresh_api_token(self) -> None:
)
auth_response_data = auth_response.json()

self._api_token = auth_response_data['token']
self._access_token = auth_response_data['token']
self._expires_in = arrow.utcnow().shift(seconds=auth_response_data['expires_in'] * 0.8)

jwt_creator = JwtCreator.create(self.client_id, self.client_secret)
self._credentials_manager.update_access_token(self._access_token, self._expires_in.timestamp(), jwt_creator)

def get_request_headers(self, additional_headers: Optional[dict] = None, without_auth: bool = False) -> dict:
headers = super().get_request_headers(additional_headers=additional_headers)

Expand All @@ -51,5 +71,20 @@ def get_request_headers(self, additional_headers: Optional[dict] = None, without
return headers

def _add_auth_header(self, headers: dict) -> dict:
headers['Authorization'] = f'Bearer {self.api_token}'
headers['Authorization'] = f'Bearer {self.get_access_token()}'
return headers

def _execute(
self,
*args,
**kwargs,
) -> Response:
response = super()._execute(*args, **kwargs)

# backend returns 200 and plain text. no way to catch it with .raise_for_status()
if response.status_code == 200 and response.content in {b'Invalid JWT Token\n\n', b'JWT Token Needed\n\n'}:
# if cached token is invalid, try to refresh it and retry the request
self.refresh_access_token()
response = super()._execute(*args, **kwargs)

return response
12 changes: 6 additions & 6 deletions tests/cli/commands/configure/test_configure_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_configure_command_no_exist_values_in_file(mocker: 'MockerFixture') -> N
)

mocked_update_credentials = mocker.patch(
'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials_file'
'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials'
)
mocked_update_api_base_url = mocker.patch(
'cycode.cli.user_settings.config_file_manager.ConfigFileManager.update_api_base_url'
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_configure_command_update_current_configs_in_files(mocker: 'MockerFixtur
)

mocked_update_credentials = mocker.patch(
'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials_file'
'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials'
)
mocked_update_api_base_url = mocker.patch(
'cycode.cli.user_settings.config_file_manager.ConfigFileManager.update_api_base_url'
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_set_credentials_update_only_client_id(mocker: 'MockerFixture') -> None:
# side effect - multiple return values, each item in the list represents return of a call
mocker.patch('click.prompt', side_effect=['', '', client_id_user_input, ''])
mocked_update_credentials = mocker.patch(
'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials_file'
'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials'
)

# Act
Expand All @@ -133,7 +133,7 @@ def test_configure_command_update_only_client_secret(mocker: 'MockerFixture') ->
# side effect - multiple return values, each item in the list represents return of a call
mocker.patch('click.prompt', side_effect=['', '', '', client_secret_user_input])
mocked_update_credentials = mocker.patch(
'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials_file'
'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials'
)

# Act
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_configure_command_update_only_api_url(mocker: 'MockerFixture') -> None:
mocked_update_api_base_url.assert_called_once_with(api_url_user_input)


def test_configure_command_should_not_update_credentials_file(mocker: 'MockerFixture') -> None:
def test_configure_command_should_not_update_credentials(mocker: 'MockerFixture') -> None:
# Arrange
client_id_user_input = ''
client_secret_user_input = ''
Expand All @@ -179,7 +179,7 @@ def test_configure_command_should_not_update_credentials_file(mocker: 'MockerFix
# side effect - multiple return values, each item in the list represents return of a call
mocker.patch('click.prompt', side_effect=['', '', client_id_user_input, client_secret_user_input])
mocked_update_credentials = mocker.patch(
'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials_file'
'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials'
)

# Act
Expand Down
19 changes: 17 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from pathlib import Path
from typing import Optional

import pytest
import responses

from cycode.cli.user_settings.credentials_manager import CredentialsManager
from cycode.cyclient.client_creator import create_scan_client
from cycode.cyclient.cycode_token_based_client import CycodeTokenBasedClient
from cycode.cyclient.scan_client import ScanClient
Expand All @@ -29,9 +31,22 @@ def scan_client() -> ScanClient:
return create_scan_client(_CLIENT_ID, _CLIENT_SECRET, hide_response_log=False)


def create_token_based_client(
client_id: Optional[str] = None, client_secret: Optional[str] = None
) -> CycodeTokenBasedClient:
CredentialsManager.FILE_NAME = 'unit-tests-credentials.yaml'

if client_id is None:
client_id = _CLIENT_ID
if client_secret is None:
client_secret = _CLIENT_SECRET

return CycodeTokenBasedClient(client_id, client_secret)


@pytest.fixture(scope='session')
def token_based_client() -> CycodeTokenBasedClient:
return CycodeTokenBasedClient(_CLIENT_ID, _CLIENT_SECRET)
return create_token_based_client()


@pytest.fixture(scope='session')
Expand All @@ -57,4 +72,4 @@ def api_token_response(api_token_url: str) -> responses.Response:
@responses.activate
def api_token(token_based_client: CycodeTokenBasedClient, api_token_response: responses.Response) -> str:
responses.add(api_token_response)
return token_based_client.api_token
return token_based_client.get_access_token()
Loading