Skip to content

Commit 8cb7da4

Browse files
authored
CM-25040 - Improve type annotations (#137)
1 parent 2368b5a commit 8cb7da4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+381
-326
lines changed

cycode/cli/auth/auth_command.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
invoke_without_command=True, short_help='Authenticates your machine to associate CLI with your cycode account'
1616
)
1717
@click.pass_context
18-
def authenticate(context: click.Context):
18+
def authenticate(context: click.Context) -> None:
1919
if context.invoked_subcommand is not None:
2020
# if it is a subcommand, do nothing
2121
return
@@ -34,7 +34,7 @@ def authenticate(context: click.Context):
3434

3535
@authenticate.command(name='check')
3636
@click.pass_context
37-
def authorization_check(context: click.Context):
37+
def authorization_check(context: click.Context) -> None:
3838
"""Check your machine associating CLI with your cycode account"""
3939
printer = ConsolePrinter(context)
4040

@@ -43,19 +43,22 @@ def authorization_check(context: click.Context):
4343

4444
client_id, client_secret = CredentialsManager().get_credentials()
4545
if not client_id or not client_secret:
46-
return printer.print_result(failed_auth_check_res)
46+
printer.print_result(failed_auth_check_res)
47+
return
4748

4849
try:
4950
if CycodeTokenBasedClient(client_id, client_secret).api_token:
50-
return printer.print_result(passed_auth_check_res)
51+
printer.print_result(passed_auth_check_res)
52+
return
5153
except (NetworkError, HttpUnauthorizedError):
5254
if context.obj['verbose']:
5355
click.secho(f'Error: {traceback.format_exc()}', fg='red')
5456

55-
return printer.print_result(failed_auth_check_res)
57+
printer.print_result(failed_auth_check_res)
58+
return
5659

5760

58-
def _handle_exception(context: click.Context, e: Exception):
61+
def _handle_exception(context: click.Context, e: Exception) -> None:
5962
if context.obj['verbose']:
6063
click.secho(f'Error: {traceback.format_exc()}', fg='red')
6164

@@ -70,7 +73,8 @@ def _handle_exception(context: click.Context, e: Exception):
7073

7174
error = errors.get(type(e))
7275
if error:
73-
return ConsolePrinter(context).print_error(error)
76+
ConsolePrinter(context).print_error(error)
77+
return
7478

7579
if isinstance(e, click.ClickException):
7680
raise e

cycode/cli/auth/auth_manager.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22
import webbrowser
3-
from typing import Optional
3+
from typing import TYPE_CHECKING, Tuple
44

55
from requests import Request
66

@@ -10,7 +10,10 @@
1010
from cycode.cli.utils.string_utils import generate_random_string, hash_string_to_sha256
1111
from cycode.cyclient import logger
1212
from cycode.cyclient.auth_client import AuthClient
13-
from cycode.cyclient.models import ApiToken, ApiTokenGenerationPollingResponse
13+
from cycode.cyclient.models import ApiTokenGenerationPollingResponse
14+
15+
if TYPE_CHECKING:
16+
from cycode.cyclient.models import ApiToken
1417

1518

1619
class AuthManager:
@@ -20,16 +23,12 @@ class AuthManager:
2023
FAILED_POLLING_STATUS = 'Error'
2124
COMPLETED_POLLING_STATUS = 'Completed'
2225

23-
configuration_manager: ConfigurationManager
24-
credentials_manager: CredentialsManager
25-
auth_client: AuthClient
26-
27-
def __init__(self):
26+
def __init__(self) -> None:
2827
self.configuration_manager = ConfigurationManager()
2928
self.credentials_manager = CredentialsManager()
3029
self.auth_client = AuthClient()
3130

32-
def authenticate(self):
31+
def authenticate(self) -> None:
3332
logger.debug('generating pkce code pair')
3433
code_challenge, code_verifier = self._generate_pkce_code_pair()
3534

@@ -46,21 +45,21 @@ def authenticate(self):
4645
logger.debug('saving get api token')
4746
self.save_api_token(api_token)
4847

49-
def start_session(self, code_challenge: str):
48+
def start_session(self, code_challenge: str) -> str:
5049
auth_session = self.auth_client.start_session(code_challenge)
5150
return auth_session.session_id
5251

53-
def redirect_to_login_page(self, code_challenge: str, session_id: str):
52+
def redirect_to_login_page(self, code_challenge: str, session_id: str) -> None:
5453
login_url = self._build_login_url(code_challenge, session_id)
5554
webbrowser.open(login_url)
5655

57-
def get_api_token(self, session_id: str, code_verifier: str) -> Optional[ApiToken]:
56+
def get_api_token(self, session_id: str, code_verifier: str) -> 'ApiToken':
5857
api_token = self.get_api_token_polling(session_id, code_verifier)
5958
if api_token is None:
6059
raise AuthProcessError('getting api token is completed, but the token is missing')
6160
return api_token
6261

63-
def get_api_token_polling(self, session_id: str, code_verifier: str) -> Optional[ApiToken]:
62+
def get_api_token_polling(self, session_id: str, code_verifier: str) -> 'ApiToken':
6463
end_polling_time = time.time() + self.POLLING_TIMEOUT_IN_SECONDS
6564
while time.time() < end_polling_time:
6665
logger.debug('trying to get api token...')
@@ -75,18 +74,18 @@ def get_api_token_polling(self, session_id: str, code_verifier: str) -> Optional
7574

7675
raise AuthProcessError('session expired')
7776

78-
def save_api_token(self, api_token: ApiToken):
77+
def save_api_token(self, api_token: 'ApiToken') -> None:
7978
self.credentials_manager.update_credentials_file(api_token.client_id, api_token.secret)
8079

81-
def _build_login_url(self, code_challenge: str, session_id: str):
80+
def _build_login_url(self, code_challenge: str, session_id: str) -> str:
8281
app_url = self.configuration_manager.get_cycode_app_url()
8382
login_url = f'{app_url}/account/sign-in'
8483
query_params = {'source': 'cycode_cli', 'code_challenge': code_challenge, 'session_id': session_id}
8584
# TODO(MarshalX). Use auth_client instead and don't depend on "requests" lib here
8685
request = Request(url=login_url, params=query_params)
8786
return request.prepare().url
8887

89-
def _generate_pkce_code_pair(self) -> (str, str):
88+
def _generate_pkce_code_pair(self) -> Tuple[str, str]:
9089
code_verifier = generate_random_string(self.CODE_VERIFIER_LENGTH)
9190
code_challenge = hash_string_to_sha256(code_verifier)
9291
return code_challenge, code_verifier

cycode/cli/ci_integrations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import click
44

55

6-
def github_action_range():
6+
def github_action_range() -> str:
77
before_sha = os.getenv('BEFORE_SHA')
88
push_base_sha = os.getenv('BASE_SHA')
99
pr_base_sha = os.getenv('PR_BASE_SHA')
@@ -22,7 +22,7 @@ def github_action_range():
2222
# if push_base_sha and push_base_sha != "null":
2323

2424

25-
def circleci_range():
25+
def circleci_range() -> str:
2626
before_sha = os.getenv('BEFORE_SHA')
2727
current_sha = os.getenv('CURRENT_SHA')
2828
commit_range = f'{before_sha}...{current_sha}'
@@ -36,7 +36,7 @@ def circleci_range():
3636
return f'{commit_sha}~1...'
3737

3838

39-
def gitlab_range():
39+
def gitlab_range() -> str:
4040
before_sha = os.getenv('CI_COMMIT_BEFORE_SHA')
4141
commit_sha = os.getenv('CI_COMMIT_SHA', 'HEAD')
4242

@@ -46,7 +46,7 @@ def gitlab_range():
4646
return f'{commit_sha}'
4747

4848

49-
def get_commit_range():
49+
def get_commit_range() -> str:
5050
if os.getenv('GITHUB_ACTIONS'):
5151
return github_action_range()
5252
if os.getenv('CIRCLECI'):

cycode/cli/code_scanner.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import traceback
77
from platform import platform
88
from sys import getsizeof
9-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
9+
from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Optional, Tuple, Union
1010
from uuid import UUID, uuid4
1111

1212
import click
@@ -40,6 +40,10 @@
4040
from cycode.cyclient.models import Detection, DetectionSchema, DetectionsPerFile, ZippedFileScanResult
4141

4242
if TYPE_CHECKING:
43+
from git import Blob, Diff
44+
from git.objects.base import IndexObjUnion
45+
from git.objects.tree import TraversedTreeTup
46+
4347
from cycode.cli.utils.progress_bar import BaseProgressBar
4448
from cycode.cyclient.models import ScanDetailsResponse
4549
from cycode.cyclient.scan_client import ScanClient
@@ -58,7 +62,7 @@
5862
required=False,
5963
)
6064
@click.pass_context
61-
def scan_repository(context: click.Context, path: str, branch: str):
65+
def scan_repository(context: click.Context, path: str, branch: str) -> None:
6266
try:
6367
logger.debug('Starting repository scan process, %s', {'path': path, 'branch': branch})
6468

@@ -74,11 +78,11 @@ def scan_repository(context: click.Context, path: str, branch: str):
7478

7579
documents_to_scan = []
7680
for file in file_entries:
81+
# FIXME(MarshalX): probably file could be tree or submodule too. we expect blob only
7782
progress_bar.update(ProgressBarSection.PREPARE_LOCAL_FILES)
7883

79-
path = file.path if monitor else get_path_by_os(os.path.join(path, file.path))
80-
81-
documents_to_scan.append(Document(path, file.data_stream.read().decode('UTF-8', errors='replace')))
84+
file_path = file.path if monitor else get_path_by_os(os.path.join(path, file.path))
85+
documents_to_scan.append(Document(file_path, file.data_stream.read().decode('UTF-8', errors='replace')))
8286

8387
documents_to_scan = exclude_irrelevant_documents_to_scan(context, documents_to_scan)
8488

@@ -103,15 +107,17 @@ def scan_repository(context: click.Context, path: str, branch: str):
103107
required=False,
104108
)
105109
@click.pass_context
106-
def scan_repository_commit_history(context: click.Context, path: str, commit_range: str):
110+
def scan_repository_commit_history(context: click.Context, path: str, commit_range: str) -> None:
107111
try:
108112
logger.debug('Starting commit history scan process, %s', {'path': path, 'commit_range': commit_range})
109-
return scan_commit_range(context, path=path, commit_range=commit_range)
113+
scan_commit_range(context, path=path, commit_range=commit_range)
110114
except Exception as e:
111115
_handle_exception(context, e)
112116

113117

114-
def scan_commit_range(context: click.Context, path: str, commit_range: str, max_commits_count: Optional[int] = None):
118+
def scan_commit_range(
119+
context: click.Context, path: str, commit_range: str, max_commits_count: Optional[int] = None
120+
) -> None:
115121
scan_type = context.obj['scan_type']
116122
progress_bar = context.obj['progress_bar']
117123

@@ -166,38 +172,40 @@ def scan_commit_range(context: click.Context, path: str, commit_range: str, max_
166172
logger.debug('List of commit ids to scan, %s', {'commit_ids': commit_ids_to_scan})
167173
logger.debug('Starting to scan commit range (It may take a few minutes)')
168174

169-
return scan_documents(context, documents_to_scan, is_git_diff=True, is_commit_range=True)
175+
scan_documents(context, documents_to_scan, is_git_diff=True, is_commit_range=True)
176+
return None
170177

171178

172179
@click.command(
173180
short_help='Execute scan in a CI environment which relies on the '
174181
'CYCODE_TOKEN and CYCODE_REPO_LOCATION environment variables'
175182
)
176183
@click.pass_context
177-
def scan_ci(context: click.Context):
178-
return scan_commit_range(context, path=os.getcwd(), commit_range=get_commit_range())
184+
def scan_ci(context: click.Context) -> None:
185+
scan_commit_range(context, path=os.getcwd(), commit_range=get_commit_range())
179186

180187

181188
@click.command(short_help='Scan the files in the path supplied in the command')
182189
@click.argument('path', nargs=1, type=click.STRING, required=True)
183190
@click.pass_context
184-
def scan_path(context: click.Context, path):
191+
def scan_path(context: click.Context, path: str) -> None:
185192
logger.debug('Starting path scan process, %s', {'path': path})
186193
files_to_scan = get_relevant_files_in_path(path=path, exclude_patterns=['**/.git/**', '**/.cycode/**'])
187194
files_to_scan = exclude_irrelevant_files(context, files_to_scan)
188195
logger.debug('Found all relevant files for scanning %s', {'path': path, 'file_to_scan_count': len(files_to_scan)})
189-
return scan_disk_files(context, path, files_to_scan)
196+
scan_disk_files(context, path, files_to_scan)
190197

191198

192199
@click.command(short_help='Use this command to scan the content that was not committed yet')
193200
@click.argument('ignored_args', nargs=-1, type=click.UNPROCESSED)
194201
@click.pass_context
195-
def pre_commit_scan(context: click.Context, ignored_args: List[str]):
202+
def pre_commit_scan(context: click.Context, ignored_args: List[str]) -> None:
196203
scan_type = context.obj['scan_type']
197204
progress_bar = context.obj['progress_bar']
198205

199206
if scan_type == consts.SCA_SCAN_TYPE:
200-
return scan_sca_pre_commit(context)
207+
scan_sca_pre_commit(context)
208+
return
201209

202210
diff_files = Repo(os.getcwd()).index.diff('HEAD', create_patch=True, R=True)
203211

@@ -209,13 +217,13 @@ def pre_commit_scan(context: click.Context, ignored_args: List[str]):
209217
documents_to_scan.append(Document(get_path_by_os(get_diff_file_path(file)), get_diff_file_content(file)))
210218

211219
documents_to_scan = exclude_irrelevant_documents_to_scan(context, documents_to_scan)
212-
return scan_documents(context, documents_to_scan, is_git_diff=True)
220+
scan_documents(context, documents_to_scan, is_git_diff=True)
213221

214222

215223
@click.command(short_help='Use this command to scan commits on the server side before pushing them to the repository')
216224
@click.argument('ignored_args', nargs=-1, type=click.UNPROCESSED)
217225
@click.pass_context
218-
def pre_receive_scan(context: click.Context, ignored_args: List[str]):
226+
def pre_receive_scan(context: click.Context, ignored_args: List[str]) -> None:
219227
try:
220228
scan_type = context.obj['scan_type']
221229
if scan_type != consts.SECRET_SCAN_TYPE:
@@ -253,13 +261,13 @@ def pre_receive_scan(context: click.Context, ignored_args: List[str]):
253261
_handle_exception(context, e)
254262

255263

256-
def scan_sca_pre_commit(context: click.Context):
264+
def scan_sca_pre_commit(context: click.Context) -> None:
257265
scan_parameters = get_default_scan_parameters(context)
258266
git_head_documents, pre_committed_documents = get_pre_commit_modified_documents(context.obj['progress_bar'])
259267
git_head_documents = exclude_irrelevant_documents_to_scan(context, git_head_documents)
260268
pre_committed_documents = exclude_irrelevant_documents_to_scan(context, pre_committed_documents)
261269
sca_code_scanner.perform_pre_hook_range_scan_actions(git_head_documents, pre_committed_documents)
262-
return scan_commit_range_documents(
270+
scan_commit_range_documents(
263271
context,
264272
git_head_documents,
265273
pre_committed_documents,
@@ -268,7 +276,7 @@ def scan_sca_pre_commit(context: click.Context):
268276
)
269277

270278

271-
def scan_sca_commit_range(context: click.Context, path: str, commit_range: str):
279+
def scan_sca_commit_range(context: click.Context, path: str, commit_range: str) -> None:
272280
progress_bar = context.obj['progress_bar']
273281

274282
scan_parameters = get_scan_parameters(context, path)
@@ -282,12 +290,10 @@ def scan_sca_commit_range(context: click.Context, path: str, commit_range: str):
282290
path, from_commit_documents, from_commit_rev, to_commit_documents, to_commit_rev
283291
)
284292

285-
return scan_commit_range_documents(
286-
context, from_commit_documents, to_commit_documents, scan_parameters=scan_parameters
287-
)
293+
scan_commit_range_documents(context, from_commit_documents, to_commit_documents, scan_parameters=scan_parameters)
288294

289295

290-
def scan_disk_files(context: click.Context, path: str, files_to_scan: List[str]):
296+
def scan_disk_files(context: click.Context, path: str, files_to_scan: List[str]) -> None:
291297
scan_parameters = get_scan_parameters(context, path)
292298
scan_type = context.obj['scan_type']
293299
progress_bar = context.obj['progress_bar']
@@ -307,7 +313,7 @@ def scan_disk_files(context: click.Context, path: str, files_to_scan: List[str])
307313
continue
308314

309315
perform_pre_scan_documents_actions(context, scan_type, documents, is_git_diff)
310-
return scan_documents(context, documents, is_git_diff=is_git_diff, scan_parameters=scan_parameters)
316+
scan_documents(context, documents, is_git_diff=is_git_diff, scan_parameters=scan_parameters)
311317

312318

313319
def set_issue_detected_by_scan_results(context: click.Context, scan_results: List[LocalScanResult]) -> None:
@@ -757,19 +763,21 @@ def get_oldest_unupdated_commit_for_branch(commit: str) -> Optional[str]:
757763
return commits[0]
758764

759765

760-
def get_diff_file_path(file):
766+
def get_diff_file_path(file: 'Diff') -> Optional[str]:
761767
return file.b_path if file.b_path else file.a_path
762768

763769

764-
def get_diff_file_content(file):
770+
def get_diff_file_content(file: 'Diff') -> str:
765771
return file.diff.decode('UTF-8', errors='replace')
766772

767773

768-
def should_process_git_object(obj, _: int) -> bool:
774+
def should_process_git_object(obj: 'Blob', _: int) -> bool:
769775
return obj.type == 'blob' and obj.size > 0
770776

771777

772-
def get_git_repository_tree_file_entries(path: str, branch: str):
778+
def get_git_repository_tree_file_entries(
779+
path: str, branch: str
780+
) -> Union[Iterator['IndexObjUnion'], Iterator['TraversedTreeTup']]:
773781
return Repo(path).tree(branch).traverse(predicate=should_process_git_object)
774782

775783

@@ -867,7 +875,7 @@ def _exclude_detections_by_scan_type(
867875
return detections
868876

869877

870-
def exclude_detections_in_deleted_lines(detections) -> List:
878+
def exclude_detections_in_deleted_lines(detections: List[Detection]) -> List[Detection]:
871879
return [detection for detection in detections if detection.detection_details.get('line_type') != 'Removed']
872880

873881

@@ -969,7 +977,7 @@ def _should_exclude_detection(detection: Detection, exclusions: Dict) -> bool:
969977
return False
970978

971979

972-
def _is_detection_sha_configured_in_exclusions(detection, exclusions: List[str]) -> bool:
980+
def _is_detection_sha_configured_in_exclusions(detection: Detection, exclusions: List[str]) -> bool:
973981
detection_sha = detection.detection_details.get('sha512', '')
974982
return detection_sha in exclusions
975983

0 commit comments

Comments
 (0)