Skip to content

Commit 3fd2630

Browse files
committed
Port src/ changes from #181
1 parent 432b1f8 commit 3fd2630

File tree

5 files changed

+701
-1
lines changed

5 files changed

+701
-1
lines changed

src/dda/tools/git.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55

66
from contextlib import contextmanager
77
from functools import cached_property
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Any
99

1010
from dda.tools.base import ExecutionContext, Tool
11+
from dda.utils.git.changeset import ChangeSet
1112
from dda.utils.git.constants import GitEnvVars
13+
from dda.utils.git.remote import Remote
1214

1315
if TYPE_CHECKING:
1416
from collections.abc import Generator
1517

18+
from dda.utils.git.commit import Commit, CommitDetails
19+
1620

1721
class Git(Tool):
1822
"""
@@ -70,3 +74,147 @@ def author_email(self) -> str:
7074
return env_email
7175

7276
return self.capture(["config", "--get", "user.email"]).strip()
77+
78+
# === PRETEMPLATED COMMANDS === #
79+
def get_remote_details(self, remote_name: str = "origin") -> Remote:
80+
"""
81+
Get the details of the given remote for the Git repository in the current working directory.
82+
The returned tuple is (org, repo, url).
83+
"""
84+
85+
remote_url = self.capture(
86+
["config", "--get", f"remote.{remote_name}.url"],
87+
).strip()
88+
89+
return Remote(remote_url) # type: ignore[abstract]
90+
91+
def get_head_commit(self) -> Commit:
92+
"""
93+
Get the current HEAD commit of the Git repository in the current working directory.
94+
"""
95+
from dda.utils.git.commit import Commit
96+
97+
sha1 = self.capture(["rev-parse", "HEAD"]).strip()
98+
99+
return Commit(sha1=sha1)
100+
101+
def get_commit_details(self, sha1: str) -> CommitDetails:
102+
"""
103+
Get the details of the given commit in the Git repository in the current working directory.
104+
"""
105+
from datetime import datetime
106+
107+
from dda.utils.git.commit import CommitDetails
108+
109+
raw_details = self.capture([
110+
"show",
111+
"--quiet",
112+
# Use a format that is easy to parse
113+
# fmt: author name, author email, author date, parent SHAs, commit message body
114+
"--format=%an%n%ae%n%ad%n%P%n%B",
115+
"--date=iso-strict",
116+
sha1,
117+
])
118+
author_name, author_email, date_str, parents_str, *message_lines = raw_details.splitlines()
119+
120+
return CommitDetails(
121+
author_name=author_name,
122+
author_email=author_email,
123+
datetime=datetime.fromisoformat(date_str),
124+
message="\n".join(message_lines).strip().strip('"'),
125+
parent_shas=list(parents_str.split()),
126+
)
127+
128+
def _capture_diff_lines(self, *args: str, **kwargs: Any) -> list[str]:
129+
diff_args = [
130+
"-c",
131+
"core.quotepath=false",
132+
"diff",
133+
"-U0",
134+
"--no-color",
135+
"--no-prefix",
136+
"--no-renames",
137+
"--no-ext-diff",
138+
# "-z",
139+
]
140+
return self.capture([*diff_args, *args], check=False, **kwargs).strip().splitlines()
141+
142+
def _compare_refs(self, ref1: str, ref2: str) -> ChangeSet:
143+
return ChangeSet.generate_from_diff_output(self._capture_diff_lines(ref1, ref2))
144+
145+
def get_commit_changes(self, sha1: str) -> ChangeSet:
146+
"""
147+
Get the changes of the given commit in the Git repository in the current working directory.
148+
"""
149+
return self._compare_refs(f"{sha1}^", sha1)
150+
151+
def get_changes_between_commits(self, a: str, b: str) -> ChangeSet:
152+
"""
153+
Get the changes between two commits, identified by their SHA-1 hashes.
154+
"""
155+
return self._compare_refs(a, b)
156+
157+
def get_working_tree_changes(self) -> ChangeSet:
158+
"""
159+
Get the changes in the working tree of the Git repository in the current working directory.
160+
"""
161+
from os import environ
162+
163+
from dda.utils.fs import temp_file
164+
165+
with temp_file(suffix=".git_index") as temp_index_path:
166+
# Set up environment with temporary index
167+
original_env = environ.copy()
168+
temp_env = original_env | {"GIT_INDEX_FILE": str(temp_index_path.resolve())}
169+
170+
# Populate the temporary index with HEAD
171+
self.run(["read-tree", "HEAD"], env=temp_env)
172+
173+
# Get list of untracked files
174+
untracked_files_output = self.capture(
175+
["ls-files", "--others", "--exclude-standard", "-z"], env=temp_env
176+
).strip()
177+
untracked_files = [x.strip() for x in untracked_files_output.split("\0") if x.strip()]
178+
179+
# Add untracked files to the index with --intent-to-add
180+
if untracked_files:
181+
self.run(["add", "--intent-to-add", *untracked_files], env=temp_env)
182+
183+
# Get all changes (tracked + untracked) with a single diff command
184+
diff_lines = self._capture_diff_lines("HEAD", env=temp_env)
185+
186+
return ChangeSet.generate_from_diff_output(diff_lines)
187+
188+
def get_merge_base(self, remote_name: str = "origin") -> str:
189+
"""
190+
Get the merge base of the current branch.
191+
"""
192+
res = self.capture(["merge-base", "HEAD", remote_name], check=False).strip()
193+
if not res:
194+
self.app.display_warning("Could not determine merge base for current branch. Using `main` instead.")
195+
return "main"
196+
return res
197+
198+
def get_changes_with_base(
199+
self,
200+
base: str | None = None,
201+
*,
202+
include_working_tree: bool = True,
203+
remote_name: str = "origin",
204+
) -> ChangeSet:
205+
"""
206+
Get the changes with the given base.
207+
By default, this base is the merge base of the current branch.
208+
If it cannot be determined, `main` will be used instead.
209+
210+
If `include_working_tree` is True, the changes in the working tree will be included.
211+
If `remote_name` is provided, the changes will be compared to the branch in the remote with this name.
212+
"""
213+
if base is None:
214+
base = self.get_merge_base(remote_name)
215+
216+
head = self.get_head_commit()
217+
changes = ChangeSet.generate_from_diff_output(self._capture_diff_lines(base, "...", head.sha1))
218+
if include_working_tree:
219+
changes |= self.get_working_tree_changes()
220+
return changes

src/dda/utils/fs.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,30 @@ def __as_exe(self) -> Path:
152152
def __as_exe(self) -> Path:
153153
return self
154154

155+
@classmethod
156+
def enc_hook(cls, obj: Any) -> Any:
157+
if isinstance(obj, cls):
158+
return repr(obj)
159+
160+
message = f"Objects of type {type(obj)} are not supported"
161+
raise NotImplementedError(message)
162+
163+
@classmethod
164+
def dec_hook(cls, obj_type: type, obj: Any) -> Any:
165+
if obj_type is cls:
166+
# Was encoded as the repr of this object
167+
# Should be of the form f"{cls.__qualname__}({str(obj)})"
168+
qualname, path = obj.split("(")
169+
path = path.rstrip(")").strip("'").strip('"')
170+
if qualname != cls.__qualname__:
171+
message = f"Objects of type {obj_type} are not supported"
172+
raise NotImplementedError(message)
173+
174+
return cls(path)
175+
176+
message = f"Objects of type {obj_type} are not supported"
177+
raise NotImplementedError(message)
178+
155179

156180
@contextmanager
157181
def temp_directory() -> Generator[Path, None, None]:
@@ -170,3 +194,22 @@ def temp_directory() -> Generator[Path, None, None]:
170194

171195
with TemporaryDirectory() as d:
172196
yield Path(d).resolve()
197+
198+
199+
@contextmanager
200+
def temp_file(suffix: str = "") -> Generator[Path, None, None]:
201+
"""
202+
A context manager that creates a temporary file and yields a path to it. Example:
203+
204+
```python
205+
with temp_file() as temp_file:
206+
...
207+
```
208+
209+
Yields:
210+
The resolved path to the temporary file, following all symlinks.
211+
"""
212+
from tempfile import NamedTemporaryFile
213+
214+
with NamedTemporaryFile(suffix=suffix) as f:
215+
yield Path(f.name)

0 commit comments

Comments
 (0)