Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion safety/scan/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from enum import Enum
import logging
from pathlib import Path
from urllib.parse import urlparse, parse_qs, urlencode, urlunparse

import json
import sys
Expand Down Expand Up @@ -282,7 +283,14 @@ def process_report(

# Append the branch name if available
if branch_name:
project_url_with_branch = f"{project_url}?branch={branch_name}"
# Parse the URL to handle existing query parameters properly
parsed_url = urlparse(project_url)
query_params = parse_qs(parsed_url.query)
query_params["branch"] = [branch_name]
new_query = urlencode(query_params, doseq=True)
project_url_with_branch = urlunparse(
parsed_url._replace(query=new_query)
)
else:
project_url_with_branch = project_url

Expand Down
81 changes: 75 additions & 6 deletions tests/scan/test_command.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import unittest
import tempfile
from urllib.parse import urlparse, parse_qs

from click.testing import CliRunner
from safety.auth.models import Auth
from safety.cli import cli
from safety.console import main_console as console
from unittest.mock import patch

class TestScanCommand(unittest.TestCase):

class TestScanCommand(unittest.TestCase):
def setUp(self):
self.runner = CliRunner(mix_stderr=False)
self.target = tempfile.mkdtemp()
Expand All @@ -20,14 +21,82 @@ def setUp(self):
cli.commands = cli.all_commands
self.cli = cli

@patch.object(Auth, 'is_valid', return_value=False)
@patch('safety.auth.utils.SafetyAuthSession.get_authentication_type', return_value="unauthenticated")
@patch.object(Auth, "is_valid", return_value=False)
@patch(
"safety.auth.utils.SafetyAuthSession.get_authentication_type",
return_value="unauthenticated",
)
def test_scan(self, mock_is_valid, mock_get_auth_type):
result = self.runner.invoke(self.cli, ["scan", "--target", self.target, "--output", "json"])
result = self.runner.invoke(
self.cli, ["scan", "--target", self.target, "--output", "json"]
)
self.assertEqual(result.exit_code, 1)

result = self.runner.invoke(self.cli, ["--stage", "production", "scan", "--target", self.target, "--output", "json"])
result = self.runner.invoke(
self.cli,
[
"--stage",
"production",
"scan",
"--target",
self.target,
"--output",
"json",
],
)
self.assertEqual(result.exit_code, 1)

result = self.runner.invoke(self.cli, ["--stage", "cicd", "scan", "--target", self.target, "--output", "screen"])
result = self.runner.invoke(
self.cli,
["--stage", "cicd", "scan", "--target", self.target, "--output", "screen"],
)
self.assertEqual(result.exit_code, 1)

def test_url_parameter_handling(self):
"""Test that branch parameters are properly added to URLs with existing query parameters."""
from urllib.parse import urlencode, urlunparse

# Test cases: (input_url, branch_name, expected_url)
test_cases = [
# URL without existing parameters
(
"https://platform.safetycli.com/project/test",
"master",
"https://platform.safetycli.com/project/test?branch=master",
),
# URL with existing parameters
(
"https://platform.safetycli.com/project/test?env=prod",
"feature-branch",
"https://platform.safetycli.com/project/test?env=prod&branch=feature-branch",
),
# URL with multiple existing parameters
(
"https://platform.safetycli.com/project/test?env=prod&org=myorg",
"main",
"https://platform.safetycli.com/project/test?env=prod&org=myorg&branch=main",
),
]

for input_url, branch_name, expected_url in test_cases:
with self.subTest(input_url=input_url, branch_name=branch_name):
# This is the same logic as in scan/command.py lines 287-291
parsed_url = urlparse(input_url)
query_params = parse_qs(parsed_url.query)
query_params["branch"] = [branch_name]
new_query = urlencode(query_params, doseq=True)
result_url = urlunparse(parsed_url._replace(query=new_query))

# Parse both URLs to compare query parameters (order might differ)
expected_parsed = urlparse(expected_url)
result_parsed = urlparse(result_url)

# Check that base URL is the same
self.assertEqual(result_parsed.scheme, expected_parsed.scheme)
self.assertEqual(result_parsed.netloc, expected_parsed.netloc)
self.assertEqual(result_parsed.path, expected_parsed.path)

# Check query parameters
expected_params = parse_qs(expected_parsed.query)
result_params = parse_qs(result_parsed.query)
self.assertEqual(result_params, expected_params)
Loading