Skip to content

Commit

Permalink
Ensure proper cache initialization before writing
Browse files Browse the repository at this point in the history
Writing cache data is interruptible; this prevents a pathological case
where interrupting a cache write can cause the cache directory to never
be properly initialized with its supporting files.

Unify `Cache.mkdir` with `Cache.set` while I'm here so the former also
properly initializes the cache directory.

Fixes pytest-dev#12167.
  • Loading branch information
tamird committed Mar 31, 2024
1 parent 12e061e commit 98cfcce
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 27 deletions.
1 change: 1 addition & 0 deletions changelog/12167.trivial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cache: ensure supporting files (``CACHEDIR.TAG``, ``.gitignore``, etc.) are always created in the cache directory, even in the event of the test session being interrupted.
46 changes: 28 additions & 18 deletions src/_pytest/cacheprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union

from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.compat import assert_never
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
Expand Down Expand Up @@ -123,6 +125,10 @@ def warn(self, fmt: str, *, _ispytest: bool = False, **args: object) -> None:
stacklevel=3,
)

def _mkdir(self, path: Path) -> None:
self._ensure_cache_dir_and_supporting_files()
path.mkdir(exist_ok=True, parents=True)

def mkdir(self, name: str) -> Path:
"""Return a directory path object with the given name.
Expand All @@ -141,7 +147,7 @@ def mkdir(self, name: str) -> Path:
if len(path.parts) > 1:
raise ValueError("name is not allowed to contain path separators")
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
res.mkdir(exist_ok=True, parents=True)
self._mkdir(res)
return res

def _getvaluepath(self, key: str) -> Path:
Expand Down Expand Up @@ -178,19 +184,13 @@ def set(self, key: str, value: object) -> None:
"""
path = self._getvaluepath(key)
try:
if path.parent.is_dir():
cache_dir_exists_already = True
else:
cache_dir_exists_already = self._cachedir.exists()
path.parent.mkdir(exist_ok=True, parents=True)
self._mkdir(path.parent)
except OSError as exc:
self.warn(
f"could not create cache path {path}: {exc}",
_ispytest=True,
)
return
if not cache_dir_exists_already:
self._ensure_supporting_files()
data = json.dumps(value, ensure_ascii=False, indent=2)
try:
f = path.open("w", encoding="UTF-8")
Expand All @@ -203,17 +203,27 @@ def set(self, key: str, value: object) -> None:
with f:
f.write(data)

def _ensure_supporting_files(self) -> None:
"""Create supporting files in the cache dir that are not really part of the cache."""
readme_path = self._cachedir / "README.md"
readme_path.write_text(README_CONTENT, encoding="UTF-8")

gitignore_path = self._cachedir.joinpath(".gitignore")
msg = "# Created by pytest automatically.\n*\n"
gitignore_path.write_text(msg, encoding="UTF-8")
def _ensure_cache_dir_and_supporting_files(self) -> None:
"""Create the cache dir and its supporting files."""
self._cachedir.mkdir(exist_ok=True, parents=True)

cachedir_tag_path = self._cachedir.joinpath("CACHEDIR.TAG")
cachedir_tag_path.write_bytes(CACHEDIR_TAG_CONTENT)
files: Iterable[Tuple[str, Union[str, bytes]]] = (
("README.md", README_CONTENT),
(".gitignore", "# Created by pytest automatically.\n*\n"),
("CACHEDIR.TAG", CACHEDIR_TAG_CONTENT),
)
for file, content in files:
if isinstance(content, str):
mode = "xt"
elif isinstance(content, bytes):
mode = "xb"
else:
assert_never(content)
try:
with open(self._cachedir.joinpath(file), mode, encoding="UTF-8") as f:
f.write(content)
except FileExistsError:
pass


class LFPluginCollWrapper:
Expand Down
22 changes: 13 additions & 9 deletions testing/test_cacheprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,12 +1263,7 @@ def test_gitignore(pytester: Pytester) -> None:
cache.set("foo", "bar")
msg = "# Created by pytest automatically.\n*\n"
gitignore_path = cache._cachedir.joinpath(".gitignore")
assert gitignore_path.read_text(encoding="UTF-8") == msg

# Does not overwrite existing/custom one.
gitignore_path.write_text("custom", encoding="utf-8")
cache.set("something", "else")
assert gitignore_path.read_text(encoding="UTF-8") == "custom"
assert gitignore_path.read_text(encoding="utf-8") == msg


def test_preserve_keys_order(pytester: Pytester) -> None:
Expand All @@ -1282,9 +1277,15 @@ def test_preserve_keys_order(pytester: Pytester) -> None:
assert list(read_back.items()) == [("z", 1), ("b", 2), ("a", 3), ("d", 10)]


def test_does_not_create_boilerplate_in_existing_dirs(pytester: Pytester) -> None:
def test_does_not_overwrite_with_boilerplate(pytester: Pytester) -> None:
from _pytest.cacheprovider import Cache

files = ["README.md", ".gitignore"]

for filename in files:
with open(filename, "w", encoding="utf-8") as f:
f.write(filename)

pytester.makeini(
"""
[pytest]
Expand All @@ -1296,8 +1297,11 @@ def test_does_not_create_boilerplate_in_existing_dirs(pytester: Pytester) -> Non
cache.set("foo", "bar")

assert os.path.isdir("v") # cache contents
assert not os.path.exists(".gitignore")
assert not os.path.exists("README.md")

# Original files already existed so left unchanged.
for filename in files:
with open(filename, encoding="utf-8") as f:
assert f.read() == filename


def test_cachedir_tag(pytester: Pytester) -> None:
Expand Down

0 comments on commit 98cfcce

Please sign in to comment.