diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 073facca..443c5e72 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -13,7 +13,8 @@ "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind", - "source=${localEnv:HOME}${localEnv:USERPROFILE}/.ssh,target=/home/developer/.ssh,type=bind,consistency=cached" + "source=${localEnv:HOME}${localEnv:USERPROFILE}/.ssh,target=/home/developer/.ssh,type=bind,consistency=cached", + "source=${localEnv:HOME}/.safety,target=/home/developer/.safety,type=bind,consistency=cached" ], "remoteEnv": { diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 9cd276d7..f4583b89 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -43,20 +43,23 @@ jobs: #### Quick Test with Python Package \`\`\`bash # Download and run with uv - gh run download ${context.runId} -n dist + gh run download ${context.runId} -n dist -R pyupio/safety uv run --with safety-${version}-py3-none-any.whl safety --version \`\`\` #### Binary Installation \`\`\`bash # Linux - gh run download ${context.runId} -n safety-linux -D linux + gh run download ${context.runId} -n safety-linux -D linux -R pyupio/safety cd linux && mv safety safety-pr && chmod +x safety-pr # macOS - gh run download ${context.runId} -n safety-macos -D macos + gh run download ${context.runId} -n safety-macos -D macos -R pyupio/safety cd macos && mv safety safety-pr && chmod +x safety-pr + # Windows + gh run download ${context.runId} -n safety-windows -D windows -R pyupio/safety + cd windows && mv safety.exe safety-pr.exe ./safety-pr --version \`\`\` diff --git a/.github/workflows/reusable-build.yml b/.github/workflows/reusable-build.yml index 5968bb20..0fe68b55 100644 --- a/.github/workflows/reusable-build.yml +++ b/.github/workflows/reusable-build.yml @@ -49,7 +49,7 @@ jobs: exit 1 fi BRANCH_NAME="${{ inputs.branch-name }}" - SLUG=$(echo "$BRANCH_NAME" | iconv -t ascii//TRANSLIT | sed -r s/[^a-zA-Z0-9]+/-/g | sed -r s/^-+\|-+$//g | tr A-Z a-z) + SLUG=$(echo "$BRANCH_NAME" | iconv -t ascii//TRANSLIT | sed -r 's/[^a-zA-Z0-9]+/./g' | sed -r 's/^.+\|.+$//g' | tr A-Z a-z) echo "SLUG=$SLUG" >> $GITHUB_OUTPUT - name: Version bump (PR) diff --git a/.vscode/launch.json b/.vscode/launch.json index e63fffac..51d31db1 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,7 +10,8 @@ // This uses the default environment which is a virtual environment // created by Hatch "python": "${workspaceFolder}/.hatch/bin/python", - "console": "integratedTerminal" + "console": "integratedTerminal", + "justMyCode": false, } ], "inputs": [ @@ -64,15 +65,26 @@ "auth login", "auth login --headless", "auth logout", + "auth status", // Scan commands "scan", - "--key ADD-YOUR-API-KEY scan", + "--key $SAFETY_API_KEY scan", + "--stage cicd --key $SAFETY_API_KEY scan", "scan --use-server-matching", "scan --detailed-output", "--debug scan", "--disable-optional-telemetry scan", "scan --output json --output-file json", + "scan --help", + + // Firewall commands + "init --help", + "init local_prj", // Directory has to be created manually + "init", + "pip list", + "pip install insecure-package", + "pip install fastapi", // Check commands "check", @@ -80,7 +92,10 @@ // Other commands "license", - "--help" + "--help", + "validate --help", + "--key foo --help", + "configure" ], "default": "scan" } diff --git a/docs/.ipynb_checkpoints/Safety-CLI-Quickstart-checkpoint.ipynb b/docs/.ipynb_checkpoints/Safety-CLI-Quickstart-checkpoint.ipynb index 363fcab7..4ea4f7b1 100644 --- a/docs/.ipynb_checkpoints/Safety-CLI-Quickstart-checkpoint.ipynb +++ b/docs/.ipynb_checkpoints/Safety-CLI-Quickstart-checkpoint.ipynb @@ -1,5 +1,12 @@ { - "cells": [], + "cells": [ + { + "metadata": {}, + "cell_type": "raw", + "source": "", + "id": "e4a30302820cf149" + } + ], "metadata": {}, "nbformat": 4, "nbformat_minor": 5 diff --git a/pyproject.toml b/pyproject.toml index 3c7351c1..a5b8c29a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "setuptools>=65.5.1", "typer>=0.12.1", "typing-extensions>=4.7.1", + "python-levenshtein>=0.25.1", ] license = "MIT" license-files = ["LICENSES/*"] @@ -255,6 +256,8 @@ reportMissingTypeStubs = false addopts = "--strict-markers" markers = [ "basic: requires no extras", + "windows_only: mark test to run only on Windows platforms", + "unix_only: mark test to run only on Unix platforms" ] [tool.coverage.run] diff --git a/safety/alerts/__init__.py b/safety/alerts/__init__.py index 28972b3c..5e438d1f 100644 --- a/safety/alerts/__init__.py +++ b/safety/alerts/__init__.py @@ -6,6 +6,8 @@ from dataclasses import dataclass +from safety.constants import CONTEXT_COMMAND_TYPE + from . import github from safety.util import SafetyPolicyFile from safety.scan.constants import CLI_ALERT_COMMAND_HELP @@ -17,6 +19,10 @@ def get_safety_cli_legacy_group(): from safety.cli_util import SafetyCLILegacyGroup return SafetyCLILegacyGroup +def get_context_settings(): + from safety.cli_util import CommandType + return {CONTEXT_COMMAND_TYPE: CommandType.UTILITY} + @dataclass class Alert: """ @@ -33,7 +39,8 @@ class Alert: policy: Any = None requirements_files: Any = None -@click.group(cls=get_safety_cli_legacy_group(), help=CLI_ALERT_COMMAND_HELP, deprecated=True, utility_command=True) +@click.group(cls=get_safety_cli_legacy_group(), help=CLI_ALERT_COMMAND_HELP, + deprecated=True, context_settings=get_context_settings()) @click.option('--check-report', help='JSON output of Safety Check to work with.', type=click.File('r'), default=sys.stdin, required=True) @click.option("--key", envvar="SAFETY_API_KEY", help="API Key for safetycli.com's vulnerability database. Can be set as SAFETY_API_KEY " diff --git a/safety/auth/cli.py b/safety/auth/cli.py index 7320d12f..9bf77ea9 100644 --- a/safety/auth/cli.py +++ b/safety/auth/cli.py @@ -1,11 +1,12 @@ -from datetime import datetime import logging import sys -from safety.auth.models import Auth +from datetime import datetime -from safety.auth.utils import is_email_verified +from safety.auth.models import Auth +from safety.auth.utils import initialize, is_email_verified from safety.console import main_console as console from safety.constants import MSG_FINISH_REGISTRATION_TPL, MSG_VERIFICATION_HINT +from safety.meta import get_version try: from typing import Annotated @@ -15,22 +16,38 @@ from typing import Optional import click -from typer import Typer import typer +from rich.padding import Padding +from typer import Typer -from safety.auth.main import get_auth_info, get_authorization_data, get_token, clean_session +from safety.auth.main import ( + clean_session, + get_auth_info, + get_authorization_data, + get_token, +) from safety.auth.server import process_browser_callback -from ..cli_util import get_command_for, pass_safety_cli_obj, SafetyCLISubGroup - -from .constants import MSG_FAIL_LOGIN_AUTHED, MSG_FAIL_REGISTER_AUTHED, MSG_LOGOUT_DONE, MSG_LOGOUT_FAILED, MSG_NON_AUTHENTICATED -from safety.scan.constants import CLI_AUTH_COMMAND_HELP, CLI_AUTH_HEADLESS_HELP, DEFAULT_EPILOG, CLI_AUTH_LOGIN_HELP, CLI_AUTH_LOGOUT_HELP, CLI_AUTH_STATUS_HELP - - -from rich.padding import Padding +from safety.scan.constants import ( + CLI_AUTH_COMMAND_HELP, + CLI_AUTH_HEADLESS_HELP, + CLI_AUTH_LOGIN_HELP, + CLI_AUTH_LOGOUT_HELP, + CLI_AUTH_STATUS_HELP, + DEFAULT_EPILOG, +) + +from ..cli_util import SafetyCLISubGroup, get_command_for, pass_safety_cli_obj +from .constants import ( + MSG_FAIL_LOGIN_AUTHED, + MSG_FAIL_REGISTER_AUTHED, + MSG_LOGOUT_DONE, + MSG_LOGOUT_FAILED, + MSG_NON_AUTHENTICATED, +) LOG = logging.getLogger(__name__) -auth_app = Typer(rich_markup_mode="rich") +auth_app = Typer(rich_markup_mode="rich", name="auth") @@ -183,6 +200,8 @@ def login( render_successful_login(ctx.obj.auth, organization=organization) + initialize(ctx, refresh=True) + console.print() if ctx.obj.auth.org or ctx.obj.auth.email_verified: console.print( @@ -249,12 +268,13 @@ def status(ctx: typer.Context, ensure_auth: bool = False, """ LOG.info('status started') current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - from safety.util import get_safety_version - safety_version = get_safety_version() + safety_version = get_version() console.print(f"[{current_time}]: Safety {safety_version}") info = get_auth_info(ctx) + initialize(ctx, refresh=True) + if ensure_auth: console.print("running: safety auth status --ensure-auth") console.print() diff --git a/safety/auth/cli_utils.py b/safety/auth/cli_utils.py index cc0cca43..1eac539e 100644 --- a/safety/auth/cli_utils.py +++ b/safety/auth/cli_utils.py @@ -12,11 +12,9 @@ from safety.auth.utils import S3PresignedAdapter, SafetyAuthSession, get_keys, is_email_verified from safety.constants import REQUEST_TIMEOUT from safety.scan.constants import CLI_KEY_HELP, CLI_PROXY_HOST_HELP, CLI_PROXY_PORT_HELP, CLI_PROXY_PROTOCOL_HELP, CLI_STAGE_HELP -from safety.scan.util import Stage from safety.util import DependentOption, SafetyContext, get_proxy_dict - -from functools import wraps - +from safety.models import SafetyCLI +from safety_schemas.models import Stage LOG = logging.getLogger(__name__) @@ -89,7 +87,7 @@ def load_auth_session(click_ctx: click.Context) -> None: click_ctx (click.Context): The Click context object. """ if not click_ctx: - LOG.warn("Click context is needed to be able to load the Auth data.") + LOG.warning("Click context is needed to be able to load the Auth data.") return client = click_ctx.obj.auth.client @@ -160,89 +158,57 @@ def decorator(func: Callable) -> Callable: return decorator -def inject_session(func: Callable) -> Callable: - """ - Decorator that injects a session object into Click commands. +def inject_session(ctx: click.Context, proxy_protocol: Optional[str] = None, + proxy_host: Optional[str] = None, + proxy_port: Optional[str] = None, + key: Optional[str] = None, + stage: Optional[Stage] = None, + invoked_command: str = "") -> Any: - Builds the session object to be used in each command. + # Skip injection for specific commands that do not require authentication + if invoked_command in ["configure"]: + return - Args: - func (Callable): The Click command function. + org: Optional[Organization] = get_organization() - Returns: - Callable: The wrapped Click command function with session injection. - """ - @wraps(func) - def inner(ctx: click.Context, proxy_protocol: Optional[str] = None, - proxy_host: Optional[str] = None, - proxy_port: Optional[str] = None, - key: Optional[str] = None, - stage: Optional[Stage] = None, *args, **kwargs) -> Any: - """ - Inner function that performs the session injection. - - Args: - ctx (click.Context): The Click context object. - proxy_protocol (Optional[str]): The proxy protocol. - proxy_host (Optional[str]): The proxy host. - proxy_port (Optional[int]): The proxy port. - key (Optional[str]): The API key. - stage (Optional[Stage]): The stage. - *args (Any): Additional arguments. - **kwargs (Any): Additional keyword arguments. - - Returns: - Any: The result of the decorated function. - """ - - if ctx.invoked_subcommand == "configure": - return - - org: Optional[Organization] = get_organization() - - if not stage: - host_stage = get_host_config(key_name="stage") - stage = host_stage if host_stage else Stage.development - - proxy_config: Optional[Dict[str, str]] = get_proxy_dict(proxy_protocol, - proxy_host, proxy_port) - - client_session, openid_config = build_client_session(api_key=key, - proxies=proxy_config) - keys = get_keys(client_session, openid_config) - - auth = Auth( - stage=stage, - keys=keys, - org=org, - client_id=CLIENT_ID, - client=client_session, - code_verifier=generate_token(48) - ) - - if not ctx.obj: - from safety.models import SafetyCLI - ctx.obj = SafetyCLI() - - ctx.obj.auth=auth - - load_auth_session(ctx) - - info = get_auth_info(ctx) - - if info: - ctx.obj.auth.name = info.get("name") - ctx.obj.auth.email = info.get("email") - ctx.obj.auth.email_verified = is_email_verified(info) - SafetyContext().account = info["email"] - else: - SafetyContext().account = "" - - @ctx.call_on_close - def clean_up_on_close(): - LOG.debug('Closing requests session.') - ctx.obj.auth.client.close() - - return func(ctx, *args, **kwargs) - - return inner + if not stage: + host_stage = get_host_config(key_name="stage") + stage = host_stage if host_stage else Stage.development + + proxy_config: Optional[Dict[str, str]] = get_proxy_dict(proxy_protocol, + proxy_host, proxy_port) + + client_session, openid_config = build_client_session(api_key=key, + proxies=proxy_config) + keys = get_keys(client_session, openid_config) + + auth = Auth( + stage=stage, + keys=keys, + org=org, + client_id=CLIENT_ID, + client=client_session, + code_verifier=generate_token(48) + ) + + if not ctx.obj: + ctx.obj = SafetyCLI() + + ctx.obj.auth = auth + + load_auth_session(ctx) + + info = get_auth_info(ctx) + + if info: + ctx.obj.auth.name = info.get("name") + ctx.obj.auth.email = info.get("email") + ctx.obj.auth.email_verified = is_email_verified(info) + SafetyContext().account = info["email"] + else: + SafetyContext().account = "" + + @ctx.call_on_close + def clean_up_on_close(): + LOG.debug('Closing requests session.') + ctx.obj.auth.client.close() diff --git a/safety/auth/main.py b/safety/auth/main.py index fdb9c225..0c994648 100644 --- a/safety/auth/main.py +++ b/safety/auth/main.py @@ -9,7 +9,7 @@ from safety.auth.models import Organization from safety.auth.constants import CLI_AUTH_LOGOUT, CLI_CALLBACK, AUTH_CONFIG_USER, CLI_AUTH from safety.constants import CONFIG -from safety.scan.util import Stage +from safety_schemas.models import Stage from safety.util import get_proxy_dict diff --git a/safety/auth/models.py b/safety/auth/models.py index 3312dedc..965d1da0 100644 --- a/safety/auth/models.py +++ b/safety/auth/models.py @@ -63,6 +63,21 @@ def refresh_from(self, info: Dict) -> None: self.email = info.get("email") self.email_verified = is_email_verified(info) + def get_auth_method(self) -> str: + """ + Get the authentication method. + + Returns: + str: The authentication method. + """ + if self.client.api_key: + return "API Key" + + if self.client.token: + return "Token" + + return "None" + class XAPIKeyAuth(BaseOAuth): def __init__(self, api_key: str) -> None: """ diff --git a/safety/auth/utils.py b/safety/auth/utils.py index 390cc26b..96938f26 100644 --- a/safety/auth/utils.py +++ b/safety/auth/utils.py @@ -12,7 +12,7 @@ ) from safety.constants import ( PLATFORM_API_CHECK_UPDATES_ENDPOINT, - PLATFORM_API_INITIALIZE_SCAN_ENDPOINT, + PLATFORM_API_INITIALIZE_ENDPOINT, PLATFORM_API_POLICY_ENDPOINT, PLATFORM_API_PROJECT_CHECK_ENDPOINT, PLATFORM_API_PROJECT_ENDPOINT, @@ -20,7 +20,10 @@ PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT, PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT, REQUEST_TIMEOUT, + FeatureType, + get_config_setting ) +from safety.models import SafetyCLI from safety.scan.util import AuthenticationType from safety.util import SafetyContext, output_exception @@ -99,9 +102,17 @@ def wrapper(*args, **kwargs): except requests.exceptions.RequestException as e: raise e + # TODO: Handle content as JSON and fallback to text for all responses + if r.status_code == 403: + reason = None + try: + reason = r.json().get("detail") + except Exception: + LOG.debug("Failed to parse 403 response: %s", r.text) + raise InvalidCredentialError( - credential="Failed authentication.", reason=r.text + credential="Failed authentication.", reason=reason ) if r.status_code == 429: @@ -456,20 +467,22 @@ def check_updates( return self.get(url=PLATFORM_API_CHECK_UPDATES_ENDPOINT, params=data) @parse_response - def initialize_scan(self) -> Any: + def initialize(self) -> Any: """ - Initialize a scan. + Initialize a run. Returns: Any: The initialization result. """ try: - response = self.get(url=PLATFORM_API_INITIALIZE_SCAN_ENDPOINT, timeout=5) + response = self.get(url=PLATFORM_API_INITIALIZE_ENDPOINT, + headers={"Content-Type": "application/json"}, + timeout=5) return response except requests.exceptions.Timeout: - LOG.error("Auth request to initialize scan timed out after 5 seconds.") - except Exception as e: - LOG.exception("Exception trying to auth initialize scan", exc_info=True) + LOG.error("Auth request to initialize timed out after 5 seconds.") + except Exception: + LOG.exception("Exception trying to auth initialize", exc_info=True) return None @@ -591,3 +604,113 @@ def is_jupyter_notebook() -> bool: pass return False + + +def save_flags_config(flags: Dict[FeatureType, bool]) -> None: + """ + Save feature flags configuration to file. + + This function attempts to save feature flags to the configuration file + but will fail silently if unable to do so (e.g., due to permission issues + or disk problems). Silent failure is chosen to prevent configuration issues + from disrupting core application functionality. + + Note that if saving fails, the application will continue using existing + or default flag values until the next restart. + + Args: + flags: Dictionary mapping feature types to their enabled/disabled state + + The operation will be logged (with stack trace) if it fails. + """ + import configparser + from safety.constants import CONFIG_FILE_USER + + config = configparser.ConfigParser() + config.read(CONFIG_FILE_USER) + + flag_settings = {key.name.upper(): str(value) for key, value in flags.items()} + + if not config.has_section('settings'): + config.add_section('settings') + + settings = dict(config.items('settings')) + settings.update(flag_settings) + + for key, value in settings.items(): + config.set('settings', key, value) + + try: + with open(CONFIG_FILE_USER, 'w') as config_file: + config.write(config_file) + except Exception: + LOG.exception("Unable to save flags configuration.") + + +def get_feature_name(feature: FeatureType, as_attr: bool = False) -> str: + """Returns a formatted feature name with enabled suffix. + + Args: + feature: The feature to format the name for + as_attr: If True, formats for attribute usage (underscore), + otherwise uses hyphen + + Returns: + Formatted feature name string with enabled suffix + """ + name = feature.name.lower() + separator = '_' if as_attr else '-' + return f"{name}{separator}enabled" + + +def str_to_bool(value) -> Optional[bool]: + """Convert basic string representations to boolean.""" + if isinstance(value, bool): + return value + + if isinstance(value, str): + value = value.lower().strip() + if value in ('true'): + return True + if value in ('false'): + return False + + return None + + +def initialize(ctx: Any, refresh: bool = True) -> None: + """ + Initializes the run by loading settings. + + Args: + ctx (Any): The context object. + refresh (bool): Whether to refresh settings from the server. Defaults to True. + """ + settings = None + current_values = {} + + if not ctx.obj: + ctx.obj = SafetyCLI() + + for feature in FeatureType: + value = get_config_setting(feature.name) + current_values[feature] = str_to_bool(value) + + if refresh: + try: + settings = ctx.obj.auth.client.initialize() + except Exception: + LOG.info("Unable to initialize, continue with default values.") + + if settings: + for feature in FeatureType: + server_value = str_to_bool(settings.get(feature.config_key)) + if server_value is not None: + if current_values[feature] != server_value: + current_values[feature] = server_value + + save_flags_config(current_values) + + for feature, value in current_values.items(): + if value is not None: + setattr(ctx.obj, feature.attr_name, value) diff --git a/safety/cli.py b/safety/cli.py index 36099f3b..b0e1ccdc 100644 --- a/safety/cli.py +++ b/safety/cli.py @@ -2,7 +2,7 @@ from __future__ import absolute_import import configparser from dataclasses import asdict -from datetime import date, datetime +from datetime import date, datetime, timedelta from enum import Enum import requests import time @@ -14,45 +14,59 @@ import platform import sys from functools import wraps -from typing import Dict, Optional from packaging import version as packaging_version from packaging.version import InvalidVersion import click import typer +from safety_schemas.models.config import VulnerabilityDefinition from safety import safety from safety.console import main_console as console from safety.alerts import alert -from safety.auth import auth, inject_session, proxy_options, auth_options +from safety.auth import proxy_options, auth_options from safety.auth.models import Organization -from safety.scan.constants import CLI_LICENSES_COMMAND_HELP, CLI_MAIN_INTRODUCTION, CLI_DEBUG_HELP, CLI_DISABLE_OPTIONAL_TELEMETRY_DATA_HELP, \ - DEFAULT_EPILOG, DEFAULT_SPINNER, CLI_CHECK_COMMAND_HELP, CLI_CHECK_UPDATES_HELP, CLI_CONFIGURE_HELP, CLI_GENERATE_HELP, \ - CLI_CONFIGURE_PROXY_TIMEOUT, CLI_CONFIGURE_PROXY_REQUIRED, CLI_CONFIGURE_ORGANIZATION_ID, CLI_CONFIGURE_ORGANIZATION_NAME, \ - CLI_CONFIGURE_SAVE_TO_SYSTEM, CLI_CONFIGURE_PROXY_HOST_HELP, CLI_CONFIGURE_PROXY_PORT_HELP, CLI_CONFIGURE_PROXY_PROTOCOL_HELP, \ +from safety.pip.command import pip_app +from safety.init.command import init_app +from safety.scan import command +from safety.scan.constants import CLI_LICENSES_COMMAND_HELP, CLI_MAIN_INTRODUCTION, CLI_DEBUG_HELP, \ + CLI_DISABLE_OPTIONAL_TELEMETRY_DATA_HELP, \ + DEFAULT_EPILOG, DEFAULT_SPINNER, CLI_CHECK_COMMAND_HELP, CLI_CHECK_UPDATES_HELP, CLI_CONFIGURE_HELP, \ + CLI_GENERATE_HELP, CLI_GENERATE_MINIMUM_CVSS_SEVERITY, \ + CLI_CONFIGURE_PROXY_TIMEOUT, CLI_CONFIGURE_PROXY_REQUIRED, CLI_CONFIGURE_ORGANIZATION_ID, \ + CLI_CONFIGURE_ORGANIZATION_NAME, \ + CLI_CONFIGURE_SAVE_TO_SYSTEM, CLI_CONFIGURE_PROXY_HOST_HELP, CLI_CONFIGURE_PROXY_PORT_HELP, \ + CLI_CONFIGURE_PROXY_PROTOCOL_HELP, \ CLI_GENERATE_PATH -from .cli_util import SafetyCLICommand, SafetyCLILegacyGroup, SafetyCLILegacyCommand, SafetyCLISubGroup, SafetyCLIUtilityCommand, handle_cmd_exception -from safety.constants import BAR_LINE, CONFIG_FILE_USER, CONFIG_FILE_SYSTEM, EXIT_CODE_VULNERABILITIES_FOUND, EXIT_CODE_OK, EXIT_CODE_FAILURE +from .cli_util import CommandType, SafetyCLICommand, SafetyCLILegacyGroup, SafetyCLILegacyCommand, SafetyCLISubGroup, \ + handle_cmd_exception +from safety.constants import BAR_LINE, CONFIG_FILE_USER, CONFIG_FILE_SYSTEM, EXIT_CODE_VULNERABILITIES_FOUND, \ + EXIT_CODE_OK, EXIT_CODE_FAILURE, CONTEXT_COMMAND_TYPE from safety.errors import InvalidCredentialError, SafetyException, SafetyError from safety.formatter import SafetyFormatter -from safety.models import SafetyCLI from safety.output_utils import should_add_nl -from safety.safety import get_packages, read_vulnerabilities, process_fixes +from safety.safety import get_packages, process_fixes +from safety.scan.finder import FileFinder +from safety.scan.main import process_files from safety.util import get_packages_licenses, initializate_config_dirs, output_exception, \ MutuallyExclusiveOption, DependentOption, transform_ignore, SafetyPolicyFile, active_color_if_needed, \ - get_processed_options, get_safety_version, json_alias, bare_alias, html_alias, SafetyContext, is_a_remote_mirror, \ + get_processed_options, json_alias, bare_alias, html_alias, SafetyContext, is_a_remote_mirror, \ filter_announcements, get_fix_options +from safety.meta import get_version from safety.scan.command import scan_project_app, scan_system_app from safety.auth.cli import auth_app -from safety_schemas.models import ConfigModel, Stage +from safety.firewall.command import firewall_app +from safety_schemas.config.schemas.v3_0 import main as v3_0 +from safety_schemas.models import ConfigModel, Stage, Ecosystem, VulnerabilitySeverityLabels try: - from typing import Annotated + from typing import Annotated, Optional except ImportError: - from typing_extensions import Annotated + from typing_extensions import Annotated, Optional LOG = logging.getLogger(__name__) + def get_network_telemetry(): import psutil import socket @@ -78,10 +92,10 @@ def get_network_telemetry(): network_info['download_speed'] = None network_info['error'] = str(e) - # Get network addresses net_if_addrs = psutil.net_if_addrs() - network_info['interfaces'] = {iface: [addr.address for addr in addrs if addr.family == socket.AF_INET] for iface, addrs in net_if_addrs.items()} + network_info['interfaces'] = {iface: [addr.address for addr in addrs if addr.family == socket.AF_INET] for + iface, addrs in net_if_addrs.items()} # Get network connections net_connections = psutil.net_connections(kind='inet') @@ -113,6 +127,7 @@ def get_network_telemetry(): return network_info + def preprocess_args(f): if '--debug' in sys.argv: index = sys.argv.index('--debug') @@ -122,6 +137,7 @@ def preprocess_args(f): sys.argv.pop(index + 1) # Remove the next argument (1 or true) return f + def configure_logger(ctx, param, debug): level = logging.CRITICAL @@ -148,14 +164,15 @@ def configure_logger(ctx, param, debug): network_telemetry = get_network_telemetry() LOG.debug('Network telemetry: %s', network_telemetry) + @click.group(cls=SafetyCLILegacyGroup, help=CLI_MAIN_INTRODUCTION, epilog=DEFAULT_EPILOG) @auth_options() @proxy_options -@click.option('--disable-optional-telemetry', default=False, is_flag=True, show_default=True, help=CLI_DISABLE_OPTIONAL_TELEMETRY_DATA_HELP) +@click.option('--disable-optional-telemetry', default=False, is_flag=True, show_default=True, + help=CLI_DISABLE_OPTIONAL_TELEMETRY_DATA_HELP) @click.option('--debug', is_flag=True, help=CLI_DEBUG_HELP, callback=configure_logger) -@click.version_option(version=get_safety_version()) +@click.version_option(version=get_version()) @click.pass_context -@inject_session @preprocess_args def cli(ctx, debug, disable_optional_telemetry): """ @@ -180,6 +197,7 @@ def clean_check_command(f): """ Main entry point for validation. """ + @wraps(f) def inner(ctx, *args, **kwargs): @@ -200,7 +218,7 @@ def inner(ctx, *args, **kwargs): kwargs.pop('proxy_port', None) if ctx.get_parameter_source("json_version") != click.core.ParameterSource.DEFAULT and not ( - save_json or json or output == 'json'): + save_json or json or output == 'json'): raise click.UsageError( "Illegal usage: `--json-version` only works with JSON related outputs." ) @@ -209,18 +227,20 @@ def inner(ctx, *args, **kwargs): if ctx.get_parameter_source("apply_remediations") != click.core.ParameterSource.DEFAULT: if not authenticated: - raise InvalidCredentialError(message="The --apply-security-updates option needs authentication. See {link}.") + raise InvalidCredentialError( + message="The --apply-security-updates option needs authentication. See {link}.") if not files: raise SafetyError(message='--apply-security-updates only works with files; use the "-r" option to ' 'specify files to remediate.') auto_remediation_limit = get_fix_options(policy_file, auto_remediation_limit) - policy_file, server_audit_and_monitor = safety.get_server_policies(ctx.obj.auth.client, policy_file=policy_file, + policy_file, server_audit_and_monitor = safety.get_server_policies(ctx.obj.auth.client, + policy_file=policy_file, proxy_dictionary=None) audit_and_monitor = (audit_and_monitor and server_audit_and_monitor) kwargs.update({"auto_remediation_limit": auto_remediation_limit, - "policy_file":policy_file, + "policy_file": policy_file, "audit_and_monitor": audit_and_monitor}) except SafetyError as e: @@ -235,9 +255,10 @@ def inner(ctx, *args, **kwargs): return inner + def print_deprecation_message( - old_command: str, - deprecation_date: datetime, + old_command: str, + deprecation_date: datetime, new_command: Optional[str] = None ) -> None: """ @@ -262,28 +283,34 @@ def print_deprecation_message( click.echo(click.style(BAR_LINE, fg="yellow", bold=True)) click.echo("\n") click.echo(click.style("DEPRECATED: ", fg="red", bold=True) + - click.style(f"this command (`{old_command}`) has been DEPRECATED, and will be unsupported beyond {deprecation_date.strftime('%d %B %Y')}.", fg="yellow", bold=True)) - + click.style( + f"this command (`{old_command}`) has been DEPRECATED, and will be unsupported beyond {deprecation_date.strftime('%d %B %Y')}.", + fg="yellow", bold=True)) + if new_command: click.echo("\n") click.echo(click.style("We highly encourage switching to the new ", fg="green") + click.style(f"`{new_command}`", fg="green", bold=True) + - click.style(" command which is easier to use, more powerful, and can be set up to mimic the deprecated command if required.", fg="green")) - + click.style( + " command which is easier to use, more powerful, and can be set up to mimic the deprecated command if required.", + fg="green")) + click.echo("\n") click.echo(click.style(BAR_LINE, fg="yellow", bold=True)) click.echo("\n") - -@cli.command(cls=SafetyCLILegacyCommand, utility_command=True, help=CLI_CHECK_COMMAND_HELP) +@cli.command(cls=SafetyCLILegacyCommand, + context_settings={CONTEXT_COMMAND_TYPE: CommandType.UTILITY}, + help=CLI_CHECK_COMMAND_HELP) @proxy_options @auth_options(stage=False) @click.option("--db", default="", help="Path to a local or remote vulnerability database. Default: empty") @click.option("--full-report/--short-report", default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["output", "json", "bare"], - with_values={"output": ['json', 'bare'], "json": [True, False], "html": [True, False], "bare": [True, False]}, + with_values={"output": ['json', 'bare'], "json": [True, False], "html": [True, False], + "bare": [True, False]}, help='Full reports include a security advisory (if available). Default: --short-report') @click.option("--cache", is_flag=False, flag_value=60, default=0, help="Cache requests to the vulnerability database locally. Default: 0 seconds", @@ -298,10 +325,12 @@ def print_deprecation_message( @click.option("ignore_unpinned_requirements", "--ignore-unpinned-requirements/--check-unpinned-requirements", "-iur", default=None, help="Check or ignore unpinned requirements found.") @click.option('--json', default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["output", "bare"], - with_values={"output": ['screen', 'text', 'bare', 'json', 'html'], "bare": [True, False]}, callback=json_alias, + with_values={"output": ['screen', 'text', 'bare', 'json', 'html'], "bare": [True, False]}, + callback=json_alias, hidden=True, is_flag=True, show_default=True) @click.option('--html', default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["output", "bare"], - with_values={"output": ['screen', 'text', 'bare', 'json', 'html'], "bare": [True, False]}, callback=html_alias, + with_values={"output": ['screen', 'text', 'bare', 'json', 'html'], "bare": [True, False]}, + callback=html_alias, hidden=True, is_flag=True, show_default=True) @click.option('--bare', default=False, cls=MutuallyExclusiveOption, mutually_exclusive=["output", "json"], with_values={"output": ['screen', 'text', 'bare', 'json'], "json": [True, False]}, callback=bare_alias, @@ -368,9 +397,11 @@ def check(ctx, db, full_report, stdin, files, cache, ignore, ignore_unpinned_req 'ignore_unpinned_requirements': ignore_unpinned_requirements} LOG.info('Calling the check function') - vulns, db_full = safety.check(session=ctx.obj.auth.client, packages=packages, db_mirror=db, cached=cache, ignore_vulns=ignore, + vulns, db_full = safety.check(session=ctx.obj.auth.client, packages=packages, db_mirror=db, cached=cache, + ignore_vulns=ignore, ignore_severity_rules=ignore_severity_rules, proxy=None, - include_ignored=True, is_env_scan=is_env_scan, telemetry=ctx.obj.config.telemetry_enabled, + include_ignored=True, is_env_scan=is_env_scan, + telemetry=ctx.obj.config.telemetry_enabled, params=params) LOG.debug('Vulnerabilities returned: %s', vulns) LOG.debug('full database returned is None: %s', db_full is None) @@ -452,6 +483,7 @@ def clean_license_command(f): """ Main entry point for validation. """ + @wraps(f) def inner(ctx, *args, **kwargs): # TODO: Remove this soon, for now it keeps a legacy behavior @@ -465,7 +497,9 @@ def inner(ctx, *args, **kwargs): return inner -@cli.command(cls=SafetyCLILegacyCommand, utility_command=True, help=CLI_LICENSES_COMMAND_HELP) +@cli.command(cls=SafetyCLILegacyCommand, + context_settings={CONTEXT_COMMAND_TYPE: CommandType.UTILITY}, + help=CLI_LICENSES_COMMAND_HELP) @proxy_options @auth_options(stage=False) @click.option("--db", default="", @@ -505,7 +539,8 @@ def license(ctx, db, output, cache, files): announcements = [] if not db: - announcements = safety.get_announcements(session=ctx.obj.auth.client, telemetry=ctx.obj.config.telemetry_enabled) + announcements = safety.get_announcements(session=ctx.obj.auth.client, + telemetry=ctx.obj.config.telemetry_enabled) output_report = SafetyFormatter(output=output).render_licenses(announcements, filtered_packages_licenses) @@ -513,37 +548,113 @@ def license(ctx, db, output, cache, files): print_deprecation_message("license", date(2024, 6, 1), new_command=None) -@cli.command(cls=SafetyCLILegacyCommand, utility_command=True, help=CLI_GENERATE_HELP) +@cli.command(cls=SafetyCLILegacyCommand, + context_settings={CONTEXT_COMMAND_TYPE: CommandType.UTILITY}, + help=CLI_GENERATE_HELP) @click.option("--path", default=".", help=CLI_GENERATE_PATH) +@click.option("--minimum-cvss-severity", default="critical", help=CLI_GENERATE_MINIMUM_CVSS_SEVERITY) @click.argument('name', required=True) @click.pass_context -def generate(ctx, name, path): +def generate(ctx, name, path, minimum_cvss_severity): """Create a boilerplate Safety CLI policy file NAME is the name of the file type to generate. Valid values are: policy_file """ - if name != 'policy_file': + if name != 'policy_file' and name != 'installation_policy': click.secho(f'This Safety version only supports "policy_file" generation. "{name}" is not supported.', fg='red', file=sys.stderr) sys.exit(EXIT_CODE_FAILURE) LOG.info('Running generate %s', name) + if name == 'policy_file': + generate_policy_file(name, path) + elif name == 'installation_policy': + generate_installation_policy(ctx, name, path, minimum_cvss_severity) + + +def generate_installation_policy(ctx, name, path, minimum_cvss_severity): + all_severities = [severity.name.lower() for severity in VulnerabilitySeverityLabels] + policy_severities = all_severities[all_severities.index(minimum_cvss_severity.lower()):] + policy_severities_set = set(policy_severities[:]) + + target = path + + ecosystems = [Ecosystem.PYTHON] + to_include = {file_type: paths for file_type, paths in ctx.obj.config.scan.include_files.items() if + file_type.ecosystem in ecosystems} + + # Initialize file finder + file_finder = FileFinder(target=target, ecosystems=ecosystems, + max_level=ctx.obj.config.scan.max_depth, + exclude=ctx.obj.config.scan.ignore, + include_files=to_include, + console=console) + + for handler in file_finder.handlers: + if handler.ecosystem: + wait_msg = "Fetching Safety's vulnerability database..." + with console.status(wait_msg, spinner=DEFAULT_SPINNER): + handler.download_required_assets(ctx.obj.auth.client) + + wait_msg = "Scanning project directory" + with console.status(wait_msg, spinner=DEFAULT_SPINNER): + path, file_paths = file_finder.search() + + target_ecosystems = ", ".join([member.value for member in ecosystems]) + wait_msg = f"Analyzing {target_ecosystems} files and environments for security findings" + + config = ctx.obj.config + + vulnerabilities = [] + with console.status(wait_msg, spinner=DEFAULT_SPINNER) as status: + for path, analyzed_file in process_files(paths=file_paths, + config=config): + affected_specifications = analyzed_file.dependency_results.get_affected_specifications() + if any(affected_specifications): + for spec in affected_specifications: + for vuln in spec.vulnerabilities: + if (vuln.severity + and vuln.severity.cvssv3 + and vuln.severity.cvssv3.get("base_severity", "none").lower() in policy_severities_set): + vulnerabilities.append(vuln) + + policy = v3_0.Config( + installation=v3_0.Installation( + default_action=v3_0.InstallationAction.ALLOW, + allow=v3_0.AllowedInstallation( + packages = None, + vulnerabilities={ + vuln.vulnerability_id: v3_0.IgnoredVulnerability( + reason=f"Autogenerated policy for {vuln.package_name} package.", + expires=date.today() + timedelta(days=90)) + for vuln in vulnerabilities + }), + deny=v3_0.DeniedInstallation( + packages=None, + vulnerabilities=v3_0.DeniedVulnerability( + block_on_any_of=v3_0.DeniedVulnerabilityCriteria(cvss_severity=policy_severities) + ) + ) + ) + ) + + click.secho(policy.json(by_alias=True, exclude_none=True, indent=4)) + + +def generate_policy_file(name, path): path = Path(path) if not path.exists(): click.secho(f'The path "{path}" does not exist.', fg='red', file=sys.stderr) sys.exit(EXIT_CODE_FAILURE) - policy = path / '.safety-policy.yml' - default_config = ConfigModel() - try: default_config.save_policy_file(policy) LOG.debug('Safety created the policy file.') msg = f'A default Safety policy file has been generated! Review the file contents in the path {path} in the ' \ - 'file: .safety-policy.yml' + 'file: .safety-policy.yml' click.secho(msg, fg='green') except Exception as exc: if isinstance(exc, OSError): @@ -554,8 +665,10 @@ def generate(ctx, name, path): sys.exit(EXIT_CODE_FAILURE) -@cli.command(cls=SafetyCLILegacyCommand, utility_command=True) -@click.option("--path", default=".safety-policy.yml", help="Path where the generated file will be saved. Default: current directory") +@cli.command(cls=SafetyCLILegacyCommand, + context_settings={CONTEXT_COMMAND_TYPE: CommandType.UTILITY}) +@click.option("--path", default=".safety-policy.yml", + help="Path where the generated file will be saved. Default: current directory") @click.argument('name') @click.argument('version', required=False) @click.pass_context @@ -574,7 +687,9 @@ def validate(ctx, name, version, path): sys.exit(EXIT_CODE_FAILURE) if version not in ["3.0", "2.0", None]: - click.secho(f'Version "{version}" is not a valid value, allowed values are 3.0 and 2.0. Use --path to specify the target file.', fg='red', file=sys.stderr) + click.secho( + f'Version "{version}" is not a valid value, allowed values are 3.0 and 2.0. Use --path to specify the target file.', + fg='red', file=sys.stderr) sys.exit(EXIT_CODE_FAILURE) def fail_validation(e): @@ -622,7 +737,7 @@ def fail_validation(e): @cli.command(cls=SafetyCLILegacyCommand, help=CLI_CONFIGURE_HELP, - utility_command=True) + context_settings={CONTEXT_COMMAND_TYPE: CommandType.UTILITY}) @click.option("--proxy-protocol", "-pr", type=click.Choice(['http', 'https']), default='https', cls=DependentOption, required_options=['proxy_host'], help=CLI_CONFIGURE_PROXY_PROTOCOL_HELP) @@ -711,7 +826,8 @@ def configure(ctx, proxy_protocol, proxy_host, proxy_port, proxy_timeout, config.write(configfile) except Exception as e: if (isinstance(e, OSError) and e.errno == 2 or e is PermissionError) and save_to_system: - click.secho("Unable to save the configuration: writing to system-wide Safety configuration file requires admin privileges") + click.secho( + "Unable to save the configuration: writing to system-wide Safety configuration file requires admin privileges") else: click.secho(f"Unable to save the configuration, error: {e}") sys.exit(1) @@ -720,32 +836,36 @@ def configure(ctx, proxy_protocol, proxy_host, proxy_port, proxy_timeout, cli_app = typer.Typer(rich_markup_mode="rich", cls=SafetyCLISubGroup) typer.rich_utils.STYLE_HELPTEXT = "" + def print_check_updates_header(console): - VERSION = get_safety_version() + VERSION = get_version() console.print( f"Safety {VERSION} checking for Safety version and configuration updates:") + class Output(str, Enum): SCREEN = "screen" JSON = "json" + @cli_app.command( - cls=SafetyCLIUtilityCommand, - help=CLI_CHECK_UPDATES_HELP, - name="check-updates", epilog=DEFAULT_EPILOG, - context_settings={"allow_extra_args": True, - "ignore_unknown_options": True}, - ) + cls=SafetyCLICommand, + help=CLI_CHECK_UPDATES_HELP, + name="check-updates", epilog=DEFAULT_EPILOG, + context_settings={"allow_extra_args": True, + "ignore_unknown_options": True, + CONTEXT_COMMAND_TYPE: CommandType.UTILITY}, +) @handle_cmd_exception def check_updates(ctx: typer.Context, - version: Annotated[ - int, - typer.Option(min=1), - ] = 1, - output: Annotated[Output, - typer.Option( - help="The main output generated by Safety CLI.") - ] = Output.SCREEN): + version: Annotated[ + int, + typer.Option(min=1), + ] = 1, + output: Annotated[Output, + typer.Option( + help="The main output generated by Safety CLI.") + ] = Output.SCREEN): """ Check for Safety CLI version updates """ @@ -757,7 +877,7 @@ def check_updates(ctx: typer.Context, wait_msg = "Authenticating and checking for Safety CLI updates" - VERSION = get_safety_version() + VERSION = get_version() PYTHON_VERSION = platform.python_version() OS_TYPE = platform.system() @@ -792,7 +912,8 @@ def check_updates(ctx: typer.Context, console.print() console.print("[red]Safety is not authenticated, please first authenticate and try again.[/red]") console.print() - console.print("To authenticate, use the `auth` command: `safety auth login` Or for more help: `safety auth —help`") + console.print( + "To authenticate, use the `auth` command: `safety auth login` Or for more help: `safety auth —help`") sys.exit(1) if not data: @@ -827,15 +948,18 @@ def check_updates(ctx: typer.Context, f"If Safety was installed from a requirements file, update Safety to version {latest_available_version} in that requirements file." ) console.print() - console.print(f"Pip: To install the updated version of Safety directly via pip, run: pip install safety=={latest_available_version}") + console.print( + f"Pip: To install the updated version of Safety directly via pip, run: pip install safety=={latest_available_version}") elif packaging_version.parse(latest_available_version) < packaging_version.parse(VERSION): # Notify user about downgrading - console.print(f"Latest stable version is {latest_available_version}. If you want to downgrade to this version, you can run: pip install safety=={latest_available_version}") + console.print( + f"Latest stable version is {latest_available_version}. If you want to downgrade to this version, you can run: pip install safety=={latest_available_version}") else: console.print("You are already using the latest stable version of Safety.") except InvalidVersion as invalid_version: LOG.exception(f'Invalid version format encountered: {invalid_version}') - console.print(f"Error: Invalid version format encountered for the latest available version: {latest_available_version}") + console.print( + f"Error: Invalid version format encountered for the latest available version: {latest_available_version}") console.print("Please report this issue or try again later.") if console.quiet: @@ -848,11 +972,14 @@ def check_updates(ctx: typer.Context, console.print_json(json.dumps(response)) -cli.add_command(typer.main.get_command(cli_app), "check-updates") -cli.add_command(typer.main.get_command(scan_project_app), "scan") -cli.add_command(typer.main.get_command(scan_system_app), "system-scan") +cli.add_command(typer.main.get_command(cli_app), name="check-updates") +cli.add_command(typer.main.get_command(init_app), name="init") +cli.add_command(typer.main.get_command(scan_project_app), name="scan") +cli.add_command(typer.main.get_command(scan_system_app), name="system-scan") +cli.add_command(typer.main.get_command(pip_app), name="pip") -cli.add_command(typer.main.get_command(auth_app), "auth") +cli.add_command(typer.main.get_command(auth_app), name="auth") +cli.add_command(typer.main.get_command(firewall_app), name="firewall") cli.add_command(alert) diff --git a/safety/cli_util.py b/safety/cli_util.py index 3579dbc0..82aa3e9b 100644 --- a/safety/cli_util.py +++ b/safety/cli_util.py @@ -1,42 +1,65 @@ -from collections import defaultdict import logging import subprocess import sys +from collections import defaultdict +from enum import Enum +from functools import wraps + from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Union + import click -from functools import wraps import typer -from typer.core import TyperGroup, TyperCommand, MarkupMode from rich.console import Console from rich.table import Table from rich.text import Text +from typer.core import MarkupMode, TyperCommand, TyperGroup +from click.utils import make_str from safety.auth.constants import CLI_AUTH, MSG_NON_AUTHENTICATED from safety.auth.models import Auth -from safety.constants import MSG_NO_AUTHD_CICD_PROD_STG, MSG_NO_AUTHD_CICD_PROD_STG_ORG, MSG_NO_AUTHD_DEV_STG, MSG_NO_AUTHD_DEV_STG_ORG_PROMPT, MSG_NO_AUTHD_DEV_STG_PROMPT, MSG_NO_AUTHD_NOTE_CICD_PROD_STG_TPL, MSG_NO_VERIFIED_EMAIL_TPL +from safety.auth.cli_utils import inject_session +from safety.constants import ( + BETA_PANEL_DESCRIPTION_HELP, + MSG_NO_AUTHD_CICD_PROD_STG, + MSG_NO_AUTHD_CICD_PROD_STG_ORG, + MSG_NO_AUTHD_DEV_STG, + MSG_NO_AUTHD_DEV_STG_ORG_PROMPT, + MSG_NO_AUTHD_DEV_STG_PROMPT, + MSG_NO_AUTHD_NOTE_CICD_PROD_STG_TPL, + MSG_NO_VERIFIED_EMAIL_TPL, + CONTEXT_COMMAND_TYPE, + FeatureType +) from safety.scan.constants import CONSOLE_HELP_THEME - from safety.scan.models import ScanOutput +from safety.models import SafetyCLI -from .util import output_exception from .errors import SafetyError, SafetyException +from .util import output_exception LOG = logging.getLogger(__name__) +class CommandType(Enum): + MAIN = "main" + UTILITY = "utility" + BETA = "beta" + def custom_print_options_panel(name: str, params: List[Any], ctx: Any, console: Console) -> None: """ Print a panel with options. Args: name (str): The title of the panel. - params (List[Any]): The list of options to print. + params (List[Any]): The list of options/arguments to print. ctx (Any): The context object. markup_mode (str): The markup mode. console (Console): The console to print to. """ table = Table(title=name, show_lines=True) for param in params: - table.add_row(str(param.opts), param.help or "") + opts = getattr(param, 'opts', '') + help_text = getattr(param, 'help', '') + table.add_row(str(opts), help_text) console.print(table) def custom_print_commands_panel(name: str, commands: List[Any], console: Console) -> None: @@ -131,7 +154,6 @@ def pass_safety_cli_obj(func): def inner(ctx, *args, **kwargs): if not ctx.obj: - from .models import SafetyCLI ctx.obj = SafetyCLI() return func(ctx, *args, **kwargs) @@ -149,13 +171,17 @@ def pretty_format_help(obj: Union[click.Command, click.Group], ctx (click.Context): The Click context. markup_mode (MarkupMode): The markup mode. """ - from typer.rich_utils import highlighter, STYLE_USAGE_COMMAND, \ - ARGUMENTS_PANEL_TITLE, OPTIONS_PANEL_TITLE, \ - COMMANDS_PANEL_TITLE from rich.align import Align - from rich.padding import Padding from rich.console import Console + from rich.padding import Padding from rich.theme import Theme + from typer.rich_utils import ( + ARGUMENTS_PANEL_TITLE, + COMMANDS_PANEL_TITLE, + OPTIONS_PANEL_TITLE, + STYLE_USAGE_COMMAND, + highlighter, + ) console = Console() @@ -284,6 +310,7 @@ def pretty_format_help(obj: Union[click.Command, click.Group], def print_main_command_panels(*, name: str, + commands_type: CommandType, commands: List[click.Command], markup_mode: MarkupMode, console) -> None: @@ -297,13 +324,20 @@ def print_main_command_panels(*, console: The Rich console. """ from rich import box + from rich.panel import Panel from rich.table import Table from rich.text import Text - from rich.panel import Panel - from typer.rich_utils import STYLE_COMMANDS_TABLE_SHOW_LINES, STYLE_COMMANDS_TABLE_LEADING, \ - STYLE_COMMANDS_TABLE_BOX, STYLE_COMMANDS_TABLE_BORDER_STYLE, STYLE_COMMANDS_TABLE_ROW_STYLES, \ - STYLE_COMMANDS_TABLE_PAD_EDGE, STYLE_COMMANDS_TABLE_PADDING, STYLE_COMMANDS_PANEL_BORDER, \ - ALIGN_COMMANDS_PANEL + from typer.rich_utils import ( + ALIGN_COMMANDS_PANEL, + STYLE_COMMANDS_PANEL_BORDER, + STYLE_COMMANDS_TABLE_BORDER_STYLE, + STYLE_COMMANDS_TABLE_BOX, + STYLE_COMMANDS_TABLE_LEADING, + STYLE_COMMANDS_TABLE_PAD_EDGE, + STYLE_COMMANDS_TABLE_PADDING, + STYLE_COMMANDS_TABLE_ROW_STYLES, + STYLE_COMMANDS_TABLE_SHOW_LINES, + ) t_styles: Dict[str, Any] = { "show_lines": STYLE_COMMANDS_TABLE_SHOW_LINES, @@ -330,6 +364,17 @@ def print_main_command_panels(*, if console.size and console.size[0] > 80: console_width = console.size[0] + from rich.console import Group + + description = None + + if commands_type is CommandType.BETA: + description = Group( + Text(""), + Text(BETA_PANEL_DESCRIPTION_HELP), + Text("") + ) + commands_table.add_column(style="bold cyan", no_wrap=True, width=column_width, max_width=column_width) commands_table.add_column(width=console_width - column_width) @@ -338,7 +383,7 @@ def print_main_command_panels(*, for command in commands: helptext = command.short_help or command.help or "" command_name = command.name or "" - command_name_text = Text(command_name) + command_name_text = Text(command_name, style="") if commands_type is CommandType.BETA else Text(command_name) rows.append( [ command_name_text, @@ -351,9 +396,10 @@ def print_main_command_panels(*, for row in rows: commands_table.add_row(*row) if commands_table.row_count: + renderables = [description, commands_table] if description is not None else [Text(""), commands_table] + console.print( - Panel( - commands_table, + Panel(Group(*renderables), border_style=STYLE_COMMANDS_PANEL_BORDER, title=name, title_align=ALIGN_COMMANDS_PANEL, @@ -371,14 +417,17 @@ def format_main_help(obj: Union[click.Command, click.Group], ctx (click.Context): The Click context. markup_mode (MarkupMode): The markup mode. """ - from typer.rich_utils import highlighter, STYLE_USAGE_COMMAND, \ - ARGUMENTS_PANEL_TITLE, OPTIONS_PANEL_TITLE, \ - COMMANDS_PANEL_TITLE - from rich.align import Align - from rich.padding import Padding from rich.console import Console + from rich.padding import Padding from rich.theme import Theme + from typer.rich_utils import ( + ARGUMENTS_PANEL_TITLE, + COMMANDS_PANEL_TITLE, + OPTIONS_PANEL_TITLE, + STYLE_USAGE_COMMAND, + highlighter, + ) typer_console = Console() @@ -404,31 +453,32 @@ def format_main_help(obj: Union[click.Command, click.Group], ) if isinstance(obj, click.MultiCommand): - UTILITY_COMMANDS_PANEL_TITLE = "Commands cont." - panel_to_commands: DefaultDict[str, List[click.Command]] = defaultdict(list) + UTILITY_COMMANDS_PANEL_TITLE = "Utility commands" + BETA_COMMANDS_PANEL_TITLE = "Beta Commands :rocket:" + + COMMANDS_PANEL_TITLE_CONSTANTS = { + CommandType.MAIN: COMMANDS_PANEL_TITLE, + CommandType.UTILITY: UTILITY_COMMANDS_PANEL_TITLE, + CommandType.BETA: BETA_COMMANDS_PANEL_TITLE + } + + panel_to_commands: Dict[CommandType, List[click.Command]] = {} + + # Keep order of panels + for command_type in COMMANDS_PANEL_TITLE_CONSTANTS.keys(): + panel_to_commands[command_type] = [] + for command_name in obj.list_commands(ctx): command = obj.get_command(ctx, command_name) if command and not command.hidden: - panel_name = ( - UTILITY_COMMANDS_PANEL_TITLE if command.utility_command else COMMANDS_PANEL_TITLE - ) - panel_to_commands[panel_name].append(command) + command_type = command.context_settings.get(CONTEXT_COMMAND_TYPE, CommandType.MAIN) + panel_to_commands[command_type].append(command) - # Print each command group panel - default_commands = panel_to_commands.get(COMMANDS_PANEL_TITLE, []) - print_main_command_panels( - name=COMMANDS_PANEL_TITLE, - commands=default_commands, - markup_mode=markup_mode, - console=console, - ) - for panel_name, commands in panel_to_commands.items(): - if panel_name == COMMANDS_PANEL_TITLE: - # Already printed above - continue + for command_type, commands in panel_to_commands.items(): print_main_command_panels( - name=panel_name, + name=COMMANDS_PANEL_TITLE_CONSTANTS[command_type], + commands_type=command_type, commands=commands, markup_mode=markup_mode, console=console, @@ -503,8 +553,8 @@ def process_auth_status_not_ready(console, auth: Auth, ctx: typer.Context) -> No auth (Auth): The Auth object. ctx (typer.Context): The Typer context. """ - from safety_schemas.models import Stage from rich.prompt import Confirm, Prompt + from safety_schemas.models import Stage if not auth.client or not auth.client.is_using_auth_credentials(): @@ -566,26 +616,28 @@ def process_auth_status_not_ready(console, auth: Auth, ctx: typer.Context) -> No console.print(MSG_NON_AUTHENTICATED) sys.exit(1) -class UtilityCommandMixin: - """ - Mixin to add utility command functionality. - """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - """ - Initialize the UtilityCommandMixin. +class CustomContext(click.Context): + def __init__( + self, + command: "Command", + parent: Optional["Context"] = None, + command_type: CommandType = CommandType.MAIN, + feature_type: Optional[FeatureType] = None, + **kwargs + ) -> None: + self.command_type = command_type + self.feature_type = feature_type + super().__init__(command, parent=parent, **kwargs) + - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - """ - self.utility_command = kwargs.pop('utility_command', False) - super().__init__(*args, **kwargs) -class SafetyCLISubGroup(UtilityCommandMixin, TyperGroup): +class SafetyCLISubGroup(TyperGroup): """ Custom TyperGroup with additional functionality for Safety CLI. """ + context_class = CustomContext + def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: """ Format help message with rich formatting. @@ -629,10 +681,13 @@ def command( """ super().command(*args, **kwargs) -class SafetyCLICommand(UtilityCommandMixin, TyperCommand): +class SafetyCLICommand(TyperCommand): """ Custom TyperCommand with additional functionality for Safety CLI. """ + + context_class = CustomContext + def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: """ Format help message with rich formatting. @@ -660,25 +715,43 @@ def format_usage(self, ctx: click.Context, formatter: click.HelpFormatter) -> No formatter.write_usage(command_path, " ".join(pieces)) -class SafetyCLIUtilityCommand(TyperCommand): +class SafetyCLILegacyGroup(click.Group): """ - Custom TyperCommand designated as a utility command. + Custom Click Group to handle legacy command-line arguments. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - """ - Initialize the SafetyCLIUtilityCommand. - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - """ - self.utility_command = True + context_class = CustomContext + + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.all_commands = {} -class SafetyCLILegacyGroup(UtilityCommandMixin, click.Group): - """ - Custom Click Group to handle legacy command-line arguments. - """ + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + def add_command(self, cmd, name = None) -> None: + super().add_command(cmd, name) + + name = name or cmd.name + self.all_commands[name] = cmd + + def parse_args(self, ctx: click.Context, args: List[str]) -> List[str]: + parsed_args = super().parse_args(ctx, args) + + args = ctx.args + + # Workaround for legacy check options, that now are global options + subcommand_args = set(args) + PROXY_HOST_OPTIONS = set(["--proxy-host", "-ph"]) + if "check" in ctx.protected_args or "license" in ctx.protected_args and (bool(PROXY_HOST_OPTIONS.intersection(subcommand_args) or "--key" in subcommand_args)) : + proxy_options, key = self.parse_legacy_args(args) + if proxy_options: + ctx.params.update(proxy_options) + + if key: + ctx.params.update({"key": key}) + + return parsed_args def parse_legacy_args(self, args: List[str]) -> Tuple[Optional[Dict[str, str]], Optional[str]]: """ @@ -710,6 +783,22 @@ def parse_legacy_args(self, args: List[str]) -> Tuple[Optional[Dict[str, str]], proxy = options if options['proxy_host'] else None return proxy, key + def get_filtered_commands(self, ctx: click.Context) -> Dict[str, click.Command]: + from safety.auth.utils import initialize + + initialize(ctx, refresh=False) + + # Filter commands here: + from .constants import CONTEXT_FEATURE_TYPE + + disabled_features = [ + feature_type + for feature_type in FeatureType if not getattr(ctx.obj, feature_type.attr_name, False) + ] + + return {k: v for k, v in self.commands.items() if v.context_settings.get(CONTEXT_FEATURE_TYPE, None) not in disabled_features} + + def invoke(self, ctx: click.Context) -> None: """ Invoke the command, handling legacy arguments. @@ -717,23 +806,34 @@ def invoke(self, ctx: click.Context) -> None: Args: ctx (click.Context): Click context. """ - args = ctx.args - - # Workaround for legacy check options, that now are global options - subcommand_args = set(args) - PROXY_HOST_OPTIONS = set(["--proxy-host", "-ph"]) - if "check" in ctx.protected_args or "license" in ctx.protected_args and (bool(PROXY_HOST_OPTIONS.intersection(subcommand_args) or "--key" in subcommand_args)) : - proxy_options, key = self.parse_legacy_args(args) - if proxy_options: - ctx.params.update(proxy_options) - - if key: - ctx.params.update({"key": key}) + session_kwargs = { + 'ctx': ctx, + 'proxy_protocol': ctx.params.pop('proxy_protocol', None), + 'proxy_host': ctx.params.pop('proxy_host', None), + 'proxy_port': ctx.params.pop('proxy_port', None), + 'key': ctx.params.pop('key', None), + 'stage': ctx.params.pop('stage', None), + } + invoked_command = make_str(next(iter(ctx.protected_args), "")) + inject_session(**session_kwargs, invoked_command=invoked_command) + + # call initialize if the --key is used. + if session_kwargs['key']: + from safety.auth.utils import initialize + initialize(ctx, refresh=True) + + self.commands = self.get_filtered_commands(ctx) # Now, invoke the original behavior super(SafetyCLILegacyGroup, self).invoke(ctx) + def list_commands(self, ctx: click.Context) -> List[str]: + """Override click.Group.list_commands with custom filtering""" + self.commands = self.get_filtered_commands(ctx) + + return super().list_commands(ctx) + def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: """ Format help message with rich formatting. @@ -749,10 +849,12 @@ def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> Non else: pretty_format_help(self, ctx, markup_mode="rich") -class SafetyCLILegacyCommand(UtilityCommandMixin, click.Command): +class SafetyCLILegacyCommand(click.Command): """ Custom Click Command to handle legacy command-line arguments. """ + context_class = CustomContext + def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: """ Format help message with rich formatting. @@ -775,7 +877,7 @@ def handle_cmd_exception(func): The wrapped function. """ @wraps(func) - def inner(ctx, output: Optional[ScanOutput], *args, **kwargs): + def inner(ctx, output: Optional[ScanOutput] = None, *args, **kwargs): if output: kwargs.update({"output": output}) diff --git a/safety/constants.py b/safety/constants.py index 9e25d335..8c8f8aa5 100644 --- a/safety/constants.py +++ b/safety/constants.py @@ -3,7 +3,7 @@ import os from enum import Enum from pathlib import Path -from typing import Optional +from typing import Optional, Union JSON_SCHEMA_VERSION = '2.0.0' @@ -55,6 +55,8 @@ def get_user_dir() -> Path: CACHE_FILE_DIR = USER_CONFIG_DIR / f"{JSON_SCHEMA_VERSION.replace('.', '')}" DB_CACHE_FILE = CACHE_FILE_DIR / "cache.json" +PIP_LOCK = USER_CONFIG_DIR / "pip.lock" + CONFIG_FILE_NAME = "config.ini" CONFIG_FILE_SYSTEM = SYSTEM_CONFIG_DIR / CONFIG_FILE_NAME if SYSTEM_CONFIG_DIR else None CONFIG_FILE_USER = USER_CONFIG_DIR / CONFIG_FILE_NAME @@ -76,8 +78,32 @@ class URLSettings(Enum): AUTH_SERVER_URL = f'https://auth.{DEFAULT_DOMAIN}' SAFETY_PLATFORM_URL = f"https://platform.{DEFAULT_DOMAIN}" +class FeatureType(Enum): + """ + Defines server-controlled features for dynamic feature management. + + Each enum value represents a toggleable feature controlled through + server-side configuration, enabling gradual rollouts to different user + segments. Features are cached during CLI initialization. -def get_config_setting(name: str) -> Optional[str]: + History: + Created to support progressive feature rollouts and A/B testing without + disturbing users. + """ + FIREWALL = "firewall" + PLATFORM = "platform" + + @property + def config_key(self) -> str: + """For JSON/config lookup e.g. 'feature-a-enabled'""" + return f"{self.name.lower()}-enabled" + + @property + def attr_name(self) -> str: + """For Python attribute access e.g. 'feature_a_enabled'""" + return f"{self.name.lower()}_enabled" + +def get_config_setting(name: str, default=None) -> Optional[str]: """ Get the configuration setting from the config file or defaults. @@ -90,8 +116,6 @@ def get_config_setting(name: str) -> Optional[str]: config = configparser.ConfigParser() config.read(CONFIG) - default = None - if name in [setting.name for setting in URLSettings]: default = URLSettings[name] @@ -113,7 +137,7 @@ def get_config_setting(name: str) -> Optional[str]: PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT = f"{PLATFORM_API_BASE_URL}/scan" PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT = f"{PLATFORM_API_BASE_URL}/process_files" PLATFORM_API_CHECK_UPDATES_ENDPOINT = f"{PLATFORM_API_BASE_URL}/versions-and-configs" -PLATFORM_API_INITIALIZE_SCAN_ENDPOINT = f"{PLATFORM_API_BASE_URL}/initialize-scan" +PLATFORM_API_INITIALIZE_ENDPOINT = f"{PLATFORM_API_BASE_URL}/initialize" API_MIRRORS = [ @@ -176,7 +200,7 @@ def get_config_setting(name: str) -> Optional[str]: EXIT_CODE_OK = 0 EXIT_CODE_FAILURE = 1 EXIT_CODE_VULNERABILITIES_FOUND = 64 -EXIT_CODE_INVALID_API_KEY = 65 +EXIT_CODE_INVALID_AUTH_CREDENTIAL = 65 EXIT_CODE_TOO_MANY_REQUESTS = 66 EXIT_CODE_UNABLE_TO_LOAD_LOCAL_VULNERABILITY_DB = 67 EXIT_CODE_UNABLE_TO_FETCH_VULNERABILITY_DB = 68 @@ -187,3 +211,8 @@ def get_config_setting(name: str) -> Optional[str]: #For Depreciated Messages BAR_LINE = "+===========================================================================================================================================================================================+" + +BETA_PANEL_DESCRIPTION_HELP = "These commands are experimental and part of our commitment to delivering innovative features. As we refine functionality, they may be significantly altered or, in rare cases, removed without prior notice. We welcome your feedback and encourage cautious use." + +CONTEXT_COMMAND_TYPE = "command_type" +CONTEXT_FEATURE_TYPE = "feature_type" \ No newline at end of file diff --git a/safety/errors.py b/safety/errors.py index 65a55652..1dd31e2d 100644 --- a/safety/errors.py +++ b/safety/errors.py @@ -3,7 +3,7 @@ from safety.constants import ( EXIT_CODE_EMAIL_NOT_VERIFIED, EXIT_CODE_FAILURE, - EXIT_CODE_INVALID_API_KEY, + EXIT_CODE_INVALID_AUTH_CREDENTIAL, EXIT_CODE_INVALID_PROVIDED_REPORT, EXIT_CODE_INVALID_REQUIREMENT, EXIT_CODE_MALFORMED_DB, @@ -200,7 +200,7 @@ def get_exit_code(self) -> int: Returns: int: The exit code. """ - return EXIT_CODE_INVALID_API_KEY + return EXIT_CODE_INVALID_AUTH_CREDENTIAL class NotVerifiedEmailError(SafetyError): """ diff --git a/safety/firewall/__init__.py b/safety/firewall/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/safety/firewall/command.py b/safety/firewall/command.py new file mode 100644 index 00000000..7888f73e --- /dev/null +++ b/safety/firewall/command.py @@ -0,0 +1,117 @@ +import logging +import sys + +from rich.prompt import Prompt +# TODO: refactor this import and the related code +# For now, let's keep it as is +from safety.scan.constants import DEFAULT_EPILOG +from ..cli_util import CommandType, FeatureType, SafetyCLICommand, \ + SafetyCLISubGroup, handle_cmd_exception, pass_safety_cli_obj +import typer + + +from safety.console import main_console as console + +from ..constants import CONTEXT_COMMAND_TYPE, CONTEXT_FEATURE_TYPE, EXIT_CODE_OK + +from .constants import FIREWALL_HELP, MSG_FEEDBACK, MSG_REQ_FILE_LINE, \ + MSG_UNINSTALL_EXPLANATION, MSG_UNINSTALL_SUCCESS, \ + UNINSTALL_CMD_NAME, UNINSTALL_HELP, FIREWALL_CMD_NAME, \ + MSG_UNINSTALL_PIP_CONFIG, MSG_UNINSTALL_PIP_ALIAS + +from ..tool.main import reset_system +from ..tool.interceptors import create_interceptor + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + +firewall_app = typer.Typer(rich_markup_mode= "rich", cls=SafetyCLISubGroup, + name=FIREWALL_CMD_NAME) + + +LOG = logging.getLogger(__name__) + + +init_app = typer.Typer(rich_markup_mode= "rich", cls=SafetyCLISubGroup) + +@firewall_app.callback(cls=SafetyCLISubGroup, + help=FIREWALL_HELP, + epilog=DEFAULT_EPILOG, + context_settings={ + "allow_extra_args": True, + "ignore_unknown_options": True, + CONTEXT_COMMAND_TYPE: CommandType.BETA, + CONTEXT_FEATURE_TYPE: FeatureType.FIREWALL + }) +@pass_safety_cli_obj +def firewall(ctx: typer.Context) -> None: + """ + Main callback for the firewall commands. + + Args: + ctx (typer.Context): The Typer context object. + """ + LOG.info('firewall callback started') + + +@firewall_app.command( + cls=SafetyCLICommand, + name=UNINSTALL_CMD_NAME, + help=UNINSTALL_HELP, + options_metavar="[OPTIONS]", + context_settings={ + "allow_extra_args": True, + "ignore_unknown_options": True, + CONTEXT_COMMAND_TYPE: CommandType.BETA, + CONTEXT_FEATURE_TYPE: FeatureType.FIREWALL + }, +) +@handle_cmd_exception +def uninstall(ctx: typer.Context): + console.print() + console.print(MSG_UNINSTALL_EXPLANATION) + + console.print() + prompt = "Uninstall?" + should_uninstall = Prompt.ask(prompt=prompt, choices=["y", "n"], + default="y", show_default=True, + console=console).lower() == 'y' + + if not should_uninstall: + sys.exit(EXIT_CODE_OK) + + console.print() + console.print(MSG_UNINSTALL_PIP_CONFIG) + # TODO: Make it robust. The reset per tool should be included in remove + # interceptors + reset_system() + + # TODO: support reset project files + + console.print(MSG_UNINSTALL_PIP_ALIAS) + interceptor = create_interceptor() + interceptor.remove_interceptors() + + console.print() + console.print(MSG_UNINSTALL_SUCCESS) + + console.print() + console.print(MSG_REQ_FILE_LINE) + + console.print() + + # TODO: Ask for feedback + # console.print(MSG_FEEDBACK) + + # console.print() + # prompt = "Feedback (or enter to exit)" + # feedback = Prompt.ask(prompt) + + # if feedback: + # console.print() + # # TODO: send feedback to the server + # console.print("Thank you for your feedback!") + + diff --git a/safety/firewall/constants.py b/safety/firewall/constants.py new file mode 100644 index 00000000..00f8a506 --- /dev/null +++ b/safety/firewall/constants.py @@ -0,0 +1,19 @@ +MSG_UNINSTALL_EXPLANATION = "Would you like to uninstall Safety Firewall on this machine? Doing so will mean you are no longer protected from malicious or vulnerable packages." +MSG_UNINSTALL_SUCCESS = "Safety Firewall has been uninstalled from your machine. Note that your individual requirements files may still reference Safety Firewall. You can remove these references by removing the following line from your requirements files:" +MSG_REQ_FILE_LINE = "-i https://pkgs.safetycli.com/repository/public/pypi/simple/" + +MSG_FEEDBACK = "We're sorry to see you go. If you have any feedback on how we can do better, we'd love to hear it. Otherwise hit enter to exit." + + + +UNINSTALL_HELP = "Uninstall Safety Firewall from your machine." + + +FIREWALL_CMD_NAME = "firewall" +UNINSTALL_CMD_NAME = "uninstall" + + +FIREWALL_HELP = "[BETA] Manage Safety Firewall settings." + +MSG_UNINSTALL_PIP_CONFIG = "Removing global configuration for pip from: ~/.config/pip/pip.conf" +MSG_UNINSTALL_PIP_ALIAS = "Removing pip alias to safety from ~/.profile" diff --git a/safety/init/__init__.py b/safety/init/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/safety/init/command.py b/safety/init/command.py new file mode 100644 index 00000000..9fc258f4 --- /dev/null +++ b/safety/init/command.py @@ -0,0 +1,101 @@ +from pathlib import Path + +from rich.prompt import Prompt +from ..cli_util import CommandType, FeatureType, SafetyCLICommand, \ + SafetyCLISubGroup, handle_cmd_exception +import typer +import os + + +from safety.init.constants import PROJECT_INIT_CMD_NAME, PROJECT_INIT_HELP, PROJECT_INIT_DIRECTORY_HELP +from safety.init.main import create_project +from safety.console import main_console as console +from ..scan.command import scan +from ..scan.models import ScanOutput +from ..tool.main import configure_system, configure_local_directory, has_local_tool_files, configure_alias + +from ..constants import CONTEXT_COMMAND_TYPE, CONTEXT_FEATURE_TYPE + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + +init_app = typer.Typer(rich_markup_mode= "rich", cls=SafetyCLISubGroup) + +@init_app.command( + cls=SafetyCLICommand, + help=PROJECT_INIT_HELP, + name=PROJECT_INIT_CMD_NAME, + options_metavar="[OPTIONS]", + context_settings={ + "allow_extra_args": True, + "ignore_unknown_options": True, + CONTEXT_COMMAND_TYPE: CommandType.BETA, + CONTEXT_FEATURE_TYPE: FeatureType.FIREWALL + }, +) +@handle_cmd_exception +def init(ctx: typer.Context, + directory: Annotated[ + Path, + typer.Argument( + exists=True, + file_okay=False, + dir_okay=True, + writable=False, + readable=True, + resolve_path=True, + show_default=False, + help=PROJECT_INIT_DIRECTORY_HELP + ), + ] = Path(".")): + + do_init(ctx, directory, False) + + +def do_init(ctx: typer.Context, directory: Path, prompt_user: bool = True): + project_dir = directory if os.path.isabs(directory) else os.path.join(os.getcwd(), directory) + create_project(ctx, console, Path(project_dir)) + + answer = 'y' if not prompt_user else None + if prompt_user: + console.print( + "Safety prevents vulnerable or malicious packages from being installed on your computer. We do this by wrapping your package manager.") + prompt = "Do you want to enable proactive malicious package prevention?" + answer = Prompt.ask(prompt=prompt, choices=["y", "n"], + default="y", show_default=True, console=console).lower() + + if answer == 'y': + configure_system() + + if prompt_user: + prompt = "Do you want to alias pip to Safety?" + answer = Prompt.ask(prompt=prompt, choices=["y", "n"], + default="y", show_default=True, console=console).lower() + + if answer == 'y': + configure_alias() + + if has_local_tool_files(project_dir): + if prompt_user: + prompt = "Do you want to enable proactive malicious package prevention for any project in working directory?" + answer = Prompt.ask(prompt=prompt, choices=["y", "n"], + default="y", show_default=True, console=console).lower() + + if answer == 'y': + configure_local_directory(project_dir) + + if prompt_user: + prompt = "It looks like your current directory contains a requirements.txt file. Would you like Safety to scan it?" + answer = Prompt.ask(prompt=prompt, choices=["y", "n"], + default="y", show_default=True, console=console).lower() + + if answer == 'y': + ctx.command.name = "scan" + ctx.params = { + "target": directory, + "output": ScanOutput.SCREEN, + "policy_file_path": None + } + scan(ctx=ctx, target=directory, output=ScanOutput.SCREEN, policy_file_path=None) diff --git a/safety/init/constants.py b/safety/init/constants.py new file mode 100644 index 00000000..fafba1f5 --- /dev/null +++ b/safety/init/constants.py @@ -0,0 +1,6 @@ +# Project options +PROJECT_INIT_CMD_NAME = "init" +PROJECT_INIT_HELP = "[BETA] Used to install Safety Firewall globally, or to initialize a project in the current directory."\ +"\nExample: safety init" +PROJECT_INIT_DIRECTORY_HELP = "[BETA] Defines a directory for creating a new project. (default: current directory)\n\n" \ + "[bold]Example: safety init /path/to/project[/bold]" diff --git a/safety/init/main.py b/safety/init/main.py new file mode 100644 index 00000000..9038d70d --- /dev/null +++ b/safety/init/main.py @@ -0,0 +1,274 @@ +import logging +import uuid +import typer +from rich.console import Console + +from .models import UnverifiedProjectModel + +import configparser +from pathlib import Path +from safety_schemas.models import ProjectModel, Stage +from safety.scan.util import GIT +from ..auth.utils import SafetyAuthSession + +from typing import Optional +from safety.scan.render import ( + print_wait_project_verification, + prompt_project_id, + prompt_link_project, +) + +PROJECT_CONFIG = ".safety-project.ini" +PROJECT_CONFIG_SECTION = "project" +PROJECT_CONFIG_ID = "id" +PROJECT_CONFIG_URL = "url" +PROJECT_CONFIG_NAME = "name" + + +LOG = logging.getLogger(__name__) + + +def check_project( + ctx: typer.Context, + session: SafetyAuthSession, + console: Console, + unverified_project: UnverifiedProjectModel, + git_origin: Optional[str], + ask_project_id: bool = False, +) -> dict: + """ + Check the project against the session and stage, verifying the project if necessary. + + Args: + console: The console for output. + ctx (typer.Context): The context of the Typer command. + session (SafetyAuthSession): The authentication session. + unverified_project (UnverifiedProjectModel): The unverified project model. + stage (Stage): The current stage. + git_origin (Optional[str]): The Git origin URL. + ask_project_id (bool): Whether to prompt for the project ID. + + Returns: + dict: The result of the project check. + """ + stage = ctx.obj.auth.stage + source = ctx.obj.telemetry.safety_source if ctx.obj.telemetry else None + data = {"scan_stage": stage, "safety_source": source} + + PRJ_SLUG_KEY = "project_slug" + PRJ_SLUG_SOURCE_KEY = "project_slug_source" + PRJ_GIT_ORIGIN_KEY = "git_origin" + + if git_origin: + data[PRJ_GIT_ORIGIN_KEY] = git_origin + + if unverified_project.id: + data[PRJ_SLUG_KEY] = unverified_project.id + data[PRJ_SLUG_SOURCE_KEY] = ".safety-project.ini" + elif not git_origin or ask_project_id: + default_id = unverified_project.project_path.parent.name + + if not default_id: + # Sometimes the parent directory is empty, so we generate + # a random ID + default_id = str(uuid.uuid4())[:10] + + unverified_project.id = prompt_project_id(console, default_id) + data[PRJ_SLUG_KEY] = unverified_project.id + data[PRJ_SLUG_SOURCE_KEY] = "user" + + status = print_wait_project_verification( + console, + data[PRJ_SLUG_KEY] if data.get(PRJ_SLUG_KEY, None) else "-", + (session.check_project, data), + on_error_delay=1, + ) + + return status + + +def verify_project( + console: Console, + ctx: typer.Context, + session: SafetyAuthSession, + unverified_project: UnverifiedProjectModel, + stage: Stage, + git_origin: Optional[str], +): + """ + Verify the project, linking it if necessary and saving the verified project information. + + Args: + console: The console for output. + ctx (typer.Context): The context of the Typer command. + session (SafetyAuthSession): The authentication session. + unverified_project (UnverifiedProjectModel): The unverified project model. + stage (Stage): The current stage. + git_origin (Optional[str]): The Git origin URL. + """ + + verified_prj = False + + link_prj = True + + while not verified_prj: + result = check_project( + ctx, + session, + console, + unverified_project, + git_origin, + ask_project_id=not link_prj, + ) + + unverified_slug = result.get("slug") + + project = result.get("project", None) + user_confirm = result.get("user_confirm", False) + + if user_confirm: + if project and link_prj: + prj_name = project.get("name", None) + prj_admin_email = project.get("admin", None) + + link_prj = prompt_link_project( + prj_name=prj_name, prj_admin_email=prj_admin_email, console=console + ) + + if not link_prj: + continue + + verified_prj = print_wait_project_verification( + console, + unverified_slug, + (session.project, {"project_id": unverified_slug}), + on_error_delay=1, + ) + + if ( + verified_prj + and isinstance(verified_prj, dict) + and verified_prj.get("slug", None) + ): + save_verified_project( + ctx, + verified_prj["slug"], + verified_prj.get("name", None), + unverified_project.project_path, + verified_prj.get("url", None), + ) + else: + verified_prj = False + + +def load_unverified_project_from_config(project_root: Path) -> UnverifiedProjectModel: + """ + Loads an unverified project from the configuration file located at the project root. + + Args: + project_root (Path): The root directory of the project. + + Returns: + UnverifiedProjectModel: An instance of UnverifiedProjectModel. + """ + config = configparser.ConfigParser() + project_path = project_root / PROJECT_CONFIG + config.read(project_path) + id = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_ID, fallback=None) + url = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_URL, fallback=None) + name = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_NAME, fallback=None) + created = True + if id: + created = False + + return UnverifiedProjectModel( + id=id, url_path=url, name=name, project_path=project_path, created=created + ) + + +def save_verified_project( + ctx: typer.Context, + slug: str, + name: Optional[str], + project_path: Path, + url_path: Optional[str], +): + """ + Save the verified project information to the context and project info file. + + Args: + ctx (typer.Context): The context of the Typer command. + slug (str): The project slug. + name (Optional[str]): The project name. + project_path (Path): The project path. + url_path (Optional[str]): The project URL path. + """ + ctx.obj.project = ProjectModel( + id=slug, name=name, project_path=project_path, url_path=url_path + ) + + save_project_info(project=ctx.obj.project, project_path=project_path) + + +def save_project_info(project: ProjectModel, project_path: Path) -> bool: + """ + Saves the project information to the configuration file. + + Args: + project (ProjectModel): The ProjectModel object containing project + information. + project_path (Path): The path to the configuration file. + + Returns: + bool: True if the project information was saved successfully, False + otherwise. + """ + config = configparser.ConfigParser() + config.read(project_path) + + if PROJECT_CONFIG_SECTION not in config.sections(): + config[PROJECT_CONFIG_SECTION] = {} + + config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_ID] = project.id + if project.url_path: + config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_URL] = project.url_path + if project.name: + config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_NAME] = project.name + + try: + with open(project_path, "w") as configfile: + config.write(configfile) + except Exception: + LOG.exception("Error saving project info") + return False + + return True + + +def create_project( + ctx: typer.Context, console: Console, target: Path +): + """ + Loads existing project from the specified target locations or creates a new project. + + Args: + ctx: The CLI context + session: The authentication session + console: The console object + target (Path): The target location + """ + # Load .safety-project.ini + unverified_project = load_unverified_project_from_config(project_root=target) + + stage = ctx.obj.auth.stage + session = ctx.obj.auth.client + git_data = GIT(root=target).build_git_data() + origin = None + + if git_data: + origin = git_data.origin + + if ctx.obj.platform_enabled: + verify_project(console, ctx, session, unverified_project, stage, origin) + else: + console.print("Project creation is not supported for your account.") diff --git a/safety/init/models.py b/safety/init/models.py new file mode 100644 index 00000000..8e334933 --- /dev/null +++ b/safety/init/models.py @@ -0,0 +1,17 @@ +from pathlib import Path +from typing import Optional + +from pydantic.dataclasses import dataclass + + +@dataclass +class UnverifiedProjectModel: + """ + Data class representing an unverified project model. + """ + + id: Optional[str] + project_path: Path + created: bool + name: Optional[str] = None + url_path: Optional[str] = None diff --git a/safety/meta.py b/safety/meta.py new file mode 100644 index 00000000..26d23b41 --- /dev/null +++ b/safety/meta.py @@ -0,0 +1,20 @@ +from importlib.metadata import PackageNotFoundError, version +import logging +from typing import Optional + + +LOG = logging.getLogger(__name__) + + +def get_version() -> Optional[str]: + """ + Get the version of the Safety package. + + Returns: + Optional[str]: The Safety version if found, otherwise None. + """ + try: + return version("safety") + except PackageNotFoundError: + LOG.exception("Unable to get Safety version.") + return None diff --git a/safety/models/__init__.py b/safety/models/__init__.py new file mode 100644 index 00000000..e2799c0f --- /dev/null +++ b/safety/models/__init__.py @@ -0,0 +1,18 @@ +from .obj import SafetyCLI +from .requirements import is_pinned_requirement +from .vulnerabilities import Vulnerability, CVE, Severity, Fix, \ + SafetyRequirement, Package, SafetyEncoder, RequirementFile + + +__all__ = [ + 'Package', + 'SafetyCLI', + 'Vulnerability', + 'CVE', + 'Severity', + 'Fix', + 'is_pinned_requirement', + 'SafetyRequirement', + 'SafetyEncoder', + 'RequirementFile' +] \ No newline at end of file diff --git a/safety/models/obj.py b/safety/models/obj.py new file mode 100644 index 00000000..8a2cfccc --- /dev/null +++ b/safety/models/obj.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from safety.auth.models import Auth + from rich.console import Console + + from safety_schemas.models import MetadataModel, ReportSchemaVersion, \ + TelemetryModel, PolicyFileModel, ConfigModel + +@dataclass +class SafetyCLI: + """ + A class representing Safety CLI settings. + """ + auth: Optional['Auth'] = None + telemetry: Optional['TelemetryModel'] = None + metadata: Optional['MetadataModel'] = None + schema: Optional['ReportSchemaVersion'] = None + project = None + config: Optional['ConfigModel'] = None + console: Optional['Console'] = None + system_scan_policy: Optional['PolicyFileModel'] = None + platform_enabled: bool = False + firewall_enabled: bool = False diff --git a/safety/models/requirements.py b/safety/models/requirements.py new file mode 100644 index 00000000..e846a3cc --- /dev/null +++ b/safety/models/requirements.py @@ -0,0 +1,19 @@ +from packaging.specifiers import SpecifierSet + + +def is_pinned_requirement(spec: SpecifierSet) -> bool: + """ + Check if a requirement is pinned. + + Args: + spec (SpecifierSet): The specifier set of the requirement. + + Returns: + bool: True if the requirement is pinned, False otherwise. + """ + if not spec or len(spec) != 1: + return False + + specifier = next(iter(spec)) + + return (specifier.operator == '==' and '*' != specifier.version[-1]) or specifier.operator == '===' diff --git a/safety/models.py b/safety/models/vulnerabilities.py similarity index 91% rename from safety/models.py rename to safety/models/vulnerabilities.py index 20a81887..988de0ed 100644 --- a/safety/models.py +++ b/safety/models/vulnerabilities.py @@ -2,28 +2,25 @@ from collections import namedtuple from dataclasses import dataclass, field from datetime import datetime -from typing import Any, List, Optional, Set, Tuple, Union, Dict +from typing import Any, List, Optional, Set, Union, Dict from dparse.dependencies import Dependency from dparse import parse, filetypes -from typing import Any, List, Optional from packaging.specifiers import SpecifierSet from packaging.requirements import Requirement from packaging.utils import canonicalize_name from packaging.version import parse as parse_version from packaging.version import Version -from safety_schemas.models import ConfigModel from safety.errors import InvalidRequirementError -from safety_schemas.models import MetadataModel, ReportSchemaVersion, TelemetryModel, \ - PolicyFileModel try: from packaging.version import LegacyVersion as legacyType except ImportError: legacyType = None +from .requirements import is_pinned_requirement class DictConverter(object): """ @@ -120,24 +117,6 @@ def to_dict(self, **kwargs: Any) -> Dict: } -def is_pinned_requirement(spec: SpecifierSet) -> bool: - """ - Check if a requirement is pinned. - - Args: - spec (SpecifierSet): The specifier set of the requirement. - - Returns: - bool: True if the requirement is pinned, False otherwise. - """ - if not spec or len(spec) != 1: - return False - - specifier = next(iter(spec)) - - return (specifier.operator == '==' and '*' != specifier.version[-1]) or specifier.operator == '===' - - @dataclass class Package(DictConverter): """ @@ -454,21 +433,3 @@ class Safety: client: Any keys: Any - -from safety.auth.models import Auth -from rich.console import Console - -@dataclass -class SafetyCLI: - """ - A class representing Safety CLI settings. - """ - auth: Optional[Auth] = None - telemetry: Optional[TelemetryModel] = None - metadata: Optional[MetadataModel] = None - schema: Optional[ReportSchemaVersion] = None - project = None - config: Optional[ConfigModel] = None - console: Optional[Console] = None - system_scan_policy: Optional[PolicyFileModel] = None - platform_enabled: bool = False diff --git a/safety/output_utils.py b/safety/output_utils.py index 0ebb928d..da10c27d 100644 --- a/safety/output_utils.py +++ b/safety/output_utils.py @@ -1,21 +1,26 @@ -from dataclasses import asdict import json import logging import os import textwrap +from dataclasses import asdict from datetime import datetime -from typing import List, Tuple, Dict, Optional, Any, Union +from typing import Any, Dict, List, Optional, Tuple, Union import click +from jinja2 import Environment, PackageLoader -from packaging.specifiers import SpecifierSet from safety.constants import RED, YELLOW +from safety.meta import get_version from safety.models import Fix, is_pinned_requirement -from safety.util import get_safety_version, Package, get_terminal_size, \ - SafetyContext, build_telemetry_data, build_git_data, is_a_remote_mirror, get_remediations_count - -from jinja2 import Environment, PackageLoader - +from safety.util import ( + Package, + SafetyContext, + build_git_data, + build_telemetry_data, + get_remediations_count, + get_terminal_size, + is_a_remote_mirror, +) LOG = logging.getLogger(__name__) @@ -956,7 +961,7 @@ def get_report_brief_info(as_dict: bool = False, report_type: int = 1, **kwargs: brief_data['api_key'] = bool(key) brief_data['account'] = account brief_data['local_database_path'] = db if db else None - brief_data['safety_version'] = get_safety_version() + brief_data['safety_version'] = get_version() brief_data['timestamp'] = current_time brief_data['packages_found'] = len(packages) # Vuln report @@ -1007,7 +1012,7 @@ def get_report_brief_info(as_dict: bool = False, report_type: int = 1, **kwargs: timestamp = [{'style': False, 'value': 'Timestamp '}, {'style': True, 'value': current_time}] brief_info = [[{'style': False, 'value': 'Safety '}, - {'style': True, 'value': 'v' + get_safety_version()}, + {'style': True, 'value': 'v' + get_version()}, {'style': False, 'value': ' is scanning for '}, {'style': True, 'value': scanning_types.get(context.command, {}).get('name', '')}, {'style': True, 'value': '...'}] + safety_policy_used + audit_and_monitor, action_executed diff --git a/safety/pip/__init__.py b/safety/pip/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/safety/pip/command.py b/safety/pip/command.py new file mode 100644 index 00000000..813a967d --- /dev/null +++ b/safety/pip/command.py @@ -0,0 +1,49 @@ +from pathlib import Path + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + +import typer +from typer import Option + +from .constants import PIP_COMMAND_NAME, PIP_COMMAND_HELP +from .decorators import optional_project_command +from ..cli_util import CommandType, FeatureType, SafetyCLICommand, \ + SafetyCLISubGroup +from ..tool.utils import PipCommand + +from ..constants import CONTEXT_COMMAND_TYPE, CONTEXT_FEATURE_TYPE + +pip_app = typer.Typer(rich_markup_mode="rich", cls=SafetyCLISubGroup) + + +@pip_app.command( + cls=SafetyCLICommand, + help=PIP_COMMAND_HELP, + name=PIP_COMMAND_NAME, + options_metavar="[OPTIONS]", + context_settings={"allow_extra_args": True, + "ignore_unknown_options": True, + CONTEXT_COMMAND_TYPE: CommandType.BETA, + CONTEXT_FEATURE_TYPE: FeatureType.FIREWALL}, +) +@optional_project_command +def init( + ctx: typer.Context, + target: Annotated[ + Path, + Option( + exists=True, + file_okay=False, + dir_okay=True, + writable=False, + readable=True, + resolve_path=True, + show_default=False, + ), + ] = Path("."), +): + command = PipCommand.from_args(ctx.args) + command.execute(ctx) diff --git a/safety/pip/constants.py b/safety/pip/constants.py new file mode 100644 index 00000000..ed285675 --- /dev/null +++ b/safety/pip/constants.py @@ -0,0 +1,3 @@ +# PIP options +PIP_COMMAND_NAME = "pip" +PIP_COMMAND_HELP = "[BETA] Commands for managing Safety project.\nExample: safety pip list" diff --git a/safety/pip/decorators.py b/safety/pip/decorators.py new file mode 100644 index 00000000..73359342 --- /dev/null +++ b/safety/pip/decorators.py @@ -0,0 +1,53 @@ +from functools import wraps +from pathlib import Path + +from safety_schemas.models import ProjectModel + +from ..cli_util import process_auth_status_not_ready +from safety.console import main_console +from ..init.main import load_unverified_project_from_config, verify_project +from ..scan.util import GIT + + +def optional_project_command(func): + @wraps(func) + def inner(ctx, target: Path, *args, **kwargs): + ctx.obj.console = main_console + ctx.params.pop("console", None) + + if not ctx.obj.auth.is_valid(): + process_auth_status_not_ready( + console=main_console, auth=ctx.obj.auth, ctx=ctx + ) + + upload_request_id = kwargs.pop("upload_request_id", None) + + # Load .safety-project.ini + unverified_project = load_unverified_project_from_config(project_root=target) + + if ctx.obj.platform_enabled and not unverified_project.created: + stage = ctx.obj.auth.stage + session = ctx.obj.auth.client + git_data = GIT(root=target).build_git_data() + origin = None + + if git_data: + origin = git_data.origin + + verify_project( + main_console, ctx, session, unverified_project, stage, origin + ) + + ctx.obj.project.git = git_data + else: + ctx.obj.project = ProjectModel( + id="", + name="Undefined project", + project_path=unverified_project.project_path, + ) + + ctx.obj.project.upload_request_id = upload_request_id + + return func(ctx, target=target, *args, **kwargs) + + return inner diff --git a/safety/scan/command.py b/safety/scan/command.py index 4bbba379..dd8fa095 100644 --- a/safety/scan/command.py +++ b/safety/scan/command.py @@ -5,7 +5,6 @@ import json import sys from typing import Any, Dict, List, Optional, Set, Tuple, Callable -from typing_extensions import Annotated from safety.constants import EXIT_CODE_VULNERABILITIES_FOUND from safety.safety import process_fixes_scan @@ -27,13 +26,20 @@ SYSTEM_SCAN_TARGET_HELP, SCAN_APPLY_FIXES, SCAN_DETAILED_OUTPUT, CLI_SCAN_COMMAND_HELP, CLI_SYSTEM_SCAN_COMMAND_HELP from safety.scan.decorators import inject_metadata, scan_project_command_init, scan_system_command_init from safety.scan.finder.file_finder import should_exclude -from safety.scan.main import load_policy_file, load_unverified_project_from_config, process_files, save_report_as +from safety.init.main import load_unverified_project_from_config +from safety.scan.main import load_policy_file, process_files, save_report_as from safety.scan.models import ScanExport, ScanOutput, SystemScanExport, SystemScanOutput from safety.scan.render import print_detected_ecosystems_section, print_fixes_section, print_summary, render_scan_html, render_scan_spdx, render_to_console -from safety.scan.util import Stage from safety_schemas.models import Ecosystem, FileModel, FileType, ProjectModel, \ - ReportModel, ScanType, VulnerabilitySeverityLabels, SecurityUpdates, Vulnerability + ReportModel, ScanType, VulnerabilitySeverityLabels, SecurityUpdates, Vulnerability, \ + Stage from safety.scan.fun_mode.easter_eggs import run_easter_egg + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + LOG = logging.getLogger(__name__) diff --git a/safety/scan/constants.py b/safety/scan/constants.py index 7a9cf13b..6f0b5e26 100644 --- a/safety/scan/constants.py +++ b/safety/scan/constants.py @@ -1,11 +1,11 @@ -from safety.util import get_safety_version +from safety.meta import get_version # Console Help Theme CONSOLE_HELP_THEME = { "nhc": "grey82" } -CLI_VERSION = get_safety_version() +CLI_VERSION = get_version() CLI_WEBSITE_URL="https://safetycli.com" CLI_DOCUMENTATION_URL="https://docs.safetycli.com" CLI_SUPPORT_EMAIL="support@safetycli.com" @@ -53,6 +53,8 @@ CLI_VALIDATE_HELP = "Check if your local Safety CLI policy file is valid."\ "\nExample: Example: safety validate --path /path/to/policy.yml" +CLI_GATEWAY_CONFIGURE_COMMAND_HELP = "Configures the project in the working directory to use Gateway." + # Global options help _CLI_PROXY_TIP_HELP = f"[nhc]Note: proxy details can be set globally in a config file.[/nhc]\n\nSee [bold]safety configure --help[/bold]\n\n" @@ -144,6 +146,8 @@ # Generate options CLI_GENERATE_PATH = "The path where the generated file will be saved (default: current directory).\n\n" \ "[bold]Example: safety generate policy_file --path .my-project-safety-policy.yml[/bold]" +CLI_GENERATE_MINIMUM_CVSS_SEVERITY = "The minimum CVSS severity to generate the installation policy for.\n\n" \ +"[bold]Example: safety generate installation_policy --minimum-cvss-severity high[/bold]" # Command default settings CMD_PROJECT_NAME = "scan" diff --git a/safety/scan/decorators.py b/safety/scan/decorators.py index 72d7e4e1..26c03a7d 100644 --- a/safety/scan/decorators.py +++ b/safety/scan/decorators.py @@ -2,8 +2,6 @@ import logging import os from pathlib import Path -from random import randint -import sys from typing import Any, List, Optional from rich.padding import Padding @@ -12,16 +10,17 @@ from safety.auth.cli import render_email_note from safety.cli_util import process_auth_status_not_ready from safety.console import main_console -from safety.constants import SAFETY_POLICY_FILE_NAME, SYSTEM_CONFIG_DIR, SYSTEM_POLICY_FILE, USER_POLICY_FILE -from safety.errors import SafetyError, SafetyException, ServerError +from safety.constants import SYSTEM_POLICY_FILE, USER_POLICY_FILE +from safety.errors import SafetyError, SafetyException from safety.scan.constants import DEFAULT_SPINNER -from safety.scan.main import PROJECT_CONFIG, download_policy, load_policy_file, \ - load_unverified_project_from_config, resolve_policy +from safety.scan.main import download_policy, load_policy_file, resolve_policy from safety.scan.models import ScanOutput, SystemScanOutput -from safety.scan.render import print_announcements, print_header, print_project_info, print_wait_policy_download +from safety.scan.render import print_announcements, print_header, print_wait_policy_download from safety.scan.util import GIT +from ..init.main import load_unverified_project_from_config, verify_project + +from safety.scan.validators import fail_if_not_allowed_stage -from safety.scan.validators import verify_project from safety.util import build_telemetry_data, pluralize from safety_schemas.models import MetadataModel, ScanType, ReportSchemaVersion, \ PolicySource @@ -29,26 +28,6 @@ LOG = logging.getLogger(__name__) -def initialize_scan(ctx: Any, console: Console) -> None: - """ - Initializes the scan by setting platform_enabled based on the response from the server. - """ - data = None - - try: - data = ctx.obj.auth.client.initialize_scan() - except SafetyException as e: - LOG.error("Unable to initialize scan", exc_info=True) - except SafetyError as e: - if e.error_code: - raise e - except Exception as e: - LOG.exception("Exception trying to initialize scan", exc_info=True) - - if data: - ctx.obj.platform_enabled = data.get("platform-enabled", False) - - def scan_project_command_init(func): """ Decorator to make general verifications before each project scan command. @@ -70,10 +49,6 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, upload_request_id = kwargs.pop("upload_request_id", None) - # Run the initialize if it was not fired by a system-scan - if not upload_request_id: - initialize_scan(ctx, console) - # Load .safety-project.ini unverified_project = load_unverified_project_from_config(project_root=target) @@ -98,8 +73,8 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, project_path=unverified_project.project_path ) - ctx.obj.project.upload_request_id = upload_request_id ctx.obj.project.git = git_data + ctx.obj.project.upload_request_id = upload_request_id if not policy_file_path: policy_file_path = target / Path(".safety-policy.yml") @@ -202,8 +177,6 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], process_auth_status_not_ready(console=console, auth=ctx.obj.auth, ctx=ctx) - initialize_scan(ctx, console) - console.print() print_header(console=console, targets=targets, is_system_scan=True) diff --git a/safety/scan/main.py b/safety/scan/main.py index 8df72f33..8a0c6c32 100644 --- a/safety/scan/main.py +++ b/safety/scan/main.py @@ -1,36 +1,30 @@ -import configparser import logging -from pathlib import Path -import re -import requests import os -from urllib.parse import urljoin import platform import time -from typing import Any, Dict, Generator, Optional, Set, Tuple, Union +from pathlib import Path +from typing import Any, Dict, Generator, Optional, Set, Tuple + from pydantic import ValidationError -import typer +from safety_schemas.models import ( + ConfigModel, + FileType, + PolicyFileModel, + PolicySource, + ScanType, + Stage, +) + +from safety.scan.util import GIT + from ..auth.utils import SafetyAuthSession from ..errors import SafetyError from .ecosystems.base import InspectableFile from .ecosystems.target import InspectableFileContext -from .models import ScanExport, UnverifiedProjectModel -from safety.scan.util import GIT - -from safety_schemas.models import FileType, PolicyFileModel, PolicySource, \ - ConfigModel, Stage, ProjectModel, ScanType -from safety.util import get_safety_version - -from safety.constants import PLATFORM_API_BASE_URL +from .models import ScanExport LOG = logging.getLogger(__name__) -PROJECT_CONFIG = ".safety-project.ini" -PROJECT_CONFIG_SECTION = "project" -PROJECT_CONFIG_ID = "id" -PROJECT_CONFIG_URL = "url" -PROJECT_CONFIG_NAME = "name" - def download_policy(session: SafetyAuthSession, project_id: str, stage: Stage, branch: Optional[str]) -> Optional[PolicyFileModel]: """ @@ -82,56 +76,6 @@ def download_policy(session: SafetyAuthSession, project_id: str, stage: Stage, b return None -def load_unverified_project_from_config(project_root: Path) -> UnverifiedProjectModel: - """ - Loads an unverified project from the configuration file located at the project root. - - Args: - project_root (Path): The root directory of the project. - - Returns: - UnverifiedProjectModel: An instance of UnverifiedProjectModel. - """ - config = configparser.ConfigParser() - project_path = project_root / PROJECT_CONFIG - config.read(project_path) - id = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_ID, fallback=None) - id = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_ID, fallback=None) - url = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_URL, fallback=None) - name = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_NAME, fallback=None) - created = True - if id: - created = False - - return UnverifiedProjectModel(id=id, url_path=url, - name=name, project_path=project_path, - created=created) - - -def save_project_info(project: ProjectModel, project_path: Path) -> None: - """ - Saves the project information to the configuration file. - - Args: - project (ProjectModel): The ProjectModel object containing project information. - project_path (Path): The path to the configuration file. - """ - config = configparser.ConfigParser() - config.read(project_path) - - if PROJECT_CONFIG_SECTION not in config.sections(): - config[PROJECT_CONFIG_SECTION] = {} - - config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_ID] = project.id - if project.url_path: - config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_URL] = project.url_path - if project.name: - config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_NAME] = project.name - - with open(project_path, 'w') as configfile: - config.write(configfile) - - def load_policy_file(path: Path) -> Optional[PolicyFileModel]: """ Loads a policy file from the specified path. @@ -236,7 +180,7 @@ def build_meta(target: Path) -> Dict[str, Any]: } client_metadata = { - "version": get_safety_version(), + "version": get_version(), } return { diff --git a/safety/scan/models.py b/safety/scan/models.py index 86ffc21a..82174605 100644 --- a/safety/scan/models.py +++ b/safety/scan/models.py @@ -1,9 +1,6 @@ from enum import Enum -from pathlib import Path from typing import Optional -from pydantic.dataclasses import dataclass - class FormatMixin: """ Mixin class providing format-related utilities for Enum classes. @@ -120,14 +117,3 @@ class SystemScanExport(str, Enum): Enum representing different system scan export formats. """ JSON = "json" - -@dataclass -class UnverifiedProjectModel(): - """ - Data class representing an unverified project model. - """ - id: Optional[str] - project_path: Path - created: bool - name: Optional[str] = None - url_path: Optional[str] = None diff --git a/safety/scan/render.py b/safety/scan/render.py index a44a51ef..44a63b1d 100644 --- a/safety/scan/render.py +++ b/safety/scan/render.py @@ -1,31 +1,39 @@ -from collections import defaultdict -from datetime import datetime +import datetime import itertools import json import logging -from pathlib import Path import time +from collections import defaultdict +from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple -from rich.prompt import Prompt -from rich.text import Text + +import typer from rich.console import Console from rich.padding import Padding -from safety_schemas.models import Vulnerability, ReportModel -import typer +from rich.prompt import Prompt +from rich.text import Text +from safety_schemas.models import ( + Ecosystem, + FileType, + IgnoreCodes, + PolicyFileModel, + PolicySource, + ProjectModel, + PythonDependency, + ReportModel, + Vulnerability, +) + from safety import safety from safety.auth.constants import SAFETY_PLATFORM_URL from safety.errors import SafetyException +from safety.meta import get_version from safety.output_utils import parse_html from safety.scan.constants import DEFAULT_SPINNER - -from safety_schemas.models import Ecosystem, FileType, PolicyFileModel, \ - PolicySource, ProjectModel, IgnoreCodes, Stage, PythonDependency - -from safety.util import get_basic_announcements, get_safety_version +from safety.util import clean_project_id, get_basic_announcements LOG = logging.getLogger(__name__) -import datetime def render_header(targets: List[Path], is_system_scan: bool) -> Text: """ @@ -38,7 +46,7 @@ def render_header(targets: List[Path], is_system_scan: bool) -> Text: Returns: Text: Rendered header text. """ - version = get_safety_version() + version = get_version() scan_datetime = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") action = f"scanning {', '.join([str(t) for t in targets])}" @@ -356,51 +364,36 @@ def print_wait_policy_download(console: Console, closure: Tuple[Any, Dict[str, A return policy -def prompt_project_id(console: Console, stage: Stage, prj_root_name: Optional[str], do_not_exit: bool = True) -> Optional[str]: +def prompt_project_id( + console: Console, + default_id: str) -> str: """ - Prompt the user to set a project ID for the scan. - - Args: - console (Console): The console for output. - stage (Stage): The current stage. - prj_root_name (Optional[str]): The root name of the project. - do_not_exit (bool): Indicates if the function should not exit on failure. - - Returns: - Optional[str]: The project ID. + Prompt the user to set a project ID, on a non-interactive mode it will + fallback to the default ID parameter. """ - from safety.util import clean_project_id - default_prj_id = clean_project_id(prj_root_name) if prj_root_name else None - - non_interactive_mode = console.quiet or not console.is_interactive - if stage is not Stage.development and non_interactive_mode: - # Fail here - console.print("The scan needs to be linked to a project.") - raise typer.Exit(code=1) - - hint = "" - if default_prj_id: - hint = f" If empty Safety will use [bold]{default_prj_id}[/bold]" - prompt_text = f"Set a project id for this scan (no spaces).{hint}" + default_prj_id = clean_project_id(default_id) - def ask(): - prj_id = None + interactive_mode = console.is_interactive and not console.quiet - result = Prompt.ask(prompt_text, default=None, console=console) - - if result: - prj_id = clean_project_id(result) - elif default_prj_id: - prj_id = default_prj_id - - return prj_id + if not interactive_mode: + LOG.info("Fallback to default project id, because of " \ + "non-interactive mode.") + + return default_prj_id - project_id = ask() + hint = f" If empty Safety will use [bold]{default_prj_id}[/bold]" + prompt_text = f"Set a project id (no spaces).{hint}" - while not project_id and do_not_exit: - project_id = ask() + while True: + result = Prompt.ask( + prompt_text, + console=console, + default=default_prj_id, + show_default=False + ) - return project_id + return clean_project_id(result) if result != default_prj_id \ + else default_prj_id def prompt_link_project(console: Console, prj_name: str, prj_admin_email: str) -> bool: @@ -421,7 +414,7 @@ def prompt_link_project(console: Console, prj_name: str, prj_admin_email: str) - f"[bold]Project admin:[/bold] {prj_admin_email}"): console.print(Padding(detail, (0, 0, 0, 2)), emoji=True) - prompt_question = "Do you want to link this scan with this existing project?" + prompt_question = "Do you want to link it with this existing project?" answer = Prompt.ask(prompt=prompt_question, choices=["y", "n"], default="y", show_default=True, console=console).lower() @@ -577,9 +570,8 @@ def generate_spdx_creation_info(spdx_version: str, project_identifier: str) -> A DOC_COMMENT = f"This document was created using SPDX {spdx_version}" CREATOR_COMMENT = "Safety CLI automatically created this SPDX document from a scan report." - from ..util import get_safety_version TOOL_ID = "safety" - TOOL_VERSION = get_safety_version() + TOOL_VERSION = get_version() doc_creator = Actor( actor_type=ActorType.TOOL, @@ -635,11 +627,10 @@ def create_packages(dependencies: List[PythonDependency]) -> List[Any]: Returns: List[Any]: List of SPDX packages. """ - from spdx_tools.spdx.model.spdx_no_assertion import SpdxNoAssertion - from spdx_tools.spdx.model import ( Package, ) + from spdx_tools.spdx.model.spdx_no_assertion import SpdxNoAssertion doc_pkgs = [] pkgs_added = set([]) @@ -736,10 +727,7 @@ def render_scan_spdx(report: ReportModel, obj: Any, spdx_version: Optional[str]) Returns: Optional[Any]: The rendered SPDX document in JSON format. """ - from spdx_tools.spdx.writer.write_utils import ( - convert, - validate_and_deduplicate - ) + from spdx_tools.spdx.writer.write_utils import convert, validate_and_deduplicate # Set to latest supported if a version is not specified if not spdx_version: diff --git a/safety/scan/validators.py b/safety/scan/validators.py index a118b47e..923c8d8b 100644 --- a/safety/scan/validators.py +++ b/safety/scan/validators.py @@ -3,12 +3,9 @@ from pathlib import Path from typing import Optional, Tuple import typer -from safety.scan.main import save_project_info -from safety.scan.models import ScanExport, ScanOutput, UnverifiedProjectModel -from safety.scan.render import print_wait_project_verification, prompt_project_id, prompt_link_project +from safety.scan.models import ScanExport, ScanOutput -from safety_schemas.models import AuthenticationType, ProjectModel, Stage -from safety.auth.utils import SafetyAuthSession +from safety_schemas.models import AuthenticationType MISSING_SPDX_EXTENSION_MSG = "spdx extra is not installed, please install it with: pip install safety[spdx]" @@ -57,122 +54,22 @@ def output_callback(output: ScanOutput) -> str: return output.value -def save_verified_project(ctx: typer.Context, slug: str, name: Optional[str], project_path: Path, url_path: Optional[str]): +def fail_if_not_allowed_stage(ctx: typer.Context): """ - Save the verified project information to the context and project info file. + Fail the command if the authentication type is not allowed in the current stage. Args: ctx (typer.Context): The context of the Typer command. - slug (str): The project slug. - name (Optional[str]): The project name. - project_path (Path): The project path. - url_path (Optional[str]): The project URL path. """ - ctx.obj.project = ProjectModel( - id=slug, - name=name, - project_path=project_path, - url_path=url_path - ) - if ctx.obj.auth.stage is Stage.development: - save_project_info(project=ctx.obj.project, - project_path=project_path) - - -def check_project(console, ctx: typer.Context, session: SafetyAuthSession, - unverified_project: UnverifiedProjectModel, stage: Stage, - git_origin: Optional[str], ask_project_id: bool = False) -> dict: - """ - Check the project against the session and stage, verifying the project if necessary. - - Args: - console: The console for output. - ctx (typer.Context): The context of the Typer command. - session (SafetyAuthSession): The authentication session. - unverified_project (UnverifiedProjectModel): The unverified project model. - stage (Stage): The current stage. - git_origin (Optional[str]): The Git origin URL. - ask_project_id (bool): Whether to prompt for the project ID. + if ctx.resilient_parsing: + return - Returns: - dict: The result of the project check. - """ stage = ctx.obj.auth.stage - source = ctx.obj.telemetry.safety_source if ctx.obj.telemetry else None - data = {"scan_stage": stage, "safety_source": source} - - PRJ_SLUG_KEY = "project_slug" - PRJ_SLUG_SOURCE_KEY = "project_slug_source" - PRJ_GIT_ORIGIN_KEY = "git_origin" - - if git_origin: - data[PRJ_GIT_ORIGIN_KEY] = git_origin - - if unverified_project.id: - data[PRJ_SLUG_KEY] = unverified_project.id - data[PRJ_SLUG_SOURCE_KEY] = ".safety-project.ini" - elif not git_origin or ask_project_id: - # Set a project id for this scan (no spaces). If empty Safety will use: pyupio: - parent_root_name = None - if unverified_project.project_path.parent.name: - parent_root_name = unverified_project.project_path.parent.name - - unverified_project.id = prompt_project_id(console, stage, parent_root_name) - data[PRJ_SLUG_KEY] = unverified_project.id - data[PRJ_SLUG_SOURCE_KEY] = "user" - - status = print_wait_project_verification(console, data[PRJ_SLUG_KEY] if data.get(PRJ_SLUG_KEY, None) else "-", - (session.check_project, data), on_error_delay=1) - - return status - - -def verify_project(console, ctx: typer.Context, session: SafetyAuthSession, - unverified_project: UnverifiedProjectModel, stage: Stage, - git_origin: Optional[str]): - """ - Verify the project, linking it if necessary and saving the verified project information. - - Args: - console: The console for output. - ctx (typer.Context): The context of the Typer command. - session (SafetyAuthSession): The authentication session. - unverified_project (UnverifiedProjectModel): The unverified project model. - stage (Stage): The current stage. - git_origin (Optional[str]): The Git origin URL. - """ - - verified_prj = False - - link_prj = True - - while not verified_prj: - result = check_project(console, ctx, session, unverified_project, stage, git_origin, ask_project_id=not link_prj) - - unverified_slug = result.get("slug") - - project = result.get("project", None) - user_confirm = result.get("user_confirm", False) - - if user_confirm: - if project and link_prj: - prj_name = project.get("name", None) - prj_admin_email = project.get("admin", None) - - link_prj = prompt_link_project(prj_name=prj_name, - prj_admin_email=prj_admin_email, - console=console) - - if not link_prj: - continue + auth_type: AuthenticationType = ctx.obj.auth.client.get_authentication_type() - verified_prj = print_wait_project_verification( - console, unverified_slug, (session.project, - {"project_id": unverified_slug}), - on_error_delay=1) + if os.getenv("SAFETY_DB_DIR"): + return - if verified_prj and isinstance(verified_prj, dict) and verified_prj.get("slug", None): - save_verified_project(ctx, verified_prj["slug"], verified_prj.get("name", None), - unverified_project.project_path, verified_prj.get("url", None)) - else: - verified_prj = False + if not auth_type.is_allowed_in(stage): + raise typer.BadParameter(f"'{auth_type.value}' auth type isn't allowed with " \ + f"the '{stage}' stage.") diff --git a/safety/tool/__init__.py b/safety/tool/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/safety/tool/constants.py b/safety/tool/constants.py new file mode 100644 index 00000000..dd9b9956 --- /dev/null +++ b/safety/tool/constants.py @@ -0,0 +1,1006 @@ +REPOSITORY_URL = "https://pkgs.safetycli.com/repository/public/pypi/simple/" +PROJECT_CONFIG = ".safety-project.ini" + +MOST_FREQUENTLY_DOWNLOADED_PYPI_PACKAGES = [ + "boto3", + "urllib3", + "botocore", + "requests", + "setuptools", + "certifi", + "idna", + "charset-normalizer", + "aiobotocore", + "typing-extensions", + "python-dateutil", + "s3transfer", + "packaging", + "grpcio-status", + "s3fs", + "six", + "fsspec", + "pyyaml", + "numpy", + "importlib-metadata", + "cryptography", + "zipp", + "cffi", + "pip", + "pandas", + "google-api-core", + "pycparser", + "pydantic", + "protobuf", + "wheel", + "jmespath", + "attrs", + "rsa", + "pyasn1", + "click", + "platformdirs", + "pytz", + "colorama", + "jinja2", + "awscli", + "markupsafe", + "tomli", + "pyjwt", + "googleapis-common-protos", + "filelock", + "virtualenv", + "cachetools", + "wrapt", + "google-auth", + "pluggy", + "pytest", + "pydantic-core", + "pyparsing", + "docutils", + "pyarrow", + "pyasn1-modules", + "requests-oauthlib", + "aiohttp", + "scipy", + "jsonschema", + "oauthlib", + "sqlalchemy", + "iniconfig", + "exceptiongroup", + "yarl", + "decorator", + "multidict", + "psutil", + "soupsieve", + "greenlet", + "tzdata", + "pillow", + "isodate", + "pygments", + "beautifulsoup4", + "annotated-types", + "requests-toolbelt", + "frozenlist", + "tomlkit", + "pyopenssl", + "aiosignal", + "distlib", + "async-timeout", + "more-itertools", + "openpyxl", + "tqdm", + "et-xmlfile", + "grpcio", + "deprecated", + "cloudpickle", + "lxml", + "pynacl", + "werkzeug", + "proto-plus", + "azure-core", + "google-cloud-storage", + "asn1crypto", + "coverage", + "websocket-client", + "msgpack", + "h11", + "rich", + "dill", + "pexpect", + "sniffio", + "anyio", + "mypy-extensions", + "ptyprocess", + "importlib-resources", + "sortedcontainers", + "matplotlib", + "chardet", + "rpds-py", + "grpcio-tools", + "aiohappyeyeballs", + "flask", + "httpx", + "referencing", + "scikit-learn", + "jsonschema-specifications", + "httpcore", + "pyzmq", + "poetry-core", + "keyring", + "google-cloud-core", + "python-dotenv", + "pathspec", + "markdown-it-py", + "pkginfo", + "msal", + "networkx", + "bcrypt", + "mdurl", + "gitpython", + "psycopg2-binary", + "poetry-plugin-export", + "google-resumable-media", + "paramiko", + "kiwisolver", + "smmap", + "gitdb", + "xmltodict", + "snowflake-connector-python", + "tabulate", + "cycler", + "typedload", + "jaraco-classes", + "jeepney", + "secretstorage", + "ruamel-yaml", + "tenacity", + "wcwidth", + "build", + "backoff", + "shellingham", + "threadpoolctl", + "regex", + "itsdangerous", + "portalocker", + "py", + "google-crc32c", + "rapidfuzz", + "pyproject-hooks", + "py4j", + "google-cloud-bigquery", + "fastjsonschema", + "sqlparse", + "mccabe", + "pytest-cov", + "awswrangler", + "trove-classifiers", + "msal-extensions", + "azure-storage-blob", + "google-api-python-client", + "pycodestyle", + "joblib", + "google-auth-oauthlib", + "ruamel-yaml-clib", + "tzlocal", + "docker", + "alembic", + "fonttools", + "prompt-toolkit", + "cachecontrol", + "azure-identity", + "distro", + "marshmallow", + "uritemplate", + "isort", + "cython", + "ply", + "httplib2", + "redis", + "pymysql", + "pyrsistent", + "gym-notices", + "google-auth-httplib2", + "poetry", + "blinker", + "defusedxml", + "dnspython", + "dulwich", + "toml", + "gunicorn", + "crashtest", + "markdown", + "nest-asyncio", + "babel", + "cleo", + "sentry-sdk", + "opentelemetry-api", + "scramp", + "multiprocess", + "installer", + "termcolor", + "black", + "huggingface-hub", + "mock", + "msrest", + "pendulum", + "requests-aws4auth", + "ipython", + "pyflakes", + "pycryptodomex", + "grpc-google-iam-v1", + "types-requests", + "azure-common", + "traitlets", + "fastapi", + "setuptools-scm", + "tornado", + "flake8", + "contourpy", + "prometheus-client", + "future", + "openai", + "mako", + "pycryptodome", + "imageio", + "jedi", + "webencodings", + "pygithub", + "parso", + "transformers", + "typing-inspect", + "kubernetes", + "jsonpointer", + "matplotlib-inline", + "starlette", + "loguru", + "opentelemetry-sdk", + "retry", + "argcomplete", + "pkgutil-resolve-name", + "redshift-connector", + "elasticsearch", + "pymongo", + "opentelemetry-semantic-conventions", + "pytzdata", + "pytest-runner", + "asgiref", + "pg8000", + "bs4", + "datadog", + "debugpy", + "python-json-logger", + "jsonpath-ng", + "uvicorn", + "executing", + "smart-open", + "zope-interface", + "asttokens", + "typer", + "aioitertools", + "apache-airflow", + "sagemaker", + "arrow", + "google-pasta", + "pyspark", + "humanfriendly", + "websockets", + "stack-data", + "shapely", + "pure-eval", + "torch", + "oscrypto", + "tokenizers", + "pysocks", + "sphinx", + "typeguard", + "tox", + "scikit-image", + "requests-file", + "google-cloud-pubsub", + "pytest-mock", + "google-cloud-secret-manager", + "snowflake-sqlalchemy", + "mysql-connector-python", + "pylint", + "jupyter-core", + "jupyter-client", + "astroid", + "jsonpatch", + "setproctitle", + "adal", + "types-python-dateutil", + "ipykernel", + "xgboost", + "orjson", + "schema", + "tb-nightly", + "nbconvert", + "xlrd", + "toolz", + "appdirs", + "aiofiles", + "sympy", + "opensearch-py", + "nodeenv", + "pywavelets", + "jaraco-functools", + "jupyter-server", + "nbformat", + "jupyterlab", + "progressbar2", + "comm", + "identify", + "bleach", + "mypy", + "pathos", + "pyodbc", + "pre-commit", + "xlsxwriter", + "rfc3339-validator", + "aws-requests-auth", + "gym", + "pox", + "ppft", + "mistune", + "aenum", + "jaraco-context", + "tinycss2", + "pbr", + "google-cloud-appengine-logging", + "notebook", + "db-dtypes", + "mpmath", + "sentencepiece", + "responses", + "cfgv", + "cattrs", + "python-utils", + "slack-sdk", + "jupyterlab-server", + "nbclient", + "lz4", + "ipywidgets", + "sshtunnel", + "absl-py", + "widgetsnbextension", + "watchdog", + "asynctest", + "semver", + "rfc3986", + "google-cloud-aiplatform", + "jupyterlab-widgets", + "altair", + "pandas-gbq", + "click-man", + "tensorboard", + "smdebug-rulesconfig", + "simplejson", + "text-unidecode", + "argon2-cffi", + "apache-airflow-providers-common-sql", + "snowballstemmer", + "azure-mgmt-core", + "docker-pycreds", + "nltk", + "python-slugify", + "croniter", + "structlog", + "selenium", + "antlr4-python3-runtime", + "google-cloud-logging", + "argon2-cffi-bindings", + "azure-storage-file-datalake", + "django", + "pydeequ", + "pytest-xdist", + "h5py", + "google-cloud-resource-manager", + "dataclasses", + "execnet", + "send2trash", + "opentelemetry-proto", + "google-cloud-bigquery-storage", + "oauth2client", + "dataclasses-json", + "json5", + "tiktoken", + "wandb", + "databricks-sql-connector", + "langchain-core", + "overrides", + "prettytable", + "pandocfilters", + "semantic-version", + "jupyterlab-pygments", + "msrestazure", + "safetensors", + "hvac", + "colorlog", + "imbalanced-learn", + "monotonic", + "seaborn", + "alabaster", + "terminado", + "webcolors", + "ordered-set", + "graphql-core", + "notebook-shim", + "lazy-object-proxy", + "funcsigs", + "numba", + "llvmlite", + "gremlinpython", + "xxhash", + "great-expectations", + "flatbuffers", + "pydata-google-auth", + "fqdn", + "uri-template", + "imagesize", + "opentelemetry-exporter-otlp-proto-common", + "isoduration", + "backports-tarfile", + "wsproto", + "tensorflow", + "thrift", + "hypothesis", + "rfc3986-validator", + "trio", + "inflection", + "html5lib", + "plotly", + "entrypoints", + "sphinxcontrib-serializinghtml", + "jupyter-events", + "lockfile", + "coloredlogs", + "sphinxcontrib-htmlhelp", + "cached-property", + "sphinxcontrib-qthelp", + "sphinxcontrib-devhelp", + "sphinxcontrib-applehelp", + "gast", + "azure-cli", + "azure-datalake-store", + "opentelemetry-exporter-otlp-proto-http", + "pyproject-api", + "azure-mgmt-resource", + "async-lru", + "faker", + "sphinxcontrib-jsmath", + "nose", + "opencv-python", + "outcome", + "statsmodels", + "readme-renderer", + "jupyter-server-terminals", + "libcst", + "retrying", + "datasets", + "aniso8601", + "pybind11", + "databricks-sdk", + "pyroaring", + "azure-keyvault-secrets", + "email-validator", + "argparse", + "parameterized", + "docopt", + "google-cloud-audit-log", + "confluent-kafka", + "kafka-python", + "pymssql", + "zeep", + "gcsfs", + "click-plugins", + "jupyter-lsp", + "ruff", + "deepdiff", + "docstring-parser", + "tblib", + "time-machine", + "jiter", + "patsy", + "azure-storage-common", + "deprecation", + "azure-nspkg", + "databricks-cli", + "nh3", + "twine", + "invoke", + "delta-spark", + "watchtower", + "mlflow", + "pydantic-settings", + "azure-mgmt-storage", + "opentelemetry-exporter-otlp-proto-grpc", + "applicationinsights", + "dbt-core", + "freezegun", + "pickleshare", + "apache-airflow-providers-ssh", + "python-multipart", + "langchain", + "uv", + "unidecode", + "azure-keyvault-keys", + "azure-cosmos", + "pytest-metadata", + "pipenv", + "tensorboard-data-server", + "azure-graphrbac", + "google-cloud-kms", + "backcall", + "trio-websocket", + "azure-keyvault", + "pytest-asyncio", + "psycopg2", + "google-cloud-dataproc", + "keras", + "datetime", + "zope-event", + "apache-airflow-providers-google", + "backports-zoneinfo", + "google-cloud-monitoring", + "looker-sdk", + "azure-mgmt-containerregistry", + "makefun", + "google-cloud-vision", + "mlflow-skinny", + "hatchling", + "spacy", + "torchvision", + "apache-airflow-providers-snowflake", + "google-cloud-spanner", + "google-cloud-container", + "nvidia-nccl-cu12", + "triton", + "gevent", + "google-cloud-dlp", + "uvloop", + "simple-salesforce", + "tldextract", + "analytics-python", + "apache-airflow-providers-databricks", + "tensorflow-estimator", + "google-cloud-bigquery-datatransfer", + "azure-mgmt-keyvault", + "azure-mgmt-cosmosdb", + "azure-mgmt-compute", + "graphviz", + "google-cloud-tasks", + "ujson", + "opentelemetry-instrumentation", + "azure-mgmt-authorization", + "fastavro", + "httptools", + "pathlib2", + "azure-mgmt-network", + "google-cloud-datacatalog", + "pkce", + "google-ads", + "opt-einsum", + "sh", + "jsondiff", + "azure-mgmt-msi", + "google-cloud-firestore", + "evergreen-py", + "google-cloud-bigtable", + "astunparse", + "watchfiles", + "configparser", + "flask-appbuilder", + "fabric", + "azure-mgmt-recoveryservices", + "apache-airflow-providers-mysql", + "scp", + "db-contrib-tool", + "google-cloud-build", + "omegaconf", + "azure-mgmt-monitor", + "ecdsa", + "gspread", + "azure-mgmt-signalr", + "azure-mgmt-containerinstance", + "blis", + "thinc", + "bitarray", + "murmurhash", + "pycrypto", + "dask", + "requests-mock", + "catalogue", + "cymem", + "azure-mgmt-sql", + "preshed", + "google-cloud-workflows", + "opentelemetry-exporter-otlp", + "azure-mgmt-web", + "google-cloud-redis", + "azure-batch", + "kombu", + "pywin32", + "azure-data-tables", + "wasabi", + "azure-mgmt-containerservice", + "azure-mgmt-servicebus", + "azure-mgmt-redis", + "google-cloud-dataplex", + "srsly", + "pytimeparse", + "google-cloud-language", + "authlib", + "google-cloud-automl", + "google-cloud-videointelligence", + "google-cloud-os-login", + "azure-mgmt-rdbms", + "brotli", + "pyserial", + "azure-mgmt-dns", + "langchain-community", + "nvidia-cudnn-cu12", + "texttable", + "azure-mgmt-advisor", + "google-cloud-memcache", + "azure-mgmt-eventhub", + "tensorflow-serving-api", + "gsutil", + "lark", + "azure-cli-core", + "flask-cors", + "pysftp", + "celery", + "langcodes", + "azure-mgmt-batch", + "azure-mgmt-loganalytics", + "azure-mgmt-cdn", + "ninja", + "azure-mgmt-recoveryservicesbackup", + "azure-mgmt-iothub", + "azure-mgmt-search", + "azure-mgmt-marketplaceordering", + "azure-mgmt-trafficmanager", + "azure-mgmt-managementgroups", + "pip-tools", + "azure-mgmt-cognitiveservices", + "azure-mgmt-devtestlabs", + "azure-mgmt-eventgrid", + "python-gnupg", + "jira", + "pypdf2", + "azure-mgmt-applicationinsights", + "azure-mgmt-servicefabric", + "billiard", + "azure-mgmt-media", + "azure-mgmt-billing", + "ratelimit", + "azure-mgmt-iothubprovisioningservices", + "azure-mgmt-policyinsights", + "azure-mgmt-nspkg", + "google-cloud-orchestration-airflow", + "apache-airflow-providers-cncf-kubernetes", + "azure-mgmt-batchai", + "azure-mgmt-iotcentral", + "azure-mgmt-datamigration", + "graphql-relay", + "azure-mgmt-maps", + "graphene", + "azure-appconfiguration", + "amqp", + "google-cloud-dataproc-metastore", + "mdit-py-plugins", + "google-cloud-translate", + "ijson", + "sqlalchemy-bigquery", + "vine", + "nvidia-cublas-cu12", + "nvidia-nvjitlink-cu12", + "spacy-loggers", + "spacy-legacy", + "levenshtein", + "agate", + "azure-mgmt-datalake-nspkg", + "knack", + "yapf", + "awscrt", + "azure-mgmt-datalake-store", + "google-cloud-dataform", + "types-pyyaml", + "confection", + "propcache", + "google-cloud-speech", + "nvidia-cuda-runtime-cu12", + "opencensus", + "opencensus-context", + "nvidia-cuda-cupti-cu12", + "nvidia-cuda-nvrtc-cu12", + "parsedatetime", + "nvidia-cusparse-cu12", + "nvidia-cufft-cu12", + "nvidia-cusolver-cu12", + "grpcio-gcp", + "nvidia-curand-cu12", + "google-cloud-texttospeech", + "typing", + "humanize", + "pytest-html", + "langsmith", + "nvidia-nvtx-cu12", + "flask-sqlalchemy", + "opentelemetry-util-http", + "narwhals", + "azure-multiapi-storage", + "gcloud-aio-storage", + "pycountry", + "jsonpickle", + "zstandard", + "avro-python3", + "libclang", + "apispec", + "gcloud-aio-auth", + "azure-storage-queue", + "contextlib2", + "azure-mgmt-datalake-analytics", + "gcloud-aio-bigquery", + "azure-mgmt-reservations", + "javaproperties", + "tensorflow-io-gcs-filesystem", + "azure-loganalytics", + "djangorestframework", + "azure-mgmt-consumption", + "hpack", + "google-cloud-compute", + "click-didyoumean", + "azure-mgmt-relay", + "parsimonious", + "azure-synapse-artifacts", + "python-magic", + "azure-cli-telemetry", + "click-repl", + "moto", + "pyathena", + "pyproj", + "protobuf3-to-dict", + "durationpy", + "stevedore", + "python-daemon", + "azure-synapse-spark", + "apache-airflow-providers-http", + "mypy-boto3-s3", + "pyspnego", + "cfn-lint", + "astor", + "azure-mgmt-apimanagement", + "h2", + "hyperframe", + "azure-mgmt-hdinsight", + "azure-mgmt-privatedns", + "boto3-stubs", + "mashumaro", + "dateparser", + "ml-dtypes", + "mysqlclient", + "azure-mgmt-security", + "opencensus-ext-azure", + "azure-mgmt-synapse", + "azure-mgmt-kusto", + "azure-mgmt-netapp", + "grpcio-health-checking", + "azure-mgmt-redhatopenshift", + "iso8601", + "lightgbm", + "azure-mgmt-appconfiguration", + "azure-keyvault-administration", + "boto", + "azure-mgmt-sqlvirtualmachine", + "azure-mgmt-imagebuilder", + "azure-synapse-accesscontrol", + "enum34", + "azure-mgmt-servicelinker", + "azure-mgmt-botservice", + "azure-mgmt-servicefabricmanagedclusters", + "jpype1", + "python-jose", + "azure-mgmt-databoxedge", + "azure-synapse-managedprivateendpoints", + "azure-mgmt-extendedlocation", + "office365-rest-python-client", + "onnxruntime", + "azure-mgmt-managedservices", + "cramjam", + "urllib3-secure-extra", + "avro", + "holidays", + "psycopg", + "botocore-stubs", + "fasteners", + "resolvelib", + "partd", + "hyperlink", + "leather", + "apscheduler", + "flask-wtf", + "jupyter", + "marisa-trie", + "locket", + "jupyter-console", + "python-http-client", + "elastic-transport", + "dbt-extractor", + "tensorflow-text", + "language-data", + "inflect", + "fuzzywuzzy", + "cytoolz", + "cmake", + "parse", + "python-gitlab", + "mypy-boto3-rds", + "tifffile", + "eth-utils", + "eth-hash", + "netaddr", + "incremental", + "setuptools-rust", + "python-levenshtein", + "geopandas", + "twisted", + "langchain-text-splitters", + "types-awscrt", + "apache-airflow-providers-fab", + "yamllint", + "cligj", + "sphinx-rtd-theme", + "azure-mgmt-deploymentmanager", + "pytest-timeout", + "lazy-loader", + "wtforms", + "bytecode", + "accelerate", + "polars", + "sendgrid", + "frozendict", + "flask-login", + "opentelemetry-instrumentation-requests", + "jaydebeapi", + "eth-typing", + "dacite", + "types-pytz", + "py-cpuinfo", + "querystring-parser", + "universal-pathlib", + "dbt-semantic-interfaces", + "magicattr", + "cssselect", + "fastparquet", + "opencv-python-headless", + "automat", + "unicodecsv", + "constantly", + "kfp", + "ddtrace", + "logbook", + "envier", + "cloudpathlib", + "types-s3transfer", + "google-cloud-dataflow-client", + "sqlalchemy-utils", + "apache-beam", + "validators", + "bracex", + "apache-airflow-providers-ftp", + "phonenumbers", + "diskcache", + "mergedeep", + "slicer", + "shap", + "python-docx", + "types-urllib3", + "pytest-rerunfailures", + "types-setuptools", + "pathy", + "pytz-deprecation-shim", + "yappi", + "pydot", + "types-protobuf", + "ipython-genutils", + "pytorch-lightning", + "fire", + "apache-airflow-providers-sqlite", + "nvidia-cublas-cu11", + "azure-storage-file-share", + "mmh3", + "azure-mgmt-datafactory", + "azure-servicebus", + "nvidia-cudnn-cu11", + "inject", + "typed-ast", + "connexion", + "configargparse", + "linkify-it-py", + "aws-sam-translator", + "slackclient", + "eth-abi", + "pydash", + "timm", + "datadog-api-client", + "nvidia-cuda-runtime-cu11", + "nvidia-cuda-nvrtc-cu11", + "geographiclib", + "gradio", + "cron-descriptor", + "ansible", + "azure-kusto-data", + "django-cors-headers", + "junit-xml", + "geopy", + "uc-micro-py", + "pyee", + "xarray", + "ansible-core", + "pypdf", + "pyotp", + "starkbank-ecdsa", + "geoip2", + "multimethod", + "eth-account", + "meson", + "jellyfish", + "futures", + "cachelib", + "flask-caching", + "natsort", + "autopep8", + "torchaudio", + "torchmetrics", + "pydub", + "pandera", + "pyhcl", + "apache-airflow-providers-slack", + "oracledb", + "google-cloud-run", + "h3", + "apache-airflow-providers-amazon", + "sqlalchemy-spanner", + "events", + "google-cloud-batch", + "requests-ntlm", + "bottle", + "google-cloud-storage-transfer", + "junitparser", + "apache-airflow-providers-smtp", + "apache-airflow-providers-imap", + "emoji", + "crcmod", + "statsd", + "limits", + "apache-airflow-providers-common-io", + "methodtools", + "asyncpg", + "strictyaml", + "wcmatch", + "marshmallow-sqlalchemy", + "faiss-cpu", + "sentence-transformers", + "psycopg-binary", + "azure-keyvault-certificates", + "django-filter", + "maxminddb", + "weasel", + "gql", + "onnx", + "fiona", + "boltons", + "dbt-common", + "bidict", + "keras-applications", + "json-merge-patch", + "elasticsearch-dsl", + "ftfy", + "swagger-ui-bundle", + "tableauserverclient", + "flask-jwt-extended", + "lightning-utilities", + "meson-python", + "google-cloud", +] + diff --git a/safety/tool/interceptors/__init__.py b/safety/tool/interceptors/__init__.py new file mode 100644 index 00000000..16db3d41 --- /dev/null +++ b/safety/tool/interceptors/__init__.py @@ -0,0 +1,4 @@ +from .types import InterceptorType +from .factory import create_interceptor + +__all__ = ['InterceptorType', 'create_interceptor'] diff --git a/safety/tool/interceptors/base.py b/safety/tool/interceptors/base.py new file mode 100644 index 00000000..aadf6263 --- /dev/null +++ b/safety/tool/interceptors/base.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import List, Dict, Optional, Tuple +from .types import InterceptorType + +from safety.meta import get_version + +@dataclass +class Tool: + name: str + binary_names: List[str] + + +# TODO: Add Event driven output and support --safety-ping flag to test the +# interceptors status. +class CommandInterceptor(ABC): + """ + Abstract base class for command interceptors. + This class provides a framework for installing and removing interceptors + for various tools. Subclasses must implement the `_batch_install_tools` + and `_batch_remove_tools` methods to handle the actual installation and + removal processes. + + Attributes: + interceptor_type (InterceptorType): The type of the interceptor. + tools (Dict[str, Tool]): A dictionary mapping tool names to Tool + objects. + Note: + All method implementations should be idempotent. + """ + + def __init__(self, interceptor_type: InterceptorType): + self.interceptor_type = interceptor_type + self.tools: Dict[str, Tool] = { + 'pip': Tool('pip', ['pip', 'pip3']), + } + + @abstractmethod + def _batch_install_tools(self, tools: List[Tool]) -> bool: + """ + Install multiple tools at once. Must be implemented by subclasses. + """ + pass + + @abstractmethod + def _batch_remove_tools(self, tools: List[Tool]) -> bool: + """ + Remove multiple tools at once. Must be implemented by subclasses. + """ + pass + + def install_interceptors(self, tools: Optional[List[str]] = None) -> bool: + """ + Install interceptors for the specified tools or all tools if none + specified. + """ + tools_to_install = self._get_tools(tools) + return self._batch_install_tools(tools_to_install) + + def remove_interceptors(self, tools: Optional[List[str]] = None) -> bool: + """ + Remove interceptors for the specified tools or all tools if none + specified. + """ + tools_to_remove = self._get_tools(tools) + return self._batch_remove_tools(tools_to_remove) + + def _get_tools(self, tools: Optional[List[str]] = None) -> List[Tool]: + """ + Get list of Tool objects based on tool names. + """ + if tools is None: + return list(self.tools.values()) + return [self.tools[name] for name in tools if name in self.tools] + + + def _generate_metadata_content(self, prepend: str) -> Tuple[str, str, str]: + """ + Create metadata for the files that are managed by us. + """ + return ( + f"{prepend} DO NOT EDIT THIS FILE DIRECTLY", + f"{prepend} Last updated at: {datetime.now(timezone.utc).isoformat()}", + f"{prepend} Updated by: safety v{get_version()}" + ) \ No newline at end of file diff --git a/safety/tool/interceptors/factory.py b/safety/tool/interceptors/factory.py new file mode 100644 index 00000000..0179b289 --- /dev/null +++ b/safety/tool/interceptors/factory.py @@ -0,0 +1,30 @@ +from sys import platform +from typing import Optional +from .types import InterceptorType +from .unix import UnixAliasInterceptor +from .windows import WindowsInterceptor +from .base import CommandInterceptor + + +def create_interceptor(interceptor_type: + Optional[InterceptorType] = None) -> CommandInterceptor: + """ + Create appropriate interceptor based on OS and type + """ + interceptor_map = { + InterceptorType.UNIX_ALIAS: UnixAliasInterceptor, + InterceptorType.WINDOWS_BAT: WindowsInterceptor + } + + if interceptor_type: + return interceptor_map[interceptor_type]() + + # Auto-select based on OS + if platform == 'win32': + return interceptor_map[InterceptorType.WINDOWS_BAT]() + + if platform in ['linux', 'linux2', 'darwin']: + # Default to alias-based on Unix-like systems + return interceptor_map[InterceptorType.UNIX_ALIAS]() + + raise NotImplementedError(f"Platform '{platform}' is not supported.") diff --git a/safety/tool/interceptors/types.py b/safety/tool/interceptors/types.py new file mode 100644 index 00000000..6e1b698c --- /dev/null +++ b/safety/tool/interceptors/types.py @@ -0,0 +1,5 @@ +from enum import Enum, auto + +class InterceptorType(Enum): + UNIX_ALIAS = auto() + WINDOWS_BAT = auto() \ No newline at end of file diff --git a/safety/tool/interceptors/unix.py b/safety/tool/interceptors/unix.py new file mode 100644 index 00000000..e50baecb --- /dev/null +++ b/safety/tool/interceptors/unix.py @@ -0,0 +1,188 @@ +import logging +from pathlib import Path +import re +import shutil +import tempfile +from typing import List +from .base import CommandInterceptor, Tool +from .types import InterceptorType + +from safety.constants import USER_CONFIG_DIR + + +LOG = logging.getLogger(__name__) + +class UnixAliasInterceptor(CommandInterceptor): + + def __init__(self): + super().__init__(InterceptorType.UNIX_ALIAS) + self.user_rc_path = self._get_user_rc_path() + self.custom_rc_path = self._get_custom_rc_path() + + # Update these markers could be a breaking change; be careful to handle + # backward compatibility + self.marker_start = "# >>> Safety >>>" + self.marker_end = "# <<< Safety <<<" + + # .profile is not a rc file, but is what we support for now. + def _get_user_rc_path(self) -> Path: + """ + We support .profile for now. + """ + home = Path.home() + return home / '.profile' + + def _get_custom_rc_path(self) -> Path: + return USER_CONFIG_DIR / ".safety_profile" + + def _backup_file(self, path: Path) -> None: + """ + Create backup of file if it exists + """ + if path.exists(): + backup_path = path.with_suffix('.backup') + shutil.copy2(path, backup_path) + + def _generate_user_rc_content(self) -> str: + """ + Generate the content to be added to user's rc. + + Example: + ``` + # >>> Safety >>> + [ -f "$HOME/.safety/.safety_profile" ] && . "$HOME/.safety/.safety_profile" + # <<< Safety <<< + ``` + """ + lines = ( + self.marker_start, + f'[ -f "{self.custom_rc_path}" ] && . "{self.custom_rc_path}"', + self.marker_end, + ) + return "\n".join(lines) + "\n" + + def _is_configured(self) -> bool: + """ + Check if the configuration block exists in user's rc file + """ + try: + if not self.user_rc_path.exists(): + return False + + content = self.user_rc_path.read_text() + return self.marker_start in content and self.marker_end in content + + except OSError as e: + LOG.info("Failed to read user's rc file") + return False + + def _generate_custom_rc_content(self, aliases: List[str]) -> str: + """ + Generate the content for the custom profile with metadata + """ + metadata_lines = self._generate_metadata_content(prepend="#") + aliases_lines = tuple(aliases) + + lines = (self.marker_start,) + metadata_lines + aliases_lines + \ + (self.marker_end,) + + return "\n".join(lines) + "\n" + + def _ensure_source_line_in_user_rc(self) -> None: + """ + Ensure source line exists in user's .profile + + If the source line is not present in the user's .profile, append it. + If the user's .profile does not exist, create it. + """ + source_line = self._generate_user_rc_content() + + if not self.user_rc_path.exists(): + self.user_rc_path.write_text(source_line) + return + + if not self._is_configured(): + with open(self.user_rc_path, 'a') as f: + f.write(source_line) + + def _batch_install_tools(self, tools: List[Tool]) -> bool: + """ + Install aliases for multiple tools + """ + try: + # Generate aliases + aliases = [] + for tool in tools: + for binary in tool.binary_names: + alias_def = f'alias {binary}="safety {tool.name}"' + aliases.append(alias_def) + + if not aliases: + return False + + # Create safety profile directory if it doesn't exist + self.custom_rc_path.parent.mkdir(parents=True, exist_ok=True) + + # Generate new profile content + profile_content = self._generate_custom_rc_content(aliases) + + # Backup target files + for f_path in [self.user_rc_path, self.custom_rc_path]: + self._backup_file(path=f_path) + + # Override our custom profile + # TODO: handle exceptions + self.custom_rc_path.write_text(profile_content) + + # Ensure source line in user's .profile + self._ensure_source_line_in_user_rc() + + return True + + except Exception as e: + print(f"Failed to batch install aliases: {e}") + return False + + def _batch_remove_tools(self, tools: List[Tool]) -> bool: + """ + This will remove all the tools. + + NOTE: for now this does not support to remove individual tools. + """ + try: + + # Backup target files + for f_path in [self.user_rc_path, self.custom_rc_path]: + self._backup_file(path=f_path) + + if self._is_configured(): + temp_dir = tempfile.gettempdir() + temp_file = Path(temp_dir) / f"{self.user_rc_path.name}.tmp" + + pattern = rf"{self.marker_start}\n.*?\{self.marker_end}\n?" + + with open(self.user_rc_path, 'r') as src, \ + open(temp_file, 'w') as dst: + content = src.read() + cleaned_content = re.sub(pattern, '', content, + flags=re.DOTALL) + dst.write(cleaned_content) + + if not temp_file.exists(): + LOG.info("Temp file is empty or invalid") + return False + + shutil.move(str(temp_file), str(self.user_rc_path)) + + self.custom_rc_path.unlink(missing_ok=True) + + return True + except Exception as e: + print(f"Failed to batch remove aliases: {e}") + return False + + def _install_tool(self, tool: Tool) -> bool: + return self._batch_install_tools([tool]) + + def _remove_tool(self, tool: Tool) -> bool: + return self._batch_remove_tools([tool]) diff --git a/safety/tool/interceptors/windows.py b/safety/tool/interceptors/windows.py new file mode 100644 index 00000000..0eb4aa07 --- /dev/null +++ b/safety/tool/interceptors/windows.py @@ -0,0 +1,169 @@ +import logging +import os +import shutil +from pathlib import Path +from sys import platform +from typing import TYPE_CHECKING, List + +from .base import CommandInterceptor, Tool +from .types import InterceptorType + +if TYPE_CHECKING or platform == "win32": + import winreg + +LOG = logging.getLogger(__name__) + + +class WindowsInterceptor(CommandInterceptor): + def __init__(self): + super().__init__(InterceptorType.WINDOWS_BAT) + self.scripts_dir = Path.home() / 'AppData' / 'Local' / 'safety' + self.backup_dir = self.scripts_dir / 'backups' + self.backup_win_env_path = self.backup_dir / 'path_backup.txt' + + # Update these markers could be a breaking change; be careful to handle + # backward compatibility + self.marker_start = ">>> Safety >>>" + self.marker_end = "<<< Safety <<<" + + def _backup_path_env(self, path_content: str) -> None: + """ + Backup current PATH to a file + """ + self.backup_dir.mkdir(parents=True, exist_ok=True) + + metadata_lines = self._generate_metadata_content(prepend="") + + lines = (self.marker_start,) + metadata_lines + (path_content,) + \ + (self.marker_end,) + + content = "\n".join(lines) + "\n" + + self.backup_win_env_path.write_text(content) + + def _generate_bat_content(self, tool_name: str) -> str: + """ + Generate the content for the bat with metadata + """ + metadata_lines = self._generate_metadata_content(prepend="REM") + + no_echo = "@echo off" + wrapper = f"safety {tool_name} %*" + lines = (no_echo, + f"REM {self.marker_start}",) + metadata_lines + (wrapper,) + \ + (f"REM {self.marker_end}",) + + return "\n".join(lines) + "\n" + + def _batch_install_tools(self, tools: List[Tool]) -> bool: + """ + Install interceptors for multiple tools at once + """ + try: + wrappers = [] + for tool in tools: + for binary in tool.binary_names: + # TODO: Switch to binary once we support safety pip3, etc. + wrapper = self._generate_bat_content(tool.name) + wrappers.append((binary, wrapper)) + + if not wrappers: + return False + + # Create safety directory if it doesn't exist + self.scripts_dir.mkdir(parents=True, exist_ok=True) + + for binary, wrapper in wrappers: + wrapper_path = self.scripts_dir / f'{binary}.bat' + wrapper_path.write_text(wrapper) + + # Add scripts directory to PATH if needed + self._update_path() + + return True + + except Exception as e: + LOG.info("Failed to batch install tools") + return False + + def _batch_remove_tools(self, tools: List[Tool]) -> bool: + """ + Remove interceptors for multiple tools at once. + + Note: We don't support removing specific tools yet, + so we remove all tools. + """ + try: + self._update_path(remove=True) + if self.scripts_dir.exists(): + shutil.rmtree(self.scripts_dir) + + return True + + except Exception as e: + LOG.info("Failed to batch remove tools.") + return False + + def _update_path(self, remove: bool = False) -> bool: + """ + Update Windows PATH environment variable + """ + + try: + with winreg.OpenKey(winreg.HKEY_CURRENT_USER, 'Environment', 0, winreg.KEY_ALL_ACCESS) as key: + # Get current PATH value + try: + path_val = winreg.QueryValueEx(key, 'PATH')[0] + self._backup_path_env(path_content=path_val) + except FileNotFoundError: + path_val = '' + + # Convert to Path objects + paths = [Path(p) for p in path_val.split(os.pathsep) if p] + + if remove: + if self.scripts_dir in paths: + paths.remove(self.scripts_dir) + new_path = os.pathsep.join(str(p) for p in paths) + winreg.SetValueEx(key, 'PATH', 0, winreg.REG_EXPAND_SZ, new_path) + else: + if self.scripts_dir not in paths: + paths.insert(0, self.scripts_dir) # Add to beginning + new_path_val = os.pathsep.join(str(p) for p in paths) + winreg.SetValueEx(key, 'PATH', 0, winreg.REG_EXPAND_SZ, new_path_val) + + return True + except Exception as e: + LOG.info("Failed to update PATH") + return False + + def _install_tool(self, tool: Tool) -> bool: + """Individual tool installation (fallback method)""" + return self._batch_install_tools([tool]) + + def _remove_tool(self, tool: Tool) -> bool: + """Individual tool removal (fallback method)""" + return self._batch_remove_tools([tool]) + + def _validate_installation(self, tool: Tool) -> bool: + try: + # Check if batch files exist + for binary in tool.binary_names: + batch_script = self.scripts_dir / f'{binary}.bat' + if not batch_script.exists(): + return False + + # Check if directory is in PATH + key = winreg.OpenKey( + winreg.HKEY_CURRENT_USER, + 'Environment', + 0, + winreg.KEY_READ + ) + path = winreg.QueryValueEx(key, 'PATH')[0] + winreg.CloseKey(key) + + return str(self.scripts_dir) in path + + except Exception as e: + return False \ No newline at end of file diff --git a/safety/tool/main.py b/safety/tool/main.py new file mode 100644 index 00000000..84b35d05 --- /dev/null +++ b/safety/tool/main.py @@ -0,0 +1,53 @@ +import os.path +from pathlib import Path + +from safety.console import main_console as console +from safety.tool.utils import PipConfigurator, PipRequirementsConfigurator, PoetryPyprojectConfigurator, is_os_supported + +from .interceptors import create_interceptor + + +def has_local_tool_files(directory: Path) -> bool: + configurators = [PipRequirementsConfigurator(), PoetryPyprojectConfigurator()] + + for file_name in os.listdir(directory): + if os.path.isfile(file_name): + file = Path(file_name) + for configurator in configurators: + if configurator.is_supported(file): + return True + + return False + + +def configure_system(): + configurators = [PipConfigurator()] + + for configurator in configurators: + configurator.configure() + +def reset_system(): + configurators = [PipConfigurator()] + + for configurator in configurators: + configurator.reset() + +def configure_alias(): + if not is_os_supported(): + return + + interceptor = create_interceptor() + interceptor.install_interceptors() + + console.print("Configured PIP alias") + + +def configure_local_directory(directory: Path): + configurators = [PipRequirementsConfigurator(), PoetryPyprojectConfigurator()] + + for file_name in os.listdir(directory): + if os.path.isfile(file_name): + file = Path(file_name) + for configurator in configurators: + if configurator.is_supported(file): + configurator.configure(file) diff --git a/safety/tool/pip.py b/safety/tool/pip.py new file mode 100644 index 00000000..1716439f --- /dev/null +++ b/safety/tool/pip.py @@ -0,0 +1,102 @@ +import base64 +import json +import shutil +import subprocess +from pathlib import Path +from typing import Optional +from urllib.parse import urlsplit, urlunsplit + +import typer +from rich.console import Console + +from safety.tool.resolver import get_unwrapped_command + +from safety.console import main_console + +REPOSITORY_URL = "https://pkgs.safetycli.com/repository/public/pypi/simple/" + + +class Pip: + + @classmethod + def is_installed(cls) -> bool: + """ + Checks if the PIP program is installed + + Returns: + True if PIP is installed on system, or false otherwise + """ + return shutil.which("pip") is not None + + @classmethod + def configure_requirements(cls, file: Path, console: Optional[Console] = main_console) -> None: + """ + Configures Safety index url for specified requirements file. + + Args: + file (Path): Path to requirements.txt file. + console (Console): Console instance. + """ + + with open(file, "r+") as f: + content = f.read() + + index_config = f"-i {REPOSITORY_URL}\n" + if content.find(index_config) == -1: + f.seek(0) + f.write(index_config + content) + + console.print(f"Configured {file} file") + else: + console.print(f"{file} is already configured. Skipping.") + + @classmethod + def configure_system(cls, console: Optional[Console] = main_console): + """ + Configures PIP system to use to Safety index url. + """ + try: + subprocess.run([get_unwrapped_command(name="pip"), "config", "set", "global.index-url", REPOSITORY_URL], capture_output=True) + console.print("Configured PIP global settings") + except Exception as e: + console.print("Failed to configure PIP global settings.") + + @classmethod + def reset_system(cls, console: Optional[Console] = main_console): + # TODO: Move this logic and implement it in a more robust way + try: + subprocess.run([get_unwrapped_command(name="pip"), "config", "unset", "global.index-url"], capture_output=True) + except Exception as e: + console.print("Failed to reset PIP global settings.") + + + @classmethod + def index_credentials(cls, ctx: typer.Context): + auth_envelop = json.dumps({ + "version": "1.0", + "access_token": ctx.obj.auth.client.token["access_token"], + "api_key": ctx.obj.auth.client.api_key, + "project_id": ctx.obj.project.id if ctx.obj.project else None, + }) + return base64.urlsafe_b64encode(auth_envelop.encode("utf-8")).decode("utf-8") + + @classmethod + def default_index_url(cls) -> str: + return "https://pypi.org/simple/" + + @classmethod + def build_index_url(cls, ctx: typer.Context, index_url: Optional[str]) -> str: + if index_url is None: + index_url = REPOSITORY_URL + + url = urlsplit(index_url) + + encoded_auth = cls.index_credentials(ctx) + netloc = f'user:{encoded_auth}@{url.netloc}' + + if type(url.netloc) == bytes: + url = url._replace(netloc=netloc.encode("utf-8")) + elif type(url.netloc) == str: + url = url._replace(netloc=netloc) + + return urlunsplit(url) diff --git a/safety/tool/poetry.py b/safety/tool/poetry.py new file mode 100644 index 00000000..130d5b92 --- /dev/null +++ b/safety/tool/poetry.py @@ -0,0 +1,51 @@ +import shutil +import subprocess +from pathlib import Path +from typing import Optional +import sys + +from rich.console import Console + +from safety.console import main_console +from safety.tool.pip import REPOSITORY_URL +from safety.tool.resolver import get_unwrapped_command + +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib + +class Poetry: + + @classmethod + def is_installed(cls) -> bool: + """ + Checks if the PIP program is installed + + Returns: + True if PIP is installed on system, or false otherwise + """ + return shutil.which("poetry") is not None + + @classmethod + def is_poetry_project_file(cls, file: Path) -> bool: + try: + cfg = tomllib.loads(file.read_text()) + return cfg.get("build-system", {}).get("requires") == "poetry-core" + except (IOError, ValueError) as e: + return False + + @classmethod + def configure_pyproject(cls, file: Path, console: Optional[Console] = main_console) -> None: + """ + Configures index url for specified requirements file. + + Args: + file (Path): Path to requirements.txt file. + console (Console): Console instance. + """ + if not cls.is_installed(): + console.log("Poetry is not installed.") + + subprocess.run([get_unwrapped_command(name="poetry"), "source", "add", "safety", REPOSITORY_URL], capture_output=True) + console.print(f"Configured {file} file") diff --git a/safety/tool/resolver.py b/safety/tool/resolver.py new file mode 100644 index 00000000..1c9cae47 --- /dev/null +++ b/safety/tool/resolver.py @@ -0,0 +1,24 @@ +from sys import platform +import subprocess + + +def get_unwrapped_command(name: str) -> str: + """ + Find the true executable for a command, skipping wrappers/aliases/.bat files. + + Args: + command: The command to resolve (e.g. 'pip', 'python') + + Returns: + Path to the actual executable + """ + if platform in ["win32"]: + lookup_term = f"{name}.exe" + where_result = subprocess.run(["where.exe", lookup_term], + capture_output=True, text=True) + if where_result.returncode == 0: + for path in where_result.stdout.splitlines(): + if not path.lower().endswith(f"{name}.bat"): + return path + + return name \ No newline at end of file diff --git a/safety/tool/utils.py b/safety/tool/utils.py new file mode 100644 index 00000000..99c8993f --- /dev/null +++ b/safety/tool/utils.py @@ -0,0 +1,274 @@ +import abc +import os.path +import re +import subprocess +from abc import abstractmethod +from pathlib import Path +from sys import platform +from tempfile import mkstemp + +import typer +from Levenshtein import distance +from filelock import FileLock +from rich.padding import Padding +from rich.prompt import Prompt + +from safety.console import main_console as console +from safety.constants import PIP_LOCK +from safety.tool.constants import MOST_FREQUENTLY_DOWNLOADED_PYPI_PACKAGES, PROJECT_CONFIG, REPOSITORY_URL +from safety.tool.pip import Pip +from safety.tool.poetry import Poetry +from safety.tool.resolver import get_unwrapped_command + +from typing_extensions import List + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from subprocess import CompletedProcess + + +def is_os_supported(): + return platform in ["linux", "linux2", "darwin", "win32"] + +class BuildFileConfigurator(abc.ABC): + + @abc.abstractmethod + def is_supported(self, file: Path) -> bool: + """ + Returns whether a specific file is supported by this class. + Args: + file (str): The file to check. + Returns: + bool: Whether the file is supported by this class. + """ + pass + + @abc.abstractmethod + def configure(self, file: Path) -> None: + """ + Configures specific file. + Args: + file (str): The file to configure. + """ + pass + + +class PipRequirementsConfigurator(BuildFileConfigurator): + __file_name_pattern = re.compile("^([a-zA-Z_-]+)?requirements([a-zA-Z_-]+)?.txt$") + + def is_supported(self, file: Path) -> bool: + return self.__file_name_pattern.match(os.path.basename(file)) is not None + + def configure(self, file: Path) -> None: + Pip.configure_requirements(file) + + +class PoetryPyprojectConfigurator(BuildFileConfigurator): + __file_name_pattern = re.compile("^pyproject.toml$") + + def is_supported(self, file: Path) -> bool: + return self.__file_name_pattern.match(os.path.basename(file)) is not None and Poetry.is_poetry_project_file( + file) + + def configure(self, file: Path) -> None: + Poetry.configure_pyproject(file) + + +# TODO: Review if we should move this/hook up this into interceptors. +class ToolConfigurator(abc.ABC): + + @abc.abstractmethod + def configure(self) -> None: + """ + Configures specific tool. + """ + pass + + @abc.abstractmethod + def reset(self) -> None: + """ + Resets specific tool. + """ + pass + +class PipConfigurator(ToolConfigurator): + + def configure(self) -> None: + Pip.configure_system() + + def reset(self) -> None: + Pip.reset_system() + + +class PipCommand(abc.ABC): + + def __init__(self, args: List[str], capture_output: bool = False) -> None: + self._args = args + self.__capture_output = capture_output + self.__filelock = FileLock(PIP_LOCK, 10) + + @abstractmethod + def before(self, ctx: typer.Context): + pass + + @abstractmethod + def after(self, ctx: typer.Context, result): + pass + + def execute(self, ctx: typer.Context): + with self.__filelock: + self.before(ctx) + # TODO: Safety should redirect to the proper pip, if the user is + # using pip3, it should be redirected to pip3, not pip to avoid any + # issues. + args = [get_unwrapped_command(name="pip")] + self.__remove_safety_args(self._args) + result = subprocess.run(args, capture_output=self.__capture_output, env=self.env(ctx)) + self.after(ctx, result) + + def env(self, ctx: typer.Context): + return os.environ.copy() + + @classmethod + def from_args(self, args): + if "install" in args: + return PipInstallCommand(args) + elif "uninstall" in args: + return PipUninstallCommand(args) + else: + return PipGenericCommand(args) + + def __remove_safety_args(self, args: List[str]): + return [arg for arg in args if not arg.startswith("--safety")] + + +class PipGenericCommand(PipCommand): + + def __init__(self, args: List[str]) -> None: + super().__init__(args) + + def before(self, ctx: typer.Context): + pass + + def after(self, ctx: typer.Context, result): + pass + + +class PipInstallCommand(PipCommand): + + def __init__(self, args: List[str]) -> None: + super().__init__(args) + self.package_names = [] + self.__index_url = None + + def before(self, ctx: typer.Context): + args = self._args + + ranges_to_delete = [] + for ind, val in enumerate(args): + if ind > 0 and (args[ind - 1].startswith("-i") or args[ind - 1].startswith("--index-url")): + if args[ind].startswith("https://pkgs.safetycli.com"): + self.__index_url = args[ind] + + ranges_to_delete.append((ind - 1, ind)) + elif ind > 0 and (args[ind - 1] == "-r" or args[ind - 1] == "--requirement"): + requirement_file = args[ind] + + if not Path(requirement_file).is_file(): + continue + + with open(requirement_file, "r") as f: + fd, tmp_requirements_path = mkstemp(suffix="safety-requirements.txt", text=True) + with os.fdopen(fd, "w") as tf: + requirements = re.sub(r"^(-i|--index-url).*$", "", f.read(), flags=re.MULTILINE) + tf.write(requirements) + + args[ind] = tmp_requirements_path + elif ind > 0 and (not args[ind - 1].startswith("-e") or not args[ind - 1].startswith("--editable")) and not args[ind].startswith("-"): + if args[ind] == '.': + continue + + package_name = args[ind] + (valid, candidate_package_name) = self.__check_typosquatting(package_name) + if not valid: + prompt = f"You are about to install {package_name} package. Did you mean to install {candidate_package_name}?" + answer = Prompt.ask(prompt=prompt, choices=["y", "n"], + default="y", show_default=True, console=console).lower() + if answer == 'y': + package_name = candidate_package_name + console.print(f"Installing {package_name} package instead.") + args[ind] = package_name + + self.__add_package_name(package_name) + + for (start, end) in ranges_to_delete: + args = args[:start] + args[end + 1:] + + self._args = args + + def after(self, ctx: typer.Context, result: 'CompletedProcess[str]'): + if result and result.returncode == 0: + self.__run_scan() + else: + self.__render_package_details() + + def env(self, ctx: typer.Context) -> dict: + env = super().env(ctx) + env["PIP_INDEX_URL"] = Pip.build_index_url(ctx, self.__index_url) if not self.__is_check_disabled() else Pip.default_index_url() + return env + + def __is_check_disabled(self): + return "--safety-disable-check" in self._args + + def __check_typosquatting(self, package_name): + max_edit_distance = 2 if len(package_name) > 5 else 1 + + if package_name in MOST_FREQUENTLY_DOWNLOADED_PYPI_PACKAGES: + return (True, package_name) + + for pkg in MOST_FREQUENTLY_DOWNLOADED_PYPI_PACKAGES: + if (abs(len(pkg) - len(package_name)) <= max_edit_distance + and distance(pkg, package_name) <= max_edit_distance): + return (False, pkg) + + return (True, package_name) + + def __run_scan(self): + if not is_os_supported(): + return + + target = os.getcwd() + if Path(os.path.join(target, PROJECT_CONFIG)).is_file(): + try: + subprocess.Popen( + ['safety', 'scan'], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + stdin=subprocess.DEVNULL, + start_new_session=True + ) + except Exception: + pass + + def __add_package_name(self, package_name): + r = re.compile(r"^([a-zA-Z_-]+)(([~<>=]=)[a-zA-Z0-9._-]+)?") + match = r.match(package_name) + if match: + self.package_names.append(match.group(1)) + + def __render_package_details(self): + for package_name in self.package_names: + console.print( + Padding(f"Learn more: [link]https://data.safetycli.com/packages/pypi/{package_name}/[/link]", + (0, 0, 0, 1)), emoji=True) + + +class PipUninstallCommand(PipCommand): + + def __init__(self, args: List[str]) -> None: + super().__init__(args) + + def before(self, ctx: typer.Context): + pass + + def after(self, ctx: typer.Context, result): + pass diff --git a/safety/util.py b/safety/util.py index 83c30527..79f0f3ca 100644 --- a/safety/util.py +++ b/safety/util.py @@ -7,22 +7,33 @@ from datetime import datetime from difflib import SequenceMatcher from threading import Lock -from typing import List, Optional, Dict, Generator, Tuple, Union, Any +from typing import Any, Dict, Generator, List, Optional, Tuple import click from click import BadParameter -from dparse import parse, filetypes +from dparse import filetypes, parse +from packaging.specifiers import SpecifierSet from packaging.utils import canonicalize_name from packaging.version import parse as parse_version -from packaging.specifiers import SpecifierSet from requests import PreparedRequest from ruamel.yaml import YAML from ruamel.yaml.error import MarkedYAMLError +from safety_schemas.models import TelemetryModel -from safety.constants import EXIT_CODE_FAILURE, EXIT_CODE_OK, HASH_REGEX_GROUPS, SYSTEM_CONFIG_DIR, USER_CONFIG_DIR +from safety.constants import ( + EXIT_CODE_FAILURE, + EXIT_CODE_OK, + HASH_REGEX_GROUPS, + SYSTEM_CONFIG_DIR, + USER_CONFIG_DIR, +) from safety.errors import InvalidProvidedReportError -from safety.models import Package, RequirementFile, is_pinned_requirement, SafetyRequirement -from safety_schemas.models import TelemetryModel +from safety.models import ( + Package, + RequirementFile, + SafetyRequirement, + is_pinned_requirement, +) LOG = logging.getLogger(__name__) @@ -230,7 +241,7 @@ def get_used_options() -> Dict[str, Dict[str, int]]: return used_options -def get_safety_version() -> str: +def get_version() -> str: """ Get the version of Safety. @@ -320,7 +331,7 @@ def build_telemetry_data(telemetry: bool = True, 'safety_options': get_used_options() } if telemetry else {} - body['safety_version'] = get_safety_version() + body['safety_version'] = get_version() body['safety_source'] = os.environ.get("SAFETY_SOURCE", None) or context.safety_source if not 'safety_options' in body: diff --git a/tests/auth/test_main.py b/tests/auth/test_auth_main.py similarity index 98% rename from tests/auth/test_main.py rename to tests/auth/test_auth_main.py index e425ba71..7a5abd65 100644 --- a/tests/auth/test_main.py +++ b/tests/auth/test_auth_main.py @@ -9,7 +9,7 @@ -class TestMain(unittest.TestCase): +class TestAuthMain(unittest.TestCase): def setUp(self): self.assets = Path(__file__).parent / Path("test_assets/") diff --git a/tests/auth/test_auth_utils.py b/tests/auth/test_auth_utils.py new file mode 100644 index 00000000..c6e9629b --- /dev/null +++ b/tests/auth/test_auth_utils.py @@ -0,0 +1,139 @@ +import unittest +from unittest.mock import MagicMock, Mock, patch, call +from safety.auth.utils import initialize +from safety.errors import InvalidCredentialError +from safety.auth.utils import FeatureType, str_to_bool, get_config_setting, save_flags_config + +class TestUtils(unittest.TestCase): + + @patch('safety.auth.utils.get_config_setting') + @patch('safety.auth.utils.str_to_bool') + @patch('safety.auth.utils.save_flags_config') + def test_initialize_with_no_session( + self, + mock_save_flags_config, + mock_str_to_bool, + mock_get_config_setting): + + ctx = Mock() + ctx.obj = None + mock_get_config_setting.return_value = 'true' + mock_str_to_bool.return_value = True + + # First test: when auth is None + with patch('safety.models.SafetyCLI') as MockSafetyCLI: + mock_safety_cli = Mock() + mock_safety_cli.auth = None + MockSafetyCLI.return_value = mock_safety_cli + + initialize(ctx, refresh=True) + + # Verify expected behavior when auth is None + mock_save_flags_config.assert_not_called() + self.assertEqual(mock_get_config_setting.call_count, len(FeatureType)) + + # Reset mock call counts + mock_get_config_setting.reset_mock() + mock_save_flags_config.reset_mock() + + # Second test: when auth is populated but raises exception + ctx = Mock() + mock_safety_cli = Mock() + + mock_initialize = Mock(side_effect=InvalidCredentialError()) + mock_client = Mock() + mock_client.initialize = mock_initialize + mock_auth = Mock() + mock_auth.client = mock_client + mock_safety_cli.auth = mock_auth + ctx.obj = mock_safety_cli + + initialize(ctx, refresh=True) + + # On exception, it should fall back to default values + mock_safety_cli.auth.client.initialize.assert_called_once() + mock_save_flags_config.assert_not_called() + self.assertEqual(mock_get_config_setting.call_count, len(FeatureType)) + + @patch('safety.auth.utils.get_config_setting') + @patch('safety.auth.utils.str_to_bool') + @patch('safety.auth.utils.save_flags_config') + def test_initialize_without_refresh(self, + mock_save_flags_config, + mock_str_to_bool, + mock_get_config_setting): + ctx = MagicMock() + ctx.obj = None + mock_get_config_setting.return_value = 'true' + mock_str_to_bool.return_value = True + + with patch('safety.auth.utils.SafetyCLI') as MockSafetyCLI, \ + patch('safety.auth.utils.setattr') as mock_setattr: + + mock_safety_cli = MockSafetyCLI.return_value + + initialize(ctx, refresh=False) + + mock_safety_cli.auth.client.initialize.assert_not_called() + mock_save_flags_config.assert_not_called() + self.assertEqual(mock_get_config_setting.call_count, + len(FeatureType)) + + expected_calls = [ + call(mock_safety_cli, feature.attr_name, True) + for feature in FeatureType + ] + mock_setattr.assert_has_calls(expected_calls, any_order=True) + + # Verify number of calls matches number of features + self.assertEqual(mock_setattr.call_count, len(FeatureType)) + + + @patch('safety.auth.utils.get_config_setting') + @patch('safety.auth.utils.save_flags_config') + def test_initialize_with_server_response(self, + mock_save_flags_config, + mock_get_config_setting): + + ctx = Mock() + mock_safety_cli = Mock() + + SERVER_RESPONSE = { + "organization": "Test", + "plan": {}, + "firewall-enabled": "false", + "platform-enabled": "true" + } + + mock_initialize = Mock( + return_value={"organization": "Test", + "plan": {}, + "firewall-enabled": "false", + "platform-enabled": "true"}) + mock_client = Mock() + mock_client.initialize = mock_initialize + mock_auth = Mock() + mock_auth.client = mock_client + mock_safety_cli.auth = mock_auth + ctx.obj = mock_safety_cli + + with patch('safety.auth.utils.setattr') as mock_setattr: + + initialize(ctx, refresh=True) + + mock_safety_cli.auth.client.initialize.assert_called_once() + mock_save_flags_config.assert_called_once() + self.assertEqual(mock_get_config_setting.call_count, + len(FeatureType)) + + # Server response should override current values + expected_calls = [ + call(mock_safety_cli, + feature.attr_name, + str_to_bool(SERVER_RESPONSE[feature.config_key])) + for feature in FeatureType + ] + mock_setattr.assert_has_calls(expected_calls, any_order=True) + + # Verify number of calls matches number of features + self.assertEqual(mock_setattr.call_count, len(FeatureType)) diff --git a/tests/auth/test_cli.py b/tests/auth/test_cli.py index f77aeae6..0e801fbc 100644 --- a/tests/auth/test_cli.py +++ b/tests/auth/test_cli.py @@ -1,11 +1,8 @@ -from unittest.mock import Mock, PropertyMock, patch, ANY -import click +from unittest.mock import patch, ANY from click.testing import CliRunner import unittest from safety.cli import cli -from safety.cli_util import get_command_for - class TestSafetyAuthCLI(unittest.TestCase): @@ -13,6 +10,9 @@ def setUp(self): self.maxDiff = None self.runner = CliRunner(mix_stderr=False) + cli.commands = cli.all_commands + self.cli = cli + @unittest.skip("We are bypassing email verification for now") @patch("safety.auth.cli.fail_if_authenticated") @patch("safety.auth.cli.get_authorization_data") @@ -26,7 +26,7 @@ def test_auth_calls_login( "email": "user@safetycli.com", "name": "Safety User", } - result = self.runner.invoke(cli, ["auth"]) + result = self.runner.invoke(self.cli, ["auth"]) fail_if_authenticated.assert_called_once() get_authorization_data.assert_called_once() diff --git a/tests/init/test_init_main.py b/tests/init/test_init_main.py new file mode 100644 index 00000000..4df8cbac --- /dev/null +++ b/tests/init/test_init_main.py @@ -0,0 +1,212 @@ +from pathlib import Path +import unittest +from unittest.mock import Mock, call, patch + +from safety.init.main import PROJECT_CONFIG_ID, PROJECT_CONFIG_NAME, \ + PROJECT_CONFIG_URL, PROJECT_CONFIG_SECTION, check_project, create_project, save_project_info, \ + save_verified_project + +from safety_schemas.models import ProjectModel, Stage + + +class TestInitMain(unittest.TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + @patch('safety.init.main.prompt_project_id') + @patch('safety.init.main.print_wait_project_verification') + def test_check_project_without_id(self, mock_wait_verification, + mock_prompt_id): + """ + If not project id is provided, the user should be prompted for one, + then the project should be verified. + """ + ctx = Mock() + ctx.obj.auth.stage = Stage.production + ctx.obj.telemetry = Mock(safety_source="cli") + + session = Mock() + console = Mock() + unverified_project = Mock( + id=None, + project_path=Path("/test/dir/project/.safety-project.ini") + ) + + # Mock prompt returning project id + mock_prompt_id.return_value = "prompted-id" + mock_wait_verification.return_value = {"status": "success"} + + _ = check_project(ctx, session, console, unverified_project, + git_origin=None) + + # Assert prompt was called with parent dir name + mock_prompt_id.assert_called_once_with(console, "project") + + # Assert core data is correct with prompted id + expected_data = { + "scan_stage": Stage.production, + "safety_source": "cli", + "project_slug": "prompted-id", + "project_slug_source": "user" + } + + # Assert verification called with prompted data + mock_wait_verification.assert_called_once_with( + console, + "prompted-id", + (session.check_project, expected_data), + on_error_delay=1 + ) + + @patch('safety.init.main.save_project_info') + def test_save_verified_project(self, mock_save_project_info): + ctx = Mock() + ctx.obj = Mock() + + values = { + 'id': "test-project", + 'name': "Test Project", + 'project_path': Path("/path/to/project"), + 'url_path': "/test/url" + } + + save_verified_project(ctx, slug=values['id'], + **{k: v for k, v in values.items() if k != 'id'}) + + # Assert project is correct type and values + self.assertIsInstance(ctx.obj.project, ProjectModel) + for attr, expected in values.items(): + self.assertEqual(getattr(ctx.obj.project, attr), expected) + + mock_save_project_info.assert_called_once_with( + project=ctx.obj.project, + project_path=values['project_path'] + ) + + + @patch('configparser.ConfigParser') + def test_save_project_info_success(self, mock_config_cls): + # case_name, project, expected_config + test_cases = [ + ( + "full_project", + ProjectModel( + id="test-id", + url_path="http://example.com", + name="Test Project" + ), + { + PROJECT_CONFIG_ID: "test-id", + PROJECT_CONFIG_URL: "http://example.com", + PROJECT_CONFIG_NAME: "Test Project" + } + ), + ( + "id_only", + ProjectModel(id="test-id"), + {PROJECT_CONFIG_ID: "test-id"} + ), + ( + "with_url", + ProjectModel( + id="test-id", + url_path="http://example.com" + ), + { + PROJECT_CONFIG_ID: "test-id", + PROJECT_CONFIG_URL: "http://example.com" + } + ), + ( + "with_name", + ProjectModel( + id="test-id", + name="Test Project" + ), + { + PROJECT_CONFIG_ID: "test-id", + PROJECT_CONFIG_NAME: "Test Project" + } + ), + ] + + mock_config = mock_config_cls.return_value + section_mock = mock_config.__getitem__.return_value + + for case_name, project, expected_config in test_cases: + with self.subTest(case_name=case_name): + mock_config.reset_mock() + + result = save_project_info(project, "test_config.ini") + + # Assert the result + self.assertTrue(result) + mock_config.__getitem__.assert_called_with(PROJECT_CONFIG_SECTION) + + calls = [ + call(key, value) for key, value in expected_config.items()] + section_mock.__setitem__.assert_has_calls(calls, + any_order=False) + + + @patch('configparser.ConfigParser') + def test_save_project_info_file_error(self, mock_config_cls): + project = ProjectModel(id="test-id") + + mock_config = mock_config_cls.return_value + mock_config.write.side_effect = Exception("Write error") + + result = save_project_info(project, "test_config.ini") + + self.assertFalse(result) + + + @patch('safety.init.main.load_unverified_project_from_config') + @patch('safety.init.main.GIT') + @patch('safety.init.main.verify_project') + def test_create_project_with_platform_enabled(self, mock_verify, mock_git, \ + mock_load_project): + ctx = Mock() + ctx.obj.platform_enabled = True + console = Mock() + target = Path("/some/path") + + mock_git_instance = mock_git.return_value + mock_git_instance.build_git_data.return_value = Mock(origin="test-origin") + mock_project = mock_load_project.return_value + + create_project(ctx, console, target) + + # Make sure project loads unverified, git data is built and + # project verification is called + mock_load_project.assert_called_once_with(project_root=target) + mock_git_instance.build_git_data.assert_called_once() + mock_verify.assert_called_once_with( + console, + ctx, + ctx.obj.auth.client, + mock_project, + ctx.obj.auth.stage, + "test-origin" + ) + + + @patch('safety.init.main.load_unverified_project_from_config') + @patch('safety.init.main.GIT') + @patch('safety.init.main.verify_project') + def test_create_project_platform_disabled(self, mock_verify, mock_git, mock_load_project): + ctx = Mock() + ctx.obj.platform_enabled = False + console = Mock() + target = Path("/some/path") + + create_project(ctx, console, target) + + mock_verify.assert_not_called() + console.print.assert_called_once_with( + "Project creation is not supported for your account." + ) diff --git a/tests/scan/test_command.py b/tests/scan/test_command.py index 36cef9f7..714a8aeb 100644 --- a/tests/scan/test_command.py +++ b/tests/scan/test_command.py @@ -17,14 +17,17 @@ def setUp(self): # is initialized in the CLI console.quiet = False + cli.commands = cli.all_commands + self.cli = cli + @patch.object(Auth, 'is_valid', return_value=False) @patch('safety.auth.utils.SafetyAuthSession.get_authentication_type', return_value="unauthenticated") def test_scan(self, mock_is_valid, mock_get_auth_type): - result = self.runner.invoke(cli, ["scan", "--target", self.target, "--output", "json"]) + result = self.runner.invoke(self.cli, ["scan", "--target", self.target, "--output", "json"]) self.assertEqual(result.exit_code, 1) - result = self.runner.invoke(cli, ["--stage", "production", "scan", "--target", self.target, "--output", "json"]) + result = self.runner.invoke(self.cli, ["--stage", "production", "scan", "--target", self.target, "--output", "json"]) self.assertEqual(result.exit_code, 1) - result = self.runner.invoke(cli, ["--stage", "cicd", "scan", "--target", self.target, "--output", "screen"]) + result = self.runner.invoke(self.cli, ["--stage", "cicd", "scan", "--target", self.target, "--output", "screen"]) self.assertEqual(result.exit_code, 1) diff --git a/tests/scan/test_file_handlers.py b/tests/scan/test_file_handlers.py index 43b59be6..d259e76d 100644 --- a/tests/scan/test_file_handlers.py +++ b/tests/scan/test_file_handlers.py @@ -1,5 +1,5 @@ import os -import pytest + from unittest.mock import Mock, patch from safety.scan.finder.handlers import PythonFileHandler diff --git a/tests/scan/test_render.py b/tests/scan/test_render.py index 4c5296fc..3f6213d8 100644 --- a/tests/scan/test_render.py +++ b/tests/scan/test_render.py @@ -4,7 +4,7 @@ from pathlib import Path import datetime -from safety.scan.render import print_announcements, print_summary, render_header +from safety.scan.render import print_announcements, print_summary, render_header, prompt_project_id from safety_schemas.models import ProjectModel, IgnoreCodes, PolicySource class TestRender(unittest.TestCase): @@ -15,7 +15,7 @@ def setUp(self): self.project.policy = MagicMock() self.project.policy.source = PolicySource.cloud - @patch('safety.scan.render.get_safety_version') + @patch('safety.scan.render.get_version') def test_render_header(self, mock_get_safety_version): mock_get_safety_version.return_value = '3.0.0' @@ -115,3 +115,101 @@ def test_print_summary(self, mock_render_to_console): call('0 security issues found, 0 fixes suggested.'), call('[number]0[/number] fixes suggested, resolving [number]0[/number] vulnerabilities.') ]) + + @patch("safety.scan.render.clean_project_id") + def test_prompt_project_id_non_interactive(self, clean_project_id): + """ + Under these cases, the default project ID should be cleaned and + returned. The prompt should not be shown. + """ + + test_cases = [ + # Non-interactive mode + (True, False, "default_a", "default_a_cleaned"), + # Quiet mode like JSON output under interactive mode + (True, True, "default_b", "default_b_cleaned"), + # No Quiet and Not interactive mode + (False, False, "default_c", "default_c_cleaned"), + ] + + for quiet, is_interactive, default_id, expected_result in test_cases: + with self.subTest(quiet=quiet, is_interactive=is_interactive): + + clean_project_id.return_value = f"{default_id}_cleaned" + console = MagicMock(quiet=quiet, is_interactive=is_interactive) + + result = prompt_project_id(console, default_id) + + assert result == expected_result + + assert result == expected_result, ( + f"Failed for quiet={quiet}, " + f"is_interactive={is_interactive}\n" + f"Expected: {expected_result}\n" + f"Got: {result}\n" + f"Default ID was: {default_id}" + ) + + try: + clean_project_id.assert_called_once_with(default_id) + except AssertionError: + raise AssertionError( + f"Mock wasn't called correctly for " + f"quiet={quiet}, is_interactive={is_interactive}\n" + f"Expected (1) call with: {default_id}\n" + f"Actual calls were " + f"({len(clean_project_id.call_args_list)}): " + f"{clean_project_id.call_args_list}" + ) + + clean_project_id.reset_mock() + + + @patch("safety.scan.render.clean_project_id") + def test_prompt_project_id_interactive(self, clean_project_id): + default_id = "default-project" + default_id_cleaned = f"{default_id}_cleaned" + + test_cases = [("custom-project", "custom-project_cleaned"), + ("", default_id_cleaned)] + + for user_input, expected in test_cases: + with self.subTest(user_input=user_input): + console = MagicMock(quiet=False, is_interactive=True) + + clean_project_id.side_effect = lambda input_string: ( + f"{input_string}_cleaned" + ) + + with patch('safety.scan.render.Prompt.ask') as ask: + + # We mimic the behavior of Prompt.ask on empty input + ask.side_effect = lambda *args, **kwargs: ( + kwargs['default'] if user_input == "" else user_input + ) + + result = prompt_project_id(console, default_id) + + # Verify Prompt.ask was called correctly + ask.assert_called_once_with( + f"Set a project id (no spaces). If empty Safety will use [bold]{default_id_cleaned}[/bold]", + console=console, + default=default_id_cleaned, + show_default=False + ) + + calls = [call(default_id)] + call_count = 1 + + if user_input != "": + calls.append(call(user_input)) + call_count = 2 + + clean_project_id.assert_has_calls(calls) + assert clean_project_id.call_count == call_count + + print(result, expected) + assert result == expected + + # Reset mocks for next test case + clean_project_id.reset_mock() diff --git a/tests/test-safety-project.ini b/tests/test-safety-project.ini new file mode 100644 index 00000000..0e47071b --- /dev/null +++ b/tests/test-safety-project.ini @@ -0,0 +1,4 @@ +[project] +id = safety +url = /projects/e008f386-0a5e-4967-b8b9-079239d5f93c/findings +name = safety diff --git a/tests/test_cli.py b/tests/test_cli.py index 1114afb7..cfbd68aa 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,27 +1,27 @@ import json import logging import os -import sys import shutil +import socket +import sys import tempfile import unittest from pathlib import Path from unittest.mock import Mock, patch -import requests -import socket -import psutil import click +import psutil +import requests from click.testing import CliRunner -from packaging.specifiers import SpecifierSet from packaging.version import Version +from safety_schemas.models.base import AuthenticationType -from safety import cli -from safety.console import main_console as console -from safety.models import CVE, SafetyRequirement, Severity, Vulnerability -from safety.util import Package, SafetyContext, get_safety_version from safety.auth.models import Auth -from safety_schemas.models.base import AuthenticationType +from safety.cli import cli +from safety.console import main_console as console +from safety.meta import get_version +from safety.models import CVE, SafetyCLI, SafetyRequirement, Severity, Vulnerability +from safety.util import Package, SafetyContext def get_vulnerability(vuln_kwargs=None, cve_kwargs=None, pkg_kwargs=None): @@ -81,13 +81,23 @@ def setUp(self): # is initialized in the CLI console.quiet = False + # Reset the commands + # CLI initialization is made on the import of the module + # so we need to reset the commands to avoid side effects + # + # TODO: this is a workaround, we should improve the way the + # CLI is initialized + cli.commands = cli.all_commands + self.cli = cli + + def test_command_line_interface(self): runner = CliRunner() - result = runner.invoke(cli.cli) + result = runner.invoke(self.cli) expected = "Usage: cli [OPTIONS] COMMAND [ARGS]..." for option in [[], ["--help"]]: - result = runner.invoke(cli.cli, option) + result = runner.invoke(self.cli, option) self.assertEqual(result.exit_code, 0) self.assertIn(expected, click.unstyle(result.output)) @@ -95,14 +105,14 @@ def test_command_line_interface(self): def test_check_vulnerabilities_found_default(self, check_func): check_func.return_value = [get_vulnerability()], None EXPECTED_EXIT_CODE_VULNS_FOUND = 64 - result = self.runner.invoke(cli.cli, ['check']) + result = self.runner.invoke(self.cli, ['check']) self.assertEqual(result.exit_code, EXPECTED_EXIT_CODE_VULNS_FOUND) @patch("safety.safety.check") def test_check_vulnerabilities_not_found_default(self, check_func): check_func.return_value = [], None EXPECTED_EXIT_CODE_VULNS_NOT_FOUND = 0 - result = self.runner.invoke(cli.cli, ['check']) + result = self.runner.invoke(self.cli, ['check']) self.assertEqual(result.exit_code, EXPECTED_EXIT_CODE_VULNS_NOT_FOUND) @patch("safety.safety.check") @@ -111,7 +121,7 @@ def test_check_vulnerabilities_found_with_outputs(self, check_func): EXPECTED_EXIT_CODE_VULNS_FOUND = 64 for output in self.output_options: - result = self.runner.invoke(cli.cli, ['check', '--output', output]) + result = self.runner.invoke(self.cli, ['check', '--output', output]) self.assertEqual(result.exit_code, EXPECTED_EXIT_CODE_VULNS_FOUND) @patch("safety.safety.check") @@ -120,7 +130,7 @@ def test_check_vulnerabilities_not_found_with_outputs(self, check_func): EXPECTED_EXIT_CODE_VULNS_NOT_FOUND = 0 for output in self.output_options: - result = self.runner.invoke(cli.cli, ['check', '--output', output]) + result = self.runner.invoke(self.cli, ['check', '--output', output]) self.assertEqual(result.exit_code, EXPECTED_EXIT_CODE_VULNS_NOT_FOUND) @patch("safety.safety.check") @@ -131,11 +141,11 @@ def test_check_continue_on_error(self, check_func): for vulns in [[get_vulnerability()], []]: check_func.return_value = vulns, None - result = self.runner.invoke(cli.cli, ['check', '--continue-on-error']) + result = self.runner.invoke(self.cli, ['check', '--continue-on-error']) self.assertEqual(result.exit_code, EXPECTED_EXIT_CODE_CONTINUE_ON_ERROR) for output in self.output_options: - result = self.runner.invoke(cli.cli, ['check', '--output', output, '--continue-on-error']) + result = self.runner.invoke(self.cli, ['check', '--output', output, '--continue-on-error']) self.assertEqual(result.exit_code, EXPECTED_EXIT_CODE_CONTINUE_ON_ERROR) @patch("safety.safety.get_announcements") @@ -143,7 +153,7 @@ def test_announcements_if_is_not_tty(self, get_announcements_func): announcement = {'type': 'error', 'message': 'Please upgrade now'} get_announcements_func.return_value = [announcement] message = f"* {announcement.get('message')}" - result = self.runner.invoke(cli.cli, ['check']) + result = self.runner.invoke(self.cli, ['check']) self.assertTrue('ANNOUNCEMENTS' in result.stderr) self.assertTrue(message in result.stderr) @@ -157,7 +167,7 @@ def test_check_ignore_format_backward_compatible(self, check): dirname = os.path.dirname(__file__) reqs_path = os.path.join(dirname, "reqs_4.txt") - _ = runner.invoke(cli.cli, ['check', '--file', reqs_path, '--ignore', "123,456", '--ignore', "789"]) + _ = runner.invoke(self.cli, ['check', '--file', reqs_path, '--ignore', "123,456", '--ignore', "789"]) try: check_call_kwargs = check.call_args[1] # Python < 3.8 except IndexError: @@ -171,14 +181,14 @@ def test_check_ignore_format_backward_compatible(self, check): self.assertEqual(check_call_kwargs['ignore_vulns'], ignored_transformed) def test_validate_with_unsupported_argument(self): - result = self.runner.invoke(cli.cli, ['validate', 'safety_ci']) + result = self.runner.invoke(self.cli, ['validate', 'safety_ci']) msg = 'This Safety version only supports "policy_file" validation. "safety_ci" is not supported.\n' self.assertEqual(click.unstyle(result.stderr), msg) self.assertEqual(result.exit_code, 1) def test_validate_with_wrong_path(self): p = Path('imaginary/path') - result = self.runner.invoke(cli.cli, ['validate', 'policy_file', '--path', str(p)]) + result = self.runner.invoke(self.cli, ['validate', 'policy_file', '--path', str(p)]) msg = f'The path "{str(p)}" does not exist.\n' self.assertEqual(click.unstyle(result.stderr), msg) self.assertEqual(result.exit_code, 1) @@ -188,7 +198,7 @@ def test_validate_with_basic_policy_file(self): # Test with policy version 2.0 path = os.path.join(dirname, "test_policy_file", "default_policy_file.yml") - result = self.runner.invoke(cli.cli, ['validate', 'policy_file', '2.0', '--path', path]) + result = self.runner.invoke(self.cli, ['validate', 'policy_file', '2.0', '--path', path]) cleaned_stdout = click.unstyle(result.stdout) msg = 'The Safety policy file (Valid only for the check command) was successfully parsed with the following values:\n' parsed = json.dumps( @@ -215,7 +225,7 @@ def test_validate_with_basic_policy_file(self): # Test with policy version 3.0 path = os.path.join(dirname, "test_policy_file", "v3_0", "default_policy_file.yml") - result = self.runner.invoke(cli.cli, ['validate', 'policy_file', '3.0', '--path', path]) + result = self.runner.invoke(self.cli, ['validate', 'policy_file', '3.0', '--path', path]) cleaned_stdout = click.unstyle(result.stdout) msg = 'The Safety policy (3.0) file (Used for scan and system-scan commands) was successfully parsed with the following values:\n' @@ -292,13 +302,11 @@ def test_validate_with_basic_policy_file(self): self.assertEqual(result.exit_code, 0) - - def test_validate_with_policy_file_using_invalid_keyword(self): dirname = os.path.dirname(__file__) filename = 'default_policy_file_using_invalid_keyword.yml' path = os.path.join(dirname, "test_policy_file", filename) - result = self.runner.invoke(cli.cli, ['validate', 'policy_file', '2.0', '--path', path]) + result = self.runner.invoke(self.cli, ['validate', 'policy_file', '2.0', '--path', path]) cleaned_stdout = click.unstyle(result.stderr) msg_hint = 'HINT: "security" -> "transitive" is not a valid keyword. Valid keywords in this level are: ' \ 'ignore-cvss-severity-below, ignore-cvss-unknown-severity, ignore-vulnerabilities, ' \ @@ -309,7 +317,7 @@ def test_validate_with_policy_file_using_invalid_keyword(self): self.assertEqual(result.exit_code, 1) path = os.path.join(dirname, "test_policy_file", "v3_0", filename) - result = self.runner.invoke(cli.cli, ['validate', 'policy_file', '3.0', '--path', path]) + result = self.runner.invoke(self.cli, ['validate', 'policy_file', '3.0', '--path', path]) cleaned_stdout = click.unstyle(result.stderr) msg = f'Unable to load the Safety Policy file ("{path}"), this command only supports version 3.0, details: 1 validation error for Config' @@ -321,7 +329,7 @@ def test_validate_with_policy_file_using_invalid_typo_keyword(self): dirname = os.path.dirname(__file__) filename = 'default_policy_file_using_invalid_typo_keyword.yml' path = os.path.join(dirname, "test_policy_file", filename) - result = self.runner.invoke(cli.cli, ['validate', 'policy_file', '2.0', '--path', path]) + result = self.runner.invoke(self.cli, ['validate', 'policy_file', '2.0', '--path', path]) cleaned_stdout = click.unstyle(result.stderr) msg_hint = 'HINT: "security" -> "ignore-vunerabilities" is not a valid keyword. Maybe you meant: ' \ 'ignore-vulnerabilities\n' @@ -332,21 +340,21 @@ def test_validate_with_policy_file_using_invalid_typo_keyword(self): def test_generate_pass(self): with tempfile.TemporaryDirectory() as tempdir: - result = self.runner.invoke(cli.cli, ['generate', 'policy_file', '--path', tempdir]) + result = self.runner.invoke(self.cli, ['generate', 'policy_file', '--path', tempdir]) cleaned_stdout = click.unstyle(result.stdout) msg = f'A default Safety policy file has been generated! Review the file contents in the path {tempdir} ' \ f'in the file: .safety-policy.yml\n' self.assertEqual(msg, cleaned_stdout) def test_generate_with_unsupported_argument(self): - result = self.runner.invoke(cli.cli, ['generate', 'safety_ci']) + result = self.runner.invoke(self.cli, ['generate', 'safety_ci']) msg = 'This Safety version only supports "policy_file" generation. "safety_ci" is not supported.\n' self.assertEqual(click.unstyle(result.stderr), msg) self.assertEqual(result.exit_code, 1) def test_generate_with_wrong_path(self): p = Path('imaginary/path') - result = self.runner.invoke(cli.cli, ['generate', 'policy_file', '--path', str(p)]) + result = self.runner.invoke(self.cli, ['generate', 'policy_file', '--path', str(p)]) msg = f'The path "{str(p)}" does not exist.\n' self.assertEqual(click.unstyle(result.stderr), msg) self.assertEqual(result.exit_code, 1) @@ -354,13 +362,13 @@ def test_generate_with_wrong_path(self): def test_check_with_fix_does_verify_api_key(self): dirname = os.path.dirname(__file__) req_file = os.path.join(dirname, "test_fix", "basic", "reqs_simple.txt") - result = self.runner.invoke(cli.cli, ['check', '-r', req_file, '--apply-security-updates']) + result = self.runner.invoke(self.cli, ['check', '-r', req_file, '--apply-security-updates']) self.assertEqual(click.unstyle(result.stderr), "The --apply-security-updates option needs authentication. See https://docs.safetycli.com/safety-docs/support/invalid-api-key-error.\n") self.assertEqual(result.exit_code, 65) def test_check_with_fix_only_works_with_files(self): - result = self.runner.invoke(cli.cli, ['check', '--key', 'TEST-API_KEY', '--apply-security-updates']) + result = self.runner.invoke(self.cli, ['check', '--key', 'TEST-API_KEY', '--apply-security-updates']) self.assertEqual(click.unstyle(result.stderr), '--apply-security-updates only works with files; use the "-r" option to specify files to remediate.\n') self.assertEqual(result.exit_code, 1) @@ -398,13 +406,13 @@ def test_check_with_fix(self, get_packages, calculate_remediations, check_func, req_file = os.path.join(tempdir, 'reqs_simple_minor.txt') shutil.copy(source_req, req_file) - self.runner.invoke(cli.cli, ['check', '-r', req_file, '--key', 'TEST-API_KEY', + self.runner.invoke(self.cli, ['check', '-r', req_file, '--key', 'TEST-API_KEY', '--apply-security-updates']) with open(req_file) as f: self.assertEqual("django==1.8\nsafety==2.3.0\nflask==0.87.0", f.read()) - self.runner.invoke(cli.cli, ['check', '-r', req_file, '--key', 'TEST-API_KEY', '--apply-security-updates', + self.runner.invoke(self.cli, ['check', '-r', req_file, '--key', 'TEST-API_KEY', '--apply-security-updates', '--auto-security-updates-limit', 'minor']) with open(req_file) as f: @@ -423,12 +431,12 @@ def test_check_with_fix(self, get_packages, calculate_remediations, check_func, "more_info_url": "https://pyup.io/p/pypi/django/52d/"}} } - self.runner.invoke(cli.cli, ['check', '-r', req_file, '--key', 'TEST-API_KEY', '--apply-security-updates', + self.runner.invoke(self.cli, ['check', '-r', req_file, '--key', 'TEST-API_KEY', '--apply-security-updates', '-asul', 'minor', '--json']) with open(req_file) as f: self.assertEqual("django==1.9\nsafety==2.3.0\nflask==0.87.0", f.read()) - self.runner.invoke(cli.cli, ['check', '-r', req_file, '--key', 'TEST-API_KEY', '--apply-security-updates', + self.runner.invoke(self.cli, ['check', '-r', req_file, '--key', 'TEST-API_KEY', '--apply-security-updates', '-asul', 'major', '--output', 'bare']) with open(req_file) as f: @@ -440,7 +448,7 @@ def test_check_ignore_unpinned_requirements(self): reqs_unpinned = os.path.join(dirname, "reqs_unpinned.txt") # Test default behavior (ignore_unpinned_requirements is None) - result = self.runner.invoke(cli.cli, ['check', '-r', reqs_unpinned, '--db', db, '--output', 'text']) + result = self.runner.invoke(self.cli, ['check', '-r', reqs_unpinned, '--db', db, '--output', 'text']) # Check for deprecation message self.assertIn("DEPRECATED: this command (`check`) has been DEPRECATED", result.output) @@ -454,14 +462,14 @@ def test_check_ignore_unpinned_requirements(self): self.assertIn(expected_warning, result.output) # Test ignore_unpinned_requirements set to True - result = self.runner.invoke(cli.cli, ['check', '-r', reqs_unpinned, '--ignore-unpinned-requirements', + result = self.runner.invoke(self.cli, ['check', '-r', reqs_unpinned, '--ignore-unpinned-requirements', '--db', db, '--output', 'text']) self.assertIn("Warning: django and numpy are unpinned and potential vulnerabilities are", result.output) self.assertIn("being ignored given `ignore-unpinned-requirements` is True in your config.", result.output) # Test check_unpinned_requirements set to True - result = self.runner.invoke(cli.cli, ['check', '-r', reqs_unpinned, '--db', db, '--json', '-i', 'some id', + result = self.runner.invoke(self.cli, ['check', '-r', reqs_unpinned, '--db', db, '--json', '-i', 'some id', '--check-unpinned-requirements']) # Check for deprecation message @@ -490,7 +498,7 @@ def test_basic_html_output_pass(self): db = os.path.join(dirname, "test_db") reqs_unpinned = os.path.join(dirname, "reqs_unpinned.txt") - result = self.runner.invoke(cli.cli, ['check', '-r', reqs_unpinned, '--db', db, '--output', 'html']) + result = self.runner.invoke(self.cli, ['check', '-r', reqs_unpinned, '--db', db, '--output', 'html']) ignored = "
Found vulnerabilities that were ignored: 2
" announcement = "Warning: django and numpy are unpinned." @@ -500,7 +508,7 @@ def test_basic_html_output_pass(self): reqs_affected = os.path.join(dirname, "reqs_pinned_affected.txt") - result = self.runner.invoke(cli.cli, ['check', '-r', reqs_affected, '--db', db, '--output', 'html']) + result = self.runner.invoke(self.cli, ['check', '-r', reqs_affected, '--db', db, '--output', 'html']) self.assertIn("remediations-suggested", result.stdout) self.assertIn("Use API Key", result.stdout) @@ -526,7 +534,7 @@ def test_license_with_file(self, fetch_database_url): dirname = os.path.dirname(__file__) test_filename = os.path.join(dirname, "reqs_4.txt") - result = self.runner.invoke(cli.cli, ['license', '--key', 'foo', '--file', test_filename]) + result = self.runner.invoke(self.cli, ['license', '--key', 'foo', '--file', test_filename]) print(result.stdout) self.assertEqual(result.exit_code, 0) @@ -534,7 +542,9 @@ def test_license_with_file(self, fetch_database_url): @patch.object(Auth, 'is_valid', return_value=True) @patch('safety.auth.utils.SafetyAuthSession.get_authentication_type', return_value=AuthenticationType.TOKEN) @patch('safety.safety.fetch_database', return_value={'vulnerable_packages': []}) - def test_debug_flag(self, mock_get_auth_info, mock_is_valid, mock_get_auth_type, mock_fetch_database): + @patch('safety.auth.utils.initialize', return_value=None) + @patch('safety.auth.cli_utils.SafetyCLI', return_value=SafetyCLI(platform_enabled=False, firewall_enabled=False)) + def test_debug_flag(self, *args): """ Test the behavior of the CLI when invoked with the '--debug' flag. @@ -548,25 +558,26 @@ def test_debug_flag(self, mock_get_auth_info, mock_is_valid, mock_get_auth_type, mock_get_auth_type: Mock for retrieving the authentication type. mock_fetch_database: Mock for database fetching operations. """ - result = self.runner.invoke(cli.cli, ['--debug', 'scan']) + result = self.runner.invoke(self.cli, ['--debug', 'scan']) assert result.exit_code == 0, ( f"CLI exited with code {result.exit_code} and output: {result.output} and error: {result.stderr}" ) - expected_output_snippet = f"{get_safety_version()} scanning" + expected_output_snippet = f"{get_version()} scanning" assert expected_output_snippet in result.output, ( f"Expected output to contain: {expected_output_snippet}, but got: {result.output}" ) - @patch('safety.auth.cli.get_auth_info', return_value={'email': 'test@test.com'}) @patch.object(Auth, 'is_valid', return_value=True) @patch('safety.auth.utils.SafetyAuthSession.get_authentication_type', return_value=AuthenticationType.TOKEN) @patch('safety.safety.fetch_database', return_value={'vulnerable_packages': []}) - def test_debug_flag_with_value_1(self, mock_get_auth_info, mock_is_valid, mock_get_auth_type, mock_fetch_database): + def test_debug_flag_with_value_1(self, *args): sys.argv = ['safety', '--debug', '1', 'scan'] - @cli.preprocess_args + from safety.cli import preprocess_args + + @preprocess_args def dummy_function(): pass @@ -580,10 +591,12 @@ def dummy_function(): @patch.object(Auth, 'is_valid', return_value=True) @patch('safety.auth.utils.SafetyAuthSession.get_authentication_type', return_value=AuthenticationType.TOKEN) @patch('safety.safety.fetch_database', return_value={'vulnerable_packages': []}) - def test_debug_flag_with_value_true(self, mock_get_auth_info, mock_is_valid, mock_get_auth_type, mock_fetch_database): + def test_debug_flag_with_value_true(self, *args): sys.argv = ['safety', '--debug', 'true', 'scan'] - @cli.preprocess_args + from safety.cli import preprocess_args + + @preprocess_args def dummy_function(): pass @@ -593,6 +606,53 @@ def dummy_function(): # Assert the preprocessed arguments assert preprocessed_args == ['--debug', 'scan'], f"Preprocessed args: {preprocessed_args}" + @patch('safety.auth.utils.get_config_setting', return_value=None) + @patch('safety.auth.cli.get_auth_info', return_value={'email': 'test@test.com'}) + @patch.object(Auth, 'is_valid', return_value=True) + @patch('safety.auth.utils.SafetyAuthSession.get_authentication_type', return_value=AuthenticationType.TOKEN) + @patch('safety.auth.utils.SafetyAuthSession.check_project', return_value={'user_confirm': True}) + @patch('safety.auth.utils.SafetyAuthSession.project', return_value={'slug': 'slug'}) + @patch('safety.auth.cli_utils.SafetyCLI', return_value=SafetyCLI(platform_enabled=True, firewall_enabled=True)) + def test_init_project(self, *args): + # Workarounds + from safety.console import main_console as test_console + + test_cases = [ + (False, + "Configured PIP global settings\nConfigured PIP alias\n", 0), + (True, + "Set a project id (no spaces). If empty Safety will use", 1)] + + for interactive, output, exit_code in test_cases: + with self.subTest(output=output, exit_code=exit_code): + with patch('safety.cli.console', + new=test_console) as t_console, \ + tempfile.TemporaryDirectory() as tempdir: + + t_console.is_interactive = interactive + + result = self.runner.invoke(self.cli, ['init', tempdir]) + cleaned_stdout = click.unstyle(result.stdout) + assert result.exit_code == exit_code, f"CLI exited with non-zero exit code {result.exit_code}" + assert cleaned_stdout.startswith(output), f"CLI exited with output: {result.output}" + + + @patch('safety.auth.cli.get_auth_info', return_value={'email': 'test@test.com'}) + @patch.object(Auth, 'is_valid', return_value=True) + @patch('safety.auth.utils.SafetyAuthSession.get_authentication_type', return_value=AuthenticationType.TOKEN) + @patch('safety.auth.utils.SafetyAuthSession.check_project', return_value={'user_confirm': True}) + @patch('safety.auth.utils.SafetyAuthSession.project', return_value={'slug': 'slug'}) + @patch('safety.auth.cli_utils.SafetyCLI', return_value=SafetyCLI(platform_enabled=True, firewall_enabled=True)) + def test_existing_project_is_linked(self, *args): + dirname = os.path.dirname(__file__) + source_project_file = os.path.join(dirname, "test-safety-project.ini") + + with tempfile.TemporaryDirectory() as tempdir: + project_file = os.path.join(tempdir, '.safety-project.ini') + shutil.copy(source_project_file, project_file) + result = self.runner.invoke(self.cli, ['init', tempdir]) + assert result.exit_code == 0, f"CLI exited with code {result.exit_code} and output: {result.output} and error: {result.stderr}" + class TestNetworkTelemetry(unittest.TestCase): @patch('psutil.net_io_counters') @@ -615,7 +675,8 @@ def test_get_network_telemetry(self, mock_requests_get, mock_net_if_stats, mock_ mock_requests_get.return_value = mock_response # Run the function - result = cli.get_network_telemetry() + from safety.cli import get_network_telemetry + result = get_network_telemetry() # Assert the network telemetry data self.assertEqual(result['bytes_sent'], 1000) @@ -635,7 +696,8 @@ def test_get_network_telemetry(self, mock_requests_get, mock_net_if_stats, mock_ @patch('requests.get', side_effect=requests.RequestException('Network error')) def test_get_network_telemetry_request_exception(self, mock_requests_get): # Run the function - result = cli.get_network_telemetry() + from safety.cli import get_network_telemetry + result = get_network_telemetry() # Assert the download_speed is None and error is captured self.assertIsNone(result['download_speed']) @@ -644,7 +706,8 @@ def test_get_network_telemetry_request_exception(self, mock_requests_get): @patch('psutil.net_io_counters', side_effect=psutil.AccessDenied('Access denied')) def test_get_network_telemetry_access_denied(self, mock_net_io_counters): # Run the function - result = cli.get_network_telemetry() + from safety.cli import get_network_telemetry + result = get_network_telemetry() # Assert the error is captured self.assertIn('error', result) @@ -658,6 +721,8 @@ def test_configure_logger_debug(self, mock_get_network_telemetry, mock_config_re mock_get_network_telemetry.return_value = {'dummy_key': 'dummy_value'} mock_config_read.return_value = None + from safety.cli import configure_logger + ctx = Mock() param = Mock() debug = True @@ -666,7 +731,7 @@ def test_configure_logger_debug(self, mock_get_network_telemetry, mock_config_re patch('logging.basicConfig') as mock_basicConfig, \ patch('configparser.ConfigParser.items', return_value=[('key', 'value')]), \ patch('configparser.ConfigParser.sections', return_value=['section']): - cli.configure_logger(ctx, param, debug) + configure_logger(ctx, param, debug) mock_basicConfig.assert_called_with(format='%(asctime)s %(name)s => %(message)s', level=logging.DEBUG) # Check if network telemetry logging was called @@ -680,6 +745,8 @@ def test_configure_logger_non_debug(self, mock_config_read): param = Mock() debug = False + from safety.cli import configure_logger + with patch('logging.basicConfig') as mock_basicConfig: - cli.configure_logger(ctx, param, debug) + configure_logger(ctx, param, debug) mock_basicConfig.assert_called_with(format='%(asctime)s %(name)s => %(message)s', level=logging.CRITICAL) diff --git a/tests/tool/interceptors/test_factory.py b/tests/tool/interceptors/test_factory.py new file mode 100644 index 00000000..48b7fc3f --- /dev/null +++ b/tests/tool/interceptors/test_factory.py @@ -0,0 +1,44 @@ +import unittest +from unittest.mock import patch + +from safety.tool.interceptors.types import InterceptorType +from safety.tool.interceptors.unix import UnixAliasInterceptor +from safety.tool.interceptors.windows import WindowsInterceptor +from safety.tool.interceptors.factory import create_interceptor + + +class TestFactory(unittest.TestCase): + def test_explicit_unix_alias_interceptor(self): + interceptor = create_interceptor(InterceptorType.UNIX_ALIAS) + self.assertIsInstance(interceptor, UnixAliasInterceptor) + + def test_explicit_windows_interceptor(self): + interceptor = create_interceptor(InterceptorType.WINDOWS_BAT) + self.assertIsInstance(interceptor, WindowsInterceptor) + + @patch('safety.tool.interceptors.factory.platform', 'win32') + def test_auto_select_windows(self): + interceptor = create_interceptor() + self.assertIsInstance(interceptor, WindowsInterceptor) + + def test_auto_select_unix_like(self): + unix_platforms = ['linux', 'linux2', 'darwin'] + + for platform in unix_platforms: + with self.subTest(platform=platform): + with patch('safety.tool.interceptors.factory.platform', + platform): + interceptor = create_interceptor() + self.assertIsInstance(interceptor, UnixAliasInterceptor) + + @patch('safety.tool.interceptors.factory.platform', 'unsupported_os') + def test_unsupported_platform(self): + with self.assertRaises(NotImplementedError) as context: + create_interceptor() + self.assertIn("Platform 'unsupported_os' is not supported", + str(context.exception)) + + def test_invalid_interceptor_type(self): + invalid_type = "INVALID_TYPE" + with self.assertRaises(KeyError): + create_interceptor(invalid_type) diff --git a/tests/tool/interceptors/test_unix.py b/tests/tool/interceptors/test_unix.py new file mode 100644 index 00000000..20b6c2ee --- /dev/null +++ b/tests/tool/interceptors/test_unix.py @@ -0,0 +1,90 @@ +from sys import platform +import unittest +from unittest.mock import patch, mock_open +from pathlib import Path +import tempfile +import shutil + +from datetime import datetime, timezone + +import pytest + +from safety.tool.interceptors.unix import UnixAliasInterceptor + + +@pytest.mark.unix_only +@pytest.mark.skipif(platform not in ["linux", "linux2", "darwin"], + reason="Unix-specific tests") +class TestUnixAliasInterceptor(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.temp_dir) + + @patch('safety.tool.interceptors.unix.Path.home') + @patch('safety.tool.interceptors.base.datetime') + @patch('safety.tool.interceptors.base.get_version') + def test_interceptors_all_tools(self, mock_version, + mock_datetime, + mock_home): + + mock_home.return_value = Path(self.temp_dir) + mock_version.return_value = "1.0.0" + mock_now = datetime(2024, 1, 1, tzinfo=timezone.utc) + mock_datetime.now.return_value = mock_now + + safety_config_user_dir = Path(self.temp_dir) / '.safety' + + with patch('safety.tool.interceptors.unix.USER_CONFIG_DIR', + safety_config_user_dir): + + interceptor = UnixAliasInterceptor() + result = interceptor.install_interceptors() + + self.assertTrue(result) + + profile_path = Path(self.temp_dir) / '.profile' + safety_profile_path = Path(self.temp_dir) / '.safety' / '.safety_profile' + + self.assertTrue(profile_path.exists()) + self.assertTrue(safety_profile_path.exists()) + + # test the content of the generated files + expected_profile_content = ( + "# >>> Safety >>>\n" + f'[ -f "{safety_profile_path}" ] && . "{safety_profile_path}"\n' + "# <<< Safety <<<\n" + ) + + expected_safety_profile_content = ( + "# >>> Safety >>>\n" + "# DO NOT EDIT THIS FILE DIRECTLY\n" + f"# Last updated at: {mock_now.isoformat()}\n" + "# Updated by: safety v1.0.0\n" + 'alias pip="safety pip"\n' + 'alias pip3="safety pip"\n' + "# <<< Safety <<<\n" + ) + + self.assertEqual(profile_path.read_text(), expected_profile_content) + self.assertEqual(safety_profile_path.read_text(), + expected_safety_profile_content) + + # Let's test remove_interceptors + result = interceptor.remove_interceptors() + + self.assertTrue(result) + self.assertTrue(profile_path.exists()) + self.assertEqual(profile_path.read_text(), "") + + self.assertFalse(safety_profile_path.exists()) + + def test_install_interceptors_nonexistent_tool(self): + interceptor = UnixAliasInterceptor() + result = interceptor.install_interceptors(['nonexistent']) + self.assertFalse(result) + + def test_uninstall_interceptors_all_tools(self): + interceptor = UnixAliasInterceptor() + result = interceptor.install_interceptors() + self.assertTrue(result) diff --git a/tests/tool/interceptors/test_windows.py b/tests/tool/interceptors/test_windows.py new file mode 100644 index 00000000..0157e05d --- /dev/null +++ b/tests/tool/interceptors/test_windows.py @@ -0,0 +1,103 @@ +import os +import shutil +from sys import platform +import tempfile +import unittest +import pytest +from unittest.mock import MagicMock, call, patch +from pathlib import Path +from datetime import datetime, timezone + +from safety.tool.interceptors.windows import WindowsInterceptor + + +@pytest.mark.windows_only +@pytest.mark.skipif(platform not in ["win32"], + reason="Windows-specific tests") +class TestWindowsInterceptor(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.temp_dir) + + @patch('safety.tool.interceptors.windows.Path.home') + @patch('safety.tool.interceptors.base.datetime') + @patch('safety.tool.interceptors.base.get_version') + @patch('safety.tool.interceptors.windows.winreg') + def test_interceptors_all_tools(self, mock_winreg, mock_version, + mock_datetime, mock_home): + import winreg + + mock_home.return_value = Path(self.temp_dir) + mock_version.return_value = "1.0.0" + mock_now = datetime(2024, 1, 1, tzinfo=timezone.utc) + mock_datetime.now.return_value = mock_now + + # Mock Windows registry + mock_key = MagicMock() + mock_winreg.OpenKey.return_value.__enter__.return_value = mock_key + + original_path = "C:\\existing\\path" + mock_winreg.QueryValueEx.return_value = (original_path, mock_winreg.REG_EXPAND_SZ) + + mock_winreg.HKEY_CURRENT_USER = winreg.HKEY_CURRENT_USER + mock_winreg.KEY_ALL_ACCESS = winreg.KEY_ALL_ACCESS + mock_winreg.REG_EXPAND_SZ = winreg.REG_EXPAND_SZ + + # Initialize interceptor + interceptor = WindowsInterceptor() + + expected_new_path = f"{str(interceptor.scripts_dir)}{os.pathsep}{original_path}" + + # Test installation + result = interceptor.install_interceptors() + self.assertTrue(result) + + # Verify bat files were created + for tool in interceptor.tools.values(): + for binary in tool.binary_names: + bat_path = interceptor.scripts_dir / f'{binary}.bat' + self.assertTrue(bat_path.exists(), + f"Bat file for {binary} does not exist") + + expected_bat_content = ( + "@echo off\n" + "REM >>> Safety >>>\n" + "REM DO NOT EDIT THIS FILE DIRECTLY\n" + f"REM Last updated at: {mock_now.isoformat()}\n" + "REM Updated by: safety v1.0.0\n" + f"safety {tool.name} %*\n" + "REM <<< Safety <<<\n" + ) + self.assertEqual(bat_path.read_text(), expected_bat_content) + + # Verify backup was created + backup_path = interceptor.backup_dir / 'path_backup.txt' + self.assertTrue(backup_path.exists()) + expected_backup_content = ( + ">>> Safety >>>\n" + " DO NOT EDIT THIS FILE DIRECTLY\n" + f" Last updated at: {mock_now.isoformat()}\n" + " Updated by: safety v1.0.0\n" + f"{original_path}\n" + "<<< Safety <<<\n" + ) + self.assertEqual(backup_path.read_text(), expected_backup_content) + + mock_winreg.SetValueEx.assert_called_once_with( + mock_key, 'PATH', 0, winreg.REG_EXPAND_SZ, + expected_new_path + ) + + mock_winreg.SetValueEx.reset_mock() + + # Test remove + mock_winreg.QueryValueEx.return_value = (expected_new_path, mock_winreg.REG_EXPAND_SZ) + + result = interceptor.remove_interceptors() + self.assertTrue(result) + self.assertFalse(interceptor.scripts_dir.exists()) + + mock_winreg.SetValueEx.assert_called_once_with( + mock_key, 'PATH', 0, winreg.REG_EXPAND_SZ, original_path + )