diff --git a/git/cmd.py b/git/cmd.py index 3665eb029..d6f8f946a 100644 --- a/git/cmd.py +++ b/git/cmd.py @@ -14,7 +14,6 @@ import subprocess import threading from textwrap import dedent -import unittest.mock from git.compat import ( defenc, @@ -24,7 +23,7 @@ is_win, ) from git.exc import CommandError -from git.util import is_cygwin_git, cygpath, expand_path, remove_password_if_present +from git.util import is_cygwin_git, cygpath, expand_path, remove_password_if_present, patch_env from .exc import GitCommandError, GitCommandNotFound, UnsafeOptionError, UnsafeProtocolError from .util import ( @@ -965,10 +964,10 @@ def execute( '"kill_after_timeout" feature is not supported on Windows.', ) # Only search PATH, not CWD. This must be in the *caller* environment. The "1" can be any value. - patch_caller_env = unittest.mock.patch.dict(os.environ, {"NoDefaultCurrentDirectoryInExePath": "1"}) + maybe_patch_caller_env = patch_env("NoDefaultCurrentDirectoryInExePath", "1") else: cmd_not_found_exception = FileNotFoundError # NOQA # exists, flake8 unknown @UndefinedVariable - patch_caller_env = contextlib.nullcontext() + maybe_patch_caller_env = contextlib.nullcontext() # end handle stdout_sink = PIPE if with_stdout else getattr(subprocess, "DEVNULL", None) or open(os.devnull, "wb") @@ -984,7 +983,7 @@ def execute( istream_ok, ) try: - with patch_caller_env: + with maybe_patch_caller_env: proc = Popen( command, env=env, diff --git a/git/util.py b/git/util.py index f6dedf0f2..f80580cff 100644 --- a/git/util.py +++ b/git/util.py @@ -158,6 +158,20 @@ def cwd(new_dir: PathLike) -> Generator[PathLike, None, None]: os.chdir(old_dir) +@contextlib.contextmanager +def patch_env(name: str, value: str) -> Generator[None, None, None]: + """Context manager to temporarily patch an environment variable.""" + old_value = os.getenv(name) + os.environ[name] = value + try: + yield + finally: + if old_value is None: + del os.environ[name] + else: + os.environ[name] = old_value + + def rmtree(path: PathLike) -> None: """Remove the given recursively. @@ -935,7 +949,7 @@ def _obtain_lock_or_raise(self) -> None: ) try: - with open(lock_file, mode='w'): + with open(lock_file, mode="w"): pass except OSError as e: raise IOError(str(e)) from e