Skip to content

Commit ae990bf

Browse files
committed
Port src/ changes from #181
1 parent 1171b07 commit ae990bf

File tree

5 files changed

+702
-0
lines changed

5 files changed

+702
-0
lines changed

src/dda/tools/git.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
from __future__ import annotations
55

66
from functools import cached_property
7+
from typing import TYPE_CHECKING, Any
78

89
from dda.tools.base import Tool
10+
from dda.utils.git.changeset import ChangeSet
911
from dda.utils.git.constants import GitAuthorEnvVars
12+
from dda.utils.git.remote import Remote
13+
14+
if TYPE_CHECKING:
15+
from dda.utils.git.commit import Commit, CommitDetails
1016

1117

1218
class Git(Tool):
@@ -62,3 +68,147 @@ def author_email(self) -> str:
6268
return env_email
6369

6470
return self.capture(["config", "--get", "user.email"]).strip()
71+
72+
# === PRETEMPLATED COMMANDS === #
73+
def get_remote_details(self, remote_name: str = "origin") -> Remote:
74+
"""
75+
Get the details of the given remote for the Git repository in the current working directory.
76+
The returned tuple is (org, repo, url).
77+
"""
78+
79+
remote_url = self.capture(
80+
["config", "--get", f"remote.{remote_name}.url"],
81+
).strip()
82+
83+
return Remote(remote_url) # type: ignore[abstract]
84+
85+
def get_head_commit(self) -> Commit:
86+
"""
87+
Get the current HEAD commit of the Git repository in the current working directory.
88+
"""
89+
from dda.utils.git.commit import Commit
90+
91+
sha1 = self.capture(["rev-parse", "HEAD"]).strip()
92+
93+
return Commit(sha1=sha1)
94+
95+
def get_commit_details(self, sha1: str) -> CommitDetails:
96+
"""
97+
Get the details of the given commit in the Git repository in the current working directory.
98+
"""
99+
from datetime import datetime
100+
101+
from dda.utils.git.commit import CommitDetails
102+
103+
raw_details = self.capture([
104+
"show",
105+
"--quiet",
106+
# Use a format that is easy to parse
107+
# fmt: author name, author email, author date, parent SHAs, commit message body
108+
"--format=%an%n%ae%n%ad%n%P%n%B",
109+
"--date=iso-strict",
110+
sha1,
111+
])
112+
author_name, author_email, date_str, parents_str, *message_lines = raw_details.splitlines()
113+
114+
return CommitDetails(
115+
author_name=author_name,
116+
author_email=author_email,
117+
datetime=datetime.fromisoformat(date_str),
118+
message="\n".join(message_lines).strip().strip('"'),
119+
parent_shas=list(parents_str.split()),
120+
)
121+
122+
def _capture_diff_lines(self, *args: str, **kwargs: Any) -> list[str]:
123+
diff_args = [
124+
"-c",
125+
"core.quotepath=false",
126+
"diff",
127+
"-U0",
128+
"--no-color",
129+
"--no-prefix",
130+
"--no-renames",
131+
"--no-ext-diff",
132+
# "-z",
133+
]
134+
return self.capture([*diff_args, *args], check=False, **kwargs).strip().splitlines()
135+
136+
def _compare_refs(self, ref1: str, ref2: str) -> ChangeSet:
137+
return ChangeSet.generate_from_diff_output(self._capture_diff_lines(ref1, ref2))
138+
139+
def get_commit_changes(self, sha1: str) -> ChangeSet:
140+
"""
141+
Get the changes of the given commit in the Git repository in the current working directory.
142+
"""
143+
return self._compare_refs(f"{sha1}^", sha1)
144+
145+
def get_changes_between_commits(self, a: str, b: str) -> ChangeSet:
146+
"""
147+
Get the changes between two commits, identified by their SHA-1 hashes.
148+
"""
149+
return self._compare_refs(a, b)
150+
151+
def get_working_tree_changes(self) -> ChangeSet:
152+
"""
153+
Get the changes in the working tree of the Git repository in the current working directory.
154+
"""
155+
from os import environ
156+
157+
from dda.utils.fs import temp_file
158+
159+
with temp_file(suffix=".git_index") as temp_index_path:
160+
# Set up environment with temporary index
161+
original_env = environ.copy()
162+
temp_env = original_env | {"GIT_INDEX_FILE": str(temp_index_path.resolve())}
163+
164+
# Populate the temporary index with HEAD
165+
self.run(["read-tree", "HEAD"], env=temp_env)
166+
167+
# Get list of untracked files
168+
untracked_files_output = self.capture(
169+
["ls-files", "--others", "--exclude-standard", "-z"], env=temp_env
170+
).strip()
171+
untracked_files = [x.strip() for x in untracked_files_output.split("\0") if x.strip()]
172+
173+
# Add untracked files to the index with --intent-to-add
174+
if untracked_files:
175+
self.run(["add", "--intent-to-add", *untracked_files], env=temp_env)
176+
177+
# Get all changes (tracked + untracked) with a single diff command
178+
diff_lines = self._capture_diff_lines("HEAD", env=temp_env)
179+
180+
return ChangeSet.generate_from_diff_output(diff_lines)
181+
182+
def get_merge_base(self, remote_name: str = "origin") -> str:
183+
"""
184+
Get the merge base of the current branch.
185+
"""
186+
res = self.capture(["merge-base", "HEAD", remote_name], check=False).strip()
187+
if not res:
188+
self.app.display_warning("Could not determine merge base for current branch. Using `main` instead.")
189+
return "main"
190+
return res
191+
192+
def get_changes_with_base(
193+
self,
194+
base: str | None = None,
195+
*,
196+
include_working_tree: bool = True,
197+
remote_name: str = "origin",
198+
) -> ChangeSet:
199+
"""
200+
Get the changes with the given base.
201+
By default, this base is the merge base of the current branch.
202+
If it cannot be determined, `main` will be used instead.
203+
204+
If `include_working_tree` is True, the changes in the working tree will be included.
205+
If `remote_name` is provided, the changes will be compared to the branch in the remote with this name.
206+
"""
207+
if base is None:
208+
base = self.get_merge_base(remote_name)
209+
210+
head = self.get_head_commit()
211+
changes = ChangeSet.generate_from_diff_output(self._capture_diff_lines(base, "...", head.sha1))
212+
if include_working_tree:
213+
changes |= self.get_working_tree_changes()
214+
return changes

src/dda/utils/fs.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,30 @@ def __as_exe(self) -> Path:
152152
def __as_exe(self) -> Path:
153153
return self
154154

155+
@classmethod
156+
def enc_hook(cls, obj: Any) -> Any:
157+
if isinstance(obj, cls):
158+
return repr(obj)
159+
160+
message = f"Objects of type {type(obj)} are not supported"
161+
raise NotImplementedError(message)
162+
163+
@classmethod
164+
def dec_hook(cls, obj_type: type, obj: Any) -> Any:
165+
if obj_type is cls:
166+
# Was encoded as the repr of this object
167+
# Should be of the form f"{cls.__qualname__}({str(obj)})"
168+
qualname, path = obj.split("(")
169+
path = path.rstrip(")").strip("'").strip('"')
170+
if qualname != cls.__qualname__:
171+
message = f"Objects of type {obj_type} are not supported"
172+
raise NotImplementedError(message)
173+
174+
return cls(path)
175+
176+
message = f"Objects of type {obj_type} are not supported"
177+
raise NotImplementedError(message)
178+
155179

156180
@contextmanager
157181
def temp_directory() -> Generator[Path, None, None]:
@@ -170,3 +194,22 @@ def temp_directory() -> Generator[Path, None, None]:
170194

171195
with TemporaryDirectory() as d:
172196
yield Path(d).resolve()
197+
198+
199+
@contextmanager
200+
def temp_file(suffix: str = "") -> Generator[Path, None, None]:
201+
"""
202+
A context manager that creates a temporary file and yields a path to it. Example:
203+
204+
```python
205+
with temp_file() as temp_file:
206+
...
207+
```
208+
209+
Yields:
210+
The resolved path to the temporary file, following all symlinks.
211+
"""
212+
from tempfile import NamedTemporaryFile
213+
214+
with NamedTemporaryFile(suffix=suffix) as f:
215+
yield Path(f.name)

0 commit comments

Comments
 (0)