Skip to content

Commit e48f51f

Browse files
committed
feat(git): Make Git.get_working_tree_changes use a temporary index instead of manually diffing every file
This is much higher performance, as only a fixed number of calls to `git` have to be made (considering the speed of git vs Python itself, we can probably say this makes the function O(1) instead of O(n)). Moreover, this makes the function work properly on Windows, as we do not require `/dev/null` to exist anymore.
1 parent 2cacab5 commit e48f51f

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

src/dda/tools/git.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from __future__ import annotations
55

66
from functools import cached_property
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Any
88

99
from dda.tools.base import Tool
1010
from dda.utils.git.changeset import ChangeSet
@@ -131,7 +131,7 @@ def get_commit_details(self, sha1: str) -> CommitDetails:
131131
parent_shas=list(parents_str.split()),
132132
)
133133

134-
def _capture_diff_lines(self, *args: str) -> list[str]:
134+
def _capture_diff_lines(self, *args: str, **kwargs: Any) -> list[str]:
135135
diff_args = [
136136
"-c",
137137
"core.quotepath=false",
@@ -143,7 +143,7 @@ def _capture_diff_lines(self, *args: str) -> list[str]:
143143
"--no-ext-diff",
144144
# "-z",
145145
]
146-
return self.capture([*diff_args, *args], check=False).strip().splitlines()
146+
return self.capture([*diff_args, *args], check=False, **kwargs).strip().splitlines()
147147

148148
def _compare_refs(self, ref1: str, ref2: str) -> ChangeSet:
149149
return ChangeSet.generate_from_diff_output(self._capture_diff_lines(ref1, ref2))
@@ -164,25 +164,32 @@ def get_working_tree_changes(self) -> ChangeSet:
164164
"""
165165
Get the changes in the working tree of the Git repository in the current working directory.
166166
"""
167-
from itertools import chain
167+
from os import environ
168+
169+
from dda.utils.fs import temp_file
168170

169-
from dda.utils.git.changeset import ChangeSet
171+
with temp_file(suffix=".git_index") as temp_index_path:
172+
# Set up environment with temporary index
173+
original_env = environ.copy()
174+
temp_env = original_env | {"GIT_INDEX_FILE": str(temp_index_path.resolve())}
170175

171-
# Capture changes to already-tracked files - `diff HEAD` does not include any untracked files !
172-
tracked_changes = ChangeSet.generate_from_diff_output(self._capture_diff_lines("HEAD"))
176+
# Populate the temporary index with HEAD
177+
self.run(["read-tree", "HEAD"], env=temp_env)
173178

174-
# Capture changes to untracked files
175-
other_files_output = self.capture(["ls-files", "--others", "--exclude-standard", "-z"]).strip()
176-
untracked_files = [x.strip() for x in other_files_output.split("\0") if x] # Remove empty strings
179+
# Get list of untracked files
180+
untracked_files_output = self.capture(
181+
["ls-files", "--others", "--exclude-standard", "-z"], env=temp_env
182+
).strip()
183+
untracked_files = [x.strip() for x in untracked_files_output.split("\0") if x.strip()]
177184

178-
if not untracked_files:
179-
return tracked_changes
185+
# Add untracked files to the index with --intent-to-add
186+
if untracked_files:
187+
self.run(["add", "--intent-to-add", *untracked_files], env=temp_env)
180188

181-
diffs = list(chain.from_iterable(self._capture_diff_lines("/dev/null", file) for file in untracked_files))
182-
untracked_changes = ChangeSet.generate_from_diff_output(diffs)
189+
# Get all changes (tracked + untracked) with a single diff command
190+
diff_lines = self._capture_diff_lines("HEAD", env=temp_env)
183191

184-
# Combine the changes
185-
return tracked_changes | untracked_changes
192+
return ChangeSet.generate_from_diff_output(diff_lines)
186193

187194
def get_merge_base(self, remote_name: str = "origin") -> str:
188195
"""

src/dda/utils/fs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,22 @@ def temp_directory() -> Generator[Path, None, None]:
194194

195195
with TemporaryDirectory() as d:
196196
yield Path(d).resolve()
197+
198+
199+
@contextmanager
200+
def temp_file(suffix: str = "") -> Generator[Path, None, None]:
201+
"""
202+
A context manager that creates a temporary file and yields a path to it. Example:
203+
204+
```python
205+
with temp_file() as temp_file:
206+
...
207+
```
208+
209+
Yields:
210+
The resolved path to the temporary file, following all symlinks.
211+
"""
212+
from tempfile import NamedTemporaryFile
213+
214+
with NamedTemporaryFile(suffix=suffix) as f:
215+
yield Path(f.name)

0 commit comments

Comments
 (0)