diff --git a/dvc/external_repo.py b/dvc/external_repo.py index f09aae71cc..dc4c108e84 100644 --- a/dvc/external_repo.py +++ b/dvc/external_repo.py @@ -15,15 +15,16 @@ from dvc.exceptions import FileMissingError, PathMissingError from dvc.remote import RemoteConfig from dvc.utils.fs import remove, fs_copy -from dvc.scm import SCM +from dvc.scm.git import Git logger = logging.getLogger(__name__) @contextmanager -def external_repo(url, rev=None): - path = _cached_clone(url, rev) +def external_repo(url, rev=None, for_write=False): + logger.debug("Creating external repo {}@{}", url, rev) + path = _cached_clone(url, rev, for_write=for_write) try: repo = ExternalRepo(path, url) except NotDvcRepoError: @@ -41,7 +42,8 @@ def external_repo(url, rev=None): raise PathMissingError(exc.path, url) finally: repo.close() - _remove(path) + if for_write: + _remove(path) CLONES = {} @@ -129,7 +131,7 @@ def __init__(self, root_dir, url): @cached_property def scm(self): - return SCM(self.root_dir) + return Git(self.root_dir) def close(self): if "scm" in self.__dict__: @@ -158,23 +160,21 @@ def open_by_relpath(self, path, mode="r", encoding=None, **kwargs): raise PathMissingError(path, self.url) -@wrap_with(threading.Lock()) -def _cached_clone(url, rev): +def _cached_clone(url, rev, for_write=False): """Clone an external git repo to a temporary directory. Returns the path to a local temporary directory with the specified - revision checked out. + revision checked out. If for_write is set prevents reusing this dir via + cache. """ - # Get or create a clean clone - clone_path = CLONES.get(url) - if clone_path: - # Do not pull for known shas, branches and tags might move - if not _is_known_sha(clone_path, rev): - _git_pull(clone_path) - else: - clone_path = tempfile.mkdtemp("dvc-clone") - _git_clone(url, clone_path) - CLONES[url] = clone_path + if not for_write and Git.is_sha(rev) and (url, rev) in CLONES: + return CLONES[url, rev] + + clone_path = _clone_default_branch(url, rev) + rev_sha = Git(clone_path).resolve_rev(rev or "HEAD") + + if not for_write and (url, rev_sha) in CLONES: + return CLONES[url, rev_sha] # Copy to a new dir to keep the clone clean repo_path = tempfile.mkdtemp("dvc-erepo") @@ -184,49 +184,43 @@ def _cached_clone(url, rev): if rev is not None: _git_checkout(repo_path, rev) + if not for_write: + CLONES[url, rev_sha] = repo_path return repo_path -def _git_clone(url, path): - from dvc.scm.git import Git - - git = Git.clone(url, path) - git.close() - +@wrap_with(threading.Lock()) +def _clone_default_branch(url, rev): + """Get or create a clean clone of the url. -def _git_checkout(repo_path, rev): - from dvc.scm import Git + The cloned is reactualized with git pull unless rev is a known sha. + """ + clone_path = CLONES.get(url) - git = Git(repo_path) + git = None try: - git.checkout(rev) + if clone_path: + git = Git(clone_path) + # Do not pull for known shas, branches and tags might move + if not Git.is_sha(rev) or not git.is_known(rev): + git.pull() + else: + clone_path = tempfile.mkdtemp("dvc-clone") + git = Git.clone(url, clone_path) + CLONES[url] = clone_path finally: - git.close() - - -def _is_known_sha(repo_path, rev): - import git - from gitdb.exc import BadName - - if not rev or not git.Repo.re_hexsha_shortened.search(rev): - return False - - try: - git.Repo(repo_path).commit(rev) - return True - except BadName: - return False + if git: + git.close() + return clone_path -def _git_pull(repo_path): - import git - repo = git.Repo(repo_path) +def _git_checkout(repo_path, rev): + git = Git(repo_path) try: - msg = repo.git.pull() - logger.debug("external repo: git pull: {}", msg) + git.checkout(rev) finally: - repo.close() + git.close() def _remove(path): diff --git a/dvc/scm/git/__init__.py b/dvc/scm/git/__init__.py index ac63854ffa..a2b9c2569b 100644 --- a/dvc/scm/git/__init__.py +++ b/dvc/scm/git/__init__.py @@ -8,14 +8,9 @@ from dvc.exceptions import GitHookAlreadyExistsError from dvc.scm.base import Base -from dvc.scm.base import CloneError -from dvc.scm.base import FileNotInRepoError -from dvc.scm.base import RevError -from dvc.scm.base import SCMError +from dvc.scm.base import CloneError, FileNotInRepoError, RevError, SCMError from dvc.scm.git.tree import GitTree -from dvc.utils import fix_env -from dvc.utils import is_binary -from dvc.utils import relpath +from dvc.utils import fix_env, is_binary, relpath from dvc.utils.fs import path_isin logger = logging.getLogger(__name__) @@ -94,6 +89,12 @@ def clone(url, to_path, rev=None): return repo + @staticmethod + def is_sha(rev): + import git + + return rev and git.Repo.re_hexsha_shortened.search(rev) + @staticmethod def is_repo(root_dir): return os.path.isdir(Git._get_git_dir(root_dir)) @@ -211,14 +212,16 @@ def checkout(self, branch, create_new=False): self.repo.git.checkout(branch) def pull(self): - info, = self.repo.remote().pull() - if info.flags & info.ERROR: - raise SCMError("pull failed: {}".format(info.note)) + infos = self.repo.remote().pull() + for info in infos: + if info.flags & info.ERROR: + raise SCMError("pull failed: {}".format(info.note)) def push(self): - info, = self.repo.remote().push() - if info.flags & info.ERROR: - raise SCMError("push failed: {}".format(info.summary)) + infos = self.repo.remote().push() + for info in infos: + if info.flags & info.ERROR: + raise SCMError("push failed: {}".format(info.summary)) def branch(self, branch): self.repo.git.branch(branch) @@ -324,15 +327,45 @@ def get_tree(self, rev): return GitTree(self.repo, self.resolve_rev(rev)) def get_rev(self): - return self.repo.git.rev_parse("HEAD") + return self.repo.rev_parse("HEAD").hexsha def resolve_rev(self, rev): - from git.exc import GitCommandError - + from git.exc import BadName, GitCommandError + from contextlib import suppress + + def _resolve_rev(name): + with suppress(BadName, GitCommandError): + try: + # Try python implementation of rev-parse first, it's faster + return self.repo.rev_parse(name).hexsha + except NotImplementedError: + # Fall back to `git rev-parse` for advanced features + return self.repo.git.rev_parse(name) + + # Resolve across local names + sha = _resolve_rev(rev) + if sha: + return sha + + # Try all the remotes and if it resolves unambiguously then take it + if not Git.is_sha(rev): + shas = { + _resolve_rev("{}/{}".format(remote.name, rev)) + for remote in self.repo.remotes + } - {None} + if len(shas) > 1: + raise RevError("ambiguous Git revision '{}'".format(rev)) + if len(shas) == 1: + return shas.pop() + + raise RevError("unknown Git revision '{}'".format(rev)) + + def has_rev(self, rev): try: - return self.repo.git.rev_parse(rev) - except GitCommandError: - raise RevError("unknown Git revision '{}'".format(rev)) + self.resolve_rev(rev) + return True + except RevError: + return False def close(self): self.repo.close()