Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions cli/auth/auth_command.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import json

import click
import traceback

from cli.auth.auth_manager import AuthManager
from cli.exceptions.custom_exceptions import AuthProcessError, CycodeError
from cli.user_settings.credentials_manager import CredentialsManager
from cli.exceptions.custom_exceptions import AuthProcessError, NetworkError, HttpUnauthorizedError
from cyclient import logger
from cyclient.cycode_token_based_client import CycodeTokenBasedClient


@click.command()
@click.group(invoke_without_command=True)
@click.pass_context
def authenticate(context: click.Context):
""" Authenticates your machine to associate CLI with your cycode account """
if context.invoked_subcommand is not None:
# if it is a subcommand do nothing
return

try:
logger.debug("starting authentication process")
auth_manager = AuthManager()
Expand All @@ -18,14 +27,50 @@ def authenticate(context: click.Context):
_handle_exception(context, e)


@authenticate.command(name='check')
@click.pass_context
def authorization_check(context: click.Context):
""" Check your machine associating CLI with your cycode account """
passed_auth_check_args = {'context': context, 'content': {
'success': True,
'message': 'You are authorized'
}, 'color': 'green'}
failed_auth_check_args = {'context': context, 'content': {
'success': False,
'message': 'You are not authorized'
}, 'color': 'red'}

client_id, client_secret = CredentialsManager().get_credentials()
if not client_id or not client_secret:
return _print_result(**failed_auth_check_args)

try:
# TODO(MarshalX): This property performs HTTP request to refresh the token. This must be the method.
if CycodeTokenBasedClient(client_id, client_secret).api_token:
return _print_result(**passed_auth_check_args)
except (NetworkError, HttpUnauthorizedError):
if context.obj['verbose']:
click.secho(f'Error: {traceback.format_exc()}', fg='red', nl=False)

return _print_result(**failed_auth_check_args)


def _print_result(context: click.Context, content: dict, color: str) -> None:
# the current impl of printers supports only results of scans
if context.obj['output'] == 'text':
return click.secho(content['message'], fg=color)

return click.echo(json.dumps({'result': content['success'], 'message': content['message']}))


def _handle_exception(context: click.Context, e: Exception):
verbose = context.obj["verbose"]
if verbose:
click.secho(f'Error: {traceback.format_exc()}', fg='red', nl=False)
if isinstance(e, AuthProcessError):
click.secho('Authentication failed. Please try again later using the command `cycode auth`',
fg='red', nl=False)
elif isinstance(e, CycodeError):
elif isinstance(e, NetworkError):
click.secho('Authentication failed. Please try again later using the command `cycode auth`',
fg='red', nl=False)
elif isinstance(e, click.ClickException):
Expand Down
1 change: 1 addition & 0 deletions cli/auth/auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _build_login_url(self, code_challenge: str, session_id: str):
'code_challenge': code_challenge,
'session_id': session_id
}
# TODO(MarshalX). Use auth_client instead and don't depend on "requests" lib here
request = Request(url=login_url, params=query_params)
return request.prepare().url

Expand Down
2 changes: 1 addition & 1 deletion cli/code_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def _handle_exception(context: click.Context, e: Exception):

# TODO(MarshalX): Create global CLI errors database and move this
errors: CliScanErrors = {
CycodeError: CliScanError(
NetworkError: CliScanError(
soft_fail=True,
code='cycode_error',
message='Cycode was unable to complete this scan. '
Expand Down
19 changes: 15 additions & 4 deletions cli/cycode.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
help='Run scan without failing, always return a non-error status code',
type=bool,
required=False)
@click.option('--output', default='text',
@click.option('--output', default=None,
help="""
\b
Specify the results output (text/json),
Expand Down Expand Up @@ -104,7 +104,9 @@ def code_scan(context: click.Context, scan_type, client_id, secret, show_secret,
context.obj["soft_fail"] = config["soft_fail"]

context.obj["scan_type"] = scan_type
context.obj["output"] = output
if output is not None:
# save backward compatability with old style command
context.obj["output"] = output
context.obj["client"] = get_cycode_client(client_id, secret)
context.obj["severity_threshold"] = severity_threshold
context.obj["monitor"] = monitor
Expand Down Expand Up @@ -135,16 +137,25 @@ def finalize(context: click.Context, *args, **kwargs):
@click.option(
"--verbose", "-v", is_flag=True, default=False, help="Show detailed logs",
)
@click.option(
'--output',
default='text',
help='Specify the output (text/json), the default is text',
type=click.Choice(['text', 'json'])
)
@click.version_option(__version__, prog_name="cycode")
@click.pass_context
def main_cli(context: click.Context, verbose: bool):
def main_cli(context: click.Context, verbose: bool, output: str):
context.ensure_object(dict)
configuration_manager = ConfigurationManager()

verbose = verbose or configuration_manager.get_verbose_flag()
context.obj["verbose"] = verbose
context.obj['verbose'] = verbose
log_level = logging.DEBUG if verbose else logging.INFO
logger.setLevel(log_level)

context.obj['output'] = output


def get_cycode_client(client_id, client_secret):
if not client_id or not client_secret:
Expand Down
21 changes: 15 additions & 6 deletions cli/exceptions/custom_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from requests import Response


class CycodeError(Exception):
def __init__(self, status_code: int, error_message: str):
"""Base class for all custom exceptions"""


class NetworkError(CycodeError):
def __init__(self, status_code: int, error_message: str, response: Response):
self.status_code = status_code
self.error_message = error_message
self.response = response
super().__init__(self.error_message)

def __str__(self):
return f'error occurred during the request. status code: {self.status_code}, error message: ' \
f'{self.error_message}'


class ScanAsyncError(Exception):
class ScanAsyncError(CycodeError):
def __init__(self, error_message: str):
self.error_message = error_message
super().__init__(self.error_message)
Expand All @@ -18,17 +26,18 @@ def __str__(self):
return f'error occurred during the scan. error message: {self.error_message}'


class HttpUnauthorizedError(Exception):
def __init__(self, error_message: str):
class HttpUnauthorizedError(CycodeError):
def __init__(self, error_message: str, response: Response):
self.status_code = 401
self.error_message = error_message
self.response = response
super().__init__(self.error_message)

def __str__(self):
return 'Http Unauthorized Error'


class ZipTooLargeError(Exception):
class ZipTooLargeError(CycodeError):
def __init__(self, size_limit: int):
self.size_limit = size_limit
super().__init__()
Expand All @@ -37,7 +46,7 @@ def __str__(self):
return f'The size of zip to scan is too large, size limit: {self.size_limit}'


class AuthProcessError(Exception):
class AuthProcessError(CycodeError):
def __init__(self, error_message: str):
self.error_message = error_message
super().__init__()
Expand Down
28 changes: 12 additions & 16 deletions cyclient/auth_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import requests.exceptions
from requests import Response
from typing import Optional

from requests import Response

from .cycode_client import CycodeClient
from . import models
from cli.exceptions.custom_exceptions import CycodeError
from cli.exceptions.custom_exceptions import NetworkError, HttpUnauthorizedError


class AuthClient:
Expand All @@ -12,26 +13,21 @@ class AuthClient:
def __init__(self):
self.cycode_client = CycodeClient()

def start_session(self, code_challenge: str):
path = f"{self.AUTH_CONTROLLER_PATH}/start"
def start_session(self, code_challenge: str) -> models.AuthenticationSession:
path = f'{self.AUTH_CONTROLLER_PATH}/start'
body = {'code_challenge': code_challenge}
try:
response = self.cycode_client.post(url_path=path, body=body)
return self.parse_start_session_response(response)
except requests.exceptions.Timeout as e:
raise CycodeError(504, e.response.text)
except requests.exceptions.HTTPError as e:
raise CycodeError(e.response.status_code, e.response.text)
response = self.cycode_client.post(url_path=path, body=body)
return self.parse_start_session_response(response)

def get_api_token(self, session_id: str, code_verifier: str) -> Optional[models.ApiTokenGenerationPollingResponse]:
path = f"{self.AUTH_CONTROLLER_PATH}/token"
path = f'{self.AUTH_CONTROLLER_PATH}/token'
body = {'session_id': session_id, 'code_verifier': code_verifier}
try:
response = self.cycode_client.post(url_path=path, body=body)
return self.parse_api_token_polling_response(response)
except requests.exceptions.HTTPError as e:
except (NetworkError, HttpUnauthorizedError) as e:
return self.parse_api_token_polling_response(e.response)
except Exception as e:
except Exception:
return None

@staticmethod
Expand All @@ -42,5 +38,5 @@ def parse_start_session_response(response: Response) -> models.AuthenticationSes
def parse_api_token_polling_response(response: Response) -> Optional[models.ApiTokenGenerationPollingResponse]:
try:
return models.ApiTokenGenerationPollingResponseSchema().load(response.json())
except Exception as e:
except Exception:
return None
8 changes: 1 addition & 7 deletions cyclient/cycode_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
from cyclient import config, __version__
from cyclient import config
from cyclient.cycode_client_base import CycodeClientBase


class CycodeClient(CycodeClientBase):

MANDATORY_HEADERS: dict = {
"User-Agent": f'cycode-cli_{__version__}',
}

def __init__(self):
super().__init__(config.cycode_api_url)
self.timeout = config.timeout

54 changes: 36 additions & 18 deletions cyclient/cycode_client_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from requests import Response, request
from requests import Response, request, exceptions

from cyclient import config, __version__
from cli.exceptions.custom_exceptions import NetworkError, HttpUnauthorizedError


class CycodeClientBase:

MANDATORY_HEADERS: dict = {
"User-Agent": f'cycode-cli_{__version__}',
'User-Agent': f'cycode-cli_{__version__}',
}

def __init__(self, api_url):
def __init__(self, api_url: str):
self.timeout = config.timeout
self.api_url = api_url

Expand All @@ -20,8 +21,7 @@ def post(
headers: dict = None,
**kwargs
) -> Response:
return self._execute(
method="post", endpoint=url_path, json=body, headers=headers, **kwargs)
return self._execute(method='post', endpoint=url_path, json=body, headers=headers, **kwargs)

def put(
self,
Expand All @@ -30,16 +30,15 @@ def put(
headers: dict = None,
**kwargs
) -> Response:
return self._execute(
method="put", endpoint=url_path, json=body, headers=headers, **kwargs)
return self._execute(method='put', endpoint=url_path, json=body, headers=headers, **kwargs)

def get(
self,
url_path: str,
headers: dict = None,
**kwargs
) -> Response:
return self._execute(method="get", endpoint=url_path, headers=headers, **kwargs)
return self._execute(method='get', endpoint=url_path, headers=headers, **kwargs)

def _execute(
self,
Expand All @@ -48,20 +47,39 @@ def _execute(
headers: dict = None,
**kwargs
) -> Response:

url = self.build_full_url(self.api_url, endpoint)

response = request(
method=method, url=url, timeout=self.timeout, headers=self.get_request_headers(headers), **kwargs
)
response.raise_for_status()
return response
try:
response = request(
method=method, url=url, timeout=self.timeout, headers=self.get_request_headers(headers), **kwargs
)

response.raise_for_status()
return response
except Exception as e:
self._handle_exception(e)

def get_request_headers(self, additional_headers: dict = None):
def get_request_headers(self, additional_headers: dict = None) -> dict:
if additional_headers is None:
return self.MANDATORY_HEADERS
return self.MANDATORY_HEADERS.copy()
return {**self.MANDATORY_HEADERS, **additional_headers}

def build_full_url(self, url, endpoint):
return f"{url}/{endpoint}"
def build_full_url(self, url: str, endpoint: str) -> str:
return f'{url}/{endpoint}'

def _handle_exception(self, e: Exception):
if isinstance(e, exceptions.Timeout):
raise NetworkError(504, 'Timeout Error', e.response)
elif isinstance(e, exceptions.HTTPError):
self._handle_http_exception(e)
elif isinstance(e, exceptions.ConnectionError):
raise NetworkError(502, 'Connection Error', e.response)
else:
raise e

@staticmethod
def _handle_http_exception(e: exceptions.HTTPError):
if e.response.status_code == 401:
raise HttpUnauthorizedError(e.response.text, e.response)

raise NetworkError(e.response.status_code, e.response.text, e.response)
Loading