Skip to content

Commit a0ca07f

Browse files
authored
CM-22206 Add "auth check" command (#101)
1 parent 76f8dee commit a0ca07f

File tree

12 files changed

+225
-175
lines changed

12 files changed

+225
-175
lines changed

cli/auth/auth_command.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1+
import json
2+
13
import click
24
import traceback
5+
36
from cli.auth.auth_manager import AuthManager
4-
from cli.exceptions.custom_exceptions import AuthProcessError, CycodeError
7+
from cli.user_settings.credentials_manager import CredentialsManager
8+
from cli.exceptions.custom_exceptions import AuthProcessError, NetworkError, HttpUnauthorizedError
59
from cyclient import logger
10+
from cyclient.cycode_token_based_client import CycodeTokenBasedClient
611

712

8-
@click.command()
13+
@click.group(invoke_without_command=True)
914
@click.pass_context
1015
def authenticate(context: click.Context):
1116
""" Authenticates your machine to associate CLI with your cycode account """
17+
if context.invoked_subcommand is not None:
18+
# if it is a subcommand do nothing
19+
return
20+
1221
try:
1322
logger.debug("starting authentication process")
1423
auth_manager = AuthManager()
@@ -18,14 +27,50 @@ def authenticate(context: click.Context):
1827
_handle_exception(context, e)
1928

2029

30+
@authenticate.command(name='check')
31+
@click.pass_context
32+
def authorization_check(context: click.Context):
33+
""" Check your machine associating CLI with your cycode account """
34+
passed_auth_check_args = {'context': context, 'content': {
35+
'success': True,
36+
'message': 'You are authorized'
37+
}, 'color': 'green'}
38+
failed_auth_check_args = {'context': context, 'content': {
39+
'success': False,
40+
'message': 'You are not authorized'
41+
}, 'color': 'red'}
42+
43+
client_id, client_secret = CredentialsManager().get_credentials()
44+
if not client_id or not client_secret:
45+
return _print_result(**failed_auth_check_args)
46+
47+
try:
48+
# TODO(MarshalX): This property performs HTTP request to refresh the token. This must be the method.
49+
if CycodeTokenBasedClient(client_id, client_secret).api_token:
50+
return _print_result(**passed_auth_check_args)
51+
except (NetworkError, HttpUnauthorizedError):
52+
if context.obj['verbose']:
53+
click.secho(f'Error: {traceback.format_exc()}', fg='red', nl=False)
54+
55+
return _print_result(**failed_auth_check_args)
56+
57+
58+
def _print_result(context: click.Context, content: dict, color: str) -> None:
59+
# the current impl of printers supports only results of scans
60+
if context.obj['output'] == 'text':
61+
return click.secho(content['message'], fg=color)
62+
63+
return click.echo(json.dumps({'result': content['success'], 'message': content['message']}))
64+
65+
2166
def _handle_exception(context: click.Context, e: Exception):
2267
verbose = context.obj["verbose"]
2368
if verbose:
2469
click.secho(f'Error: {traceback.format_exc()}', fg='red', nl=False)
2570
if isinstance(e, AuthProcessError):
2671
click.secho('Authentication failed. Please try again later using the command `cycode auth`',
2772
fg='red', nl=False)
28-
elif isinstance(e, CycodeError):
73+
elif isinstance(e, NetworkError):
2974
click.secho('Authentication failed. Please try again later using the command `cycode auth`',
3075
fg='red', nl=False)
3176
elif isinstance(e, click.ClickException):

cli/auth/auth_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def _build_login_url(self, code_challenge: str, session_id: str):
8484
'code_challenge': code_challenge,
8585
'session_id': session_id
8686
}
87+
# TODO(MarshalX). Use auth_client instead and don't depend on "requests" lib here
8788
request = Request(url=login_url, params=query_params)
8889
return request.prepare().url
8990

cli/code_scanner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def _handle_exception(context: click.Context, e: Exception):
836836

837837
# TODO(MarshalX): Create global CLI errors database and move this
838838
errors: CliScanErrors = {
839-
CycodeError: CliScanError(
839+
NetworkError: CliScanError(
840840
soft_fail=True,
841841
code='cycode_error',
842842
message='Cycode was unable to complete this scan. '

cli/cycode.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
help='Run scan without failing, always return a non-error status code',
6161
type=bool,
6262
required=False)
63-
@click.option('--output', default='text',
63+
@click.option('--output', default=None,
6464
help="""
6565
\b
6666
Specify the results output (text/json),
@@ -104,7 +104,9 @@ def code_scan(context: click.Context, scan_type, client_id, secret, show_secret,
104104
context.obj["soft_fail"] = config["soft_fail"]
105105

106106
context.obj["scan_type"] = scan_type
107-
context.obj["output"] = output
107+
if output is not None:
108+
# save backward compatability with old style command
109+
context.obj["output"] = output
108110
context.obj["client"] = get_cycode_client(client_id, secret)
109111
context.obj["severity_threshold"] = severity_threshold
110112
context.obj["monitor"] = monitor
@@ -135,16 +137,25 @@ def finalize(context: click.Context, *args, **kwargs):
135137
@click.option(
136138
"--verbose", "-v", is_flag=True, default=False, help="Show detailed logs",
137139
)
140+
@click.option(
141+
'--output',
142+
default='text',
143+
help='Specify the output (text/json), the default is text',
144+
type=click.Choice(['text', 'json'])
145+
)
138146
@click.version_option(__version__, prog_name="cycode")
139147
@click.pass_context
140-
def main_cli(context: click.Context, verbose: bool):
148+
def main_cli(context: click.Context, verbose: bool, output: str):
141149
context.ensure_object(dict)
142150
configuration_manager = ConfigurationManager()
151+
143152
verbose = verbose or configuration_manager.get_verbose_flag()
144-
context.obj["verbose"] = verbose
153+
context.obj['verbose'] = verbose
145154
log_level = logging.DEBUG if verbose else logging.INFO
146155
logger.setLevel(log_level)
147156

157+
context.obj['output'] = output
158+
148159

149160
def get_cycode_client(client_id, client_secret):
150161
if not client_id or not client_secret:

cli/exceptions/custom_exceptions.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
1+
from requests import Response
2+
3+
14
class CycodeError(Exception):
2-
def __init__(self, status_code: int, error_message: str):
5+
"""Base class for all custom exceptions"""
6+
7+
8+
class NetworkError(CycodeError):
9+
def __init__(self, status_code: int, error_message: str, response: Response):
310
self.status_code = status_code
411
self.error_message = error_message
12+
self.response = response
513
super().__init__(self.error_message)
614

715
def __str__(self):
816
return f'error occurred during the request. status code: {self.status_code}, error message: ' \
917
f'{self.error_message}'
1018

1119

12-
class ScanAsyncError(Exception):
20+
class ScanAsyncError(CycodeError):
1321
def __init__(self, error_message: str):
1422
self.error_message = error_message
1523
super().__init__(self.error_message)
@@ -18,17 +26,18 @@ def __str__(self):
1826
return f'error occurred during the scan. error message: {self.error_message}'
1927

2028

21-
class HttpUnauthorizedError(Exception):
22-
def __init__(self, error_message: str):
29+
class HttpUnauthorizedError(CycodeError):
30+
def __init__(self, error_message: str, response: Response):
2331
self.status_code = 401
2432
self.error_message = error_message
33+
self.response = response
2534
super().__init__(self.error_message)
2635

2736
def __str__(self):
2837
return 'Http Unauthorized Error'
2938

3039

31-
class ZipTooLargeError(Exception):
40+
class ZipTooLargeError(CycodeError):
3241
def __init__(self, size_limit: int):
3342
self.size_limit = size_limit
3443
super().__init__()
@@ -37,7 +46,7 @@ def __str__(self):
3746
return f'The size of zip to scan is too large, size limit: {self.size_limit}'
3847

3948

40-
class AuthProcessError(Exception):
49+
class AuthProcessError(CycodeError):
4150
def __init__(self, error_message: str):
4251
self.error_message = error_message
4352
super().__init__()

cyclient/auth_client.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import requests.exceptions
2-
from requests import Response
31
from typing import Optional
2+
3+
from requests import Response
4+
45
from .cycode_client import CycodeClient
56
from . import models
6-
from cli.exceptions.custom_exceptions import CycodeError
7+
from cli.exceptions.custom_exceptions import NetworkError, HttpUnauthorizedError
78

89

910
class AuthClient:
@@ -12,26 +13,21 @@ class AuthClient:
1213
def __init__(self):
1314
self.cycode_client = CycodeClient()
1415

15-
def start_session(self, code_challenge: str):
16-
path = f"{self.AUTH_CONTROLLER_PATH}/start"
16+
def start_session(self, code_challenge: str) -> models.AuthenticationSession:
17+
path = f'{self.AUTH_CONTROLLER_PATH}/start'
1718
body = {'code_challenge': code_challenge}
18-
try:
19-
response = self.cycode_client.post(url_path=path, body=body)
20-
return self.parse_start_session_response(response)
21-
except requests.exceptions.Timeout as e:
22-
raise CycodeError(504, e.response.text)
23-
except requests.exceptions.HTTPError as e:
24-
raise CycodeError(e.response.status_code, e.response.text)
19+
response = self.cycode_client.post(url_path=path, body=body)
20+
return self.parse_start_session_response(response)
2521

2622
def get_api_token(self, session_id: str, code_verifier: str) -> Optional[models.ApiTokenGenerationPollingResponse]:
27-
path = f"{self.AUTH_CONTROLLER_PATH}/token"
23+
path = f'{self.AUTH_CONTROLLER_PATH}/token'
2824
body = {'session_id': session_id, 'code_verifier': code_verifier}
2925
try:
3026
response = self.cycode_client.post(url_path=path, body=body)
3127
return self.parse_api_token_polling_response(response)
32-
except requests.exceptions.HTTPError as e:
28+
except (NetworkError, HttpUnauthorizedError) as e:
3329
return self.parse_api_token_polling_response(e.response)
34-
except Exception as e:
30+
except Exception:
3531
return None
3632

3733
@staticmethod
@@ -42,5 +38,5 @@ def parse_start_session_response(response: Response) -> models.AuthenticationSes
4238
def parse_api_token_polling_response(response: Response) -> Optional[models.ApiTokenGenerationPollingResponse]:
4339
try:
4440
return models.ApiTokenGenerationPollingResponseSchema().load(response.json())
45-
except Exception as e:
41+
except Exception:
4642
return None

cyclient/cycode_client.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
1-
from cyclient import config, __version__
1+
from cyclient import config
22
from cyclient.cycode_client_base import CycodeClientBase
33

44

55
class CycodeClient(CycodeClientBase):
6-
7-
MANDATORY_HEADERS: dict = {
8-
"User-Agent": f'cycode-cli_{__version__}',
9-
}
10-
116
def __init__(self):
127
super().__init__(config.cycode_api_url)
138
self.timeout = config.timeout
14-

cyclient/cycode_client_base.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
from requests import Response, request
1+
from requests import Response, request, exceptions
22

33
from cyclient import config, __version__
4+
from cli.exceptions.custom_exceptions import NetworkError, HttpUnauthorizedError
45

56

67
class CycodeClientBase:
78

89
MANDATORY_HEADERS: dict = {
9-
"User-Agent": f'cycode-cli_{__version__}',
10+
'User-Agent': f'cycode-cli_{__version__}',
1011
}
1112

12-
def __init__(self, api_url):
13+
def __init__(self, api_url: str):
1314
self.timeout = config.timeout
1415
self.api_url = api_url
1516

@@ -20,8 +21,7 @@ def post(
2021
headers: dict = None,
2122
**kwargs
2223
) -> Response:
23-
return self._execute(
24-
method="post", endpoint=url_path, json=body, headers=headers, **kwargs)
24+
return self._execute(method='post', endpoint=url_path, json=body, headers=headers, **kwargs)
2525

2626
def put(
2727
self,
@@ -30,16 +30,15 @@ def put(
3030
headers: dict = None,
3131
**kwargs
3232
) -> Response:
33-
return self._execute(
34-
method="put", endpoint=url_path, json=body, headers=headers, **kwargs)
33+
return self._execute(method='put', endpoint=url_path, json=body, headers=headers, **kwargs)
3534

3635
def get(
3736
self,
3837
url_path: str,
3938
headers: dict = None,
4039
**kwargs
4140
) -> Response:
42-
return self._execute(method="get", endpoint=url_path, headers=headers, **kwargs)
41+
return self._execute(method='get', endpoint=url_path, headers=headers, **kwargs)
4342

4443
def _execute(
4544
self,
@@ -48,20 +47,39 @@ def _execute(
4847
headers: dict = None,
4948
**kwargs
5049
) -> Response:
51-
5250
url = self.build_full_url(self.api_url, endpoint)
5351

54-
response = request(
55-
method=method, url=url, timeout=self.timeout, headers=self.get_request_headers(headers), **kwargs
56-
)
57-
response.raise_for_status()
58-
return response
52+
try:
53+
response = request(
54+
method=method, url=url, timeout=self.timeout, headers=self.get_request_headers(headers), **kwargs
55+
)
56+
57+
response.raise_for_status()
58+
return response
59+
except Exception as e:
60+
self._handle_exception(e)
5961

60-
def get_request_headers(self, additional_headers: dict = None):
62+
def get_request_headers(self, additional_headers: dict = None) -> dict:
6163
if additional_headers is None:
62-
return self.MANDATORY_HEADERS
64+
return self.MANDATORY_HEADERS.copy()
6365
return {**self.MANDATORY_HEADERS, **additional_headers}
6466

65-
def build_full_url(self, url, endpoint):
66-
return f"{url}/{endpoint}"
67+
def build_full_url(self, url: str, endpoint: str) -> str:
68+
return f'{url}/{endpoint}'
69+
70+
def _handle_exception(self, e: Exception):
71+
if isinstance(e, exceptions.Timeout):
72+
raise NetworkError(504, 'Timeout Error', e.response)
73+
elif isinstance(e, exceptions.HTTPError):
74+
self._handle_http_exception(e)
75+
elif isinstance(e, exceptions.ConnectionError):
76+
raise NetworkError(502, 'Connection Error', e.response)
77+
else:
78+
raise e
79+
80+
@staticmethod
81+
def _handle_http_exception(e: exceptions.HTTPError):
82+
if e.response.status_code == 401:
83+
raise HttpUnauthorizedError(e.response.text, e.response)
6784

85+
raise NetworkError(e.response.status_code, e.response.text, e.response)

0 commit comments

Comments
 (0)