Skip to content

Commit 44ecee8

Browse files
committed
refactor(git): Address more review comments
* Move `static_tool_capture` out of `SubprocessRunner` class * Remove `SHA1Hash` class, instead using normal `str`s and validating in `Commit.__post_init__` instead * Tweak args used for diffing in `Git._capture_diff_lines` * Tweak args used for getting untracked files in `Git.get_working_tree_changes` * Use `...` diffing in `Git.get_changes_with_base` * Convert `Commit` to a `Struct` instead of a dataclass * Fix tests
1 parent eeae3ce commit 44ecee8

File tree

8 files changed

+147
-151
lines changed

8 files changed

+147
-151
lines changed

src/dda/config/model/user.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,24 @@ def _get_name_from_git() -> str:
1010
from os import environ
1111

1212
from dda.tools.git import Git
13-
from dda.utils.process import SubprocessRunner
13+
from dda.utils.process import static_capture
1414

1515
if name := environ.get(Git.AUTHOR_NAME_ENV_VAR):
1616
return name
1717

18-
return SubprocessRunner.static_capture(["git", "config", "--global", "--get", "user.name"]).strip()
18+
return static_capture(["git", "config", "--global", "--get", "user.name"]).strip()
1919

2020

2121
def _get_emails_from_git() -> list[str]:
2222
from os import environ
2323

2424
from dda.tools.git import Git
25-
from dda.utils.process import SubprocessRunner
25+
from dda.utils.process import static_capture
2626

2727
if email := environ.get(Git.AUTHOR_EMAIL_ENV_VAR):
2828
return [email]
2929

30-
return [SubprocessRunner.static_capture(["git", "config", "--global", "--get", "user.email"]).strip()]
30+
return [static_capture(["git", "config", "--global", "--get", "user.email"]).strip()]
3131

3232

3333
class UserConfig(Struct, frozen=True):

src/dda/tools/git.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
from __future__ import annotations
55

66
from functools import cached_property
7+
from typing import TYPE_CHECKING
78

89
from dda.tools.base import Tool
910
from dda.utils.git.changeset import ChangeSet
10-
from dda.utils.git.commit import Commit, CommitDetails, SHA1Hash
11+
12+
if TYPE_CHECKING:
13+
from dda.utils.git.commit import Commit, CommitDetails
1114

1215

1316
class Git(Tool):
@@ -95,22 +98,21 @@ def get_head_commit(self) -> Commit:
9598
"""
9699
Get the current HEAD commit of the Git repository in the current working directory.
97100
"""
98-
from dda.utils.git.commit import Commit, SHA1Hash
101+
from dda.utils.git.commit import Commit
99102

100-
sha1_str = self.capture(["rev-parse", "HEAD"]).strip()
101-
sha1 = SHA1Hash(sha1_str)
103+
sha1 = self.capture(["rev-parse", "HEAD"]).strip()
102104

103105
# Get the org/repo from the remote URL
104106
org, repo, _ = self.get_remote_details()
105107
return Commit(org=org, repo=repo, sha1=sha1)
106108

107-
def get_commit_details(self, sha1: SHA1Hash) -> CommitDetails:
109+
def get_commit_details(self, sha1: str) -> CommitDetails:
108110
"""
109111
Get the details of the given commit in the Git repository in the current working directory.
110112
"""
111113
from datetime import datetime
112114

113-
from dda.utils.git.commit import CommitDetails, SHA1Hash
115+
from dda.utils.git.commit import CommitDetails
114116

115117
raw_details = self.capture([
116118
"show",
@@ -119,7 +121,7 @@ def get_commit_details(self, sha1: SHA1Hash) -> CommitDetails:
119121
# fmt: author name, author email, author date, parent SHAs, commit message body
120122
"--format=%an%n%ae%n%ad%n%P%n%B",
121123
"--date=iso-strict",
122-
str(sha1),
124+
sha1,
123125
])
124126
author_name, author_email, date_str, parents_str, *message_lines = raw_details.splitlines()
125127

@@ -128,31 +130,37 @@ def get_commit_details(self, sha1: SHA1Hash) -> CommitDetails:
128130
author_email=author_email,
129131
datetime=datetime.fromisoformat(date_str),
130132
message="\n".join(message_lines).strip().strip('"'),
131-
parent_shas=[SHA1Hash(parent_sha) for parent_sha in parents_str.split()],
133+
parent_shas=list(parents_str.split()),
132134
)
133135

134136
def _capture_diff_lines(self, *args: str) -> list[str]:
135-
diff_args = ["diff", "-U0", "--no-color", "--no-prefix", "--no-renames"]
137+
diff_args = [
138+
"-c",
139+
"core.quotepath=false",
140+
"diff",
141+
"-U0",
142+
"--no-color",
143+
"--no-prefix",
144+
"--no-renames",
145+
"--no-ext-diff",
146+
# "-z",
147+
]
136148
return self.capture([*diff_args, *args], check=False).strip().splitlines()
137149

138150
def _compare_refs(self, ref1: str, ref2: str) -> ChangeSet:
139-
return ChangeSet.generate_from_diff_output(self._capture_diff_lines(str(ref1), str(ref2)))
151+
return ChangeSet.generate_from_diff_output(self._capture_diff_lines(ref1, ref2))
140152

141-
def get_commit_changes(self, sha1: SHA1Hash) -> ChangeSet:
153+
def get_commit_changes(self, sha1: str) -> ChangeSet:
142154
"""
143155
Get the changes of the given commit in the Git repository in the current working directory.
144156
"""
145-
return self._compare_refs(f"{sha1}^", str(sha1))
157+
return self._compare_refs(f"{sha1}^", sha1)
146158

147-
def get_changes_between_commits(self, a: SHA1Hash | Commit, b: SHA1Hash | Commit) -> ChangeSet:
159+
def get_changes_between_commits(self, a: str, b: str) -> ChangeSet:
148160
"""
149-
Get the changes between two commits.
161+
Get the changes between two commits, identified by their SHA-1 hashes.
150162
"""
151-
if isinstance(a, Commit):
152-
a = a.sha1
153-
if isinstance(b, Commit):
154-
b = b.sha1
155-
return self._compare_refs(str(a), str(b))
163+
return self._compare_refs(a, b)
156164

157165
def get_working_tree_changes(self) -> ChangeSet:
158166
"""
@@ -166,26 +174,31 @@ def get_working_tree_changes(self) -> ChangeSet:
166174
tracked_changes = ChangeSet.generate_from_diff_output(self._capture_diff_lines("HEAD"))
167175

168176
# Capture changes to untracked files
169-
untracked_files = self.capture(["ls-files", "--others", "--exclude-standard"]).strip().splitlines()
177+
other_files_output = self.capture(["ls-files", "--others", "--exclude-standard", "-z"]).strip()
178+
untracked_files = [x.strip() for x in other_files_output.split("\0") if x] # Remove empty strings
179+
180+
if not untracked_files:
181+
return tracked_changes
182+
170183
diffs = list(chain.from_iterable(self._capture_diff_lines("/dev/null", file) for file in untracked_files))
171184
untracked_changes = ChangeSet.generate_from_diff_output(diffs)
172185

173186
# Combine the changes
174187
return tracked_changes | untracked_changes
175188

176-
def get_merge_base(self, remote_name: str = "origin") -> SHA1Hash | str:
189+
def get_merge_base(self, remote_name: str = "origin") -> str:
177190
"""
178191
Get the merge base of the current branch.
179192
"""
180-
res = self.capture(["merge-base", "HEAD", remote_name]).strip()
193+
res = self.capture(["merge-base", "HEAD", remote_name], check=False).strip()
181194
if not res:
182195
self.app.display_warning("Could not determine merge base for current branch. Using `main` instead.")
183196
return "main"
184-
return SHA1Hash(res)
197+
return res
185198

186199
def get_changes_with_base(
187200
self,
188-
base: SHA1Hash | Commit | str | None = None,
201+
base: str | None = None,
189202
*,
190203
include_working_tree: bool = True,
191204
remote_name: str = "origin",
@@ -200,12 +213,9 @@ def get_changes_with_base(
200213
"""
201214
if base is None:
202215
base = self.get_merge_base(remote_name)
203-
if isinstance(base, Commit):
204-
base = base.sha1
205-
base = str(base)
206216

207217
head = self.get_head_commit()
208-
changes = self._compare_refs(base, head.sha1)
218+
changes = ChangeSet.generate_from_diff_output(self._capture_diff_lines(base, "...", head.sha1))
209219
if include_working_tree:
210220
changes |= self.get_working_tree_changes()
211221
return changes

src/dda/utils/git/changeset.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from msgspec import Struct, field
1212

1313
from dda.utils.fs import Path
14-
from dda.utils.git.commit import SHA1Hash
1514

1615
if TYPE_CHECKING:
1716
from _collections_abc import dict_items, dict_keys, dict_values
@@ -52,10 +51,10 @@ class FileChanges(Struct, frozen=True):
5251
5352
Example:
5453
```diff
55-
@@ -15,2 +15 @@ if TYPE_CHECKING:
56-
- from dda.utils.git.commit import Commit, CommitDetails
57-
- from dda.utils.git.commit import SHA1Hash
58-
+ from dda.utils.git.commit import Commit, CommitDetails, SHA1Hash
54+
@@ -15,1 +15 @@ if TYPE_CHECKING:
55+
- from dda.utils.git.commit import Commit
56+
- from dda.utils.git.commit import CommitDetails
57+
+ from dda.utils.git.commit import Commit, CommitDetails
5958
```
6059
"""
6160

@@ -239,7 +238,7 @@ def changed(self) -> set[Path]:
239238
return set(self.keys())
240239

241240
# == methods == #
242-
def digest(self) -> SHA1Hash:
241+
def digest(self) -> str:
243242
"""Compute a hash of the changeset."""
244243
from hashlib import sha1
245244

@@ -249,7 +248,7 @@ def digest(self) -> SHA1Hash:
249248
digester.update(change.type.value.encode())
250249
digester.update(change.patch.encode())
251250

252-
return SHA1Hash(digester.hexdigest())
251+
return str(digester.hexdigest())
253252

254253
@classmethod
255254
def from_iter(cls, data: Iterable[FileChanges]) -> Self:

src/dda/utils/git/commit.py

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
# SPDX-License-Identifier: MIT
44
from __future__ import annotations
55

6-
from dataclasses import dataclass, field
76
from functools import cached_property
8-
from typing import TYPE_CHECKING, Self
7+
from typing import TYPE_CHECKING
8+
9+
from msgspec import Struct, field
910

1011
if TYPE_CHECKING:
1112
from datetime import datetime
@@ -14,21 +15,30 @@
1415
from dda.utils.git.changeset import ChangeSet
1516

1617

17-
@dataclass
18-
class Commit:
18+
class Commit(Struct, dict=True):
1919
"""
2020
A Git commit, identified by its SHA-1 hash.
2121
"""
2222

2323
org: str
2424
repo: str
25-
sha1: SHA1Hash
25+
sha1: str
2626

27-
_details: CommitDetails | None = field(default=None, init=False)
27+
_details: CommitDetails | None = field(default=None)
2828
_changes: ChangeSet | None = field(default=None)
2929

30+
def __post_init__(self) -> None:
31+
if len(self.sha1) != 40: # noqa: PLR2004
32+
msg = "SHA-1 hash must be 40 characters long"
33+
raise ValueError(msg)
34+
for c in self.sha1:
35+
code = ord(c)
36+
if code not in range(48, 58) and code not in range(97, 103):
37+
msg = "SHA-1 hash must contain only hexadecimal characters"
38+
raise ValueError(msg)
39+
3040
def __str__(self) -> str:
31-
return str(self.sha1)
41+
return self.sha1
3242

3343
@property
3444
def full_repo(self) -> str:
@@ -53,7 +63,7 @@ def compare_to(self, app: Application, other: Commit) -> ChangeSet:
5363
"""
5464
Compare this commit to another commit.
5565
"""
56-
return app.tools.git.get_changes_between_commits(self, other)
66+
return app.tools.git.get_changes_between_commits(self.sha1, other.sha1)
5767

5868
def get_details_and_changes_from_github(self) -> tuple[CommitDetails, ChangeSet]:
5969
"""
@@ -88,7 +98,7 @@ def get_details_and_changes_from_github(self) -> tuple[CommitDetails, ChangeSet]
8898
author_email=data["commit"]["author"]["email"],
8999
datetime=datetime.fromisoformat(data["commit"]["author"]["date"]),
90100
message=data["commit"]["message"],
91-
parent_shas=[SHA1Hash(parent["sha"]) for parent in data.get("parents", [])],
101+
parent_shas=[parent["sha"] for parent in data.get("parents", [])],
92102
)
93103

94104
return self.details, self.changes
@@ -146,35 +156,13 @@ def message(self) -> str:
146156
return self.details.message
147157

148158
@property
149-
def parent_shas(self) -> list[SHA1Hash]:
159+
def parent_shas(self) -> list[str]:
150160
return self.details.parent_shas
151161

152162

153-
@dataclass
154-
class CommitDetails:
163+
class CommitDetails(Struct):
155164
author_name: str
156165
author_email: str
157166
datetime: datetime
158167
message: str
159-
parent_shas: list[SHA1Hash]
160-
161-
162-
class SHA1Hash(str):
163-
"""
164-
A hexadecimal representation of a SHA-1 hash.
165-
"""
166-
167-
LENGTH = 40
168-
__slots__ = ()
169-
170-
def __new__(cls, value: str) -> Self:
171-
if len(value) != cls.LENGTH or any(c not in "0123456789abcdef" for c in value.lower()):
172-
msg = f"Invalid SHA-1 hash: {value}"
173-
raise ValueError(msg)
174-
return str.__new__(cls, value)
175-
176-
def __repr__(self) -> str:
177-
return f"{self.__class__.__name__}({super().__repr__()})"
178-
179-
def __bytes__(self) -> bytes:
180-
return bytes.fromhex(self)
168+
parent_shas: list[str]

0 commit comments

Comments
 (0)