diff --git a/dvc/command/update.py b/dvc/command/update.py index 47df68dd6c..73f4d097d6 100644 --- a/dvc/command/update.py +++ b/dvc/command/update.py @@ -14,7 +14,7 @@ def run(self): ret = 0 for target in self.args.targets: try: - self.repo.update(target) + self.repo.update(target, self.args.rev) except DvcException: logger.exception("failed to update '{}'.".format(target)) ret = 1 @@ -33,4 +33,7 @@ def add_parser(subparsers, parent_parser): update_parser.add_argument( "targets", nargs="+", help="DVC-files to update." ) + update_parser.add_argument( + "--rev", nargs="?", help="Git revision (e.g. SHA, branch, tag)" + ) update_parser.set_defaults(func=CmdUpdate) diff --git a/dvc/dependency/base.py b/dvc/dependency/base.py index 270fdbba71..08fce73a9f 100644 --- a/dvc/dependency/base.py +++ b/dvc/dependency/base.py @@ -27,5 +27,5 @@ class DependencyBase(object): IsNotFileOrDirError = DependencyIsNotFileOrDirError IsStageFileError = DependencyIsStageFileError - def update(self): + def update(self, rev=None): pass diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 26d58d609d..53edac176b 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -80,6 +80,9 @@ def download(self, to): repo.pull_to(self.def_path, to.path_info) - def update(self): + def update(self, rev=None): + if rev: + self.def_repo[self.PARAM_REV] = rev + with self._make_repo(locked=False) as repo: self.def_repo[self.PARAM_REV_LOCK] = repo.scm.get_rev() diff --git a/dvc/repo/update.py b/dvc/repo/update.py index 9533bb6d36..44252d0057 100644 --- a/dvc/repo/update.py +++ b/dvc/repo/update.py @@ -2,10 +2,12 @@ @locked -def update(self, target): +def update(self, target, rev=None): from dvc.stage import Stage stage = Stage.load(self, target) - stage.update() + stage.update(rev) stage.dump() + + return stage diff --git a/dvc/stage.py b/dvc/stage.py index 8c9d79b8ad..79e52add9e 100644 --- a/dvc/stage.py +++ b/dvc/stage.py @@ -405,11 +405,11 @@ def reproduce(self, interactive=False, **kwargs): return self - def update(self): + def update(self, rev=None): if not self.is_repo_import and not self.is_import: raise StageUpdateError(self.relpath) - self.deps[0].update() + self.deps[0].update(rev=rev) locked = self.locked self.locked = False try: diff --git a/scripts/completion/dvc.bash b/scripts/completion/dvc.bash index f84efcf832..0727aa8ced 100644 --- a/scripts/completion/dvc.bash +++ b/scripts/completion/dvc.bash @@ -58,7 +58,7 @@ _dvc_run='--no-exec -f --file -c --cwd -d --deps -o --outs -O --outs-no-cache -- _dvc_status='-j --jobs -r --remote -a --all-branches -T --all-tags -d --with-deps -c --cloud $(compgen -G *.dvc)' _dvc_unlock='$(compgen -G *.dvc)' _dvc_unprotect='$(compgen -G *)' -_dvc_update='$(compgen -G *.dvc)' +_dvc_update='--rev $(compgen -G *.dvc)' _dvc_version='' # Params diff --git a/scripts/completion/dvc.zsh b/scripts/completion/dvc.zsh index 7ee7386b29..4549306ac9 100644 --- a/scripts/completion/dvc.zsh +++ b/scripts/completion/dvc.zsh @@ -273,6 +273,7 @@ _dvc_unprotect=( ) _dvc_update=( + "--rev[Git revision (e.g. SHA, branch, tag)]:Revision:" "*:Stages:_files -g '(*.dvc|Dvcfile)'" ) diff --git a/tests/func/test_update.py b/tests/func/test_update.py index 54a5bacbf5..ca69a04309 100644 --- a/tests/func/test_update.py +++ b/tests/func/test_update.py @@ -2,7 +2,7 @@ import os from dvc.stage import Stage -from dvc.compat import fspath +from dvc.compat import fspath, fspath_py35 @pytest.mark.parametrize("cached", [True, False]) @@ -131,3 +131,37 @@ def test_update_import_url(tmp_dir, dvc, tmp_path_factory): assert dst.is_file() assert dst.read_text() == "updated file content" + + +def test_update_rev(tmp_dir, dvc, scm, git_dir): + with git_dir.chdir(): + git_dir.scm_gen({"foo": "foo"}, commit="first") + + dvc.imp(fspath(git_dir), "foo") + assert (tmp_dir / "foo.dvc").exists() + + with git_dir.chdir(), git_dir.branch("branch1", new=True): + git_dir.scm_gen({"foo": "foobar"}, commit="branch1 commit") + branch1_head = git_dir.scm.get_rev() + + with git_dir.chdir(), git_dir.branch("branch2", new=True): + git_dir.scm_gen({"foo": "foobar foo"}, commit="branch2 commit") + branch2_head = git_dir.scm.get_rev() + + stage = dvc.update("foo.dvc", rev="branch1") + assert stage.deps[0].def_repo == { + "url": fspath(git_dir), + "rev": "branch1", + "rev_lock": branch1_head, + } + with open(fspath_py35(tmp_dir / "foo")) as f: + assert "foobar" == f.read() + + stage = dvc.update("foo.dvc", rev="branch2") + assert stage.deps[0].def_repo == { + "url": fspath(git_dir), + "rev": "branch2", + "rev_lock": branch2_head, + } + with open(fspath_py35(tmp_dir / "foo")) as f: + assert "foobar foo" == f.read() diff --git a/tests/unit/command/test_update.py b/tests/unit/command/test_update.py index 677b24a5ab..468c731368 100644 --- a/tests/unit/command/test_update.py +++ b/tests/unit/command/test_update.py @@ -1,15 +1,19 @@ +import pytest from dvc.cli import parse_args from dvc.command.update import CmdUpdate -def test_update(dvc, mocker): +@pytest.mark.parametrize( + "command,rev", [(["update"], None), (["update", "--rev", "REV"], "REV")] +) +def test_update(dvc, mocker, command, rev): targets = ["target1", "target2", "target3"] - cli_args = parse_args(["update"] + targets) + cli_args = parse_args(command + targets) assert cli_args.func == CmdUpdate cmd = cli_args.func(cli_args) m = mocker.patch("dvc.repo.Repo.update") assert cmd.run() == 0 - calls = [mocker.call(target) for target in targets] + calls = [mocker.call(target, rev) for target in targets] m.assert_has_calls(calls)