Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CM-34882 - Add one report URL for all secrets found in the same scan #228

Merged
merged 10 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions cycode/cli/commands/scan/code_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ def _enrich_scan_result_with_data_from_detection_rules(

def _get_scan_documents_thread_func(
context: click.Context, is_git_diff: bool, is_commit_range: bool, scan_parameters: dict
) -> Callable[[List[Document]], Tuple[str, CliError, LocalScanResult]]:
) -> Tuple[Callable[[List[Document]], Tuple[str, CliError, LocalScanResult]], str]:
cycode_client = context.obj['client']
scan_type = context.obj['scan_type']
severity_threshold = context.obj['severity_threshold']
sync_option = context.obj['sync']
command_scan_type = context.info_name

scan_parameters['aggregation_id'] = str(_generate_unique_id())
aggregation_id = str(_generate_unique_id())
scan_parameters['aggregation_id'] = aggregation_id

def _scan_batch_thread_func(batch: List[Document]) -> Tuple[str, CliError, LocalScanResult]:
local_scan_result = error = error_message = None
Expand Down Expand Up @@ -224,7 +224,7 @@ def _scan_batch_thread_func(batch: List[Document]) -> Tuple[str, CliError, Local

return scan_id, error, local_scan_result

return _scan_batch_thread_func
return _scan_batch_thread_func, aggregation_id


def scan_commit_range(
Expand Down Expand Up @@ -312,11 +312,16 @@ def scan_documents(
)
return

scan_batch_thread_func = _get_scan_documents_thread_func(context, is_git_diff, is_commit_range, scan_parameters)
scan_batch_thread_func, aggregation_id = _get_scan_documents_thread_func(
context, is_git_diff, is_commit_range, scan_parameters
)
errors, local_scan_results = run_parallel_batched_scan(
scan_batch_thread_func, documents_to_scan, progress_bar=progress_bar
)

aggregation_report_url = _try_get_aggregation_report_url_if_needed(
scan_parameters, context.obj['client'], context.obj['scan_type']
)
set_aggregation_report_url(context, aggregation_report_url)
progress_bar.set_section_length(ScanProgressBarSection.GENERATE_REPORT, 1)
progress_bar.update(ScanProgressBarSection.GENERATE_REPORT)
progress_bar.stop()
Expand All @@ -325,6 +330,25 @@ def scan_documents(
print_results(context, local_scan_results, errors)


def set_aggregation_report_url(context: click.Context, aggregation_report_url: str = '') -> None:
MarshalX marked this conversation as resolved.
Show resolved Hide resolved
context.obj['aggregation_report_url'] = aggregation_report_url


def _try_get_aggregation_report_url_if_needed(
scan_parameters: dict, cycode_client: 'ScanClient', scan_type: str
) -> Optional[str]:
aggregation_id = scan_parameters.get('aggregation_id')
if not scan_parameters.get('report'):
return None
if aggregation_id is None:
return None
try:
report_url_response = cycode_client.get_scan_aggregation_report_url(aggregation_id, scan_type)
return report_url_response.report_url
except Exception as e:
logger.debug('Failed to get aggregation report url: %s', str(e))


def scan_commit_range_documents(
context: click.Context,
from_documents_to_scan: List[Document],
Expand Down
6 changes: 4 additions & 2 deletions cycode/cli/printers/console_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ def __init__(self, context: click.Context) -> None:
self.context = context
self.scan_type = self.context.obj.get('scan_type')
self.output_type = self.context.obj.get('output')

self.aggregation_report_url = self.context.obj.get('aggregation_report_url')
self._printer_class = self._AVAILABLE_PRINTERS.get(self.output_type)
if self._printer_class is None:
raise CycodeError(f'"{self.output_type}" output type is not supported.')

def print_scan_results(
self, local_scan_results: List['LocalScanResult'], errors: Optional[Dict[str, 'CliError']] = None
self,
local_scan_results: List['LocalScanResult'],
errors: Optional[Dict[str, 'CliError']] = None,
) -> None:
printer = self._get_scan_printer()
printer.print_scan_results(local_scan_results, errors)
Expand Down
6 changes: 4 additions & 2 deletions cycode/cli/printers/json_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ def print_scan_results(
scan_ids = []
report_urls = []
detections = []
aggregation_report_url = self.context.obj.get('aggregation_report_url')
if aggregation_report_url:
report_urls.append(aggregation_report_url)

for local_scan_result in local_scan_results:
scan_ids.append(local_scan_result.scan_id)

if local_scan_result.report_url:
if not aggregation_report_url and local_scan_result.report_url:
report_urls.append(local_scan_result.report_url)

for document_detections in local_scan_result.document_detections:
detections.extend(document_detections.detections)

Expand Down
5 changes: 2 additions & 3 deletions cycode/cli/printers/tables/sca_table_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
if TYPE_CHECKING:
from cycode.cli.models import LocalScanResult


column_builder = ColumnInfoBuilder()

# Building must have strict order. Represents the order of the columns in the table (from left to right)
Expand All @@ -29,7 +28,6 @@
DIRECT_DEPENDENCY_COLUMN = column_builder.build(name='Direct Dependency')
DEVELOPMENT_DEPENDENCY_COLUMN = column_builder.build(name='Development Dependency')


COLUMN_WIDTHS_CONFIG: ColumnWidths = {
REPOSITORY_COLUMN: 2,
CODE_PROJECT_COLUMN: 2,
Expand All @@ -42,6 +40,7 @@

class ScaTablePrinter(TablePrinterBase):
def _print_results(self, local_scan_results: List['LocalScanResult']) -> None:
aggregation_report_url = self.context.obj.get('aggregation_report_url')
detections_per_policy_id = self._extract_detections_per_policy_id(local_scan_results)
for policy_id, detections in detections_per_policy_id.items():
table = self._get_table(policy_id)
Expand All @@ -53,7 +52,7 @@ def _print_results(self, local_scan_results: List['LocalScanResult']) -> None:
self._print_summary_issues(len(detections), self._get_title(policy_id))
self._print_table(table)

self._print_report_urls(local_scan_results)
self._print_report_urls(local_scan_results, aggregation_report_url)

@staticmethod
def _get_title(policy_id: str) -> str:
Expand Down
10 changes: 7 additions & 3 deletions cycode/cli/printers/tables/table_printer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional

import click

Expand Down Expand Up @@ -51,7 +51,11 @@


class TablePrinter(TablePrinterBase):
def _print_results(self, local_scan_results: List['LocalScanResult']) -> None:
def _print_results(
self,
local_scan_results: List['LocalScanResult'],
aggregation_report_url: Optional[str] = None,
) -> None:
MarshalX marked this conversation as resolved.
Show resolved Hide resolved
table = self._get_table()
if self.scan_type in COLUMN_WIDTHS_CONFIG:
table.set_cols_width(COLUMN_WIDTHS_CONFIG[self.scan_type])
Expand All @@ -63,7 +67,7 @@ def _print_results(self, local_scan_results: List['LocalScanResult']) -> None:
self._enrich_table_with_values(table, detection, document_detections.document)

self._print_table(table)
self._print_report_urls(local_scan_results)
self._print_report_urls(local_scan_results, aggregation_report_url)

MarshalX marked this conversation as resolved.
Show resolved Hide resolved
def _get_table(self) -> Table:
table = Table()
Expand Down
10 changes: 8 additions & 2 deletions cycode/cli/printers/tables/table_printer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,15 @@ def _print_table(table: 'Table') -> None:
click.echo(table.get_table().draw())

@staticmethod
def _print_report_urls(local_scan_results: List['LocalScanResult']) -> None:
def _print_report_urls(
local_scan_results: List['LocalScanResult'],
aggregation_report_url: Optional[str] = None,
) -> None:
report_urls = [scan_result.report_url for scan_result in local_scan_results if scan_result.report_url]
if not report_urls:
if not report_urls and not aggregation_report_url:
return
if aggregation_report_url:
click.echo(f'Report URL: {aggregation_report_url}')
return

click.echo('Report URLs:')
Expand Down
32 changes: 19 additions & 13 deletions cycode/cli/printers/text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ def print_scan_results(

for local_scan_result in local_scan_results:
for document_detections in local_scan_result.document_detections:
self._print_document_detections(
document_detections, local_scan_result.scan_id, local_scan_result.report_url
)
self._print_document_detections(document_detections, local_scan_result.scan_id)

report_urls = [scan_result.report_url for scan_result in local_scan_results if scan_result.report_url]

self._print_report_urls(report_urls, self.context.obj.get('aggregation_report_url'))
if not errors:
return

Expand All @@ -55,27 +56,21 @@ def print_scan_results(
click.echo(f'- {scan_id}: ', nl=False)
self.print_error(error)

def _print_document_detections(
self, document_detections: DocumentDetections, scan_id: str, report_url: Optional[str]
) -> None:
def _print_document_detections(self, document_detections: DocumentDetections, scan_id: str) -> None:
document = document_detections.document
lines_to_display = self._get_lines_to_display_count()
for detection in document_detections.detections:
self._print_detection_summary(detection, document.path, scan_id, report_url)
self._print_detection_summary(detection, document.path, scan_id)
self._print_detection_code_segment(detection, document, lines_to_display)

def _print_detection_summary(
self, detection: Detection, document_path: str, scan_id: str, report_url: Optional[str]
) -> None:
def _print_detection_summary(self, detection: Detection, document_path: str, scan_id: str) -> None:
detection_name = detection.type if self.scan_type == SECRET_SCAN_TYPE else detection.message
detection_name_styled = click.style(detection_name, fg='bright_red', bold=True)

detection_sha = detection.detection_details.get('sha512')
detection_sha_message = f'\nSecret SHA: {detection_sha}' if detection_sha else ''

scan_id_message = f'\nScan ID: {scan_id}'
report_url_message = f'\nReport URL: {report_url}' if report_url else ''

detection_commit_id = detection.detection_details.get('commit_id')
detection_commit_id_message = f'\nCommit SHA: {detection_commit_id}' if detection_commit_id else ''

Expand All @@ -88,7 +83,6 @@ def _print_detection_summary(
f'(rule ID: {detection.detection_rule_id}) in file: {click.format_filename(document_path)} '
f'{detection_sha_message}'
f'{scan_id_message}'
f'{report_url_message}'
f'{detection_commit_id_message}'
f'{company_guidelines_message}'
f' ⛔'
Expand All @@ -101,6 +95,18 @@ def _print_detection_code_segment(self, detection: Detection, document: Document

self._print_detection_from_file(detection, document, code_segment_size)

@staticmethod
def _print_report_urls(report_urls: List[str], aggregation_report_url: Optional[str] = None) -> None:
if not report_urls and not aggregation_report_url:
return
if aggregation_report_url:
click.echo(f'Report URL: {aggregation_report_url}')
return

click.echo('Report URLs:')
for report_url in report_urls:
click.echo(f'- {report_url}')

@staticmethod
def _get_code_segment_start_line(detection_line: int, code_segment_size: int) -> int:
start_line = detection_line - math.ceil(code_segment_size / 2)
Expand Down
14 changes: 13 additions & 1 deletion cycode/cyclient/scan_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_scan_service_url_path(
self, scan_type: str, should_use_scan_service: bool = False, should_use_sync_flow: bool = False
) -> str:
service_path = self.scan_config.get_service_name(scan_type, should_use_scan_service)
controller_path = self.get_scan_controller_path(scan_type)
controller_path = self.get_scan_controller_path(scan_type, should_use_scan_service)
flow_type = self.get_scan_flow_type(should_use_sync_flow)
return f'{service_path}/{controller_path}{flow_type}'

Expand Down Expand Up @@ -92,6 +92,12 @@ def get_scan_report_url(self, scan_id: str, scan_type: str) -> models.ScanReport
response = self.scan_cycode_client.get(url_path=self.get_scan_report_url_path(scan_id, scan_type))
return models.ScanReportUrlResponseSchema().build_dto(response.json())

def get_scan_aggregation_report_url(self, aggregation_id: str, scan_type: str) -> models.ScanReportUrlResponse:
response = self.scan_cycode_client.get(
url_path=self.get_scan_aggregation_report_url_path(aggregation_id, scan_type)
)
return models.ScanReportUrlResponseSchema().build_dto(response.json())

def get_zipped_file_scan_async_url_path(self, scan_type: str, should_use_sync_flow: bool = False) -> str:
async_scan_type = self.scan_config.get_async_scan_type(scan_type)
async_entity_type = self.scan_config.get_async_entity_type(scan_type)
Expand Down Expand Up @@ -155,6 +161,12 @@ def get_scan_details_path(self, scan_type: str, scan_id: str) -> str:
def get_scan_report_url_path(self, scan_id: str, scan_type: str) -> str:
return f'{self.get_scan_service_url_path(scan_type, should_use_scan_service=True)}/reportUrl/{scan_id}'

def get_scan_aggregation_report_url_path(self, aggregation_id: str, scan_type: str) -> str:
return (
f'{self.get_scan_service_url_path(scan_type, should_use_scan_service=True)}'
f'/reportUrlByAggregationId/{aggregation_id}'
)

def get_scan_details(self, scan_type: str, scan_id: str) -> models.ScanDetailsResponse:
path = self.get_scan_details_path(scan_type, scan_id)
response = self.scan_cycode_client.get(url_path=path)
Expand Down
14 changes: 14 additions & 0 deletions tests/cyclient/mocked_responses/scan_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def get_scan_report_url(scan_id: Optional[UUID], scan_client: ScanClient, scan_t
return f'{api_url}/{service_url}'


def get_scan_aggregation_report_url(aggregation_id: Optional[UUID], scan_client: ScanClient, scan_type: str) -> str:
api_url = scan_client.scan_cycode_client.api_url
service_url = scan_client.get_scan_aggregation_report_url_path(str(aggregation_id), scan_type)
return f'{api_url}/{service_url}'


def get_scan_report_url_response(url: str, scan_id: Optional[UUID] = None) -> responses.Response:
if not scan_id:
scan_id = uuid4()
Expand All @@ -93,6 +99,14 @@ def get_scan_report_url_response(url: str, scan_id: Optional[UUID] = None) -> re
return responses.Response(method=responses.GET, url=url, json=json_response, status=200)


def get_scan_aggregation_report_url_response(url: str, aggregation_id: Optional[UUID] = None) -> responses.Response:
if not aggregation_id:
aggregation_id = uuid4()
json_response = {'report_url': f'https://app.domain/cli-logs-aggregation/{aggregation_id}'}

return responses.Response(method=responses.GET, url=url, json=json_response, status=200)


def get_scan_details_response(url: str, scan_id: Optional[UUID] = None) -> responses.Response:
if not scan_id:
scan_id = uuid4()
Expand Down
48 changes: 46 additions & 2 deletions tests/test_code_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@
import pytest
import responses

from cycode.cli.commands.scan.code_scanner import _try_get_report_url_if_needed
from cycode.cli.commands.scan.code_scanner import (
_try_get_aggregation_report_url_if_needed,
_try_get_report_url_if_needed,
)
from cycode.cli.config import config
from cycode.cli.files_collector.excluder import _is_relevant_file_to_scan
from cycode.cyclient.scan_client import ScanClient
from tests.conftest import TEST_FILES_PATH
from tests.cyclient.mocked_responses.scan_client import get_scan_report_url, get_scan_report_url_response
from tests.cyclient.mocked_responses.scan_client import (
get_scan_aggregation_report_url,
get_scan_aggregation_report_url_response,
get_scan_report_url,
get_scan_report_url_response,
)


def test_is_relevant_file_to_scan_sca() -> None:
Expand Down Expand Up @@ -37,3 +45,39 @@ def test_try_get_report_url_if_needed_return_result(
scan_report_url_response = scan_client.get_scan_report_url(str(scan_id), scan_type)
result = _try_get_report_url_if_needed(scan_client, True, str(scan_id), scan_type)
assert result == scan_report_url_response.report_url


@pytest.mark.parametrize('scan_type', config['scans']['supported_scans'])
def test_try_get_aggregation_report_url_if_no_report_command_needed_return_none(
scan_type: str, scan_client: ScanClient
) -> None:
aggregation_id = uuid4().hex
scan_parameter = {'aggregation_id': aggregation_id}
result = _try_get_aggregation_report_url_if_needed(scan_parameter, scan_client, scan_type)
assert result is None


@pytest.mark.parametrize('scan_type', config['scans']['supported_scans'])
def test_try_get_aggregation_report_url_if_no_aggregation_id_needed_return_none(
scan_type: str, scan_client: ScanClient
) -> None:
scan_parameter = {'report': True}
result = _try_get_aggregation_report_url_if_needed(scan_parameter, scan_client, scan_type)
assert result is None


@pytest.mark.parametrize('scan_type', config['scans']['supported_scans'])
@responses.activate
def test_try_get_aggregation_report_url_if_needed_return_result(
scan_type: str, scan_client: ScanClient, api_token_response: responses.Response
) -> None:
aggregation_id = uuid4()
scan_parameter = {'report': True, 'aggregation_id': aggregation_id}
url = get_scan_aggregation_report_url(aggregation_id, scan_client, scan_type)
responses.add(api_token_response) # mock token based client
responses.add(get_scan_aggregation_report_url_response(url, aggregation_id))

scan_aggregation_report_url_response = scan_client.get_scan_aggregation_report_url(str(aggregation_id), scan_type)

result = _try_get_aggregation_report_url_if_needed(scan_parameter, scan_client, scan_type)
assert result == scan_aggregation_report_url_response.report_url
Loading