Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/gh_worktree/commands/init.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import re
from typing import Optional
from urllib import parse
Expand Down Expand Up @@ -87,13 +86,13 @@ def __call__(self, repo: str, *destination_dir: Optional[str]):
repo_target = RepositoryTarget(repo, destination_dir=destination_dir)
repo_target.validate()

project_dir = os.path.join(self._context.cwd, repo_target.destination_dir)
project_dir = self._context.cwd / repo_target.destination_dir

if os.path.exists(project_dir):
if project_dir.exists():
# this would be problematic!
raise AssertionError(f"Project directory {project_dir} already exists")

os.makedirs(project_dir, exist_ok=True)
project_dir.mkdir(parents=True, exist_ok=True)

with self._context.use(project_dir):
self._runtime.hooks.fire(
Expand All @@ -104,7 +103,7 @@ def __call__(self, repo: str, *destination_dir: Optional[str]):
)
self._runtime.git.clone(repo_target.uri, ".bare")

with open(os.path.join(project_dir, ".git"), "w") as f:
with (project_dir / ".git").open("w", encoding="utf-8") as f:
f.write("gitdir: ./.bare")

self._runtime.git.config(
Expand All @@ -121,8 +120,10 @@ def __call__(self, repo: str, *destination_dir: Optional[str]):
is_private=repo_data["isPrivate"],
)

os.makedirs(self._context.config_dir, exist_ok=True)
with open(os.path.join(self._context.config_dir, "config.json"), "w") as f:
self._context.config_dir.mkdir(parents=True, exist_ok=True)
with (self._context.config_dir / "config.json").open(
"w", encoding="utf-8"
) as f:
config.save(f)

self._add_hooks(config)
Expand Down
4 changes: 1 addition & 3 deletions src/gh_worktree/commands/remove.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os.path

from gh_worktree.command import Command
from gh_worktree.hooks import Hook

Expand All @@ -25,7 +23,7 @@ def __call__(self, worktree_name: str, force: bool = False):
:param force: Whether to force the removal of the worktree, if it's unmerged
"""
project_dir = self._context.project_dir
if not os.path.exists(os.path.join(project_dir, worktree_name)):
if not (project_dir / worktree_name).exists():
raise ValueError(f"Worktree {worktree_name} does not exist")

with self._context.use(project_dir):
Expand Down
47 changes: 25 additions & 22 deletions src/gh_worktree/context.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,48 @@
import os
from contextlib import contextmanager
from functools import cached_property
from pathlib import Path
from typing import Union

from gh_worktree.config import Config
from gh_worktree.config import RepositoryConfig
from gh_worktree.config import GlobalConfig
from gh_worktree.config import RepositoryConfig
from gh_worktree.utils import find_up


class Context(object):
def __init__(self):
self.cwd = os.getcwd()
self.cwd = Path.cwd()

@cached_property
def project_dir(self) -> str:
def project_dir(self) -> Path:
git_bare_dir = find_up(".bare", self.cwd)
return os.path.dirname(git_bare_dir)
return git_bare_dir.parent

@property
def config_dir(self) -> str:
return os.path.join(self.project_dir, ".gh", "worktree")
def config_dir(self) -> Path:
return self.project_dir / ".gh" / "worktree"

@property
def global_config_dir(self) -> str:
def global_config_dir(self) -> Path:
try:
parent_dir = os.path.dirname(self.project_dir)
parent_dir = self.project_dir.parent
except RuntimeError:
parent_dir = os.path.dirname(self.cwd)
parent_dir = self.cwd.parent

try:
closest_gh_dir = find_up(".gh", parent_dir)
return os.path.join(closest_gh_dir, "worktree")
return closest_gh_dir / "worktree"
except RuntimeError:
# default to ~/.gh/worktree
return os.path.join(os.path.expanduser("~"), ".gh", "worktree")
return Path.home() / ".gh" / "worktree"

@contextmanager
def use(self, cwd: str):
def use(self, cwd: Union[str, Path]):
old_cwd = self.cwd
os.chdir(cwd)
self.cwd = cwd
cwd_path = Path(cwd)
os.chdir(cwd_path)
self.cwd = cwd_path
try:
yield
finally:
Expand All @@ -53,19 +56,19 @@ def assert_within_project(self):
raise AssertionError("Project not found")

def get_config(self) -> RepositoryConfig:
file_path = os.path.join(self.config_dir, "config.json")
if not os.path.exists(file_path):
file_path = self.config_dir / "config.json"
if not file_path.exists():
return RepositoryConfig()

with open(file_path, "r") as f:
with file_path.open("r", encoding="utf-8") as f:
return RepositoryConfig.load(f)

def get_global_config(self):
file_path = os.path.join(self.global_config_dir, "config.json")
if not os.path.exists(file_path):
file_path = self.global_config_dir / "config.json"
if not file_path.exists():
return GlobalConfig()

with open(file_path, "r") as f:
with file_path.open("r", encoding="utf-8") as f:
return GlobalConfig.load(f)

def set_config(self, config: Config):
Expand All @@ -76,6 +79,6 @@ def set_config(self, config: Config):
else:
raise ValueError(f"Unknown config type: {type(config)}")

os.makedirs(config_dir, exist_ok=True)
with open(os.path.join(config_dir, "config.json"), "w") as f:
config_dir.mkdir(parents=True, exist_ok=True)
with (config_dir / "config.json").open("w", encoding="utf-8") as f:
config.save(f)
28 changes: 15 additions & 13 deletions src/gh_worktree/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,23 @@ def fire(self, hook: Hook, *args, skip_project: bool = False) -> bool:
fired = False

for hooks_dir in self.iter_config_dirs(skip_project=skip_project):
hook_file = os.path.join(hooks_dir, hook.name)
if not os.path.exists(hook_file):
hook_file = hooks_dir / hook.name
if not hook_file.exists():
continue

# Ensure the hook file is executable
if not os.access(hook_file, os.X_OK):
print(f"Hook {hook_file} is not executable. Skipping.")
hook_file_str = str(hook_file)
if not os.access(hook_file_str, os.X_OK):
print(f"Hook {hook_file_str} is not executable. Skipping.")
continue

if not self._check_allowed(hook_file):
print(f"Hook {hook_file} is not allowed to run. Skipping.")
if not self._check_allowed(hook_file_str):
print(f"Hook {hook_file_str} is not allowed to run. Skipping.")
continue

fired = True
return_status = stream_exec([hook_file, *args], cwd=self.context.cwd)
command_args = [hook_file_str, *[str(arg) for arg in args]]
return_status = stream_exec(command_args, cwd=self.context.cwd)
if return_status != 0:
raise RuntimeError(
f"Hook {hook.name} failed with exit code {return_status}"
Expand Down Expand Up @@ -80,16 +82,16 @@ def _check_allowed(self, hook_file: str) -> bool:

@contextmanager
def add(self, hook: Hook):
hooks_dir = os.path.join(self.context.config_dir, "hooks")
hook_file = os.path.join(hooks_dir, hook.name)
os.makedirs(hooks_dir, exist_ok=True)
hooks_dir = self.context.config_dir / "hooks"
hook_file = hooks_dir / hook.name
hooks_dir.mkdir(parents=True, exist_ok=True)

if os.path.exists(hook_file):
if hook_file.exists():
raise HookExists(f"Hook {hook_file} already exists.")

# copy it to config
with open(hook_file, "w", newline="\n") as f:
with hook_file.open("w", encoding="utf-8", newline="\n") as f:
yield f

# allow exec
os.chmod(hook_file, os.stat(hook_file).st_mode | stat.S_IEXEC)
hook_file.chmod(hook_file.stat().st_mode | stat.S_IEXEC)
8 changes: 4 additions & 4 deletions src/gh_worktree/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
from pathlib import Path
from typing import Iterator

from gh_worktree.context import Context
Expand All @@ -10,14 +10,14 @@ class ConfigOperator(object):
def __init__(self, context: Context):
self.context = context

def iter_config_dirs(self, skip_project: bool = False) -> Iterator[str]:
def iter_config_dirs(self, skip_project: bool = False) -> Iterator[Path]:
configs = [self.context.global_config_dir]
if not skip_project:
configs.append(self.context.config_dir)

for config_dir in configs:
op_dir = os.path.join(config_dir, self.dir_name)
if not os.path.exists(op_dir):
op_dir = config_dir / self.dir_name
if not op_dir.exists():
continue

yield op_dir
6 changes: 3 additions & 3 deletions src/gh_worktree/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def __init__(self, context: Context):

def copy(self, worktree_name: str):
config = self.context.get_config()
worktree_dir = Path(os.path.join(self.context.project_dir, worktree_name))
worktree_dir = self.context.project_dir / worktree_name

self.replacement_map["REPO_NAME"] = config.name
self.replacement_map["REPO_DIR"] = self.context.project_dir
self.replacement_map["REPO_DIR"] = str(self.context.project_dir)
self.replacement_map["WORKTREE_NAME"] = worktree_name
self.replacement_map["WORKTREE_DIR"] = str(worktree_dir)

for templates_dir in self.iter_config_dirs():
for path in Path(templates_dir).rglob("*"):
for path in templates_dir.rglob("*"):
relative_path = path.relative_to(templates_dir)
self._copy(worktree_dir, path, relative_path)

Expand Down
50 changes: 31 additions & 19 deletions src/gh_worktree/utils.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,53 @@
import os
import random
import shlex
import subprocess
from pathlib import Path
from typing import List
from typing import Optional
from typing import Union

# Simple ANSI colors for the prefix
COLORS = ["\033[92m", "\033[94m", "\033[95m", "\033[96m", "\033[93m"]
COLOR_RESET = "\033[0m"


def find_up(name: str, start_path: str):
def find_up(name: str, start_path: Union[str, Path]) -> Path:
"""
Looks upward for a directory that has file or directory with `name`
:param name: The name of the file or directory to look for
:param start_path: The path to start looking from
:return: The path to the directory
"""
search_path = os.path.realpath(start_path)
search_path = Path(start_path).resolve()

while True:
name_path = os.path.join(search_path, name)
if os.path.isdir(search_path) and os.path.exists(name_path):
name_path = search_path / name
if search_path.is_dir() and name_path.exists():
return name_path
if search_path == os.path.dirname(search_path):
if search_path == search_path.parent:
break
search_path = os.path.dirname(search_path)
search_path = search_path.parent

raise RuntimeError(f"Could not find {name} in {start_path} ancestors")


def stream_exec(command: List[str], wait_time: int = 60, cwd: str = None) -> int:
def _log_prefix(command: List[str]) -> str:
"""
Returns a prefix string for visibility, via logging, into the command being executed
:param command: The command list to be executed
:return: A string for prefixing log messages
"""
command_prefix = command[:2]
command_script_path = Path(command_prefix[0])
if command_script_path.exists():
command_prefix[0] = command_script_path.name

return shlex.join(command_prefix)


def stream_exec(
command: List[str], wait_time: int = 60, cwd: Optional[Union[str, Path]] = None
) -> int:
"""
Executes a command in a subprocess and streams its output to stdout.
:param command: The command to execute as a list of strings
Expand All @@ -48,13 +66,9 @@ def stream_exec(command: List[str], wait_time: int = 60, cwd: str = None) -> int
output_color = COLORS[process.pid % len(COLORS)]
print(f"Executing: {output_color}{shlex.join(command)}{COLOR_RESET}")

command_prefix = command[:2]
if os.path.exists(command_prefix[0]):
command_prefix[0] = os.path.basename(command_prefix[0])

for line in process.stdout:
print(
f"{output_color}{shlex.join(command_prefix)} |{COLOR_RESET} {line}",
f"{output_color}{_log_prefix(command)} |{COLOR_RESET} {line}",
end="",
flush=True,
)
Expand All @@ -64,7 +78,9 @@ def stream_exec(command: List[str], wait_time: int = 60, cwd: str = None) -> int
return process.returncode


def iter_output(command: List[str], wait_time: int = 60, cwd: str = None):
def iter_output(
command: List[str], wait_time: int = 60, cwd: Optional[Union[str, Path]] = None
):
"""
Executes a command in a subprocess and iterates its output after completion
:param command: The command to execute as a list of strings
Expand All @@ -83,13 +99,9 @@ def iter_output(command: List[str], wait_time: int = 60, cwd: str = None):
cwd=cwd,
)

command_prefix = command[:2]
if os.path.exists(command_prefix[0]):
command_prefix[0] = os.path.basename(command_prefix[0])

for line in result.stdout.splitlines():
print(
f"{output_color}{shlex.join(command_prefix)} |{COLOR_RESET} {line}",
f"{output_color}{_log_prefix(command)} |{COLOR_RESET} {line}",
flush=True,
)
yield line
5 changes: 3 additions & 2 deletions tests/commands/test_checkout.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from contextlib import contextmanager
from pathlib import Path
from types import SimpleNamespace
from unittest import TestCase
from unittest.mock import ANY
Expand All @@ -11,7 +12,7 @@

class StubContext:
def __init__(self, project_dir, config):
self.project_dir = str(project_dir)
self.project_dir = project_dir
self._config = config
self.assert_called = False

Expand All @@ -33,7 +34,7 @@ def setUp(self):
owner="octo",
name="repo",
)
self.context = StubContext("/repo", self.config)
self.context = StubContext(Path("/repo"), self.config)
self.hooks = SimpleNamespace(fire=Mock())
self.git = SimpleNamespace(open_worktree=Mock(), fetch=Mock())
self.gh = SimpleNamespace(pr_status=Mock())
Expand Down
Loading