Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 61 additions & 40 deletions dvc/fs/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,55 +131,53 @@ def __init__( # noqa: PLR0913
... rev="main",
... )
"""
from pygtrie import Trie

super().__init__()
self._repo = repo
self._repo_factory = repo_factory
self._traverse_subrepos = subrepos
self._repo_stack = ExitStack()
if repo is None:
url = url if url is not None else fo
repo = self._make_repo(
url=url,
rev=rev,
subrepos=subrepos,
config=config,
remote=remote,
remote_config=remote_config,
)
assert repo is not None
self._repo_kwargs = {
"url": url if url is not None else fo,
"rev": rev,
"subrepos": subrepos,
"config": config,
"remote": remote,
"remote_config": remote_config,
}

repo_factory = repo._fs_conf["repo_factory"]
self._repo_stack.enter_context(repo)
self.path = Path(self.sep, getcwd=self._getcwd)

if not repo_factory:
from dvc.repo import Repo
@functools.cached_property
def repo(self):
if self._repo:
return self._repo

self.repo_factory: RepoFactory = Repo
else:
self.repo_factory = repo_factory

def _getcwd():
relparts: Tuple[str, ...] = ()
assert repo is not None
if repo.fs.path.isin(repo.fs.path.getcwd(), repo.root_dir):
relparts = repo.fs.path.relparts(repo.fs.path.getcwd(), repo.root_dir)
return self.root_marker + self.sep.join(relparts)

self.path = Path(self.sep, getcwd=_getcwd)
self.repo = repo
self.hash_jobs = repo.fs.hash_jobs
self._traverse_subrepos = subrepos
repo = self._make_repo(**self._repo_kwargs)

self._subrepos_trie = Trie()
"""Keeps track of each and every path with the corresponding repo."""
self._repo_stack.enter_context(repo)
self._repo = repo
return repo

key = self._get_key(self.repo.root_dir)
self._subrepos_trie[key] = repo
@functools.cached_property
def repo_factory(self):
if self._repo_factory:
return self._repo_factory

self._datafss = {}
"""Keep a datafs instance of each repo."""
if self._repo:
from dvc.repo import Repo

return Repo

if hasattr(repo, "dvc_dir"):
self._datafss[key] = DataFileSystem(index=repo.index.data["repo"])
return self.repo._fs_conf["repo_factory"]

def _getcwd(self):
relparts: Tuple[str, ...] = ()
assert self.repo is not None
if self.repo.fs.path.isin(self.repo.fs.path.getcwd(), self.repo.root_dir):
relparts = self.repo.fs.path.relparts(
self.repo.fs.path.getcwd(), self.repo.root_dir
)
return self.root_marker + self.sep.join(relparts)

@functools.cached_property
def fsid(self) -> str:
Expand All @@ -199,6 +197,17 @@ def _get_key(self, path: "StrPath") -> Key:
return ()
return parts

@functools.cached_property
def _subrepos_trie(self):
"""Keeps track of each and every path with the corresponding repo."""

from pygtrie import Trie

trie = Trie()
key = self._get_key(self.repo.root_dir)
trie[key] = self.repo
return trie

def _get_key_from_relative(self, path) -> Key:
path = self._strip_protocol(path)
parts = self.path.relparts(path, self.root_marker)
Expand All @@ -209,6 +218,18 @@ def _get_key_from_relative(self, path) -> Key:
def _from_key(self, parts: Key) -> str:
return self.repo.fs.path.join(self.repo.root_dir, *parts)

@functools.cached_property
def _datafss(self):
"""Keep a datafs instance of each repo."""

datafss = {}

if hasattr(self.repo, "dvc_dir"):
key = self._get_key(self.repo.root_dir)
datafss[key] = DataFileSystem(index=self.repo.index.data["repo"])

return datafss

@property
def repo_url(self):
return self.repo.url
Expand Down
17 changes: 9 additions & 8 deletions tests/func/repro/test_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,10 +1331,10 @@ def test_repro_keep_going(mocker, tmp_dir, dvc, copy_script):
stage1, upstream=[bar_stage], force=False, interactive=False
)
foo_call = mocker.call(foo_stage, upstream=[], force=False, interactive=False)
assert spy.call_args_list in (
[bar_call, stage1_call, foo_call],
[foo_call, bar_call, stage1_call],
)
assert len(spy.call_args_list) == 3
assert foo_call in spy.call_args_list
assert bar_call in spy.call_args_list
assert stage1_call in spy.call_args_list


def test_repro_ignore_errors(mocker, tmp_dir, dvc, copy_script):
Expand Down Expand Up @@ -1362,10 +1362,11 @@ def test_repro_ignore_errors(mocker, tmp_dir, dvc, copy_script):
force=False,
interactive=False,
)
assert spy.call_args_list in (
[bar_call, stage1_call, foo_call, stage2_call],
[foo_call, bar_call, stage1_call, stage2_call],
)
assert len(spy.call_args_list) == 4
assert foo_call in spy.call_args_list
assert bar_call in spy.call_args_list
assert stage1_call in spy.call_args_list
assert stage2_call in spy.call_args_list


@pytest.mark.parametrize("persist", [True, False])
Expand Down