Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CM-42034 - Add internal AI remediations command for IDE plugins #270

Merged
merged 6 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_executable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ ubuntu-20.04, macos-12, macos-14, windows-2019 ]
os: [ ubuntu-20.04, macos-13, macos-14, windows-2019 ]
mode: [ 'onefile', 'onedir' ]
exclude:
- os: ubuntu-20.04
Expand Down
Empty file.
67 changes: 67 additions & 0 deletions cycode/cli/commands/ai_remediation/ai_remediation_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os

import click
from patch_ng import fromstring
from rich.console import Console
from rich.markdown import Markdown

from cycode.cli.exceptions.handle_ai_remediation_errors import handle_ai_remediation_exception
from cycode.cli.models import CliResult
from cycode.cli.printers import ConsolePrinter
from cycode.cli.utils.get_api_client import get_scan_cycode_client


def _echo_remediation(context: click.Context, remediation_markdown: str, is_fix_available: bool) -> None:
printer = ConsolePrinter(context)
if printer.is_json_printer:
data = {'remediation': remediation_markdown, 'is_fix_available': is_fix_available}
printer.print_result(CliResult(success=True, message='Remediation fetched successfully', data=data))
else: # text or table
Console().print(Markdown(remediation_markdown))


def _apply_fix(context: click.Context, diff: str, is_fix_available: bool) -> None:
printer = ConsolePrinter(context)
if not is_fix_available:
printer.print_result(CliResult(success=False, message='Fix is not available for this violation'))
return

patch = fromstring(diff.encode('UTF-8'))
if patch is False:
printer.print_result(CliResult(success=False, message='Failed to parse fix diff'))
return

is_fix_applied = patch.apply(root=os.getcwd(), strip=0)
if is_fix_applied:
printer.print_result(CliResult(success=True, message='Fix applied successfully'))
else:
printer.print_result(CliResult(success=False, message='Failed to apply fix'))


@click.command(short_help='Get AI remediation (INTERNAL).', hidden=True)
@click.argument('detection_id', nargs=1, type=click.UUID, required=True)
@click.option(
'--fix',
is_flag=True,
default=False,
help='Apply fixes to resolve violations. Fix is not available for all violations.',
type=click.BOOL,
required=False,
)
@click.pass_context
def ai_remediation_command(context: click.Context, detection_id: str, fix: bool) -> None:
client = get_scan_cycode_client()

try:
remediation_markdown = client.get_ai_remediation(detection_id)
fix_diff = client.get_ai_remediation(detection_id, fix=True)
is_fix_available = bool(fix_diff) # exclude empty string, None, etc.

if fix:
_apply_fix(context, fix_diff, is_fix_available)
else:
_echo_remediation(context, remediation_markdown, is_fix_available)
except Exception as err:
handle_ai_remediation_exception(context, err)

context.exit()
2 changes: 2 additions & 0 deletions cycode/cli/commands/main_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import click

from cycode.cli.commands.ai_remediation.ai_remediation_command import ai_remediation_command
from cycode.cli.commands.auth.auth_command import auth_command
from cycode.cli.commands.configure.configure_command import configure_command
from cycode.cli.commands.ignore.ignore_command import ignore_command
Expand Down Expand Up @@ -30,6 +31,7 @@
'auth': auth_command,
'version': version_command,
'status': status_command,
'ai_remediation': ai_remediation_command,
},
context_settings=CLI_CONTEXT_SETTINGS,
)
Expand Down
2 changes: 1 addition & 1 deletion cycode/cli/commands/version/version_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from cycode.cli.consts import PROGRAM_NAME


@click.command(short_help='Show the CLI version and exit.')
@click.command(short_help='Show the CLI version and exit. Use `cycode status` instead.', deprecated=True)
@click.pass_context
def version_command(context: click.Context) -> None:
output = context.obj['output']
Expand Down
4 changes: 4 additions & 0 deletions cycode/cli/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@
SYNC_SCAN_TIMEOUT_IN_SECONDS_ENV_VAR_NAME = 'SYNC_SCAN_TIMEOUT_IN_SECONDS'
DEFAULT_SYNC_SCAN_TIMEOUT_IN_SECONDS = 180

# ai remediation
AI_REMEDIATION_TIMEOUT_IN_SECONDS_ENV_VAR_NAME = 'AI_REMEDIATION_TIMEOUT_IN_SECONDS'
DEFAULT_AI_REMEDIATION_TIMEOUT_IN_SECONDS = 60

# report with polling
REPORT_POLLING_WAIT_INTERVAL_IN_SECONDS = 5
DEFAULT_REPORT_POLLING_TIMEOUT_IN_SECONDS = 600
Expand Down
37 changes: 37 additions & 0 deletions cycode/cli/exceptions/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Optional

import click

from cycode.cli.models import CliError, CliErrors
from cycode.cli.printers import ConsolePrinter
from cycode.cli.sentry import capture_exception


def handle_errors(
context: click.Context, err: BaseException, cli_errors: CliErrors, *, return_exception: bool = False
) -> Optional['CliError']:
ConsolePrinter(context).print_exception(err)

if type(err) in cli_errors:
error = cli_errors[type(err)]

if error.soft_fail is True:
context.obj['soft_fail'] = True

if return_exception:
return error

ConsolePrinter(context).print_error(error)
return None

if isinstance(err, click.ClickException):
raise err

capture_exception(err)

unknown_error = CliError(code='unknown_error', message=str(err))
if return_exception:
return unknown_error

ConsolePrinter(context).print_error(unknown_error)
exit(1)
22 changes: 22 additions & 0 deletions cycode/cli/exceptions/handle_ai_remediation_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import click

from cycode.cli.exceptions.common import handle_errors
from cycode.cli.exceptions.custom_exceptions import KNOWN_USER_FRIENDLY_REQUEST_ERRORS, RequestHttpError
from cycode.cli.models import CliError, CliErrors


class AiRemediationNotFoundError(Exception): ...


def handle_ai_remediation_exception(context: click.Context, err: Exception) -> None:
if isinstance(err, RequestHttpError) and err.status_code == 404:
err = AiRemediationNotFoundError()

errors: CliErrors = {
**KNOWN_USER_FRIENDLY_REQUEST_ERRORS,
AiRemediationNotFoundError: CliError(
code='ai_remediation_not_found',
message='The AI remediation was not found. Please try different detection ID',
),
}
handle_errors(context, err, errors)
23 changes: 3 additions & 20 deletions cycode/cli/exceptions/handle_report_sbom_errors.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from typing import Optional

import click

from cycode.cli.exceptions import custom_exceptions
from cycode.cli.exceptions.common import handle_errors
from cycode.cli.exceptions.custom_exceptions import KNOWN_USER_FRIENDLY_REQUEST_ERRORS
from cycode.cli.models import CliError, CliErrors
from cycode.cli.printers import ConsolePrinter
from cycode.cli.sentry import capture_exception


def handle_report_exception(context: click.Context, err: Exception) -> Optional[CliError]:
ConsolePrinter(context).print_exception()

def handle_report_exception(context: click.Context, err: Exception) -> None:
errors: CliErrors = {
**KNOWN_USER_FRIENDLY_REQUEST_ERRORS,
custom_exceptions.ScanAsyncError: CliError(
Expand All @@ -25,16 +20,4 @@ def handle_report_exception(context: click.Context, err: Exception) -> Optional[
'Please try again by executing the `cycode report` command',
),
}

if type(err) in errors:
error = errors[type(err)]

ConsolePrinter(context).print_error(error)
return None

if isinstance(err, click.ClickException):
raise err

capture_exception(err)

raise click.ClickException(str(err))
handle_errors(context, err, errors)
33 changes: 4 additions & 29 deletions cycode/cli/exceptions/handle_scan_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@
import click

from cycode.cli.exceptions import custom_exceptions
from cycode.cli.exceptions.common import handle_errors
from cycode.cli.exceptions.custom_exceptions import KNOWN_USER_FRIENDLY_REQUEST_ERRORS
from cycode.cli.models import CliError, CliErrors
from cycode.cli.printers import ConsolePrinter
from cycode.cli.sentry import capture_exception
from cycode.cli.utils.git_proxy import git_proxy


def handle_scan_exception(
context: click.Context, e: Exception, *, return_exception: bool = False
context: click.Context, err: Exception, *, return_exception: bool = False
) -> Optional[CliError]:
context.obj['did_fail'] = True

ConsolePrinter(context).print_exception(e)

errors: CliErrors = {
**KNOWN_USER_FRIENDLY_REQUEST_ERRORS,
custom_exceptions.ScanAsyncError: CliError(
Expand All @@ -35,7 +32,7 @@ def handle_scan_exception(
custom_exceptions.TfplanKeyError: CliError(
soft_fail=True,
code='key_error',
message=f'\n{e!s}\n'
message=f'\n{err!s}\n'
'A crucial field is missing in your terraform plan file. '
'Please make sure that your file is well formed '
'and execute the scan again',
Expand All @@ -48,26 +45,4 @@ def handle_scan_exception(
),
}

if type(e) in errors:
error = errors[type(e)]

if error.soft_fail is True:
context.obj['soft_fail'] = True

if return_exception:
return error

ConsolePrinter(context).print_error(error)
return None

if isinstance(e, click.ClickException):
raise e

capture_exception(e)

unknown_error = CliError(code='unknown_error', message=str(e))
if return_exception:
return unknown_error

ConsolePrinter(context).print_error(unknown_error)
exit(1)
return handle_errors(context, err, errors, return_exception=return_exception)
2 changes: 1 addition & 1 deletion cycode/cli/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class CliError(NamedTuple):
soft_fail: bool = False


CliErrors = Dict[Type[Exception], CliError]
CliErrors = Dict[Type[BaseException], CliError]


class CliResult(NamedTuple):
Expand Down
12 changes: 12 additions & 0 deletions cycode/cli/printers/console_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,15 @@ def print_exception(self, e: Optional[BaseException] = None, force_print: bool =
"""Print traceback message in stderr if verbose mode is set."""
if force_print or self.context.obj.get('verbose', False):
self._printer_class(self.context).print_exception(e)

@property
def is_json_printer(self) -> bool:
return self._printer_class == JsonPrinter

@property
def is_table_printer(self) -> bool:
return self._printer_class == TablePrinter

@property
def is_text_printer(self) -> bool:
return self._printer_class == TextPrinter
7 changes: 7 additions & 0 deletions cycode/cli/user_settings/configuration_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def get_sync_scan_timeout_in_seconds(self) -> int:
)
)

def get_ai_remediation_timeout_in_seconds(self) -> int:
return int(
self._get_value_from_environment_variables(
consts.AI_REMEDIATION_TIMEOUT_IN_SECONDS_ENV_VAR_NAME, consts.DEFAULT_AI_REMEDIATION_TIMEOUT_IN_SECONDS
)
)

def get_report_polling_timeout_in_seconds(self) -> int:
return int(
self._get_value_from_environment_variables(
Expand Down
3 changes: 3 additions & 0 deletions cycode/cyclient/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ def __init__(
detection_details: dict,
detection_rule_id: str,
severity: Optional[str] = None,
id: Optional[str] = None,
) -> None:
super().__init__()
self.id = id
self.message = message
self.type = type
self.severity = severity
Expand All @@ -36,6 +38,7 @@ class DetectionSchema(Schema):
class Meta:
unknown = EXCLUDE

id = fields.String(missing=None)
message = fields.String()
type = fields.String()
severity = fields.String(missing=None)
Expand Down
22 changes: 22 additions & 0 deletions cycode/cyclient/scan_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,28 @@ def get_supported_modules_preferences(self) -> models.SupportedModulesPreference
response = self.scan_cycode_client.get(url_path='preferences/api/v1/supportedmodules')
return models.SupportedModulesPreferencesSchema().load(response.json())

@staticmethod
def get_ai_remediation_path(detection_id: str) -> str:
return f'scm-remediator/api/v1/ContentRemediation/preview/{detection_id}'

def get_ai_remediation(self, detection_id: str, *, fix: bool = False) -> str:
path = self.get_ai_remediation_path(detection_id)

data = {
'resolving_parameters': {
'get_diff': True,
'use_code_snippet': True,
'add_diff_header': True,
}
}
if not fix:
data['resolving_parameters']['remediation_action'] = 'ReplyWithRemediationDetails'

response = self.scan_cycode_client.get(
url_path=path, json=data, timeout=configuration_manager.get_ai_remediation_timeout_in_seconds()
)
return response.text.strip()

@staticmethod
def _get_policy_type_by_scan_type(scan_type: str) -> str:
scan_type_to_policy_type = {
Expand Down
Loading
Loading