Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def __call__(self, data):
"params": str,
"plots": str,
"live": str,
"auto_push": Bool,
"git_remote": str,
},
"parsing": {
"bool": All(Lower, Choices("store_true", "boolean_optional")),
Expand Down
65 changes: 37 additions & 28 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,16 @@ def save(
exp_hash = cls.hash_exp(stages)
if include_untracked:
dvc.scm.add(include_untracked, force=True) # type: ignore[call-arg]
cls.commit(
dvc.scm, # type: ignore[arg-type]
exp_hash,
exp_name=info.name,
force=force,
message=message,
)

with cls.auto_push(dvc):
cls.commit(
dvc.scm, # type: ignore[arg-type]
exp_hash,
exp_name=info.name,
force=force,
message=message,
)

ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
exp_ref = ExpRefInfo.from_ref(ref) if ref else None
untracked = dvc.scm.untracked_files()
Expand Down Expand Up @@ -460,9 +463,6 @@ def reproduce(
from dvc.repo.checkout import checkout as dvc_checkout
from dvc.ui import ui

auto_push = env2bool(DVC_EXP_AUTO_PUSH)
git_remote = os.getenv(DVC_EXP_GIT_REMOTE, None)

if queue is not None:
queue.put((rev, os.getpid()))
if log_errors and log_level is not None:
Expand All @@ -483,9 +483,6 @@ def reproduce(
message=message,
**kwargs,
) as dvc:
if auto_push:
cls._validate_remotes(dvc, git_remote)

args, kwargs = cls._repro_args(dvc)
if args:
targets: Optional[Union[list, str]] = args[0]
Expand Down Expand Up @@ -519,8 +516,6 @@ def reproduce(
dvc,
info,
exp_hash,
auto_push,
git_remote,
repro_force,
message=message,
)
Expand Down Expand Up @@ -550,20 +545,18 @@ def _repro_commit(
dvc,
info,
exp_hash,
auto_push,
git_remote,
repro_force,
message: Optional[str] = None,
) -> tuple[Optional[str], Optional["ExpRefInfo"], bool]:
cls.commit(
dvc.scm,
exp_hash,
exp_name=info.name,
force=repro_force,
message=message,
)
if auto_push:
cls._auto_push(dvc, dvc.scm, git_remote)
with cls.auto_push(dvc):
cls.commit(
dvc.scm,
exp_hash,
exp_name=info.name,
force=repro_force,
message=message,
)

ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
exp_ref: Optional["ExpRefInfo"] = ExpRefInfo.from_ref(ref) if ref else None
if cls.WARN_UNTRACKED:
Expand Down Expand Up @@ -672,15 +665,30 @@ def _repro_args(cls, dvc):
kwargs = {}
return args, kwargs

@classmethod
@contextmanager
def auto_push(cls, dvc: "Repo") -> Iterator[None]:
exp_config = dvc.config.get("exp", {})
auto_push = env2bool(DVC_EXP_AUTO_PUSH, exp_config.get("auto_push", False))
if not auto_push:
yield
return

git_remote = os.getenv(
DVC_EXP_GIT_REMOTE, exp_config.get("git_remote", "origin")
)
cls._validate_remotes(dvc, git_remote)
yield
cls._auto_push(dvc, git_remote)

@staticmethod
def _auto_push(
dvc: "Repo",
scm: "Git",
git_remote: Optional[str],
push_cache=True,
run_cache=True,
):
branch = scm.get_ref(EXEC_BRANCH, follow=False)
branch = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
try:
dvc.experiments.push(
git_remote,
Expand Down Expand Up @@ -708,6 +716,7 @@ def commit(
message: Optional[str] = None,
):
"""Commit stages as an experiment and return the commit SHA."""

rev = scm.get_rev()
if not scm.is_dirty(untracked_files=False):
logger.debug("No changes to commit")
Expand Down
37 changes: 37 additions & 0 deletions tests/func/experiments/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,40 @@ def test_push_pull_invalid_workspace(
dvc.experiments.push(git_upstream.remote, push_cache=True)
dvc.experiments.pull(git_upstream.remote, pull_cache=True)
assert "failed to collect" not in caplog.text


@pytest.mark.parametrize(
"auto_push, expected_key", [(True, "up_to_date"), (False, "success")]
)
def test_auto_push_on_run(
tmp_dir, scm, dvc, git_upstream, local_remote, exp_stage, auto_push, expected_key
):
remote = git_upstream.remote

with dvc.config.edit() as conf:
conf["exp"]["auto_push"] = auto_push
conf["exp"]["git_remote"] = remote

exp_name = "foo"
dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name=exp_name)

assert first(dvc.experiments.push(name=exp_name, git_remote=remote)) == expected_key


@pytest.mark.parametrize(
"auto_push, expected_key", [(True, "up_to_date"), (False, "success")]
)
def test_auto_push_on_save(
tmp_dir, scm, dvc, git_upstream, local_remote, exp_stage, auto_push, expected_key
):
remote = git_upstream.remote
exp_name = "foo"
dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name=exp_name)

with dvc.config.edit() as conf:
conf["exp"]["auto_push"] = auto_push
conf["exp"]["git_remote"] = remote

dvc.experiments.save(name=exp_name, force=True)

assert first(dvc.experiments.push(name=exp_name, git_remote=remote)) == expected_key