Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions cycode/cli/auth/auth_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
invoke_without_command=True, short_help='Authenticates your machine to associate CLI with your cycode account'
)
@click.pass_context
def authenticate(context: click.Context):
def authenticate(context: click.Context) -> None:
if context.invoked_subcommand is not None:
# if it is a subcommand, do nothing
return
Expand All @@ -34,7 +34,7 @@ def authenticate(context: click.Context):

@authenticate.command(name='check')
@click.pass_context
def authorization_check(context: click.Context):
def authorization_check(context: click.Context) -> None:
"""Check your machine associating CLI with your cycode account"""
printer = ConsolePrinter(context)

Expand All @@ -43,19 +43,22 @@ def authorization_check(context: click.Context):

client_id, client_secret = CredentialsManager().get_credentials()
if not client_id or not client_secret:
return printer.print_result(failed_auth_check_res)
printer.print_result(failed_auth_check_res)
return

try:
if CycodeTokenBasedClient(client_id, client_secret).api_token:
return printer.print_result(passed_auth_check_res)
printer.print_result(passed_auth_check_res)
return
except (NetworkError, HttpUnauthorizedError):
if context.obj['verbose']:
click.secho(f'Error: {traceback.format_exc()}', fg='red')

return printer.print_result(failed_auth_check_res)
printer.print_result(failed_auth_check_res)
return


def _handle_exception(context: click.Context, e: Exception):
def _handle_exception(context: click.Context, e: Exception) -> None:
if context.obj['verbose']:
click.secho(f'Error: {traceback.format_exc()}', fg='red')

Expand All @@ -70,7 +73,8 @@ def _handle_exception(context: click.Context, e: Exception):

error = errors.get(type(e))
if error:
return ConsolePrinter(context).print_error(error)
ConsolePrinter(context).print_error(error)
return

if isinstance(e, click.ClickException):
raise e
Expand Down
29 changes: 14 additions & 15 deletions cycode/cli/auth/auth_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import webbrowser
from typing import Optional
from typing import TYPE_CHECKING, Tuple

from requests import Request

Expand All @@ -10,7 +10,10 @@
from cycode.cli.utils.string_utils import generate_random_string, hash_string_to_sha256
from cycode.cyclient import logger
from cycode.cyclient.auth_client import AuthClient
from cycode.cyclient.models import ApiToken, ApiTokenGenerationPollingResponse
from cycode.cyclient.models import ApiTokenGenerationPollingResponse

if TYPE_CHECKING:
from cycode.cyclient.models import ApiToken


class AuthManager:
Expand All @@ -20,16 +23,12 @@ class AuthManager:
FAILED_POLLING_STATUS = 'Error'
COMPLETED_POLLING_STATUS = 'Completed'

configuration_manager: ConfigurationManager
credentials_manager: CredentialsManager
auth_client: AuthClient

def __init__(self):
def __init__(self) -> None:
self.configuration_manager = ConfigurationManager()
self.credentials_manager = CredentialsManager()
self.auth_client = AuthClient()

def authenticate(self):
def authenticate(self) -> None:
logger.debug('generating pkce code pair')
code_challenge, code_verifier = self._generate_pkce_code_pair()

Expand All @@ -46,21 +45,21 @@ def authenticate(self):
logger.debug('saving get api token')
self.save_api_token(api_token)

def start_session(self, code_challenge: str):
def start_session(self, code_challenge: str) -> str:
auth_session = self.auth_client.start_session(code_challenge)
return auth_session.session_id

def redirect_to_login_page(self, code_challenge: str, session_id: str):
def redirect_to_login_page(self, code_challenge: str, session_id: str) -> None:
login_url = self._build_login_url(code_challenge, session_id)
webbrowser.open(login_url)

def get_api_token(self, session_id: str, code_verifier: str) -> Optional[ApiToken]:
def get_api_token(self, session_id: str, code_verifier: str) -> 'ApiToken':
api_token = self.get_api_token_polling(session_id, code_verifier)
if api_token is None:
raise AuthProcessError('getting api token is completed, but the token is missing')
return api_token

def get_api_token_polling(self, session_id: str, code_verifier: str) -> Optional[ApiToken]:
def get_api_token_polling(self, session_id: str, code_verifier: str) -> 'ApiToken':
end_polling_time = time.time() + self.POLLING_TIMEOUT_IN_SECONDS
while time.time() < end_polling_time:
logger.debug('trying to get api token...')
Expand All @@ -75,18 +74,18 @@ def get_api_token_polling(self, session_id: str, code_verifier: str) -> Optional

raise AuthProcessError('session expired')

def save_api_token(self, api_token: ApiToken):
def save_api_token(self, api_token: 'ApiToken') -> None:
self.credentials_manager.update_credentials_file(api_token.client_id, api_token.secret)

def _build_login_url(self, code_challenge: str, session_id: str):
def _build_login_url(self, code_challenge: str, session_id: str) -> str:
app_url = self.configuration_manager.get_cycode_app_url()
login_url = f'{app_url}/account/sign-in'
query_params = {'source': 'cycode_cli', 'code_challenge': code_challenge, 'session_id': session_id}
# TODO(MarshalX). Use auth_client instead and don't depend on "requests" lib here
request = Request(url=login_url, params=query_params)
return request.prepare().url

def _generate_pkce_code_pair(self) -> (str, str):
def _generate_pkce_code_pair(self) -> Tuple[str, str]:
code_verifier = generate_random_string(self.CODE_VERIFIER_LENGTH)
code_challenge = hash_string_to_sha256(code_verifier)
return code_challenge, code_verifier
Expand Down
8 changes: 4 additions & 4 deletions cycode/cli/ci_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import click


def github_action_range():
def github_action_range() -> str:
before_sha = os.getenv('BEFORE_SHA')
push_base_sha = os.getenv('BASE_SHA')
pr_base_sha = os.getenv('PR_BASE_SHA')
Expand All @@ -22,7 +22,7 @@ def github_action_range():
# if push_base_sha and push_base_sha != "null":


def circleci_range():
def circleci_range() -> str:
before_sha = os.getenv('BEFORE_SHA')
current_sha = os.getenv('CURRENT_SHA')
commit_range = f'{before_sha}...{current_sha}'
Expand All @@ -36,7 +36,7 @@ def circleci_range():
return f'{commit_sha}~1...'


def gitlab_range():
def gitlab_range() -> str:
before_sha = os.getenv('CI_COMMIT_BEFORE_SHA')
commit_sha = os.getenv('CI_COMMIT_SHA', 'HEAD')

Expand All @@ -46,7 +46,7 @@ def gitlab_range():
return f'{commit_sha}'


def get_commit_range():
def get_commit_range() -> str:
if os.getenv('GITHUB_ACTIONS'):
return github_action_range()
if os.getenv('CIRCLECI'):
Expand Down
70 changes: 39 additions & 31 deletions cycode/cli/code_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import traceback
from platform import platform
from sys import getsizeof
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Optional, Tuple, Union
from uuid import UUID, uuid4

import click
Expand Down Expand Up @@ -40,6 +40,10 @@
from cycode.cyclient.models import Detection, DetectionSchema, DetectionsPerFile, ZippedFileScanResult

if TYPE_CHECKING:
from git import Blob, Diff
from git.objects.base import IndexObjUnion
from git.objects.tree import TraversedTreeTup

from cycode.cli.utils.progress_bar import BaseProgressBar
from cycode.cyclient.models import ScanDetailsResponse
from cycode.cyclient.scan_client import ScanClient
Expand All @@ -58,7 +62,7 @@
required=False,
)
@click.pass_context
def scan_repository(context: click.Context, path: str, branch: str):
def scan_repository(context: click.Context, path: str, branch: str) -> None:
try:
logger.debug('Starting repository scan process, %s', {'path': path, 'branch': branch})

Expand All @@ -74,11 +78,11 @@ def scan_repository(context: click.Context, path: str, branch: str):

documents_to_scan = []
for file in file_entries:
# FIXME(MarshalX): probably file could be tree or submodule too. we expect blob only
progress_bar.update(ProgressBarSection.PREPARE_LOCAL_FILES)

path = file.path if monitor else get_path_by_os(os.path.join(path, file.path))

documents_to_scan.append(Document(path, file.data_stream.read().decode('UTF-8', errors='replace')))
file_path = file.path if monitor else get_path_by_os(os.path.join(path, file.path))
documents_to_scan.append(Document(file_path, file.data_stream.read().decode('UTF-8', errors='replace')))

documents_to_scan = exclude_irrelevant_documents_to_scan(context, documents_to_scan)

Expand All @@ -103,15 +107,17 @@ def scan_repository(context: click.Context, path: str, branch: str):
required=False,
)
@click.pass_context
def scan_repository_commit_history(context: click.Context, path: str, commit_range: str):
def scan_repository_commit_history(context: click.Context, path: str, commit_range: str) -> None:
try:
logger.debug('Starting commit history scan process, %s', {'path': path, 'commit_range': commit_range})
return scan_commit_range(context, path=path, commit_range=commit_range)
scan_commit_range(context, path=path, commit_range=commit_range)
except Exception as e:
_handle_exception(context, e)


def scan_commit_range(context: click.Context, path: str, commit_range: str, max_commits_count: Optional[int] = None):
def scan_commit_range(
context: click.Context, path: str, commit_range: str, max_commits_count: Optional[int] = None
) -> None:
scan_type = context.obj['scan_type']
progress_bar = context.obj['progress_bar']

Expand Down Expand Up @@ -166,38 +172,40 @@ def scan_commit_range(context: click.Context, path: str, commit_range: str, max_
logger.debug('List of commit ids to scan, %s', {'commit_ids': commit_ids_to_scan})
logger.debug('Starting to scan commit range (It may take a few minutes)')

return scan_documents(context, documents_to_scan, is_git_diff=True, is_commit_range=True)
scan_documents(context, documents_to_scan, is_git_diff=True, is_commit_range=True)
return None


@click.command(
short_help='Execute scan in a CI environment which relies on the '
'CYCODE_TOKEN and CYCODE_REPO_LOCATION environment variables'
)
@click.pass_context
def scan_ci(context: click.Context):
return scan_commit_range(context, path=os.getcwd(), commit_range=get_commit_range())
def scan_ci(context: click.Context) -> None:
scan_commit_range(context, path=os.getcwd(), commit_range=get_commit_range())


@click.command(short_help='Scan the files in the path supplied in the command')
@click.argument('path', nargs=1, type=click.STRING, required=True)
@click.pass_context
def scan_path(context: click.Context, path):
def scan_path(context: click.Context, path: str) -> None:
logger.debug('Starting path scan process, %s', {'path': path})
files_to_scan = get_relevant_files_in_path(path=path, exclude_patterns=['**/.git/**', '**/.cycode/**'])
files_to_scan = exclude_irrelevant_files(context, files_to_scan)
logger.debug('Found all relevant files for scanning %s', {'path': path, 'file_to_scan_count': len(files_to_scan)})
return scan_disk_files(context, path, files_to_scan)
scan_disk_files(context, path, files_to_scan)


@click.command(short_help='Use this command to scan the content that was not committed yet')
@click.argument('ignored_args', nargs=-1, type=click.UNPROCESSED)
@click.pass_context
def pre_commit_scan(context: click.Context, ignored_args: List[str]):
def pre_commit_scan(context: click.Context, ignored_args: List[str]) -> None:
scan_type = context.obj['scan_type']
progress_bar = context.obj['progress_bar']

if scan_type == consts.SCA_SCAN_TYPE:
return scan_sca_pre_commit(context)
scan_sca_pre_commit(context)
return

diff_files = Repo(os.getcwd()).index.diff('HEAD', create_patch=True, R=True)

Expand All @@ -209,13 +217,13 @@ def pre_commit_scan(context: click.Context, ignored_args: List[str]):
documents_to_scan.append(Document(get_path_by_os(get_diff_file_path(file)), get_diff_file_content(file)))

documents_to_scan = exclude_irrelevant_documents_to_scan(context, documents_to_scan)
return scan_documents(context, documents_to_scan, is_git_diff=True)
scan_documents(context, documents_to_scan, is_git_diff=True)


@click.command(short_help='Use this command to scan commits on the server side before pushing them to the repository')
@click.argument('ignored_args', nargs=-1, type=click.UNPROCESSED)
@click.pass_context
def pre_receive_scan(context: click.Context, ignored_args: List[str]):
def pre_receive_scan(context: click.Context, ignored_args: List[str]) -> None:
try:
scan_type = context.obj['scan_type']
if scan_type != consts.SECRET_SCAN_TYPE:
Expand Down Expand Up @@ -253,13 +261,13 @@ def pre_receive_scan(context: click.Context, ignored_args: List[str]):
_handle_exception(context, e)


def scan_sca_pre_commit(context: click.Context):
def scan_sca_pre_commit(context: click.Context) -> None:
scan_parameters = get_default_scan_parameters(context)
git_head_documents, pre_committed_documents = get_pre_commit_modified_documents(context.obj['progress_bar'])
git_head_documents = exclude_irrelevant_documents_to_scan(context, git_head_documents)
pre_committed_documents = exclude_irrelevant_documents_to_scan(context, pre_committed_documents)
sca_code_scanner.perform_pre_hook_range_scan_actions(git_head_documents, pre_committed_documents)
return scan_commit_range_documents(
scan_commit_range_documents(
context,
git_head_documents,
pre_committed_documents,
Expand All @@ -268,7 +276,7 @@ def scan_sca_pre_commit(context: click.Context):
)


def scan_sca_commit_range(context: click.Context, path: str, commit_range: str):
def scan_sca_commit_range(context: click.Context, path: str, commit_range: str) -> None:
progress_bar = context.obj['progress_bar']

scan_parameters = get_scan_parameters(context, path)
Expand All @@ -282,12 +290,10 @@ def scan_sca_commit_range(context: click.Context, path: str, commit_range: str):
path, from_commit_documents, from_commit_rev, to_commit_documents, to_commit_rev
)

return scan_commit_range_documents(
context, from_commit_documents, to_commit_documents, scan_parameters=scan_parameters
)
scan_commit_range_documents(context, from_commit_documents, to_commit_documents, scan_parameters=scan_parameters)


def scan_disk_files(context: click.Context, path: str, files_to_scan: List[str]):
def scan_disk_files(context: click.Context, path: str, files_to_scan: List[str]) -> None:
scan_parameters = get_scan_parameters(context, path)
scan_type = context.obj['scan_type']
progress_bar = context.obj['progress_bar']
Expand All @@ -307,7 +313,7 @@ def scan_disk_files(context: click.Context, path: str, files_to_scan: List[str])
continue

perform_pre_scan_documents_actions(context, scan_type, documents, is_git_diff)
return scan_documents(context, documents, is_git_diff=is_git_diff, scan_parameters=scan_parameters)
scan_documents(context, documents, is_git_diff=is_git_diff, scan_parameters=scan_parameters)


def set_issue_detected_by_scan_results(context: click.Context, scan_results: List[LocalScanResult]) -> None:
Expand Down Expand Up @@ -757,19 +763,21 @@ def get_oldest_unupdated_commit_for_branch(commit: str) -> Optional[str]:
return commits[0]


def get_diff_file_path(file):
def get_diff_file_path(file: 'Diff') -> Optional[str]:
return file.b_path if file.b_path else file.a_path


def get_diff_file_content(file):
def get_diff_file_content(file: 'Diff') -> str:
return file.diff.decode('UTF-8', errors='replace')


def should_process_git_object(obj, _: int) -> bool:
def should_process_git_object(obj: 'Blob', _: int) -> bool:
return obj.type == 'blob' and obj.size > 0


def get_git_repository_tree_file_entries(path: str, branch: str):
def get_git_repository_tree_file_entries(
path: str, branch: str
) -> Union[Iterator['IndexObjUnion'], Iterator['TraversedTreeTup']]:
return Repo(path).tree(branch).traverse(predicate=should_process_git_object)


Expand Down Expand Up @@ -867,7 +875,7 @@ def _exclude_detections_by_scan_type(
return detections


def exclude_detections_in_deleted_lines(detections) -> List:
def exclude_detections_in_deleted_lines(detections: List[Detection]) -> List[Detection]:
return [detection for detection in detections if detection.detection_details.get('line_type') != 'Removed']


Expand Down Expand Up @@ -969,7 +977,7 @@ def _should_exclude_detection(detection: Detection, exclusions: Dict) -> bool:
return False


def _is_detection_sha_configured_in_exclusions(detection, exclusions: List[str]) -> bool:
def _is_detection_sha_configured_in_exclusions(detection: Detection, exclusions: List[str]) -> bool:
detection_sha = detection.detection_details.get('sha512', '')
return detection_sha in exclusions

Expand Down
Loading