Skip to content

Commit db5f48d

Browse files
add auto_push for experiment on run and on save.
add config.exp.auto_push and config.exp.git_remote
1 parent 89537a7 commit db5f48d

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

dvc/config_schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ def __call__(self, data):
336336
"params": str,
337337
"plots": str,
338338
"live": str,
339+
"auto_push": Bool,
340+
"git_remote": str,
339341
},
340342
"parsing": {
341343
"bool": All(Lower, Choices("store_true", "boolean_optional")),

dvc/repo/experiments/executor/base.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
from dvc.env import DVC_EXP_AUTO_PUSH, DVC_EXP_GIT_REMOTE
1616
from dvc.exceptions import DvcException
17-
from dvc.log import logger
17+
from dvc.log import logger
1818
from dvc.repo.experiments.exceptions import ExperimentExistsError
1919
from dvc.repo.experiments.refs import EXEC_BASELINE, EXEC_BRANCH, ExpRefInfo
2020
from dvc.repo.experiments.utils import to_studio_params
2121
from dvc.repo.metrics.show import _collect_top_level_metrics
2222
from dvc.repo.params.show import _collect_top_level_params
23+
from dvc.scm import Git
2324
from dvc.stage.serialize import to_lockfile
2425
from dvc.utils import dict_sha256, env2bool, relpath
2526
from dvc.utils.fs import remove
@@ -35,7 +36,6 @@
3536

3637
from dvc.repo import Repo
3738
from dvc.repo.experiments.stash import ExpStashEntry
38-
from dvc.scm import Git
3939
from dvc.stage import PipelineStage, Stage
4040

4141
logger = logger.getChild(__name__)
@@ -298,7 +298,7 @@ def save(
298298
if include_untracked:
299299
dvc.scm.add(include_untracked, force=True) # type: ignore[call-arg]
300300
cls.commit(
301-
dvc.scm, # type: ignore[arg-type]
301+
dvc, # type: ignore[arg-type]
302302
exp_hash,
303303
exp_name=info.name,
304304
force=force,
@@ -460,9 +460,6 @@ def reproduce(
460460
from dvc.repo.checkout import checkout as dvc_checkout
461461
from dvc.ui import ui
462462

463-
auto_push = env2bool(DVC_EXP_AUTO_PUSH)
464-
git_remote = os.getenv(DVC_EXP_GIT_REMOTE, None)
465-
466463
if queue is not None:
467464
queue.put((rev, os.getpid()))
468465
if log_errors and log_level is not None:
@@ -483,9 +480,6 @@ def reproduce(
483480
message=message,
484481
**kwargs,
485482
) as dvc:
486-
if auto_push:
487-
cls._validate_remotes(dvc, git_remote)
488-
489483
args, kwargs = cls._repro_args(dvc)
490484
if args:
491485
targets: Optional[Union[list, str]] = args[0]
@@ -519,8 +513,6 @@ def reproduce(
519513
dvc,
520514
info,
521515
exp_hash,
522-
auto_push,
523-
git_remote,
524516
repro_force,
525517
message=message,
526518
)
@@ -550,20 +542,17 @@ def _repro_commit(
550542
dvc,
551543
info,
552544
exp_hash,
553-
auto_push,
554-
git_remote,
555545
repro_force,
556546
message: Optional[str] = None,
557547
) -> tuple[Optional[str], Optional["ExpRefInfo"], bool]:
558548
cls.commit(
559-
dvc.scm,
549+
dvc,
560550
exp_hash,
561551
exp_name=info.name,
562552
force=repro_force,
563553
message=message,
564554
)
565-
if auto_push:
566-
cls._auto_push(dvc, dvc.scm, git_remote)
555+
567556
ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
568557
exp_ref: Optional["ExpRefInfo"] = ExpRefInfo.from_ref(ref) if ref else None
569558
if cls.WARN_UNTRACKED:
@@ -675,12 +664,11 @@ def _repro_args(cls, dvc):
675664
@staticmethod
676665
def _auto_push(
677666
dvc: "Repo",
678-
scm: "Git",
679667
git_remote: Optional[str],
680668
push_cache=True,
681669
run_cache=True,
682670
):
683-
branch = scm.get_ref(EXEC_BRANCH, follow=False)
671+
branch = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
684672
try:
685673
dvc.experiments.push(
686674
git_remote,
@@ -701,13 +689,30 @@ def _auto_push(
701689
@classmethod
702690
def commit(
703691
cls,
704-
scm: "Git",
692+
dvc: "Repo",
705693
exp_hash: str,
706694
exp_name: Optional[str] = None,
707695
force: bool = False,
708696
message: Optional[str] = None,
709697
):
710-
"""Commit stages as an experiment and return the commit SHA."""
698+
"""
699+
Commit stages as an experiment and return the commit SHA.
700+
Push the experiment if env `DVC_EXP_AUTO_PUSH` is True.
701+
"""
702+
703+
if not isinstance(dvc.scm, Git):
704+
raise DvcException("Only Git supported for experiment commits")
705+
706+
scm: Git = dvc.scm
707+
708+
exp_config = dvc.config.get("exp", {})
709+
auto_push = env2bool(DVC_EXP_AUTO_PUSH, exp_config.get("auto_push", False))
710+
git_remote = os.getenv(
711+
DVC_EXP_GIT_REMOTE, exp_config.get("git_remote", "origin")
712+
)
713+
if auto_push:
714+
cls._validate_remotes(dvc, git_remote)
715+
711716
rev = scm.get_rev()
712717
if not scm.is_dirty(untracked_files=False):
713718
logger.debug("No changes to commit")
@@ -744,6 +749,9 @@ def commit(
744749
scm.set_ref(branch, new_rev, old_ref=old_ref)
745750
scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
746751

752+
if auto_push:
753+
cls._auto_push(dvc, git_remote)
754+
747755
return new_rev
748756

749757
@staticmethod

0 commit comments

Comments
 (0)