diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 73fcf08a..a7f86c2b 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -45,7 +45,7 @@ jobs: run: poetry install - name: Run linter check - run: poetry run ruff check . + run: poetry run ruff check --output-format=github . - name: Run code style check run: poetry run ruff format --check . diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..a3e324ca --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,164 @@ +![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/cycodehq/cycode-cli/tests.yml) +![PyPI - Version](https://img.shields.io/pypi/v/cycode) +![GitHub License](https://img.shields.io/github/license/cycodehq/cycode-cli) + +## How to contribute to Cycode CLI + +The minimum version of Python that we support is 3.7. +We recommend using this version for local development. +But it’s fine to use a higher version without using new features from these versions. +We prefer 3.8 because it comes with the support of Apple Silicon, and it is as low as possible. + +The project is under Poetry project management. +To deal with it, you should install it on your system: + +Install Poetry (feel free to use Brew, etc): + +```shell +curl -sSL https://install.python-poetry.org | python - -y +``` + +Add Poetry to PATH if required. + +Add a plugin to support dynamic versioning from Git Tags: + +```shell +poetry self add "poetry-dynamic-versioning[plugin]" +``` + +Install dependencies of the project: + +```shell +poetry install +``` + +Check that the version is valid (not 0.0.0): + +```shell +poetry version +``` + +You are ready to write code! + +To run the project use: + +```shell +poetry run cycode +``` + +or main entry point in an activated virtual environment: + +```shell +python cycode/cli/main.py +``` + +### Code linting and formatting + +We use `ruff` and `ruff format`. +It is configured well, so you don’t need to do anything. +You can see all enabled rules in the `pyproject.toml` file. +Both tests and the main codebase are checked. +Try to avoid type annotations like `Any`, etc. + +GitHub Actions will check that your code is formatted well. You can run it locally: + +```shell +# lint +poetry run ruff . +# format +poetry run ruff format . +``` + +Many rules support auto-fixing. You can run it with the `--fix` flag. + +### Branching and versioning + +We use the `main` branch as the main one. +All development should be done in feature branches. +When you are ready create a Pull Request to the `main` branch. + +Each commit in the `main` branch will be built and published to PyPI as a pre-release! +Such builds could be installed with the `--pre` flag. For example: + +```shell +pip install --pre cycode +``` + +Also, you can select a specific version of the pre-release: + +```shell +pip install cycode==1.7.2.dev6 +``` + +We are using [Semantic Versioning](https://semver.org/) and the version is generated automatically from Git Tags. So, +when you are ready to release a new version, you should create a new Git Tag. The version will be generated from it. + +Pre-release versions are generated on distance from the latest Git Tag. For example, if the latest Git Tag is `1.7.2`, +then the next pre-release version will be `1.7.2.dev1`. + +We are using GitHub Releases to create Git Tags with changelogs. +For changelogs, we are using a standard template +of [Automatically generated release notes](https://docs.github.com/en/repositories/releasing-projects-on-github/automatically-generated-release-notes). + +### Testing + +We are using `pytest` for testing. You can run tests with: + +```shell +poetry run pytest +``` + +The library used for sending requests is [requests](https://github.com/psf/requests). +To mock requests, we are using the [responses](https://github.com/getsentry/responses) library. +All requests must be mocked. + +To see the code coverage of the project, you can run: + +```shell +poetry run coverage run -m pytest . +``` + +To generate the HTML report, you can run: + +```shell +poetry run coverage html +``` + +The report will be generated in the `htmlcov` folder. + +### Documentation + +Keep [README.md](README.md) up to date. +All CLI commands are documented automatically if you add a docstring to the command. +Clean up the changelog before release. + +### Publishing + +New versions are published automatically on the new GitHub Release. +It uses the OpenID Connect publishing mechanism to upload on PyPI. + +[Homebrew formula](https://formulae.brew.sh/formula/cycode) is updated automatically on the new PyPI release. + +The CLI is also distributed as executable files for Linux, macOS, and Windows. +It is powered by [PyInstaller](https://pyinstaller.org/) and the process is automated by GitHub Actions. +These executables are attached to GitHub Releases as assets. + +To pack the project locally, you should run: + +```shell +poetry build +``` + +It will create a `dist` folder with the package (sdist and wheel). You can install it locally: + +```shell +pip install dist/cycode-{version}-py3-none-any.whl +``` + +To create an executable file locally, you should run: + +```shell +poetry run pyinstaller pyinstaller.spec +``` + +It will create an executable file for **the current platform** in the `dist` folder. diff --git a/cycode/cli/commands/auth/auth_command.py b/cycode/cli/commands/auth/auth_command.py index 51eb212b..87171441 100644 --- a/cycode/cli/commands/auth/auth_command.py +++ b/cycode/cli/commands/auth/auth_command.py @@ -50,7 +50,7 @@ def authorization_check(context: click.Context) -> None: return try: - if CycodeTokenBasedClient(client_id, client_secret).api_token: + if CycodeTokenBasedClient(client_id, client_secret).get_access_token(): printer.print_result(passed_auth_check_res) return except (NetworkError, HttpUnauthorizedError): diff --git a/cycode/cli/commands/auth/auth_manager.py b/cycode/cli/commands/auth/auth_manager.py index 11fbf751..829164c2 100644 --- a/cycode/cli/commands/auth/auth_manager.py +++ b/cycode/cli/commands/auth/auth_manager.py @@ -75,7 +75,7 @@ def get_api_token_polling(self, session_id: str, code_verifier: str) -> 'ApiToke raise AuthProcessError('session expired') def save_api_token(self, api_token: 'ApiToken') -> None: - self.credentials_manager.update_credentials_file(api_token.client_id, api_token.secret) + self.credentials_manager.update_credentials(api_token.client_id, api_token.secret) def _build_login_url(self, code_challenge: str, session_id: str) -> str: app_url = self.configuration_manager.get_cycode_app_url() diff --git a/cycode/cli/commands/configure/configure_command.py b/cycode/cli/commands/configure/configure_command.py index 5f9bad0e..5fe695ac 100644 --- a/cycode/cli/commands/configure/configure_command.py +++ b/cycode/cli/commands/configure/configure_command.py @@ -48,7 +48,7 @@ def configure_command() -> None: credentials_updated = False if _should_update_value(current_client_id, client_id) or _should_update_value(current_client_secret, client_secret): credentials_updated = True - _CREDENTIALS_MANAGER.update_credentials_file(client_id, client_secret) + _CREDENTIALS_MANAGER.update_credentials(client_id, client_secret) if config_updated: click.echo(_get_urls_update_result_message()) diff --git a/cycode/cli/commands/scan/code_scanner.py b/cycode/cli/commands/scan/code_scanner.py index 4a444ded..baf5c66f 100644 --- a/cycode/cli/commands/scan/code_scanner.py +++ b/cycode/cli/commands/scan/code_scanner.py @@ -143,11 +143,13 @@ def _get_scan_documents_thread_func( severity_threshold = context.obj['severity_threshold'] command_scan_type = context.info_name + scan_parameters['aggregation_id'] = str(_generate_unique_id()) + def _scan_batch_thread_func(batch: List[Document]) -> Tuple[str, CliError, LocalScanResult]: local_scan_result = error = error_message = None detections_count = relevant_detections_count = zip_file_size = 0 - scan_id = str(_get_scan_id()) + scan_id = str(_generate_unique_id()) scan_completed = False should_use_scan_service = _should_use_scan_service(scan_type, scan_parameters) @@ -280,6 +282,9 @@ def scan_documents( is_commit_range: bool = False, scan_parameters: Optional[dict] = None, ) -> None: + if not scan_parameters: + scan_parameters = get_default_scan_parameters(context) + progress_bar = context.obj['progress_bar'] if not documents_to_scan: @@ -320,7 +325,7 @@ def scan_commit_range_documents( local_scan_result = error_message = None scan_completed = False - scan_id = str(_get_scan_id()) + scan_id = str(_generate_unique_id()) should_use_scan_service = _should_use_scan_service(scan_type, scan_parameters) from_commit_zipped_documents = InMemoryZip() to_commit_zipped_documents = InMemoryZip() @@ -393,7 +398,6 @@ def scan_commit_range_documents( zip_file_size, scan_command_type, error_message, - should_use_scan_service, ) @@ -614,6 +618,11 @@ def get_default_scan_parameters(context: click.Context) -> dict: def get_scan_parameters(context: click.Context, paths: Tuple[str]) -> dict: scan_parameters = get_default_scan_parameters(context) + if not paths: + return scan_parameters + + scan_parameters['paths'] = paths + if len(paths) != 1: # ignore remote url if multiple paths are provided return scan_parameters @@ -622,11 +631,7 @@ def get_scan_parameters(context: click.Context, paths: Tuple[str]) -> dict: if remote_url: # TODO(MarshalX): remove hardcode in context context.obj['remote_url'] = remote_url - scan_parameters.update( - { - 'remote_url': remote_url, - } - ) + scan_parameters['remote_url'] = remote_url return scan_parameters @@ -788,7 +793,7 @@ def _report_scan_status( logger.debug('Failed to report scan status, %s', {'exception_message': str(e)}) -def _get_scan_id() -> UUID: +def _generate_unique_id() -> UUID: return uuid4() diff --git a/cycode/cli/user_settings/credentials_manager.py b/cycode/cli/user_settings/credentials_manager.py index 02653f6d..c302fc96 100644 --- a/cycode/cli/user_settings/credentials_manager.py +++ b/cycode/cli/user_settings/credentials_manager.py @@ -4,15 +4,19 @@ from cycode.cli.config import CYCODE_CLIENT_ID_ENV_VAR_NAME, CYCODE_CLIENT_SECRET_ENV_VAR_NAME from cycode.cli.user_settings.base_file_manager import BaseFileManager -from cycode.cli.utils.yaml_utils import read_file +from cycode.cli.user_settings.jwt_creator import JwtCreator class CredentialsManager(BaseFileManager): HOME_PATH: str = Path.home() CYCODE_HIDDEN_DIRECTORY: str = '.cycode' FILE_NAME: str = 'credentials.yaml' + CLIENT_ID_FIELD_NAME: str = 'cycode_client_id' CLIENT_SECRET_FIELD_NAME: str = 'cycode_client_secret' + ACCESS_TOKEN_FIELD_NAME: str = 'cycode_access_token' + ACCESS_TOKEN_EXPIRES_IN_FIELD_NAME: str = 'cycode_access_token_expires_in' + ACCESS_TOKEN_CREATOR_FIELD_NAME: str = 'cycode_access_token_creator' def get_credentials(self) -> Tuple[str, str]: client_id, client_secret = self.get_credentials_from_environment_variables() @@ -28,21 +32,37 @@ def get_credentials_from_environment_variables() -> Tuple[str, str]: return client_id, client_secret def get_credentials_from_file(self) -> Tuple[Optional[str], Optional[str]]: - credentials_filename = self.get_filename() - try: - file_content = read_file(credentials_filename) - except FileNotFoundError: - return None, None - + file_content = self.read_file() client_id = file_content.get(self.CLIENT_ID_FIELD_NAME) client_secret = file_content.get(self.CLIENT_SECRET_FIELD_NAME) return client_id, client_secret - def update_credentials_file(self, client_id: str, client_secret: str) -> None: - credentials = {self.CLIENT_ID_FIELD_NAME: client_id, self.CLIENT_SECRET_FIELD_NAME: client_secret} + def update_credentials(self, client_id: str, client_secret: str) -> None: + file_content_to_update = {self.CLIENT_ID_FIELD_NAME: client_id, self.CLIENT_SECRET_FIELD_NAME: client_secret} + self.write_content_to_file(file_content_to_update) + + def get_access_token(self) -> Tuple[Optional[str], Optional[float], Optional[JwtCreator]]: + file_content = self.read_file() + + access_token = file_content.get(self.ACCESS_TOKEN_FIELD_NAME) + expires_in = file_content.get(self.ACCESS_TOKEN_EXPIRES_IN_FIELD_NAME) + + creator = None + hashed_creator = file_content.get(self.ACCESS_TOKEN_CREATOR_FIELD_NAME) + if hashed_creator: + creator = JwtCreator(hashed_creator) + + return access_token, expires_in, creator - self.get_filename() - self.write_content_to_file(credentials) + def update_access_token( + self, access_token: Optional[str], expires_in: Optional[float], creator: Optional[JwtCreator] + ) -> None: + file_content_to_update = { + self.ACCESS_TOKEN_FIELD_NAME: access_token, + self.ACCESS_TOKEN_EXPIRES_IN_FIELD_NAME: expires_in, + self.ACCESS_TOKEN_CREATOR_FIELD_NAME: str(creator) if creator else None, + } + self.write_content_to_file(file_content_to_update) def get_filename(self) -> str: return os.path.join(self.HOME_PATH, self.CYCODE_HIDDEN_DIRECTORY, self.FILE_NAME) diff --git a/cycode/cli/user_settings/jwt_creator.py b/cycode/cli/user_settings/jwt_creator.py new file mode 100644 index 00000000..e3778f92 --- /dev/null +++ b/cycode/cli/user_settings/jwt_creator.py @@ -0,0 +1,24 @@ +from cycode.cli.utils.string_utils import hash_string_to_sha256 + +_SEPARATOR = '::' + + +def _get_hashed_creator(client_id: str, client_secret: str) -> str: + return hash_string_to_sha256(_SEPARATOR.join([client_id, client_secret])) + + +class JwtCreator: + def __init__(self, hashed_creator: str) -> None: + self._hashed_creator = hashed_creator + + def __str__(self) -> str: + return self._hashed_creator + + @classmethod + def create(cls, client_id: str, client_secret: str) -> 'JwtCreator': + return cls(_get_hashed_creator(client_id, client_secret)) + + def __eq__(self, other: 'JwtCreator') -> bool: + if not isinstance(other, JwtCreator): + return NotImplemented + return str(self) == str(other) diff --git a/cycode/cyclient/__init__.py b/cycode/cyclient/__init__.py index 7018a231..9bea26e9 100644 --- a/cycode/cyclient/__init__.py +++ b/cycode/cyclient/__init__.py @@ -1,4 +1,4 @@ -from .config import logger +from cycode.cyclient.config import logger __all__ = [ 'logger', diff --git a/cycode/cyclient/auth_client.py b/cycode/cyclient/auth_client.py index 626d4ff9..91f43ad1 100644 --- a/cycode/cyclient/auth_client.py +++ b/cycode/cyclient/auth_client.py @@ -3,9 +3,8 @@ from requests import Response from cycode.cli.exceptions.custom_exceptions import HttpUnauthorizedError, NetworkError - -from . import models -from .cycode_client import CycodeClient +from cycode.cyclient import models +from cycode.cyclient.cycode_client import CycodeClient class AuthClient: diff --git a/cycode/cyclient/cycode_client.py b/cycode/cyclient/cycode_client.py index dfbd2269..eded92da 100644 --- a/cycode/cyclient/cycode_client.py +++ b/cycode/cyclient/cycode_client.py @@ -1,5 +1,5 @@ -from . import config -from .cycode_client_base import CycodeClientBase +from cycode.cyclient import config +from cycode.cyclient.cycode_client_base import CycodeClientBase class CycodeClient(CycodeClientBase): diff --git a/cycode/cyclient/cycode_client_base.py b/cycode/cyclient/cycode_client_base.py index d804b8cb..a1fb68bb 100644 --- a/cycode/cyclient/cycode_client_base.py +++ b/cycode/cyclient/cycode_client_base.py @@ -6,9 +6,7 @@ from cycode import __version__ from cycode.cli.exceptions.custom_exceptions import HttpUnauthorizedError, NetworkError from cycode.cli.user_settings.configuration_manager import ConfigurationManager -from cycode.cyclient import logger - -from . import config +from cycode.cyclient import config, logger def get_cli_user_agent() -> str: diff --git a/cycode/cyclient/cycode_dev_based_client.py b/cycode/cyclient/cycode_dev_based_client.py index f325bd6e..347797c3 100644 --- a/cycode/cyclient/cycode_dev_based_client.py +++ b/cycode/cyclient/cycode_dev_based_client.py @@ -1,7 +1,7 @@ from typing import Dict, Optional -from .config import dev_tenant_id -from .cycode_client_base import CycodeClientBase +from cycode.cyclient.config import dev_tenant_id +from cycode.cyclient.cycode_client_base import CycodeClientBase """ Send requests with api token diff --git a/cycode/cyclient/cycode_token_based_client.py b/cycode/cyclient/cycode_token_based_client.py index c73999fb..d13ce62d 100644 --- a/cycode/cyclient/cycode_token_based_client.py +++ b/cycode/cyclient/cycode_token_based_client.py @@ -2,35 +2,52 @@ from typing import Optional import arrow +from requests import Response -from .cycode_client import CycodeClient +from cycode.cli.user_settings.credentials_manager import CredentialsManager +from cycode.cli.user_settings.jwt_creator import JwtCreator +from cycode.cyclient.cycode_client import CycodeClient class CycodeTokenBasedClient(CycodeClient): - """Send requests with api token""" + """Send requests with JWT.""" def __init__(self, client_id: str, client_secret: str) -> None: super().__init__() self.client_secret = client_secret self.client_id = client_id - self._api_token = None - self._expires_in = None + self._credentials_manager = CredentialsManager() + # load cached access token + access_token, expires_in, creator = self._credentials_manager.get_access_token() + + self._access_token = self._expires_in = None + if creator == JwtCreator.create(client_id, client_secret): + # we must be sure that cached access token is created using the same client id and client secret. + # because client id and client secret could be passed via command, via env vars or via config file. + # we must not use cached access token if client id or client secret was changed. + self._access_token = access_token + self._expires_in = arrow.get(expires_in) if expires_in else None + + self._lock = Lock() - self.lock = Lock() + def get_access_token(self) -> str: + with self._lock: + self.refresh_access_token_if_needed() + return self._access_token + + def invalidate_access_token(self, in_storage: bool = False) -> None: + self._access_token = None + self._expires_in = None - @property - def api_token(self) -> str: - # TODO(MarshalX): This property performs HTTP request to refresh the token. This must be the method. - with self.lock: - self.refresh_api_token_if_needed() - return self._api_token + if in_storage: + self._credentials_manager.update_access_token(None, None, None) - def refresh_api_token_if_needed(self) -> None: - if self._api_token is None or self._expires_in is None or arrow.utcnow() >= self._expires_in: - self.refresh_api_token() + def refresh_access_token_if_needed(self) -> None: + if self._access_token is None or self._expires_in is None or arrow.utcnow() >= self._expires_in: + self.refresh_access_token() - def refresh_api_token(self) -> None: + def refresh_access_token(self) -> None: auth_response = self.post( url_path='api/v1/auth/api-token', body={'clientId': self.client_id, 'secret': self.client_secret}, @@ -39,9 +56,12 @@ def refresh_api_token(self) -> None: ) auth_response_data = auth_response.json() - self._api_token = auth_response_data['token'] + self._access_token = auth_response_data['token'] self._expires_in = arrow.utcnow().shift(seconds=auth_response_data['expires_in'] * 0.8) + jwt_creator = JwtCreator.create(self.client_id, self.client_secret) + self._credentials_manager.update_access_token(self._access_token, self._expires_in.timestamp(), jwt_creator) + def get_request_headers(self, additional_headers: Optional[dict] = None, without_auth: bool = False) -> dict: headers = super().get_request_headers(additional_headers=additional_headers) @@ -51,5 +71,20 @@ def get_request_headers(self, additional_headers: Optional[dict] = None, without return headers def _add_auth_header(self, headers: dict) -> dict: - headers['Authorization'] = f'Bearer {self.api_token}' + headers['Authorization'] = f'Bearer {self.get_access_token()}' return headers + + def _execute( + self, + *args, + **kwargs, + ) -> Response: + response = super()._execute(*args, **kwargs) + + # backend returns 200 and plain text. no way to catch it with .raise_for_status() + if response.status_code == 200 and response.content in {b'Invalid JWT Token\n\n', b'JWT Token Needed\n\n'}: + # if cached token is invalid, try to refresh it and retry the request + self.refresh_access_token() + response = super()._execute(*args, **kwargs) + + return response diff --git a/cycode/cyclient/scan_client.py b/cycode/cyclient/scan_client.py index 74c23f83..e7f2cafe 100644 --- a/cycode/cyclient/scan_client.py +++ b/cycode/cyclient/scan_client.py @@ -10,7 +10,7 @@ from cycode.cyclient.cycode_client_base import CycodeClientBase if TYPE_CHECKING: - from .scan_config_base import ScanConfigBase + from cycode.cyclient.scan_config_base import ScanConfigBase class ScanClient: diff --git a/pyproject.toml b/pyproject.toml index 87bdd11f..40e28317 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,9 @@ docstring-quotes = "double" multiline-quotes = "double" inline-quotes = "single" +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" + [tool.ruff.per-file-ignores] "tests/*.py" = ["S101", "S105"] "cycode/*.py" = ["BLE001"] diff --git a/tests/cli/commands/configure/test_configure_command.py b/tests/cli/commands/configure/test_configure_command.py index 4c42971b..c5ae2b9c 100644 --- a/tests/cli/commands/configure/test_configure_command.py +++ b/tests/cli/commands/configure/test_configure_command.py @@ -35,7 +35,7 @@ def test_configure_command_no_exist_values_in_file(mocker: 'MockerFixture') -> N ) mocked_update_credentials = mocker.patch( - 'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials_file' + 'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials' ) mocked_update_api_base_url = mocker.patch( 'cycode.cli.user_settings.config_file_manager.ConfigFileManager.update_api_base_url' @@ -80,7 +80,7 @@ def test_configure_command_update_current_configs_in_files(mocker: 'MockerFixtur ) mocked_update_credentials = mocker.patch( - 'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials_file' + 'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials' ) mocked_update_api_base_url = mocker.patch( 'cycode.cli.user_settings.config_file_manager.ConfigFileManager.update_api_base_url' @@ -110,7 +110,7 @@ def test_set_credentials_update_only_client_id(mocker: 'MockerFixture') -> None: # side effect - multiple return values, each item in the list represents return of a call mocker.patch('click.prompt', side_effect=['', '', client_id_user_input, '']) mocked_update_credentials = mocker.patch( - 'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials_file' + 'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials' ) # Act @@ -133,7 +133,7 @@ def test_configure_command_update_only_client_secret(mocker: 'MockerFixture') -> # side effect - multiple return values, each item in the list represents return of a call mocker.patch('click.prompt', side_effect=['', '', '', client_secret_user_input]) mocked_update_credentials = mocker.patch( - 'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials_file' + 'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials' ) # Act @@ -166,7 +166,7 @@ def test_configure_command_update_only_api_url(mocker: 'MockerFixture') -> None: mocked_update_api_base_url.assert_called_once_with(api_url_user_input) -def test_configure_command_should_not_update_credentials_file(mocker: 'MockerFixture') -> None: +def test_configure_command_should_not_update_credentials(mocker: 'MockerFixture') -> None: # Arrange client_id_user_input = '' client_secret_user_input = '' @@ -179,7 +179,7 @@ def test_configure_command_should_not_update_credentials_file(mocker: 'MockerFix # side effect - multiple return values, each item in the list represents return of a call mocker.patch('click.prompt', side_effect=['', '', client_id_user_input, client_secret_user_input]) mocked_update_credentials = mocker.patch( - 'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials_file' + 'cycode.cli.user_settings.credentials_manager.CredentialsManager.update_credentials' ) # Act diff --git a/tests/conftest.py b/tests/conftest.py index dc1a84fe..821a0289 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,10 @@ from pathlib import Path +from typing import Optional import pytest import responses +from cycode.cli.user_settings.credentials_manager import CredentialsManager from cycode.cyclient.client_creator import create_scan_client from cycode.cyclient.cycode_token_based_client import CycodeTokenBasedClient from cycode.cyclient.scan_client import ScanClient @@ -29,9 +31,22 @@ def scan_client() -> ScanClient: return create_scan_client(_CLIENT_ID, _CLIENT_SECRET, hide_response_log=False) +def create_token_based_client( + client_id: Optional[str] = None, client_secret: Optional[str] = None +) -> CycodeTokenBasedClient: + CredentialsManager.FILE_NAME = 'unit-tests-credentials.yaml' + + if client_id is None: + client_id = _CLIENT_ID + if client_secret is None: + client_secret = _CLIENT_SECRET + + return CycodeTokenBasedClient(client_id, client_secret) + + @pytest.fixture(scope='session') def token_based_client() -> CycodeTokenBasedClient: - return CycodeTokenBasedClient(_CLIENT_ID, _CLIENT_SECRET) + return create_token_based_client() @pytest.fixture(scope='session') @@ -57,4 +72,4 @@ def api_token_response(api_token_url: str) -> responses.Response: @responses.activate def api_token(token_based_client: CycodeTokenBasedClient, api_token_response: responses.Response) -> str: responses.add(api_token_response) - return token_based_client.api_token + return token_based_client.get_access_token() diff --git a/tests/cyclient/test_token_based_client.py b/tests/cyclient/test_token_based_client.py index b5d824f4..4c3dd4c5 100644 --- a/tests/cyclient/test_token_based_client.py +++ b/tests/cyclient/test_token_based_client.py @@ -2,30 +2,31 @@ import responses from cycode.cyclient.cycode_token_based_client import CycodeTokenBasedClient -from tests.conftest import _EXPECTED_API_TOKEN +from tests.conftest import _EXPECTED_API_TOKEN, create_token_based_client @responses.activate -def test_api_token_new(token_based_client: CycodeTokenBasedClient, api_token_response: responses.Response) -> None: +def test_access_token_new(token_based_client: CycodeTokenBasedClient, api_token_response: responses.Response) -> None: responses.add(api_token_response) - api_token = token_based_client.api_token + api_token = token_based_client.get_access_token() assert api_token == _EXPECTED_API_TOKEN @responses.activate -def test_api_token_expired(token_based_client: CycodeTokenBasedClient, api_token_response: responses.Response) -> None: +def test_access_token_expired( + token_based_client: CycodeTokenBasedClient, api_token_response: responses.Response +) -> None: responses.add(api_token_response) - # this property performs HTTP req to refresh the token. IDE doesn't know it - token_based_client.api_token # noqa: B018 + token_based_client.get_access_token() # mark token as expired token_based_client._expires_in = arrow.utcnow().shift(hours=-1) # refresh token - api_token_refreshed = token_based_client.api_token + api_token_refreshed = token_based_client.get_access_token() assert api_token_refreshed == _EXPECTED_API_TOKEN @@ -35,3 +36,63 @@ def test_get_request_headers(token_based_client: CycodeTokenBasedClient, api_tok expected_headers = {**token_based_client.MANDATORY_HEADERS, **token_based_headers} assert token_based_client.get_request_headers() == expected_headers + + +@responses.activate +def test_access_token_cached( + token_based_client: CycodeTokenBasedClient, api_token_response: responses.Response +) -> None: + # save to cache + responses.add(api_token_response) + token_based_client.get_access_token() + + # load from cache + client2 = create_token_based_client() + assert client2._access_token == token_based_client._access_token + assert client2._expires_in == token_based_client._expires_in + + +@responses.activate +def test_access_token_cached_creator_changed( + token_based_client: CycodeTokenBasedClient, api_token_response: responses.Response +) -> None: + # save to cache + responses.add(api_token_response) + token_based_client.get_access_token() + + # load from cache with another client id and client secret + client2 = create_token_based_client('client_id2', 'client_secret2') + assert client2._access_token is None + assert client2._expires_in is None + + +@responses.activate +def test_access_token_invalidation( + token_based_client: CycodeTokenBasedClient, api_token_response: responses.Response +) -> None: + # save to cache + responses.add(api_token_response) + token_based_client.get_access_token() + + expected_access_token = token_based_client._access_token + expected_expires_in = token_based_client._expires_in + + # invalidate in runtime + token_based_client.invalidate_access_token() + assert token_based_client._access_token is None + assert token_based_client._expires_in is None + + # load from cache + client2 = create_token_based_client() + assert client2._access_token == expected_access_token + assert client2._expires_in == expected_expires_in + + # invalidate in storage + client2.invalidate_access_token(in_storage=True) + assert client2._access_token is None + assert client2._expires_in is None + + # load from cache again + client3 = create_token_based_client() + assert client3._access_token is None + assert client3._expires_in is None