diff --git a/src/gitingest/__init__.py b/src/gitingest/__init__.py index 0248ad0e..75f3ea41 100644 --- a/src/gitingest/__init__.py +++ b/src/gitingest/__init__.py @@ -1,8 +1,5 @@ """Gitingest: A package for ingesting data from Git repositories.""" -from gitingest.clone import clone_repo from gitingest.entrypoint import ingest, ingest_async -from gitingest.ingestion import ingest_query -from gitingest.query_parser import parse_query -__all__ = ["clone_repo", "ingest", "ingest_async", "ingest_query", "parse_query"] +__all__ = ["ingest", "ingest_async"] diff --git a/src/gitingest/clone.py b/src/gitingest/clone.py index 1f091486..6ccf599b 100644 --- a/src/gitingest/clone.py +++ b/src/gitingest/clone.py @@ -8,13 +8,15 @@ from gitingest.config import DEFAULT_TIMEOUT from gitingest.utils.git_utils import ( check_repo_exists, + checkout_partial_clone, create_git_auth_header, create_git_command, ensure_git_installed, is_github_host, + resolve_commit, run_command, ) -from gitingest.utils.os_utils import ensure_directory +from gitingest.utils.os_utils import ensure_directory_exists_or_create from gitingest.utils.timeout_wrapper import async_timeout if TYPE_CHECKING: @@ -45,71 +47,42 @@ async def clone_repo(config: CloneConfig, *, token: str | None = None) -> None: # Extract and validate query parameters url: str = config.url local_path: str = config.local_path - commit: str | None = config.commit - branch: str | None = config.branch - tag: str | None = config.tag partial_clone: bool = config.subpath != "/" - # Create parent directory if it doesn't exist - await ensure_directory(Path(local_path).parent) + await ensure_git_installed() + await ensure_directory_exists_or_create(Path(local_path).parent) - # Check if the repository exists if not await check_repo_exists(url, token=token): msg = "Repository not found. Make sure it is public or that you have provided a valid token." raise ValueError(msg) + commit = await resolve_commit(config, token=token) + clone_cmd = ["git"] if token and is_github_host(url): clone_cmd += ["-c", create_git_auth_header(token, url=url)] - clone_cmd += ["clone", "--single-branch"] - - if config.include_submodules: - clone_cmd += ["--recurse-submodules"] - + clone_cmd += ["clone", "--single-branch", "--no-checkout", "--depth=1"] if partial_clone: clone_cmd += ["--filter=blob:none", "--sparse"] - # Shallow clone unless a specific commit is requested - if not commit: - clone_cmd += ["--depth=1"] - - # Prefer tag over branch when both are provided - if tag: - clone_cmd += ["--branch", tag] - elif branch and branch.lower() not in ("main", "master"): - clone_cmd += ["--branch", branch] - clone_cmd += [url, local_path] # Clone the repository - await ensure_git_installed() await run_command(*clone_cmd) # Checkout the subpath if it is a partial clone if partial_clone: - await _checkout_partial_clone(config, token) + await checkout_partial_clone(config, token=token) - # Checkout the commit if it is provided - if commit: - checkout_cmd = create_git_command(["git"], local_path, url, token) - await run_command(*checkout_cmd, "checkout", commit) + git = create_git_command(["git"], local_path, url, token) + # Ensure the commit is locally available + await run_command(*git, "fetch", "--depth=1", "origin", commit) -async def _checkout_partial_clone(config: CloneConfig, token: str | None) -> None: - """Configure sparse-checkout for a partially cloned repository. + # Write the work-tree at that commit + await run_command(*git, "checkout", commit) - Parameters - ---------- - config : CloneConfig - The configuration for cloning the repository, including subpath and blob flag. - token : str | None - GitHub personal access token (PAT) for accessing private repositories. - - """ - subpath = config.subpath.lstrip("/") - if config.blob: - # Remove the file name from the subpath when ingesting from a file url (e.g. blob/branch/path/file.txt) - subpath = str(Path(subpath).parent.as_posix()) - checkout_cmd = create_git_command(["git"], config.local_path, config.url, token) - await run_command(*checkout_cmd, "sparse-checkout", "set", subpath) + # Update submodules + if config.include_submodules: + await run_command(*git, "submodule", "update", "--init", "--recursive", "--depth=1") diff --git a/src/gitingest/entrypoint.py b/src/gitingest/entrypoint.py index f64dec08..ac7eaf8f 100644 --- a/src/gitingest/entrypoint.py +++ b/src/gitingest/entrypoint.py @@ -8,14 +8,21 @@ import warnings from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncGenerator +from typing import TYPE_CHECKING, AsyncGenerator +from urllib.parse import urlparse from gitingest.clone import clone_repo from gitingest.config import MAX_FILE_SIZE from gitingest.ingestion import ingest_query -from gitingest.query_parser import IngestionQuery, parse_query +from gitingest.query_parser import parse_local_dir_path, parse_remote_repo from gitingest.utils.auth import resolve_token +from gitingest.utils.compat_func import removesuffix from gitingest.utils.ignore_patterns import load_ignore_patterns +from gitingest.utils.pattern_utils import process_patterns +from gitingest.utils.query_parser_utils import KNOWN_GIT_HOSTS + +if TYPE_CHECKING: + from gitingest.schemas import IngestionQuery async def ingest_async( @@ -74,23 +81,28 @@ async def ingest_async( """ token = resolve_token(token) - query: IngestionQuery = await parse_query( - source=source, - max_file_size=max_file_size, - from_web=False, + source = removesuffix(source.strip(), ".git") + + # Determine the parsing method based on the source type + if urlparse(source).scheme in ("https", "http") or any(h in source for h in KNOWN_GIT_HOSTS): + # We either have a full URL or a domain-less slug + query = await parse_remote_repo(source, token=token) + query.include_submodules = include_submodules + _override_branch_and_tag(query, branch=branch, tag=tag) + + else: + # Local path scenario + query = parse_local_dir_path(source) + + query.max_file_size = max_file_size + query.ignore_patterns, query.include_patterns = process_patterns( + exclude_patterns=exclude_patterns, include_patterns=include_patterns, - ignore_patterns=exclude_patterns, - token=token, ) if not include_gitignored: _apply_gitignores(query) - if query.url: - _override_branch_and_tag(query, branch=branch, tag=tag) - - query.include_submodules = include_submodules - async with _clone_repo_if_remote(query, token=token): summary, tree, content = ingest_query(query) await _write_output(tree, content=content, target=output) diff --git a/src/gitingest/ingestion.py b/src/gitingest/ingestion.py index 2990a875..489a41a4 100644 --- a/src/gitingest/ingestion.py +++ b/src/gitingest/ingestion.py @@ -11,7 +11,7 @@ from gitingest.utils.ingestion_utils import _should_exclude, _should_include if TYPE_CHECKING: - from gitingest.query_parser import IngestionQuery + from gitingest.schemas import IngestionQuery def ingest_query(query: IngestionQuery) -> tuple[str, str, str]: diff --git a/src/gitingest/output_formatter.py b/src/gitingest/output_formatter.py index 94bbee62..27ad10ae 100644 --- a/src/gitingest/output_formatter.py +++ b/src/gitingest/output_formatter.py @@ -10,7 +10,7 @@ from gitingest.utils.compat_func import readlink if TYPE_CHECKING: - from gitingest.query_parser import IngestionQuery + from gitingest.schemas import IngestionQuery _TOKEN_THRESHOLDS: list[tuple[int, str]] = [ (1_000_000, "M"), @@ -84,6 +84,8 @@ def _create_summary_prefix(query: IngestionQuery, *, single_file: bool = False) if query.commit: parts.append(f"Commit: {query.commit}") + elif query.tag: + parts.append(f"Tag: {query.tag}") elif query.branch and query.branch not in ("main", "master"): parts.append(f"Branch: {query.branch}") diff --git a/src/gitingest/query_parser.py b/src/gitingest/query_parser.py index 5fabb226..0626fb8a 100644 --- a/src/gitingest/query_parser.py +++ b/src/gitingest/query_parser.py @@ -2,7 +2,6 @@ from __future__ import annotations -import re import uuid import warnings from pathlib import Path @@ -10,91 +9,18 @@ from gitingest.config import TMP_BASE_PATH from gitingest.schemas import IngestionQuery -from gitingest.utils.exceptions import InvalidPatternError from gitingest.utils.git_utils import check_repo_exists, fetch_remote_branches_or_tags -from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS from gitingest.utils.query_parser_utils import ( KNOWN_GIT_HOSTS, _get_user_and_repo_from_path, _is_valid_git_commit_hash, - _is_valid_pattern, _validate_host, _validate_url_scheme, ) -async def parse_query( - source: str, - *, - max_file_size: int, - from_web: bool, - include_patterns: set[str] | str | None = None, - ignore_patterns: set[str] | str | None = None, - token: str | None = None, -) -> IngestionQuery: - """Parse the input source to extract details for the query and process the include and ignore patterns. - - Parameters - ---------- - source : str - The source URL or file path to parse. - max_file_size : int - The maximum file size in bytes to include. - from_web : bool - Flag indicating whether the source is a web URL. - include_patterns : set[str] | str | None - Patterns to include. Can be a set of strings or a single string. - ignore_patterns : set[str] | str | None - Patterns to ignore. Can be a set of strings or a single string. - token : str | None - GitHub personal access token (PAT) for accessing private repositories. - - Returns - ------- - IngestionQuery - A dataclass object containing the parsed details of the repository or file path. - - """ - # Determine the parsing method based on the source type - if from_web or urlparse(source).scheme in ("https", "http") or any(h in source for h in KNOWN_GIT_HOSTS): - # We either have a full URL or a domain-less slug - query = await _parse_remote_repo(source, token=token) - else: - # Local path scenario - query = _parse_local_dir_path(source) - - # Combine default ignore patterns + custom patterns - ignore_patterns_set = DEFAULT_IGNORE_PATTERNS.copy() - if ignore_patterns: - ignore_patterns_set.update(_parse_patterns(ignore_patterns)) - - # Process include patterns and override ignore patterns accordingly - if include_patterns: - parsed_include = _parse_patterns(include_patterns) - # Override ignore patterns with include patterns - ignore_patterns_set = set(ignore_patterns_set) - set(parsed_include) - else: - parsed_include = None - - return IngestionQuery( - user_name=query.user_name, - repo_name=query.repo_name, - url=query.url, - subpath=query.subpath, - local_path=query.local_path, - slug=query.slug, - id=query.id, - type=query.type, - branch=query.branch, - commit=query.commit, - max_file_size=max_file_size, - ignore_patterns=ignore_patterns_set, - include_patterns=parsed_include, - ) - - -async def _parse_remote_repo(source: str, token: str | None = None) -> IngestionQuery: - """Parse a repository URL into a structured query dictionary. +async def parse_remote_repo(source: str, token: str | None = None) -> IngestionQuery: + """Parse a repository URL and return an ``IngestionQuery`` object. If source is: - A fully qualified URL ('https://gitlab.com/...'), parse & verify that domain @@ -143,7 +69,8 @@ async def _parse_remote_repo(source: str, token: str | None = None) -> Ingestion local_path = TMP_BASE_PATH / _id / slug url = f"https://{host}/{user_name}/{repo_name}" - parsed = IngestionQuery( + query = IngestionQuery( + host=host, user_name=user_name, repo_name=repo_name, url=url, @@ -155,37 +82,37 @@ async def _parse_remote_repo(source: str, token: str | None = None) -> Ingestion remaining_parts = parsed_url.path.strip("/").split("/")[2:] if not remaining_parts: - return parsed + return query possible_type = remaining_parts.pop(0) # e.g. 'issues', 'pull', 'tree', 'blob' # If no extra path parts, just return if not remaining_parts: - return parsed + return query # If this is an issues page or pull requests, return early without processing subpath # TODO: Handle issues and pull requests if remaining_parts and possible_type in {"issues", "pull"}: msg = f"Warning: Issues and pull requests are not yet supported: {url}. Returning repository root." warnings.warn(msg, RuntimeWarning, stacklevel=2) - return parsed + return query if possible_type not in {"tree", "blob"}: # TODO: Handle other types msg = f"Warning: Type '{possible_type}' is not yet supported: {url}. Returning repository root." warnings.warn(msg, RuntimeWarning, stacklevel=2) - return parsed + return query - parsed.type = possible_type # 'tree' or 'blob' + query.type = possible_type # Commit, branch, or tag commit_or_branch_or_tag = remaining_parts[0] if _is_valid_git_commit_hash(commit_or_branch_or_tag): # Commit - parsed.commit = commit_or_branch_or_tag + query.commit = commit_or_branch_or_tag remaining_parts.pop(0) # Consume the commit hash else: # Branch or tag # Try to resolve a tag - parsed.tag = await _configure_branch_or_tag( + query.tag = await _configure_branch_or_tag( remaining_parts, url=url, ref_type="tags", @@ -193,8 +120,8 @@ async def _parse_remote_repo(source: str, token: str | None = None) -> Ingestion ) # If no tag found, try to resolve a branch - if not parsed.tag: - parsed.branch = await _configure_branch_or_tag( + if not query.tag: + query.branch = await _configure_branch_or_tag( remaining_parts, url=url, ref_type="branches", @@ -202,10 +129,29 @@ async def _parse_remote_repo(source: str, token: str | None = None) -> Ingestion ) # Only configure subpath if we have identified a commit, branch, or tag. - if remaining_parts and (parsed.commit or parsed.branch or parsed.tag): - parsed.subpath += "/".join(remaining_parts) + if remaining_parts and (query.commit or query.branch or query.tag): + query.subpath += "/".join(remaining_parts) + + return query - return parsed + +def parse_local_dir_path(path_str: str) -> IngestionQuery: + """Parse the given file path into a structured query dictionary. + + Parameters + ---------- + path_str : str + The file path to parse. + + Returns + ------- + IngestionQuery + A dictionary containing the parsed details of the file path. + + """ + path_obj = Path(path_str).resolve() + slug = path_obj.name if path_str == "." else path_str.strip("/") + return IngestionQuery(local_path=path_obj, slug=slug, id=str(uuid.uuid4())) async def _configure_branch_or_tag( @@ -269,69 +215,6 @@ async def _configure_branch_or_tag( return None -def _parse_patterns(pattern: set[str] | str) -> set[str]: - """Parse and validate file/directory patterns for inclusion or exclusion. - - Takes either a single pattern string or set of pattern strings and processes them into a normalized list. - Patterns are split on commas and spaces, validated for allowed characters, and normalized. - - Parameters - ---------- - pattern : set[str] | str - Pattern(s) to parse - either a single string or set of strings - - Returns - ------- - set[str] - A set of normalized patterns. - - Raises - ------ - InvalidPatternError - If any pattern contains invalid characters. Only alphanumeric characters, - dash (-), underscore (_), dot (.), forward slash (/), plus (+), and - asterisk (*) are allowed. - - """ - patterns = pattern if isinstance(pattern, set) else {pattern} - - parsed_patterns: set[str] = set() - for p in patterns: - parsed_patterns = parsed_patterns.union(set(re.split(",| ", p))) - - # Remove empty string if present - parsed_patterns = parsed_patterns - {""} - - # Normalize Windows paths to Unix-style paths - parsed_patterns = {p.replace("\\", "/") for p in parsed_patterns} - - # Validate and normalize each pattern - for p in parsed_patterns: - if not _is_valid_pattern(p): - raise InvalidPatternError(p) - - return parsed_patterns - - -def _parse_local_dir_path(path_str: str) -> IngestionQuery: - """Parse the given file path into a structured query dictionary. - - Parameters - ---------- - path_str : str - The file path to parse. - - Returns - ------- - IngestionQuery - A dictionary containing the parsed details of the file path. - - """ - path_obj = Path(path_str).resolve() - slug = path_obj.name if path_str == "." else path_str.strip("/") - return IngestionQuery(local_path=path_obj, slug=slug, id=str(uuid.uuid4())) - - async def try_domains_for_user_and_repo(user_name: str, repo_name: str, token: str | None = None) -> str: """Attempt to find a valid repository host for the given ``user_name`` and ``repo_name``. diff --git a/src/gitingest/schemas/ingestion.py b/src/gitingest/schemas/ingestion.py index c40e11d6..9aa1c818 100644 --- a/src/gitingest/schemas/ingestion.py +++ b/src/gitingest/schemas/ingestion.py @@ -53,6 +53,8 @@ class IngestionQuery(BaseModel): # pylint: disable=too-many-instance-attributes Attributes ---------- + host : str | None + The host of the repository. user_name : str | None The username or owner of the repository. repo_name : str | None @@ -86,6 +88,7 @@ class IngestionQuery(BaseModel): # pylint: disable=too-many-instance-attributes """ + host: str | None = None user_name: str | None = None repo_name: str | None = None local_path: Path diff --git a/src/gitingest/utils/git_utils.py b/src/gitingest/utils/git_utils.py index f4215ca4..a094e944 100644 --- a/src/gitingest/utils/git_utils.py +++ b/src/gitingest/utils/git_utils.py @@ -6,7 +6,8 @@ import base64 import re import sys -from typing import Final +from pathlib import Path +from typing import TYPE_CHECKING, Final, Iterable from urllib.parse import urlparse import httpx @@ -16,6 +17,9 @@ from gitingest.utils.exceptions import InvalidGitHubTokenError from server.server_utils import Colors +if TYPE_CHECKING: + from gitingest.schemas import CloneConfig + # GitHub Personal-Access tokens (classic + fine-grained). # - ghp_ / gho_ / ghu_ / ghs_ / ghr_ → 36 alphanumerics # - github_pat_ → 22 alphanumerics + "_" + 59 alphanumerics @@ -237,7 +241,6 @@ async def fetch_remote_branches_or_tags(url: str, *, ref_type: str, token: str | await ensure_git_installed() stdout, _ = await run_command(*cmd) - # For each line in the output: # - Skip empty lines and lines that don't contain "refs/{to_fetch}/" # - Extract the branch or tag name after "refs/{to_fetch}/" @@ -321,3 +324,126 @@ def validate_github_token(token: str) -> None: """ if not re.fullmatch(_GITHUB_PAT_PATTERN, token): raise InvalidGitHubTokenError + + +async def checkout_partial_clone(config: CloneConfig, token: str | None) -> None: + """Configure sparse-checkout for a partially cloned repository. + + Parameters + ---------- + config : CloneConfig + The configuration for cloning the repository, including subpath and blob flag. + token : str | None + GitHub personal access token (PAT) for accessing private repositories. + + """ + subpath = config.subpath.lstrip("/") + if config.blob: + # Remove the file name from the subpath when ingesting from a file url (e.g. blob/branch/path/file.txt) + subpath = str(Path(subpath).parent.as_posix()) + checkout_cmd = create_git_command(["git"], config.local_path, config.url, token) + await run_command(*checkout_cmd, "sparse-checkout", "set", subpath) + + +async def resolve_commit(config: CloneConfig, token: str | None) -> str: + """Resolve the commit to use for the clone. + + Parameters + ---------- + config : CloneConfig + The configuration for cloning the repository. + token : str | None + GitHub personal access token (PAT) for accessing private repositories. + + Returns + ------- + str + The commit SHA. + + """ + if config.commit: + commit = config.commit + elif config.tag: + commit = await _resolve_ref_to_sha(config.url, pattern=f"refs/tags/{config.tag}*", token=token) + elif config.branch: + commit = await _resolve_ref_to_sha(config.url, pattern=f"refs/heads/{config.branch}", token=token) + else: + commit = await _resolve_ref_to_sha(config.url, pattern="HEAD", token=token) + return commit + + +async def _resolve_ref_to_sha(url: str, pattern: str, token: str | None = None) -> str: + """Return the commit SHA that / points to in . + + * Branch → first line from ``git ls-remote``. + * Tag → if annotated, prefer the peeled ``^{}`` line (commit). + + Parameters + ---------- + url : str + The URL of the remote repository. + pattern : str + The pattern to use to resolve the commit SHA. + token : str | None + GitHub personal access token (PAT) for accessing private repositories. + + Returns + ------- + str + The commit SHA. + + Raises + ------ + ValueError + If the ref does not exist in the remote repository. + + """ + # Build: git [-c http./.extraheader=Auth...] ls-remote + cmd: list[str] = ["git"] + if token and is_github_host(url): + cmd += ["-c", create_git_auth_header(token, url=url)] + + cmd += ["ls-remote", url, pattern] + stdout, _ = await run_command(*cmd) + lines = stdout.decode().splitlines() + sha = _pick_commit_sha(lines) + if not sha: + msg = f"{pattern!r} not found in {url}" + raise ValueError(msg) + + return sha + + +def _pick_commit_sha(lines: Iterable[str]) -> str | None: + """Return a commit SHA from ``git ls-remote`` output. + + • Annotated tag → prefer the peeled line ( refs/tags/x^{}) + • Branch / lightweight tag → first non-peeled line + + + Parameters + ---------- + lines : Iterable[str] + The lines of a ``git ls-remote`` output. + + Returns + ------- + str | None + The commit SHA, or ``None`` if no commit SHA is found. + + """ + first_non_peeled: str | None = None + + for ln in lines: + if not ln.strip(): + continue + + sha, ref = ln.split(maxsplit=1) + + if ref.endswith("^{}"): # peeled commit of annotated tag + return sha # ← best match, done + + if first_non_peeled is None: # remember the first ordinary line + first_non_peeled = sha + + return first_non_peeled # branch or lightweight tag (or None) diff --git a/src/gitingest/utils/os_utils.py b/src/gitingest/utils/os_utils.py index d90dddd2..e9c3b3e4 100644 --- a/src/gitingest/utils/os_utils.py +++ b/src/gitingest/utils/os_utils.py @@ -3,7 +3,7 @@ from pathlib import Path -async def ensure_directory(path: Path) -> None: +async def ensure_directory_exists_or_create(path: Path) -> None: """Ensure the directory exists, creating it if necessary. Parameters diff --git a/src/gitingest/utils/pattern_utils.py b/src/gitingest/utils/pattern_utils.py new file mode 100644 index 00000000..9c555873 --- /dev/null +++ b/src/gitingest/utils/pattern_utils.py @@ -0,0 +1,108 @@ +"""Pattern utilities for the Gitingest package.""" + +from __future__ import annotations + +import re + +from gitingest.utils.exceptions import InvalidPatternError +from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS + + +def process_patterns( + exclude_patterns: str | set[str] | None = None, + include_patterns: str | set[str] | None = None, +) -> tuple[set[str], set[str] | None]: + """Process include and exclude patterns. + + Parameters + ---------- + exclude_patterns : str | set[str] | None + Exclude patterns to process. + include_patterns : str | set[str] | None + Include patterns to process. + + Returns + ------- + tuple[set[str], set[str] | None] + A tuple containing the processed ignore patterns and include patterns. + + """ + # Combine default ignore patterns + custom patterns + ignore_patterns_set = DEFAULT_IGNORE_PATTERNS.copy() + if exclude_patterns: + ignore_patterns_set.update(_parse_patterns(exclude_patterns)) + + # Process include patterns and override ignore patterns accordingly + if include_patterns: + parsed_include = _parse_patterns(include_patterns) + # Override ignore patterns with include patterns + ignore_patterns_set = set(ignore_patterns_set) - set(parsed_include) + else: + parsed_include = None + + return ignore_patterns_set, parsed_include + + +def _parse_patterns(pattern: set[str] | str) -> set[str]: + """Parse and validate file/directory patterns for inclusion or exclusion. + + Takes either a single pattern string or set of pattern strings and processes them into a normalized list. + Patterns are split on commas and spaces, validated for allowed characters, and normalized. + + Parameters + ---------- + pattern : set[str] | str + Pattern(s) to parse - either a single string or set of strings + + Returns + ------- + set[str] + A set of normalized patterns. + + Raises + ------ + InvalidPatternError + If any pattern contains invalid characters. Only alphanumeric characters, + dash (-), underscore (_), dot (.), forward slash (/), plus (+), and + asterisk (*) are allowed. + + """ + patterns = pattern if isinstance(pattern, set) else {pattern} + + parsed_patterns: set[str] = set() + for p in patterns: + parsed_patterns = parsed_patterns.union(set(re.split(",| ", p))) + + # Remove empty string if present + parsed_patterns = parsed_patterns - {""} + + # Normalize Windows paths to Unix-style paths + parsed_patterns = {p.replace("\\", "/") for p in parsed_patterns} + + # Validate and normalize each pattern + for p in parsed_patterns: + if not _is_valid_pattern(p): + raise InvalidPatternError(p) + + return parsed_patterns + + +def _is_valid_pattern(pattern: str) -> bool: + """Validate if the given pattern contains only valid characters. + + This function checks if the pattern contains only alphanumeric characters or one + of the following allowed characters: dash ('-'), underscore ('_'), dot ('.'), + forward slash ('/'), plus ('+'), asterisk ('*'), or the at sign ('@'). + + Parameters + ---------- + pattern : str + The pattern to validate. + + Returns + ------- + bool + ``True`` if the pattern is valid, otherwise ``False``. + + """ + return all(c.isalnum() or c in "-_./+*@" for c in pattern) diff --git a/src/gitingest/utils/query_parser_utils.py b/src/gitingest/utils/query_parser_utils.py index 4bde02cc..881f46ea 100644 --- a/src/gitingest/utils/query_parser_utils.py +++ b/src/gitingest/utils/query_parser_utils.py @@ -38,27 +38,6 @@ def _is_valid_git_commit_hash(commit: str) -> bool: return len(commit) == sha_hex_length and all(c in HEX_DIGITS for c in commit) -def _is_valid_pattern(pattern: str) -> bool: - """Validate if the given pattern contains only valid characters. - - This function checks if the pattern contains only alphanumeric characters or one - of the following allowed characters: dash ('-'), underscore ('_'), dot ('.'), - forward slash ('/'), plus ('+'), asterisk ('*'), or the at sign ('@'). - - Parameters - ---------- - pattern : str - The pattern to validate. - - Returns - ------- - bool - ``True`` if the pattern is valid, otherwise ``False``. - - """ - return all(c.isalnum() or c in "-_./+*@" for c in pattern) - - def _validate_host(host: str) -> None: """Validate a hostname. diff --git a/src/server/models.py b/src/server/models.py index a6e71edc..1ed95710 100644 --- a/src/server/models.py +++ b/src/server/models.py @@ -7,6 +7,8 @@ from pydantic import BaseModel, Field, field_validator +from gitingest.utils.compat_func import removesuffix + # needed for type checking (pydantic) from server.form_types import IntForm, OptStrForm, StrForm # noqa: TC001 (typing-only-first-party-import) @@ -45,16 +47,16 @@ class IngestRequest(BaseModel): @field_validator("input_text") @classmethod def validate_input_text(cls, v: str) -> str: - """Validate that input_text is not empty.""" + """Validate that ``input_text`` is not empty.""" if not v.strip(): err = "input_text cannot be empty" raise ValueError(err) - return v.strip() + return removesuffix(v.strip(), ".git") @field_validator("pattern") @classmethod def validate_pattern(cls, v: str) -> str: - """Validate pattern field.""" + """Validate ``pattern`` field.""" return v.strip() diff --git a/src/server/query_processor.py b/src/server/query_processor.py index 8513426b..a7b60f61 100644 --- a/src/server/query_processor.py +++ b/src/server/query_processor.py @@ -7,9 +7,10 @@ from gitingest.clone import clone_repo from gitingest.ingestion import ingest_query -from gitingest.query_parser import IngestionQuery, parse_query +from gitingest.query_parser import parse_remote_repo from gitingest.utils.git_utils import validate_github_token -from server.models import IngestErrorResponse, IngestResponse, IngestSuccessResponse +from gitingest.utils.pattern_utils import process_patterns +from server.models import IngestErrorResponse, IngestResponse, IngestSuccessResponse, PatternType from server.server_config import MAX_DISPLAY_SIZE from server.server_utils import Colors, log_slider_to_size @@ -17,8 +18,8 @@ async def process_query( input_text: str, slider_position: int, - pattern_type: str = "exclude", - pattern: str = "", + pattern_type: PatternType, + pattern: str, token: str | None = None, ) -> IngestResponse: """Process a query by parsing input, cloning a repository, and generating a summary. @@ -32,8 +33,8 @@ async def process_query( Input text provided by the user, typically a Git repository URL or slug. slider_position : int Position of the slider, representing the maximum file size in the query. - pattern_type : str - Type of pattern to use (either "include" or "exclude") (default: ``"exclude"``). + pattern_type : PatternType + Type of pattern to use (either "include" or "exclude") pattern : str Pattern to include or exclude in the query, depending on the pattern type. token : str | None @@ -44,61 +45,42 @@ async def process_query( IngestResponse A union type, corresponding to IngestErrorResponse or IngestSuccessResponse - Raises - ------ - ValueError - If an invalid pattern type is provided. - """ - if pattern_type == "include": - include_patterns = pattern - exclude_patterns = None - elif pattern_type == "exclude": - exclude_patterns = pattern - include_patterns = None - else: - msg = f"Invalid pattern type: {pattern_type}" - raise ValueError(msg) - if token: validate_github_token(token) max_file_size = log_slider_to_size(slider_position) - query: IngestionQuery | None = None - short_repo_url = "" - try: - query = await parse_query( - source=input_text, - max_file_size=max_file_size, - from_web=True, - include_patterns=include_patterns, - ignore_patterns=exclude_patterns, - token=token, - ) - query.ensure_url() + query = await parse_remote_repo(input_text, token=token) + except Exception as exc: + print(f"{Colors.BROWN}WARN{Colors.END}: {Colors.RED}<- {Colors.END}", end="") + print(f"{Colors.RED}{exc}{Colors.END}") + return IngestErrorResponse(error=str(exc)) - # Sets the "/" for the page title - short_repo_url = f"{query.user_name}/{query.repo_name}" + query.url = cast("str", query.url) + query.host = cast("str", query.host) + query.max_file_size = max_file_size + query.ignore_patterns, query.include_patterns = process_patterns( + exclude_patterns=pattern if pattern_type == PatternType.EXCLUDE else None, + include_patterns=pattern if pattern_type == PatternType.INCLUDE else None, + ) + + clone_config = query.extract_clone_config() + await clone_repo(clone_config, token=token) - clone_config = query.extract_clone_config() - await clone_repo(clone_config, token=token) + short_repo_url = f"{query.user_name}/{query.repo_name}" # Sets the "/" for the page title + try: summary, tree, content = ingest_query(query) + # TODO: why are we writing the tree and content to a file here? local_txt_file = Path(clone_config.local_path).with_suffix(".txt") - with local_txt_file.open("w", encoding="utf-8") as f: f.write(tree + "\n" + content) except Exception as exc: - if query and query.url: - _print_error(query.url, exc, max_file_size, pattern_type, pattern) - else: - print(f"{Colors.BROWN}WARN{Colors.END}: {Colors.RED}<- {Colors.END}", end="") - print(f"{Colors.RED}{exc}{Colors.END}") - + _print_error(query.url, exc, max_file_size, pattern_type, pattern) return IngestErrorResponse(error=str(exc)) if len(content) > MAX_DISPLAY_SIZE: @@ -107,9 +89,6 @@ async def process_query( "download full ingest to see more)\n" + content[:MAX_DISPLAY_SIZE] ) - query.ensure_url() - query.url = cast("str", query.url) - _print_success( url=query.url, max_file_size=max_file_size, diff --git a/src/server/routers_utils.py b/src/server/routers_utils.py index 358596fb..665ead8f 100644 --- a/src/server/routers_utils.py +++ b/src/server/routers_utils.py @@ -7,7 +7,7 @@ from fastapi import status from fastapi.responses import JSONResponse -from server.models import IngestErrorResponse, IngestSuccessResponse +from server.models import IngestErrorResponse, IngestSuccessResponse, PatternType from server.query_processor import process_query COMMON_INGEST_RESPONSES: dict[int | str, dict[str, Any]] = { @@ -32,7 +32,7 @@ async def _perform_ingestion( result = await process_query( input_text=input_text, slider_position=max_file_size, - pattern_type=pattern_type, + pattern_type=PatternType(pattern_type), pattern=pattern, token=token, ) diff --git a/tests/conftest.py b/tests/conftest.py index 15e1d2ad..30a73f0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from __future__ import annotations import json +import sys from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict from unittest.mock import AsyncMock @@ -22,6 +23,26 @@ DEMO_URL = "https://github.com/user/repo" LOCAL_REPO_PATH = "/tmp/repo" +DEMO_COMMIT = "deadbeefdeadbeefdeadbeefdeadbeefdeadbeef" + + +def get_ensure_git_installed_call_count() -> int: + """Get the number of calls made by ensure_git_installed based on platform. + + On Windows, ensure_git_installed makes 2 calls: + 1. git --version + 2. git config core.longpaths + + On other platforms, it makes 1 call: + 1. git --version + + Returns + ------- + int + The number of calls made by ensure_git_installed + + """ + return 2 if sys.platform == "win32" else 1 @pytest.fixture @@ -168,11 +189,14 @@ def run_command_mock(mocker: MockerFixture) -> AsyncMock: The mocked function returns a dummy process whose ``communicate`` method yields generic ``stdout`` / ``stderr`` bytes. Tests can still access / tweak the mock via the fixture argument. """ - mock_exec = mocker.patch("gitingest.clone.run_command", new_callable=AsyncMock) + mock = AsyncMock(side_effect=_fake_run_command) + mocker.patch("gitingest.utils.git_utils.run_command", mock) + mocker.patch("gitingest.clone.run_command", mock) + return mock - # Provide a default dummy process so most tests don't have to create one. - dummy_process = AsyncMock() - dummy_process.communicate.return_value = (b"output", b"error") - mock_exec.return_value = dummy_process - return mock_exec +async def _fake_run_command(*args: str) -> tuple[bytes, bytes]: + if "ls-remote" in args: + # single match: refs/heads/main + return (f"{DEMO_COMMIT}\trefs/heads/main\n".encode(), b"") + return (b"output", b"error") diff --git a/tests/query_parser/test_git_host_agnostic.py b/tests/query_parser/test_git_host_agnostic.py index d3d2542a..1288674b 100644 --- a/tests/query_parser/test_git_host_agnostic.py +++ b/tests/query_parser/test_git_host_agnostic.py @@ -8,7 +8,8 @@ import pytest -from gitingest.query_parser import parse_query +from gitingest.config import MAX_FILE_SIZE +from gitingest.query_parser import parse_remote_repo from gitingest.utils.query_parser_utils import KNOWN_GIT_HOSTS # Repository matrix: (host, user, repo) @@ -33,7 +34,7 @@ async def test_parse_query_without_host( repo: str, variant: str, ) -> None: - """Verify that ``parse_query`` handles URLs, host-omitted URLs and raw slugs.""" + """Verify that ``parse_remote_repo`` handles URLs, host-omitted URLs and raw slugs.""" # Build the input URL based on the selected variant if variant == "full": url = f"https://{host}/{user}/{repo}" @@ -48,15 +49,16 @@ async def test_parse_query_without_host( # because the parser cannot guess which domain to use. if variant == "slug" and host not in KNOWN_GIT_HOSTS: with pytest.raises(ValueError, match="Could not find a valid repository host"): - await parse_query(url, max_file_size=50, from_web=True) + await parse_remote_repo(url) return - query = await parse_query(url, max_file_size=50, from_web=True) + query = await parse_remote_repo(url) # Compare against the canonical dict while ignoring unpredictable fields. actual = query.model_dump(exclude={"id", "local_path", "ignore_patterns"}) expected = { + "host": host, "user_name": user, "repo_name": repo, "url": expected_url, @@ -65,8 +67,8 @@ async def test_parse_query_without_host( "type": None, "branch": None, "tag": None, + "max_file_size": MAX_FILE_SIZE, "commit": None, - "max_file_size": 50, "include_patterns": None, "include_submodules": False, } diff --git a/tests/query_parser/test_query_parser.py b/tests/query_parser/test_query_parser.py index f6033352..8d3913bb 100644 --- a/tests/query_parser/test_query_parser.py +++ b/tests/query_parser/test_query_parser.py @@ -11,12 +11,11 @@ import pytest -from gitingest.query_parser import _parse_patterns, _parse_remote_repo, parse_query -from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS +from gitingest.query_parser import parse_local_dir_path, parse_remote_repo from tests.conftest import DEMO_URL if TYPE_CHECKING: - from gitingest.schemas.ingestion import IngestionQuery + from gitingest.schemas import IngestionQuery URLS_HTTPS: list[str] = [ @@ -52,83 +51,55 @@ async def test_parse_url_valid_http(url: str) -> None: @pytest.mark.asyncio async def test_parse_url_invalid() -> None: - """Test ``_parse_remote_repo`` with an invalid URL. + """Test ``parse_remote_repo`` with an invalid URL. Given an HTTPS URL lacking a repository structure (e.g., "https://github.com"), - When ``_parse_remote_repo`` is called, + When ``parse_remote_repo`` is called, Then a ValueError should be raised indicating an invalid repository URL. """ url = "https://github.com" with pytest.raises(ValueError, match="Invalid repository URL"): - await _parse_remote_repo(url) + await parse_remote_repo(url) @pytest.mark.asyncio @pytest.mark.parametrize("url", [DEMO_URL, "https://gitlab.com/user/repo"]) async def test_parse_query_basic(url: str) -> None: - """Test ``parse_query`` with a basic valid repository URL. + """Test ``parse_remote_repo`` with a basic valid repository URL. - Given an HTTPS URL and ignore_patterns="*.txt": - When ``parse_query`` is called, - Then user/repo, URL, and ignore patterns should be parsed correctly. + Given an HTTPS URL: + When ``parse_remote_repo`` is called, + Then user/repo, URL should be parsed correctly. """ - query = await parse_query(source=url, max_file_size=50, from_web=True, ignore_patterns="*.txt") + query = await parse_remote_repo(url) assert query.user_name == "user" assert query.repo_name == "repo" assert query.url == url - assert query.ignore_patterns - assert "*.txt" in query.ignore_patterns @pytest.mark.asyncio async def test_parse_query_mixed_case() -> None: - """Test ``parse_query`` with mixed-case URLs. + """Test ``parse_remote_repo`` with mixed-case URLs. Given a URL with mixed-case parts (e.g. "Https://GitHub.COM/UsEr/rEpO"): - When ``parse_query`` is called, + When ``parse_remote_repo`` is called, Then the user and repo names should be normalized to lowercase. """ url = "Https://GitHub.COM/UsEr/rEpO" - query = await parse_query(url, max_file_size=50, from_web=True) + query = await parse_remote_repo(url) assert query.user_name == "user" assert query.repo_name == "repo" -@pytest.mark.asyncio -async def test_parse_query_include_pattern() -> None: - """Test ``parse_query`` with a specified include pattern. - - Given a URL and include_patterns="*.py": - When ``parse_query`` is called, - Then the include pattern should be set, and default ignore patterns remain applied. - """ - query = await parse_query(DEMO_URL, max_file_size=50, from_web=True, include_patterns="*.py") - - assert query.include_patterns == {"*.py"} - assert query.ignore_patterns == DEFAULT_IGNORE_PATTERNS - - -@pytest.mark.asyncio -async def test_parse_query_invalid_pattern() -> None: - """Test ``parse_query`` with an invalid pattern. - - Given an include pattern containing special characters (e.g., "*.py;rm -rf"): - When ``parse_query`` is called, - Then a ValueError should be raised indicating invalid characters. - """ - with pytest.raises(ValueError, match="Pattern.*contains invalid characters"): - await parse_query(DEMO_URL, max_file_size=50, from_web=True, include_patterns="*.py;rm -rf") - - @pytest.mark.asyncio async def test_parse_url_with_subpaths(stub_branches: Callable[[list[str]], None]) -> None: - """Test ``_parse_remote_repo`` with a URL containing branch and subpath. + """Test ``parse_remote_repo`` with a URL containing branch and subpath. Given a URL referencing a branch ("main") and a subdir ("subdir/file"): - When ``_parse_remote_repo`` is called with remote branch fetching, + When ``parse_remote_repo`` is called with remote branch fetching, Then user, repo, branch, and subpath should be identified correctly. """ url = DEMO_URL + "/tree/main/subdir/file" @@ -145,104 +116,27 @@ async def test_parse_url_with_subpaths(stub_branches: Callable[[list[str]], None @pytest.mark.asyncio async def test_parse_url_invalid_repo_structure() -> None: - """Test ``_parse_remote_repo`` with a URL missing a repository name. + """Test ``parse_remote_repo`` with a URL missing a repository name. Given a URL like "https://github.com/user": - When ``_parse_remote_repo`` is called, + When ``parse_remote_repo`` is called, Then a ValueError should be raised indicating an invalid repository URL. """ url = "https://github.com/user" with pytest.raises(ValueError, match="Invalid repository URL"): - await _parse_remote_repo(url) + await parse_remote_repo(url) -def test_parse_patterns_valid() -> None: - """Test ``_parse_patterns`` with valid comma-separated patterns. - - Given patterns like "*.py, *.md, docs/*": - When ``_parse_patterns`` is called, - Then it should return a set of parsed strings. - """ - patterns = "*.py, *.md, docs/*" - parsed_patterns = _parse_patterns(patterns) - - assert parsed_patterns == {"*.py", "*.md", "docs/*"} - - -def test_parse_patterns_invalid_characters() -> None: - """Test ``_parse_patterns`` with invalid characters. - - Given a pattern string containing special characters (e.g. "*.py;rm -rf"): - When ``_parse_patterns`` is called, - Then a ValueError should be raised indicating invalid pattern syntax. - """ - patterns = "*.py;rm -rf" - - with pytest.raises(ValueError, match="Pattern.*contains invalid characters"): - _parse_patterns(patterns) - - -@pytest.mark.asyncio -async def test_parse_query_with_large_file_size() -> None: - """Test ``parse_query`` with a very large file size limit. +async def test_parse_local_dir_path_local_path() -> None: + """Test ``parse_local_dir_path``. - Given a URL and max_file_size=10**9: - When ``parse_query`` is called, - Then ``max_file_size`` should be set correctly and default ignore patterns remain unchanged. - """ - query = await parse_query(DEMO_URL, max_file_size=10**9, from_web=True) - - assert query.max_file_size == 10**9 - assert query.ignore_patterns == DEFAULT_IGNORE_PATTERNS - - -@pytest.mark.asyncio -async def test_parse_query_empty_patterns() -> None: - """Test ``parse_query`` with empty patterns. - - Given empty include_patterns and ignore_patterns: - When ``parse_query`` is called, - Then ``include_patterns`` becomes ``None`` and default ignore patterns apply. - """ - query = await parse_query(DEMO_URL, max_file_size=50, from_web=True, include_patterns="", ignore_patterns="") - - assert query.include_patterns is None - assert query.ignore_patterns == DEFAULT_IGNORE_PATTERNS - - -@pytest.mark.asyncio -async def test_parse_query_include_and_ignore_overlap() -> None: - """Test ``parse_query`` with overlapping patterns. - - Given include="*.py" and ignore={"*.py", "*.txt"}: - When ``parse_query`` is called, - Then "*.py" should be removed from ignore patterns. - """ - query = await parse_query( - DEMO_URL, - max_file_size=50, - from_web=True, - include_patterns="*.py", - ignore_patterns={"*.py", "*.txt"}, - ) - - assert query.include_patterns == {"*.py"} - assert query.ignore_patterns is not None - assert "*.py" not in query.ignore_patterns - assert "*.txt" in query.ignore_patterns - - -@pytest.mark.asyncio -async def test_parse_query_local_path() -> None: - """Test ``parse_query`` with a local file path. - - Given "/home/user/project" and from_web=False: - When ``parse_query`` is called, + Given "/home/user/project": + When ``parse_local_dir_path`` is called, Then the local path should be set, id generated, and slug formed accordingly. """ path = "/home/user/project" - query = await parse_query(path, max_file_size=100, from_web=False) + query = parse_local_dir_path(path) tail = Path("home/user/project") assert query.local_path.parts[-len(tail.parts) :] == tail.parts @@ -250,16 +144,15 @@ async def test_parse_query_local_path() -> None: assert query.slug == "home/user/project" -@pytest.mark.asyncio -async def test_parse_query_relative_path() -> None: - """Test ``parse_query`` with a relative path. +async def test_parse_local_dir_path_relative_path() -> None: + """Test ``parse_local_dir_path`` with a relative path. - Given "./project" and from_web=False: - When ``parse_query`` is called, + Given "./project": + When ``parse_local_dir_path`` is called, Then ``local_path`` resolves relatively, and ``slug`` ends with "project". """ path = "./project" - query = await parse_query(path, max_file_size=100, from_web=False) + query = parse_local_dir_path(path) tail = Path("project") assert query.local_path.parts[-len(tail.parts) :] == tail.parts @@ -267,17 +160,17 @@ async def test_parse_query_relative_path() -> None: @pytest.mark.asyncio -async def test_parse_query_empty_source() -> None: - """Test ``parse_query`` with an empty string. +async def test_parse_remote_repo_empty_source() -> None: + """Test ``parse_remote_repo`` with an empty string. Given an empty source string: - When ``parse_query`` is called, + When ``parse_remote_repo`` is called, Then a ValueError should be raised indicating an invalid repository URL. """ url = "" with pytest.raises(ValueError, match="Invalid repository URL"): - await parse_query(url, max_file_size=100, from_web=True) + await parse_remote_repo(url) @pytest.mark.asyncio @@ -294,10 +187,10 @@ async def test_parse_url_branch_and_commit_distinction( expected_commit: str, stub_branches: Callable[[list[str]], None], ) -> None: - """Test ``_parse_remote_repo`` distinguishing branch vs. commit hash. + """Test ``parse_remote_repo`` distinguishing branch vs. commit hash. Given either a branch URL (e.g., ".../tree/main") or a 40-character commit URL: - When ``_parse_remote_repo`` is called with branch fetching, + When ``parse_remote_repo`` is called with branch fetching, Then the function should correctly set ``branch`` or ``commit`` based on the URL content. """ stub_branches(["main", "dev", "feature-branch"]) @@ -309,31 +202,30 @@ async def test_parse_url_branch_and_commit_distinction( assert query.commit == expected_commit -@pytest.mark.asyncio -async def test_parse_query_uuid_uniqueness() -> None: - """Test ``parse_query`` for unique UUID generation. +async def test_parse_local_dir_path_uuid_uniqueness() -> None: + """Test ``parse_local_dir_path`` for unique UUID generation. Given the same path twice: - When ``parse_query`` is called repeatedly, + When ``parse_local_dir_path`` is called repeatedly, Then each call should produce a different query id. """ path = "/home/user/project" - query_1 = await parse_query(path, max_file_size=100, from_web=False) - query_2 = await parse_query(path, max_file_size=100, from_web=False) + query_1 = parse_local_dir_path(path) + query_2 = parse_local_dir_path(path) assert query_1.id != query_2.id @pytest.mark.asyncio async def test_parse_url_with_query_and_fragment() -> None: - """Test ``_parse_remote_repo`` with query parameters and a fragment. + """Test ``parse_remote_repo`` with query parameters and a fragment. Given a URL like "https://github.com/user/repo?arg=value#fragment": - When ``_parse_remote_repo`` is called, + When ``parse_remote_repo`` is called, Then those parts should be stripped, leaving a clean user/repo URL. """ url = DEMO_URL + "?arg=value#fragment" - query = await _parse_remote_repo(url) + query = await parse_remote_repo(url) assert query.user_name == "user" assert query.repo_name == "repo" @@ -342,28 +234,28 @@ async def test_parse_url_with_query_and_fragment() -> None: @pytest.mark.asyncio async def test_parse_url_unsupported_host() -> None: - """Test ``_parse_remote_repo`` with an unsupported host. + """Test ``parse_remote_repo`` with an unsupported host. Given "https://only-domain.com": - When ``_parse_remote_repo`` is called, + When ``parse_remote_repo`` is called, Then a ValueError should be raised for the unknown domain. """ url = "https://only-domain.com" with pytest.raises(ValueError, match="Unknown domain 'only-domain.com' in URL"): - await _parse_remote_repo(url) + await parse_remote_repo(url) @pytest.mark.asyncio async def test_parse_query_with_branch() -> None: - """Test ``parse_query`` when a branch is specified in a blob path. + """Test ``parse_remote_repo`` when a branch is specified in a blob path. Given "https://github.com/pandas-dev/pandas/blob/2.2.x/...": - When ``parse_query`` is called, + When ``parse_remote_repo`` is called, Then the branch should be identified, subpath set, and commit remain None. """ url = "https://github.com/pandas-dev/pandas/blob/2.2.x/.github/ISSUE_TEMPLATE/documentation_improvement.yaml" - query = await parse_query(url, max_file_size=10**9, from_web=True) + query = await parse_remote_repo(url) assert query.user_name == "pandas-dev" assert query.repo_name == "pandas" @@ -394,10 +286,10 @@ async def test_parse_repo_source_with_various_url_patterns( expected_subpath: str, stub_branches: Callable[[list[str]], None], ) -> None: - """Test ``_parse_remote_repo`` with various GitHub-style URL permutations. + """Test ``parse_remote_repo`` with various GitHub-style URL permutations. Given various GitHub-style URL permutations: - When ``_parse_remote_repo`` is called, + When ``parse_remote_repo`` is called, Then it should detect (or reject) a branch and resolve the sub-path. Branch discovery is stubbed so that only names passed to ``stub_branches`` are considered "remote". @@ -411,9 +303,10 @@ async def test_parse_repo_source_with_various_url_patterns( assert query.subpath == expected_subpath +@pytest.mark.asyncio async def _assert_basic_repo_fields(url: str) -> IngestionQuery: - """Run ``_parse_remote_repo`` and assert user, repo and slug are parsed.""" - query = await _parse_remote_repo(url) + """Run ``parse_remote_repo`` and assert user, repo and slug are parsed.""" + query = await parse_remote_repo(url) assert query.user_name == "user" assert query.repo_name == "repo" diff --git a/tests/test_clone.py b/tests/test_clone.py index 42ca1994..e3b374b8 100644 --- a/tests/test_clone.py +++ b/tests/test_clone.py @@ -4,26 +4,34 @@ and handling edge cases such as nonexistent URLs, timeouts, redirects, and specific commits or branches. """ +from __future__ import annotations + import asyncio -import subprocess -from pathlib import Path +from typing import TYPE_CHECKING from unittest.mock import AsyncMock import httpx import pytest -from pytest_mock import MockerFixture from starlette.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND from gitingest.clone import clone_repo from gitingest.schemas import CloneConfig from gitingest.utils.exceptions import AsyncTimeoutError from gitingest.utils.git_utils import check_repo_exists -from tests.conftest import DEMO_URL, LOCAL_REPO_PATH +from tests.conftest import DEMO_COMMIT, DEMO_URL, LOCAL_REPO_PATH, get_ensure_git_installed_call_count + +if TYPE_CHECKING: + from pathlib import Path + + from pytest_mock import MockerFixture + # All cloning-related tests assume (unless explicitly overridden) that the repository exists. # Apply the check-repo patch automatically so individual tests don't need to repeat it. pytestmark = pytest.mark.usefixtures("repo_exists_true") +GIT_INSTALLED_CALLS = get_ensure_git_installed_call_count() + @pytest.mark.asyncio async def test_clone_with_commit(repo_exists_true: AsyncMock, run_command_mock: AsyncMock) -> None: @@ -33,18 +41,20 @@ async def test_clone_with_commit(repo_exists_true: AsyncMock, run_command_mock: When ``clone_repo`` is called, Then the repository should be cloned and checked out at that commit. """ - expected_call_count = 2 + expected_call_count = GIT_INSTALLED_CALLS + 3 # ensure_git_installed + clone + fetch + checkout + commit_hash = "a" * 40 # Simulating a valid commit hash clone_config = CloneConfig( url=DEMO_URL, local_path=LOCAL_REPO_PATH, - commit="a" * 40, # Simulating a valid commit hash + commit=commit_hash, branch="main", ) await clone_repo(clone_config) - repo_exists_true.assert_called_once_with(clone_config.url, token=None) - assert run_command_mock.call_count == expected_call_count # Clone and checkout calls + repo_exists_true.assert_any_call(clone_config.url, token=None) + assert_standard_calls(run_command_mock, clone_config, commit=commit_hash) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -55,13 +65,14 @@ async def test_clone_without_commit(repo_exists_true: AsyncMock, run_command_moc When ``clone_repo`` is called, Then only the clone_repo operation should be performed (no checkout). """ - expected_call_count = 1 + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + resolve_commit + clone + fetch + checkout clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit=None, branch="main") await clone_repo(clone_config) - repo_exists_true.assert_called_once_with(clone_config.url, token=None) - assert run_command_mock.call_count == expected_call_count # Only clone call + repo_exists_true.assert_any_call(clone_config.url, token=None) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -84,7 +95,7 @@ async def test_clone_nonexistent_repository(repo_exists_true: AsyncMock) -> None with pytest.raises(ValueError, match="Repository not found"): await clone_repo(clone_config) - repo_exists_true.assert_called_once_with(clone_config.url, token=None) + repo_exists_true.assert_any_call(clone_config.url, token=None) @pytest.mark.asyncio @@ -117,20 +128,13 @@ async def test_clone_with_custom_branch(run_command_mock: AsyncMock) -> None: When ``clone_repo`` is called, Then the repository should be cloned shallowly to that branch. """ + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + resolve_commit + clone + fetch + checkout clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, branch="feature-branch") await clone_repo(clone_config) - run_command_mock.assert_called_once_with( - "git", - "clone", - "--single-branch", - "--depth=1", - "--branch", - "feature-branch", - clone_config.url, - clone_config.local_path, - ) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -143,9 +147,9 @@ async def test_git_command_failure(run_command_mock: AsyncMock) -> None: """ clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH) - run_command_mock.side_effect = RuntimeError("Git command failed") + run_command_mock.side_effect = RuntimeError("Git is not installed or not accessible. Please install Git first.") - with pytest.raises(RuntimeError, match="Git command failed"): + with pytest.raises(RuntimeError, match="Git is not installed or not accessible"): await clone_repo(clone_config) @@ -157,18 +161,13 @@ async def test_clone_default_shallow_clone(run_command_mock: AsyncMock) -> None: When ``clone_repo`` is called, Then the repository should be cloned with ``--depth=1`` and ``--single-branch``. """ + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + resolve_commit + clone + fetch + checkout clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH) await clone_repo(clone_config) - run_command_mock.assert_called_once_with( - "git", - "clone", - "--single-branch", - "--depth=1", - clone_config.url, - clone_config.local_path, - ) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -179,15 +178,14 @@ async def test_clone_commit(run_command_mock: AsyncMock) -> None: When ``clone_repo`` is called, Then the repository should be cloned and checked out at that commit. """ - expected_call_count = 2 - # Simulating a valid commit hash - clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit="a" * 40) + expected_call_count = GIT_INSTALLED_CALLS + 3 # ensure_git_installed + clone + fetch + checkout + commit_hash = "a" * 40 # Simulating a valid commit hash + clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit=commit_hash) await clone_repo(clone_config) - assert run_command_mock.call_count == expected_call_count # Clone and checkout calls - run_command_mock.assert_any_call("git", "clone", "--single-branch", clone_config.url, clone_config.local_path) - run_command_mock.assert_any_call("git", "-C", clone_config.local_path, "checkout", clone_config.commit) + assert_standard_calls(run_command_mock, clone_config, commit=commit_hash) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -225,40 +223,6 @@ async def test_clone_with_timeout(run_command_mock: AsyncMock) -> None: await clone_repo(clone_config) -@pytest.mark.asyncio -async def test_clone_specific_branch(tmp_path: Path) -> None: - """Test cloning a specific branch of a repository. - - Given a valid repository URL and a branch name: - When ``clone_repo`` is called, - Then the repository should be cloned and checked out at that branch. - """ - repo_url = "https://github.com/coderamp-labs/gitingest.git" - branch_name = "main" - local_path = tmp_path / "gitingest" - clone_config = CloneConfig(url=repo_url, local_path=str(local_path), branch=branch_name) - - await clone_repo(clone_config) - - assert local_path.exists(), "The repository was not cloned successfully." - assert local_path.is_dir(), "The cloned repository path is not a directory." - - loop = asyncio.get_running_loop() - current_branch = ( - ( - await loop.run_in_executor( - None, - subprocess.check_output, - ["git", "-C", str(local_path), "branch", "--show-current"], - ) - ) - .decode() - .strip() - ) - - assert current_branch == branch_name, f"Expected branch '{branch_name}', got '{current_branch}'." - - @pytest.mark.asyncio async def test_clone_branch_with_slashes(tmp_path: Path, run_command_mock: AsyncMock) -> None: """Test cloning a branch with slashes in the name. @@ -269,20 +233,13 @@ async def test_clone_branch_with_slashes(tmp_path: Path, run_command_mock: Async """ branch_name = "fix/in-operator" local_path = tmp_path / "gitingest" + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + resolve_commit + clone + fetch + checkout clone_config = CloneConfig(url=DEMO_URL, local_path=str(local_path), branch=branch_name) await clone_repo(clone_config) - run_command_mock.assert_called_once_with( - "git", - "clone", - "--single-branch", - "--depth=1", - "--branch", - "fix/in-operator", - clone_config.url, - clone_config.local_path, - ) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -293,20 +250,16 @@ async def test_clone_creates_parent_directory(tmp_path: Path, run_command_mock: When ``clone_repo`` is called, Then it should create the parent directories before attempting to clone. """ + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + resolve_commit + clone + fetch + checkout nested_path = tmp_path / "deep" / "nested" / "path" / "repo" + clone_config = CloneConfig(url=DEMO_URL, local_path=str(nested_path)) await clone_repo(clone_config) assert nested_path.parent.exists() - run_command_mock.assert_called_once_with( - "git", - "clone", - "--single-branch", - "--depth=1", - clone_config.url, - str(nested_path), - ) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -317,26 +270,15 @@ async def test_clone_with_specific_subpath(run_command_mock: AsyncMock) -> None: When ``clone_repo`` is called, Then the repository should be cloned with sparse checkout enabled and the specified subpath. """ - expected_call_count = 2 - clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, subpath="src/docs") + # ensure_git_installed + resolve_commit + clone + sparse-checkout + fetch + checkout + subpath = "src/docs" + expected_call_count = GIT_INSTALLED_CALLS + 5 + clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, subpath=subpath) await clone_repo(clone_config) # Verify the clone command includes sparse checkout flags - run_command_mock.assert_any_call( - "git", - "clone", - "--single-branch", - "--filter=blob:none", - "--sparse", - "--depth=1", - clone_config.url, - clone_config.local_path, - ) - - # Verify the sparse-checkout command sets the correct path - run_command_mock.assert_any_call("git", "-C", clone_config.local_path, "sparse-checkout", "set", "src/docs") - + assert_partial_clone_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) assert run_command_mock.call_count == expected_call_count @@ -349,42 +291,14 @@ async def test_clone_with_commit_and_subpath(run_command_mock: AsyncMock) -> Non Then the repository should be cloned with sparse checkout enabled, checked out at the specific commit, and only include the specified subpath. """ - expected_call_count = 3 - # Simulating a valid commit hash - clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit="a" * 40, subpath="src/docs") + subpath = "src/docs" + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + clone + sparse-checkout + fetch + checkout + commit_hash = "a" * 40 # Simulating a valid commit hash + clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit=commit_hash, subpath=subpath) await clone_repo(clone_config) - # Verify the clone command includes sparse checkout flags - run_command_mock.assert_any_call( - "git", - "clone", - "--single-branch", - "--filter=blob:none", - "--sparse", - clone_config.url, - clone_config.local_path, - ) - - # Verify sparse-checkout set - run_command_mock.assert_any_call( - "git", - "-C", - clone_config.local_path, - "sparse-checkout", - "set", - "src/docs", - ) - - # Verify checkout commit - run_command_mock.assert_any_call( - "git", - "-C", - clone_config.local_path, - "checkout", - clone_config.commit, - ) - + assert_partial_clone_calls(run_command_mock, clone_config, commit=commit_hash) assert run_command_mock.call_count == expected_call_count @@ -396,18 +310,58 @@ async def test_clone_with_include_submodules(run_command_mock: AsyncMock) -> Non When ``clone_repo`` is called, Then the repository should be cloned with ``--recurse-submodules`` in the git command. """ - expected_call_count = 1 # No commit and no partial clone + # ensure_git_installed + resolve_commit + clone + fetch + checkout + checkout submodules + expected_call_count = GIT_INSTALLED_CALLS + 5 clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, branch="main", include_submodules=True) await clone_repo(clone_config) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert_submodule_calls(run_command_mock, clone_config) assert run_command_mock.call_count == expected_call_count - run_command_mock.assert_called_once_with( - "git", - "clone", - "--single-branch", - "--recurse-submodules", - "--depth=1", - clone_config.url, - clone_config.local_path, - ) + + +_ENSURE_GIT_INSTALLED_CALL_ARGS = ("git", "--version") + + +def _clone_call_args(url: str, local_path: str, *, partial_clone: bool = False) -> tuple[str, ...]: + cmd = ["git", "clone", "--single-branch", "--no-checkout", "--depth=1"] + if partial_clone: + cmd += ["--filter=blob:none", "--sparse"] + cmd += [url, local_path] + return tuple(cmd) + + +def _checkout_call_args(local_path: str, commit: str | None) -> tuple[str, ...]: + return ("git", "-C", local_path, "checkout", str(commit)) + + +def _fetch_call_args(local_path: str, commit: str | None) -> tuple[str, ...]: + return ("git", "-C", local_path, "fetch", "--depth=1", "origin", str(commit)) + + +def _sparse_checkout_call_args(local_path: str, subpath: str) -> tuple[str, ...]: + return ("git", "-C", local_path, "sparse-checkout", "set", subpath) + + +def _submodule_call_args(local_path: str) -> tuple[str, ...]: + return ("git", "-C", local_path, "submodule", "update", "--init", "--recursive", "--depth=1") + + +def assert_standard_calls(mock: AsyncMock, cfg: CloneConfig, commit: str, *, partial_clone: bool = False) -> None: + """Assert that the standard clone sequence of git commands was called.""" + mock.assert_any_call(*_ENSURE_GIT_INSTALLED_CALL_ARGS) + mock.assert_any_call(*_clone_call_args(cfg.url, cfg.local_path, partial_clone=partial_clone)) + mock.assert_any_call(*_fetch_call_args(cfg.local_path, commit)) + mock.assert_any_call(*_checkout_call_args(cfg.local_path, commit)) + + +def assert_partial_clone_calls(mock: AsyncMock, cfg: CloneConfig, commit: str) -> None: + """Assert that the partial clone sequence of git commands was called.""" + assert_standard_calls(mock, cfg, commit=commit, partial_clone=True) + mock.assert_any_call(*_sparse_checkout_call_args(cfg.local_path, cfg.subpath)) + + +def assert_submodule_calls(mock: AsyncMock, cfg: CloneConfig) -> None: + """Assert that submodule update commands were called.""" + mock.assert_any_call(*_submodule_call_args(cfg.local_path)) diff --git a/tests/test_pattern_utils.py b/tests/test_pattern_utils.py new file mode 100644 index 00000000..17a4687a --- /dev/null +++ b/tests/test_pattern_utils.py @@ -0,0 +1,60 @@ +"""Test pattern utilities.""" + +import pytest + +from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS +from gitingest.utils.pattern_utils import _parse_patterns, process_patterns + + +def test_process_patterns_empty_patterns() -> None: + """Test ``process_patterns`` with empty patterns. + + Given empty ``include_patterns`` and ``exclude_patterns``: + When ``process_patterns`` is called, + Then ``include_patterns`` becomes ``None`` and ``DEFAULT_IGNORE_PATTERNS`` apply. + """ + exclude_patterns, include_patterns = process_patterns(exclude_patterns="", include_patterns="") + + assert include_patterns is None + assert exclude_patterns == DEFAULT_IGNORE_PATTERNS + + +def test_parse_patterns_valid() -> None: + """Test ``_parse_patterns`` with valid comma-separated patterns. + + Given patterns like "*.py, *.md, docs/*": + When ``_parse_patterns`` is called, + Then it should return a set of parsed strings. + """ + patterns = "*.py, *.md, docs/*" + parsed_patterns = _parse_patterns(patterns) + + assert parsed_patterns == {"*.py", "*.md", "docs/*"} + + +def test_parse_patterns_invalid_characters() -> None: + """Test ``_parse_patterns`` with invalid characters. + + Given a pattern string containing special characters (e.g. "*.py;rm -rf"): + When ``_parse_patterns`` is called, + Then a ValueError should be raised indicating invalid pattern syntax. + """ + patterns = "*.py;rm -rf" + + with pytest.raises(ValueError, match="Pattern.*contains invalid characters"): + _parse_patterns(patterns) + + +def test_process_patterns_include_and_ignore_overlap() -> None: + """Test ``process_patterns`` with overlapping patterns. + + Given include="*.py" and ignore={"*.py", "*.txt"}: + When ``process_patterns`` is called, + Then "*.py" should be removed from ignore patterns. + """ + exclude_patterns, include_patterns = process_patterns(exclude_patterns={"*.py", "*.txt"}, include_patterns="*.py") + + assert include_patterns == {"*.py"} + assert exclude_patterns is not None + assert "*.py" not in exclude_patterns + assert "*.txt" in exclude_patterns