Skip to content

Commit cd1e228

Browse files
AlexandreKempfpre-commit-ci[bot]skshetry
authored
experiments: auto push experiments (#10323)
* add auto_push for experiment on run and on save. add config.exp.auto_push and config.exp.git_remote * add func tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * separate auto_push behaviour out of commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Saugat Pachhai (सौगात) <suagatchhetri@outlook.com>
1 parent 3b2e031 commit cd1e228

File tree

3 files changed

+76
-28
lines changed

3 files changed

+76
-28
lines changed

dvc/config_schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ def __call__(self, data):
336336
"params": str,
337337
"plots": str,
338338
"live": str,
339+
"auto_push": Bool,
340+
"git_remote": str,
339341
},
340342
"parsing": {
341343
"bool": All(Lower, Choices("store_true", "boolean_optional")),

dvc/repo/experiments/executor/base.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -297,13 +297,16 @@ def save(
297297
exp_hash = cls.hash_exp(stages)
298298
if include_untracked:
299299
dvc.scm.add(include_untracked, force=True) # type: ignore[call-arg]
300-
cls.commit(
301-
dvc.scm, # type: ignore[arg-type]
302-
exp_hash,
303-
exp_name=info.name,
304-
force=force,
305-
message=message,
306-
)
300+
301+
with cls.auto_push(dvc):
302+
cls.commit(
303+
dvc.scm, # type: ignore[arg-type]
304+
exp_hash,
305+
exp_name=info.name,
306+
force=force,
307+
message=message,
308+
)
309+
307310
ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
308311
exp_ref = ExpRefInfo.from_ref(ref) if ref else None
309312
untracked = dvc.scm.untracked_files()
@@ -460,9 +463,6 @@ def reproduce(
460463
from dvc.repo.checkout import checkout as dvc_checkout
461464
from dvc.ui import ui
462465

463-
auto_push = env2bool(DVC_EXP_AUTO_PUSH)
464-
git_remote = os.getenv(DVC_EXP_GIT_REMOTE, None)
465-
466466
if queue is not None:
467467
queue.put((rev, os.getpid()))
468468
if log_errors and log_level is not None:
@@ -483,9 +483,6 @@ def reproduce(
483483
message=message,
484484
**kwargs,
485485
) as dvc:
486-
if auto_push:
487-
cls._validate_remotes(dvc, git_remote)
488-
489486
args, kwargs = cls._repro_args(dvc)
490487
if args:
491488
targets: Optional[Union[list, str]] = args[0]
@@ -519,8 +516,6 @@ def reproduce(
519516
dvc,
520517
info,
521518
exp_hash,
522-
auto_push,
523-
git_remote,
524519
repro_force,
525520
message=message,
526521
)
@@ -550,20 +545,18 @@ def _repro_commit(
550545
dvc,
551546
info,
552547
exp_hash,
553-
auto_push,
554-
git_remote,
555548
repro_force,
556549
message: Optional[str] = None,
557550
) -> tuple[Optional[str], Optional["ExpRefInfo"], bool]:
558-
cls.commit(
559-
dvc.scm,
560-
exp_hash,
561-
exp_name=info.name,
562-
force=repro_force,
563-
message=message,
564-
)
565-
if auto_push:
566-
cls._auto_push(dvc, dvc.scm, git_remote)
551+
with cls.auto_push(dvc):
552+
cls.commit(
553+
dvc.scm,
554+
exp_hash,
555+
exp_name=info.name,
556+
force=repro_force,
557+
message=message,
558+
)
559+
567560
ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
568561
exp_ref: Optional["ExpRefInfo"] = ExpRefInfo.from_ref(ref) if ref else None
569562
if cls.WARN_UNTRACKED:
@@ -672,15 +665,30 @@ def _repro_args(cls, dvc):
672665
kwargs = {}
673666
return args, kwargs
674667

668+
@classmethod
669+
@contextmanager
670+
def auto_push(cls, dvc: "Repo") -> Iterator[None]:
671+
exp_config = dvc.config.get("exp", {})
672+
auto_push = env2bool(DVC_EXP_AUTO_PUSH, exp_config.get("auto_push", False))
673+
if not auto_push:
674+
yield
675+
return
676+
677+
git_remote = os.getenv(
678+
DVC_EXP_GIT_REMOTE, exp_config.get("git_remote", "origin")
679+
)
680+
cls._validate_remotes(dvc, git_remote)
681+
yield
682+
cls._auto_push(dvc, git_remote)
683+
675684
@staticmethod
676685
def _auto_push(
677686
dvc: "Repo",
678-
scm: "Git",
679687
git_remote: Optional[str],
680688
push_cache=True,
681689
run_cache=True,
682690
):
683-
branch = scm.get_ref(EXEC_BRANCH, follow=False)
691+
branch = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
684692
try:
685693
dvc.experiments.push(
686694
git_remote,
@@ -708,6 +716,7 @@ def commit(
708716
message: Optional[str] = None,
709717
):
710718
"""Commit stages as an experiment and return the commit SHA."""
719+
711720
rev = scm.get_rev()
712721
if not scm.is_dirty(untracked_files=False):
713722
logger.debug("No changes to commit")

tests/func/experiments/test_remote.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,40 @@ def test_push_pull_invalid_workspace(
372372
dvc.experiments.push(git_upstream.remote, push_cache=True)
373373
dvc.experiments.pull(git_upstream.remote, pull_cache=True)
374374
assert "failed to collect" not in caplog.text
375+
376+
377+
@pytest.mark.parametrize(
378+
"auto_push, expected_key", [(True, "up_to_date"), (False, "success")]
379+
)
380+
def test_auto_push_on_run(
381+
tmp_dir, scm, dvc, git_upstream, local_remote, exp_stage, auto_push, expected_key
382+
):
383+
remote = git_upstream.remote
384+
385+
with dvc.config.edit() as conf:
386+
conf["exp"]["auto_push"] = auto_push
387+
conf["exp"]["git_remote"] = remote
388+
389+
exp_name = "foo"
390+
dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name=exp_name)
391+
392+
assert first(dvc.experiments.push(name=exp_name, git_remote=remote)) == expected_key
393+
394+
395+
@pytest.mark.parametrize(
396+
"auto_push, expected_key", [(True, "up_to_date"), (False, "success")]
397+
)
398+
def test_auto_push_on_save(
399+
tmp_dir, scm, dvc, git_upstream, local_remote, exp_stage, auto_push, expected_key
400+
):
401+
remote = git_upstream.remote
402+
exp_name = "foo"
403+
dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name=exp_name)
404+
405+
with dvc.config.edit() as conf:
406+
conf["exp"]["auto_push"] = auto_push
407+
conf["exp"]["git_remote"] = remote
408+
409+
dvc.experiments.save(name=exp_name, force=True)
410+
411+
assert first(dvc.experiments.push(name=exp_name, git_remote=remote)) == expected_key

0 commit comments

Comments
 (0)