1414
1515from dvc .env import DVC_EXP_AUTO_PUSH , DVC_EXP_GIT_REMOTE
1616from dvc .exceptions import DvcException
17- from dvc .log import logger
17+ from dvc .log import logger
1818from dvc .repo .experiments .exceptions import ExperimentExistsError
1919from dvc .repo .experiments .refs import EXEC_BASELINE , EXEC_BRANCH , ExpRefInfo
2020from dvc .repo .experiments .utils import to_studio_params
2121from dvc .repo .metrics .show import _collect_top_level_metrics
2222from dvc .repo .params .show import _collect_top_level_params
23+ from dvc .scm import Git
2324from dvc .stage .serialize import to_lockfile
2425from dvc .utils import dict_sha256 , env2bool , relpath
2526from dvc .utils .fs import remove
3536
3637 from dvc .repo import Repo
3738 from dvc .repo .experiments .stash import ExpStashEntry
38- from dvc .scm import Git
3939 from dvc .stage import PipelineStage , Stage
4040
4141logger = logger .getChild (__name__ )
@@ -298,7 +298,7 @@ def save(
298298 if include_untracked :
299299 dvc .scm .add (include_untracked , force = True ) # type: ignore[call-arg]
300300 cls .commit (
301- dvc . scm , # type: ignore[arg-type]
301+ dvc , # type: ignore[arg-type]
302302 exp_hash ,
303303 exp_name = info .name ,
304304 force = force ,
@@ -460,9 +460,6 @@ def reproduce(
460460 from dvc .repo .checkout import checkout as dvc_checkout
461461 from dvc .ui import ui
462462
463- auto_push = env2bool (DVC_EXP_AUTO_PUSH )
464- git_remote = os .getenv (DVC_EXP_GIT_REMOTE , None )
465-
466463 if queue is not None :
467464 queue .put ((rev , os .getpid ()))
468465 if log_errors and log_level is not None :
@@ -483,9 +480,6 @@ def reproduce(
483480 message = message ,
484481 ** kwargs ,
485482 ) as dvc :
486- if auto_push :
487- cls ._validate_remotes (dvc , git_remote )
488-
489483 args , kwargs = cls ._repro_args (dvc )
490484 if args :
491485 targets : Optional [Union [list , str ]] = args [0 ]
@@ -519,8 +513,6 @@ def reproduce(
519513 dvc ,
520514 info ,
521515 exp_hash ,
522- auto_push ,
523- git_remote ,
524516 repro_force ,
525517 message = message ,
526518 )
@@ -550,20 +542,17 @@ def _repro_commit(
550542 dvc ,
551543 info ,
552544 exp_hash ,
553- auto_push ,
554- git_remote ,
555545 repro_force ,
556546 message : Optional [str ] = None ,
557547 ) -> tuple [Optional [str ], Optional ["ExpRefInfo" ], bool ]:
558548 cls .commit (
559- dvc . scm ,
549+ dvc ,
560550 exp_hash ,
561551 exp_name = info .name ,
562552 force = repro_force ,
563553 message = message ,
564554 )
565- if auto_push :
566- cls ._auto_push (dvc , dvc .scm , git_remote )
555+
567556 ref : Optional [str ] = dvc .scm .get_ref (EXEC_BRANCH , follow = False )
568557 exp_ref : Optional ["ExpRefInfo" ] = ExpRefInfo .from_ref (ref ) if ref else None
569558 if cls .WARN_UNTRACKED :
@@ -675,12 +664,11 @@ def _repro_args(cls, dvc):
675664 @staticmethod
676665 def _auto_push (
677666 dvc : "Repo" ,
678- scm : "Git" ,
679667 git_remote : Optional [str ],
680668 push_cache = True ,
681669 run_cache = True ,
682670 ):
683- branch = scm .get_ref (EXEC_BRANCH , follow = False )
671+ branch = dvc . scm .get_ref (EXEC_BRANCH , follow = False )
684672 try :
685673 dvc .experiments .push (
686674 git_remote ,
@@ -701,13 +689,30 @@ def _auto_push(
701689 @classmethod
702690 def commit (
703691 cls ,
704- scm : "Git " ,
692+ dvc : "Repo " ,
705693 exp_hash : str ,
706694 exp_name : Optional [str ] = None ,
707695 force : bool = False ,
708696 message : Optional [str ] = None ,
709697 ):
710- """Commit stages as an experiment and return the commit SHA."""
698+ """
699+ Commit stages as an experiment and return the commit SHA.
700+ Push the experiment if env `DVC_EXP_AUTO_PUSH` is True.
701+ """
702+
703+ if not isinstance (dvc .scm , Git ):
704+ raise DvcException ("Only Git supported for experiment commits" )
705+
706+ scm : Git = dvc .scm
707+
708+ exp_config = dvc .config .get ("exp" , {})
709+ auto_push = env2bool (DVC_EXP_AUTO_PUSH , exp_config .get ("auto_push" , False ))
710+ git_remote = os .getenv (
711+ DVC_EXP_GIT_REMOTE , exp_config .get ("git_remote" , "origin" )
712+ )
713+ if auto_push :
714+ cls ._validate_remotes (dvc , git_remote )
715+
711716 rev = scm .get_rev ()
712717 if not scm .is_dirty (untracked_files = False ):
713718 logger .debug ("No changes to commit" )
@@ -744,6 +749,9 @@ def commit(
744749 scm .set_ref (branch , new_rev , old_ref = old_ref )
745750 scm .set_ref (EXEC_BRANCH , branch , symbolic = True )
746751
752+ if auto_push :
753+ cls ._auto_push (dvc , git_remote )
754+
747755 return new_rev
748756
749757 @staticmethod
0 commit comments