Skip to content

Rolling deployments for repo updates #2853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,14 @@ func (ex *RunExecutor) SetRunnerState(state string) {
ex.state = state
}

func (ex *RunExecutor) getRepoData() schemas.RepoData {
if ex.jobSpec.RepoData == nil {
// jobs submitted before 0.19.17 do not have jobSpec.RepoData
return ex.run.RunSpec.RepoData
}
return *ex.jobSpec.RepoData
}

func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error {
node_rank := ex.jobSpec.JobNum
nodes_num := ex.jobSpec.JobsPerReplica
Expand Down
5 changes: 3 additions & 2 deletions runner/internal/executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func TestExecutor_RemoteRepo(t *testing.T) {

var b bytes.Buffer
ex := makeTestExecutor(t)
ex.run.RunSpec.RepoData = schemas.RepoData{
ex.jobSpec.RepoData = &schemas.RepoData{
RepoType: "remote",
RepoBranch: "main",
RepoHash: "2b83592e506ed6fe8e49f4eaa97c3866bc9402b1",
Expand All @@ -148,7 +148,7 @@ func TestExecutor_RemoteRepo(t *testing.T) {

err = ex.execJob(context.TODO(), io.Writer(&b))
assert.NoError(t, err)
expected := fmt.Sprintf("%s\r\n%s\r\n%s\r\n", ex.run.RunSpec.RepoData.RepoHash, ex.run.RunSpec.RepoData.RepoConfigName, ex.run.RunSpec.RepoData.RepoConfigEmail)
expected := fmt.Sprintf("%s\r\n%s\r\n%s\r\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail)
assert.Equal(t, expected, b.String())
}

Expand Down Expand Up @@ -178,6 +178,7 @@ func makeTestExecutor(t *testing.T) *RunExecutor {
Env: make(map[string]string),
MaxDuration: 0, // no timeout
WorkingDir: &workingDir,
RepoData: &schemas.RepoData{RepoType: "local"},
},
Secrets: make(map[string]string),
RepoCredentials: &schemas.RepoCredentials{
Expand Down
10 changes: 5 additions & 5 deletions runner/internal/executor/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error {
err = gerrors.Wrap(err_)
}
}()
switch ex.run.RunSpec.RepoData.RepoType {
switch ex.getRepoData().RepoType {
case "remote":
log.Trace(ctx, "Fetching git repository")
if err := ex.prepareGit(ctx); err != nil {
Expand All @@ -52,7 +52,7 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error {
return gerrors.Wrap(err)
}
default:
return gerrors.Newf("unknown RepoType: %s", ex.run.RunSpec.RepoData.RepoType)
return gerrors.Newf("unknown RepoType: %s", ex.getRepoData().RepoType)
}
return err
}
Expand All @@ -61,8 +61,8 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error {
repoManager := repo.NewManager(
ctx,
ex.repoCredentials.CloneURL,
ex.run.RunSpec.RepoData.RepoBranch,
ex.run.RunSpec.RepoData.RepoHash,
ex.getRepoData().RepoBranch,
ex.getRepoData().RepoHash,
ex.jobSpec.SingleBranch,
).WithLocalPath(ex.workingDir)
if ex.repoCredentials != nil {
Expand Down Expand Up @@ -92,7 +92,7 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error {
if err := repoManager.Checkout(); err != nil {
return gerrors.Wrap(err)
}
if err := repoManager.SetConfig(ex.run.RunSpec.RepoData.RepoConfigName, ex.run.RunSpec.RepoData.RepoConfigEmail); err != nil {
if err := repoManager.SetConfig(ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail); err != nil {
return gerrors.Wrap(err)
}

Expand Down
4 changes: 4 additions & 0 deletions runner/internal/schemas/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ type JobSpec struct {
MaxDuration int `json:"max_duration"`
SSHKey *SSHKey `json:"ssh_key"`
WorkingDir *string `json:"working_dir"`
// `RepoData` is optional for compatibility with jobs submitted before 0.19.17.
// Use `RunExecutor.getRepoData()` to get non-nil `RepoData`.
// TODO: make required when supporting jobs submitted before 0.19.17 is no longer relevant.
RepoData *RepoData `json:"repo_data"`
}

type ClusterInfo struct {
Expand Down
71 changes: 56 additions & 15 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@
)
from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.models.resources import CPUSpec
from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunStatus
from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.core.services.diff import diff_models
from dstack._internal.utils.common import local_time
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.nested_list import NestedList, NestedListItem
from dstack.api._public.repos import get_ssh_keypair
from dstack.api._public.runs import Run
from dstack.api.utils import load_profile
Expand Down Expand Up @@ -102,25 +103,20 @@ def apply_configuration(
confirm_message = f"Submit the run [code]{conf.name}[/]?"
stop_run_name = None
if run_plan.current_resource is not None:
changed_fields = []
if run_plan.action == ApplyAction.UPDATE:
diff = diff_models(
run_plan.get_effective_run_spec().configuration,
run_plan.current_resource.run_spec.configuration,
)
changed_fields = list(diff.keys())
if run_plan.action == ApplyAction.UPDATE and len(changed_fields) > 0:
diff = render_run_spec_diff(
run_plan.get_effective_run_spec(),
run_plan.current_resource.run_spec,
)
if run_plan.action == ApplyAction.UPDATE and diff is not None:
console.print(
f"Active run [code]{conf.name}[/] already exists."
" Detected configuration changes that can be updated in-place:"
f" {changed_fields}"
f" Detected changes that [code]can[/] be updated in-place:\n{diff}"
)
confirm_message = "Update the run?"
elif run_plan.action == ApplyAction.UPDATE and len(changed_fields) == 0:
elif run_plan.action == ApplyAction.UPDATE and diff is None:
stop_run_name = run_plan.current_resource.run_spec.run_name
console.print(
f"Active run [code]{conf.name}[/] already exists."
" Detected no configuration changes."
f"Active run [code]{conf.name}[/] already exists. Detected no changes."
)
if command_args.yes and not command_args.force:
console.print("Use --force to apply anyway.")
Expand All @@ -129,7 +125,8 @@ def apply_configuration(
elif not run_plan.current_resource.status.is_finished():
stop_run_name = run_plan.current_resource.run_spec.run_name
console.print(
f"Active run [code]{conf.name}[/] already exists and cannot be updated in-place."
f"Active run [code]{conf.name}[/] already exists."
f" Detected changes that [error]cannot[/] be updated in-place:\n{diff}"
)
confirm_message = "Stop and override the run?"

Expand Down Expand Up @@ -611,3 +608,47 @@ def _run_resubmitted(run: Run, current_job_submission: Optional[JobSubmission])
not run.status.is_finished()
and run._run.latest_job_submission.submitted_at > current_job_submission.submitted_at
)


def render_run_spec_diff(old_spec: RunSpec, new_spec: RunSpec) -> Optional[str]:
changed_spec_fields = list(diff_models(old_spec, new_spec))
if not changed_spec_fields:
return None
friendly_spec_field_names = {
"repo_id": "Repo ID",
"repo_code_hash": "Repo files",
"repo_data": "Repo state (branch, commit, or other)",
"ssh_key_pub": "Public SSH key",
}
nested_list = NestedList()
for spec_field in changed_spec_fields:
if spec_field == "merged_profile":
continue
elif spec_field == "configuration":
if type(old_spec.configuration) is not type(new_spec.configuration):
item = NestedListItem("Configuration type")
else:
item = NestedListItem(
"Configuration properties:",
children=[
NestedListItem(field)
for field in diff_models(old_spec.configuration, new_spec.configuration)
],
)
elif spec_field == "profile":
if type(old_spec.profile) is not type(new_spec.profile):
item = NestedListItem("Profile")
else:
item = NestedListItem(
"Profile properties:",
children=[
NestedListItem(field)
for field in diff_models(old_spec.profile, new_spec.profile)
],
)
elif spec_field in friendly_spec_field_names:
item = NestedListItem(friendly_spec_field_names[spec_field])
else:
item = NestedListItem(spec_field.replace("_", " ").capitalize())
nested_list.children.append(item)
return nested_list.render()
25 changes: 23 additions & 2 deletions src/dstack/_internal/core/compatibility/runs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional

from dstack._internal.core.models.configurations import ServiceConfiguration
from dstack._internal.core.models.runs import ApplyRunPlanInput, JobSubmission, RunSpec
from dstack._internal.core.models.runs import ApplyRunPlanInput, JobSpec, JobSubmission, RunSpec
from dstack._internal.server.schemas.runs import GetRunPlanRequest


Expand All @@ -25,7 +25,10 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[Dict]:
current_resource_excludes["run_spec"] = get_run_spec_excludes(current_resource.run_spec)
job_submissions_excludes = {}
current_resource_excludes["jobs"] = {
"__all__": {"job_submissions": {"__all__": job_submissions_excludes}}
"__all__": {
"job_spec": get_job_spec_excludes([job.job_spec for job in current_resource.jobs]),
"job_submissions": {"__all__": job_submissions_excludes},
}
}
job_submissions = [js for j in current_resource.jobs for js in j.job_submissions]
if all(map(_should_exclude_job_submission_jpd_cpu_arch, job_submissions)):
Expand Down Expand Up @@ -123,6 +126,24 @@ def get_run_spec_excludes(run_spec: RunSpec) -> Optional[Dict]:
return None


def get_job_spec_excludes(job_specs: list[JobSpec]) -> Optional[dict]:
"""
Returns `job_spec` exclude mapping to exclude certain fields from the request.
Use this method to exclude new fields when they are not set to keep
clients backward-compatibility with older servers.
"""
spec_excludes: dict[str, Any] = {}

if all(s.repo_code_hash is None for s in job_specs):
spec_excludes["repo_code_hash"] = True
if all(s.repo_data is None for s in job_specs):
spec_excludes["repo_data"] = True

if spec_excludes:
return spec_excludes
return None


def _should_exclude_job_submission_jpd_cpu_arch(job_submission: JobSubmission) -> bool:
try:
return job_submission.job_provisioning_data.instance_type.resources.cpu_arch is None
Expand Down
8 changes: 8 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ class JobSpec(CoreModel):
volumes: Optional[List[MountPoint]] = None
ssh_key: Optional[JobSSHKey] = None
working_dir: Optional[str]
# `repo_data` is optional for client compatibility with pre-0.19.17 servers and for compatibility
# with jobs submitted before 0.19.17. All new jobs are expected to have non-None `repo_data`.
# For --no-repo runs, `repo_data` is `VirtualRunRepoData()`.
repo_data: Annotated[Optional[AnyRunRepoData], Field(discriminator="repo_type")] = None
# `repo_code_hash` can be None because it is not used for the repo or because the job was
# submitted before 0.19.17. See `_get_repo_code_hash` on how to get the correct `repo_code_hash`
# TODO: drop this comment when supporting jobs submitted before 0.19.17 is no longer relevant.
repo_code_hash: Optional[str] = None


class JobProvisioningData(CoreModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
session=session,
project=project,
repo=repo_model,
code_hash=run.run_spec.repo_code_hash,
code_hash=_get_repo_code_hash(run, job),
)

success = await common_utils.run_async(
Expand Down Expand Up @@ -293,7 +293,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
session=session,
project=project,
repo=repo_model,
code_hash=run.run_spec.repo_code_hash,
code_hash=_get_repo_code_hash(run, job),
)
success = await common_utils.run_async(
_process_pulling_with_shim,
Expand Down Expand Up @@ -849,6 +849,19 @@ def _get_cluster_info(
return cluster_info


def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]:
# TODO: drop this function when supporting jobs submitted before 0.19.17 is no longer relevant.
if (
job.job_spec.repo_code_hash is None
and run.run_spec.repo_code_hash is not None
and job.job_submissions[-1].deployment_num == run.deployment_num
):
# The job spec does not have `repo_code_hash`, because it was submitted before 0.19.17.
# Use `repo_code_hash` from the run.
return run.run_spec.repo_code_hash
return job.job_spec.repo_code_hash


async def _get_job_code(
session: AsyncSession, project: ProjectModel, repo: RepoModel, code_hash: Optional[str]
) -> bytes:
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/server/schemas/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class SubmitBody(CoreModel):
"max_duration",
"ssh_key",
"working_dir",
"repo_data",
}
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ async def _get_job_spec(
working_dir=self._working_dir(),
volumes=self._volumes(job_num),
ssh_key=self._ssh_key(jobs_per_replica),
repo_data=self.run_spec.repo_data,
repo_code_hash=self.run_spec.repo_code_hash,
)
return job_spec

Expand Down
16 changes: 13 additions & 3 deletions src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,14 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec):
set_resources_defaults(run_spec.configuration.resources)


_UPDATABLE_SPEC_FIELDS = ["repo_code_hash", "configuration"]
_UPDATABLE_SPEC_FIELDS = ["configuration"]
_TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS = {
"service": [
# rolling deployment
"repo_data",
"repo_code_hash",
],
}
_CONF_UPDATABLE_FIELDS = ["priority"]
_TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS = {
"dev-environment": ["inactivity_duration"],
Expand Down Expand Up @@ -935,11 +942,14 @@ def _can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> bo
def _check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec):
spec_diff = diff_models(current_run_spec, new_run_spec)
changed_spec_fields = list(spec_diff.keys())
updatable_spec_fields = _UPDATABLE_SPEC_FIELDS + _TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS.get(
new_run_spec.configuration.type, []
)
for key in changed_spec_fields:
if key not in _UPDATABLE_SPEC_FIELDS:
if key not in updatable_spec_fields:
raise ServerClientError(
f"Failed to update fields {changed_spec_fields}."
f" Can only update {_UPDATABLE_SPEC_FIELDS}."
f" Can only update {updatable_spec_fields}."
)
_check_can_update_configuration(current_run_spec.configuration, new_run_spec.configuration)

Expand Down
10 changes: 6 additions & 4 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import uuid
from collections.abc import Callable
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import Dict, List, Literal, Optional, Union
Expand Down Expand Up @@ -252,18 +253,19 @@ async def create_file_archive(
def get_run_spec(
run_name: str,
repo_id: str,
profile: Optional[Profile] = None,
configuration_path: str = "dstack.yaml",
profile: Union[Profile, Callable[[], Profile], None] = lambda: Profile(name="default"),
configuration: Optional[AnyRunConfiguration] = None,
) -> RunSpec:
if profile is None:
profile = Profile(name="default")
if callable(profile):
profile = profile()
return RunSpec(
run_name=run_name,
repo_id=repo_id,
repo_data=LocalRunRepoData(repo_dir="/"),
repo_code_hash=None,
working_dir=".",
configuration_path="dstack.yaml",
configuration_path=configuration_path,
configuration=configuration or DevEnvironmentConfiguration(ide="vscode"),
profile=profile,
ssh_key_pub="user_ssh_key",
Expand Down
Loading
Loading