+
+```shell
+$ dstack apply -f .dstack.yml
+
+Active run my-service already exists. Detected changes that can be updated in-place:
+- Repo state (branch, commit, or other)
+- File archives
+- Configuration properties:
+ - env
+ - files
+
+Update the run? [y/n]: y
+⠋ Launching my-service...
+
+ NAME BACKEND PRICE STATUS SUBMITTED
+ my-service deployment=1 running 11 mins ago
+ replica=0 deployment=0 aws (us-west-2) $0.0026 terminating 11 mins ago
+ replica=1 deployment=1 aws (us-west-2) $0.0026 running 1 min ago
+```
+
+
+
+
+
+#### Secrets
+
+Secrets let you centrally manage sensitive data like API keys and credentials. They’re scoped to a project, managed by project admins, and can be [securely referenced](../../docs/concepts/secrets.md) in run configurations.
+
+
+
+```yaml hl_lines="7"
+type: task
+name: train
+
+image: nvcr.io/nvidia/pytorch:25.05-py3
+registry_auth:
+ username: $oauthtoken
+ password: ${{ secrets.ngc_api_key }}
+
+commands:
+ - git clone https://github.com/pytorch/examples.git pytorch-examples
+ - cd pytorch-examples/distributed/ddp-tutorial-series
+ - pip install -r requirements.txt
+ - |
+ torchrun \
+ --nproc-per-node=$DSTACK_GPUS_PER_NODE \
+ --nnodes=$DSTACK_NODES_NUM \
+ multinode.py 50 10
+
+resources:
+ gpu: H100:1..2
+ shm_size: 24GB
+```
+
+
+
+#### Files
+
+By default, `dstack` mounts the repo directory (where you ran `dstack init`) to all runs.
+
+If the directory is large or you need files outside of it, use the new [files](../../docs/concepts/dev-environments/#files) property to map specific local paths into the container.
+
+
+
+```yaml
+type: task
+name: trl-sft
+
+files:
+ - .:examples # Maps the directory where `.dstack.yml` to `/workflow/examples`
+ - ~/.ssh/id_rsa:/root/.ssh/id_rsa # Maps `~/.ssh/id_rsa` to `/root/.ssh/id_rs
+
+python: 3.12
+
+env:
+ - HF_TOKEN
+ - HF_HUB_ENABLE_HF_TRANSFER=1
+ - MODEL=Qwen/Qwen2.5-0.5B
+ - DATASET=stanfordnlp/imdb
+
+commands:
+ - uv pip install trl
+ - |
+ trl sft \
+ --model_name_or_path $MODEL --dataset_name $DATASET
+ --num_processes $DSTACK_GPUS_PER_NODE
+
+resources:
+ gpu: H100:1
+```
+
+
+
+#### Tenstorrent
+
+`dstack` remains committed to supporting multiple GPU vendors—including NVIDIA, AMD, TPUs, and more recently, [Tenstorrent :material-arrow-top-right-thin:{ .external }](https://tenstorrent.com/){:target="_blank"}. The latest release improves Tenstorrent support by handling hosts with multiple N300 cards and adds Docker-in-Docker support.
+
+
+
+Huge thanks to the Tenstorrent community for testing these improvements!
+
+#### Docker in Docker
+
+Using Docker inside `dstack` run configurations is now even simpler. Just set `docker` to `true` to [enable the use of Docker CLI](../../docs/concepts/tasks.md#docker-in-docker) in your runs—allowing you to build images, run containers, use Docker Compose, and more.
+
+
+
+```yaml
+type: task
+name: docker-nvidia-smi
+
+docker: true
+
+commands:
+ - |
+ docker run --gpus all \
+ nvidia/cuda:12.3.0-base-ubuntu22.04 \
+ nvidia-smi
+
+resources:
+ gpu: H100:1
+```
+
+
+
+#### AWS EFA
+
+EFA is a network interface for EC2 that enables low-latency, high-bandwidth communication between nodes—crucial for scaling distributed deep learning. With `dstack`, EFA is automatically enabled when using supported instance types in fleets. Check out our [example](../../examples/clusters/efa/index.md)
+
+#### Default Docker images
+
+If no `image` is specified, `dstack` uses a base Docker image that now comes pre-configured with `uv`, `python`, `pip`, essential CUDA drivers, InfiniBand, and NCCL tests (located at `/opt/nccl-tests/build`).
+
+
+
+```yaml
+type: task
+name: nccl-tests
+
+nodes: 2
+
+startup_order: workers-first
+stop_criteria: master-done
+
+env:
+ - NCCL_DEBUG=INFO
+commands:
+ - |
+ if [ $DSTACK_NODE_RANK -eq 0 ]; then
+ mpirun \
+ --allow-run-as-root \
+ --hostfile $DSTACK_MPI_HOSTFILE \
+ -n $DSTACK_GPUS_NUM \
+ -N $DSTACK_GPUS_PER_NODE \
+ --bind-to none \
+ /opt/nccl-tests/build/all_reduce_perf -b 8 -e 8G -f 2 -g 1
+ else
+ sleep infinity
+ fi
+
+resources:
+ gpu: nvidia:1..8
+ shm_size: 16GB
+```
+
+
+
+These images are optimized for common use cases and kept lightweight—ideal for everyday development, training, and inference.
+
+#### Server performance
+
+Server-side performance has been improved. With optimized handling and background processing, each server replica can now handle more runs.
+
+#### Google SSO
+
+Alongside the open-source version, `dstack` also offers [dstack Enterprise :material-arrow-top-right-thin:{ .external }](https://github.com/dstackai/dstack-enterprise){:target="_blank"} — which adds dedicated support and extra integrations like Single Sign-On (SSO). The latest release introduces support for configuring your company’s Google account for authentication.
+
+
+
+If you’d like to learn more about `dstack` Enterprise, [let us know](https://calendly.com/dstackai/discovery-call).
+
+That’s all for now.
+
+!!! info "What's next?"
+ Give dstack a try, and share your feedback—whether it’s [GitHub :material-arrow-top-right-thin:{ .external }](https://github.com/dstackai/dstack){:target="_blank"} issues, PRs, or questions on [Discord :material-arrow-top-right-thin:{ .external }](https://discord.gg/u8SmfwPpMd){:target="_blank"}. We’re eager to hear from you!
diff --git a/docs/blog/posts/cursor.md b/docs/blog/posts/cursor.md
index bdc1e4a61..a5f960469 100644
--- a/docs/blog/posts/cursor.md
+++ b/docs/blog/posts/cursor.md
@@ -5,7 +5,7 @@ description: "TBA"
slug: cursor
image: https://dstack.ai/static-assets/static-assets/images/dstack-cursor-v2.png
categories:
- - Releases
+ - Changelog
---
# Accessing dev environments with Cursor
diff --git a/docs/blog/posts/dstack-metrics.md b/docs/blog/posts/dstack-metrics.md
index f06ff3151..f4647d782 100644
--- a/docs/blog/posts/dstack-metrics.md
+++ b/docs/blog/posts/dstack-metrics.md
@@ -5,7 +5,7 @@ description: "dstack introduces a new CLI command (and API) for monitoring conta
slug: dstack-metrics
image: https://dstack.ai/static-assets/static-assets/images/dstack-stats-v2.png
categories:
- - Releases
+ - Changelog
---
# Monitoring essential GPU metrics via CLI
diff --git a/docs/blog/posts/dstack-sky-own-cloud-accounts.md b/docs/blog/posts/dstack-sky-own-cloud-accounts.md
index ff0b8d182..16b68867c 100644
--- a/docs/blog/posts/dstack-sky-own-cloud-accounts.md
+++ b/docs/blog/posts/dstack-sky-own-cloud-accounts.md
@@ -4,7 +4,7 @@ date: 2024-06-11
description: "With today's release, dstack Sky supports both options: accessing the GPU marketplace and using your own cloud accounts."
slug: dstack-sky-own-cloud-accounts
categories:
- - Releases
+ - Changelog
---
# dstack Sky now supports your own cloud accounts
diff --git a/docs/blog/posts/dstack-sky.md b/docs/blog/posts/dstack-sky.md
index 7cfe80097..78d35641c 100644
--- a/docs/blog/posts/dstack-sky.md
+++ b/docs/blog/posts/dstack-sky.md
@@ -3,7 +3,7 @@ date: 2024-03-11
description: A managed service that enables you to get GPUs at competitive rates from a wide pool of providers.
slug: dstack-sky
categories:
- - Releases
+ - Changelog
---
# Introducing dstack Sky
diff --git a/docs/blog/posts/gh200-on-lambda.md b/docs/blog/posts/gh200-on-lambda.md
index 970c87b12..1741e6f2e 100644
--- a/docs/blog/posts/gh200-on-lambda.md
+++ b/docs/blog/posts/gh200-on-lambda.md
@@ -5,7 +5,7 @@ description: "TBA"
slug: gh200-on-lambda
image: https://dstack.ai/static-assets/static-assets/images/dstack-arm--gh200-lambda-min.png
categories:
- - Releases
+ - Changelog
---
# Supporting ARM and NVIDIA GH200 on Lambda
diff --git a/docs/blog/posts/gpu-blocks-and-proxy-jump.md b/docs/blog/posts/gpu-blocks-and-proxy-jump.md
index dc8bea1df..cbf9ab7dc 100644
--- a/docs/blog/posts/gpu-blocks-and-proxy-jump.md
+++ b/docs/blog/posts/gpu-blocks-and-proxy-jump.md
@@ -5,7 +5,7 @@ description: "TBA"
slug: gpu-blocks-and-proxy-jump
image: https://dstack.ai/static-assets/static-assets/images/data-centers-and-private-clouds.png
categories:
- - Releases
+ - Changelog
---
# Introducing GPU blocks and proxy jump for SSH fleets
diff --git a/docs/blog/posts/inactivity-duration.md b/docs/blog/posts/inactivity-duration.md
index 28ea3f5d7..d04a8eba4 100644
--- a/docs/blog/posts/inactivity-duration.md
+++ b/docs/blog/posts/inactivity-duration.md
@@ -5,7 +5,7 @@ description: "dstack introduces a new feature that automatically detects and shu
slug: inactivity-duration
image: https://dstack.ai/static-assets/static-assets/images/inactive-dev-environments-auto-shutdown.png
categories:
- - Releases
+ - Changelog
---
# Auto-shutdown for inactive dev environments—no idle GPUs
diff --git a/docs/blog/posts/instance-volumes.md b/docs/blog/posts/instance-volumes.md
index c4f5e3b1b..95b48cee1 100644
--- a/docs/blog/posts/instance-volumes.md
+++ b/docs/blog/posts/instance-volumes.md
@@ -5,7 +5,7 @@ description: "To simplify caching across runs and the use of NFS, we introduce a
image: https://dstack.ai/static-assets/static-assets/images/dstack-instance-volumes.png
slug: instance-volumes
categories:
- - Releases
+ - Changelog
---
# Introducing instance volumes to persist data on instances
diff --git a/docs/blog/posts/intel-gaudi.md b/docs/blog/posts/intel-gaudi.md
index 7abc69f41..6f95f49d0 100644
--- a/docs/blog/posts/intel-gaudi.md
+++ b/docs/blog/posts/intel-gaudi.md
@@ -5,7 +5,7 @@ description: "dstack now supports Intel Gaudi accelerators with SSH fleets, simp
slug: intel-gaudi
image: https://dstack.ai/static-assets/static-assets/images/dstack-intel-gaudi-and-intel-tiber-cloud.png-v2
categories:
- - Releases
+ - Changelog
---
# Supporting Intel Gaudi AI accelerators with SSH fleets
diff --git a/docs/blog/posts/metrics-ui.md b/docs/blog/posts/metrics-ui.md
index 74719af2d..b15bbffc5 100644
--- a/docs/blog/posts/metrics-ui.md
+++ b/docs/blog/posts/metrics-ui.md
@@ -5,7 +5,7 @@ description: "TBA"
slug: metrics-ui
image: https://dstack.ai/static-assets/static-assets/images/dstack-metrics-ui-v3-min.png
categories:
- - Releases
+ - Changelog
---
# Built-in UI for monitoring essential GPU metrics
diff --git a/docs/blog/posts/mpi.md b/docs/blog/posts/mpi.md
index ef5d68582..5473d64a2 100644
--- a/docs/blog/posts/mpi.md
+++ b/docs/blog/posts/mpi.md
@@ -5,7 +5,7 @@ description: "TBA"
slug: mpi
image: https://dstack.ai/static-assets/static-assets/images/dstack-mpi-v2.png
categories:
- - Releases
+ - Changelog
---
# Supporting MPI and NCCL/RCCL tests
diff --git a/docs/blog/posts/nebius.md b/docs/blog/posts/nebius.md
index c6f4374db..ef484b0f3 100644
--- a/docs/blog/posts/nebius.md
+++ b/docs/blog/posts/nebius.md
@@ -5,7 +5,7 @@ description: "TBA"
slug: nebius
image: https://dstack.ai/static-assets/static-assets/images/dstack-nebius-v2.png
categories:
- - Releases
+ - Changelog
---
# Supporting GPU provisioning and orchestration on Nebius
diff --git a/docs/blog/posts/nvidia-and-amd-on-vultr.md b/docs/blog/posts/nvidia-and-amd-on-vultr.md
index eda4519b4..2fb30ebbc 100644
--- a/docs/blog/posts/nvidia-and-amd-on-vultr.md
+++ b/docs/blog/posts/nvidia-and-amd-on-vultr.md
@@ -5,7 +5,7 @@ description: "Introducing integration with Vultr: The new integration allows Vul
slug: nvidia-and-amd-on-vultr
image: https://dstack.ai/static-assets/static-assets/images/dstack-vultr.png
categories:
- - Releases
+ - Changelog
---
# Supporting NVIDIA and AMD accelerators on Vultr
diff --git a/docs/blog/posts/prometheus.md b/docs/blog/posts/prometheus.md
index 58299dfb4..5482d0c13 100644
--- a/docs/blog/posts/prometheus.md
+++ b/docs/blog/posts/prometheus.md
@@ -5,7 +5,7 @@ description: "TBA"
slug: prometheus
image: https://dstack.ai/static-assets/static-assets/images/dstack-prometheus-v3.png
categories:
- - Releases
+ - Changelog
---
# Exporting GPU, cost, and other metrics to Prometheus
diff --git a/docs/blog/posts/tpu-on-gcp.md b/docs/blog/posts/tpu-on-gcp.md
index c38bea20a..8fff83cb4 100644
--- a/docs/blog/posts/tpu-on-gcp.md
+++ b/docs/blog/posts/tpu-on-gcp.md
@@ -4,7 +4,7 @@ date: 2024-09-10
description: "Learn how to use TPUs with dstack for fine-tuning and deploying LLMs, leveraging open-source tools like Hugging Face’s Optimum TPU and vLLM."
slug: tpu-on-gcp
categories:
- - Releases
+ - Changelog
---
# Using TPUs for fine-tuning and deploying LLMs
diff --git a/docs/blog/posts/volumes-on-runpod.md b/docs/blog/posts/volumes-on-runpod.md
index a6d436790..de0c8d6d0 100644
--- a/docs/blog/posts/volumes-on-runpod.md
+++ b/docs/blog/posts/volumes-on-runpod.md
@@ -4,7 +4,7 @@ date: 2024-08-13
description: "Learn how to use volumes with dstack to optimize model inference cold start times on RunPod."
slug: volumes-on-runpod
categories:
- - Releases
+ - Changelog
---
# Using volumes to optimize cold starts on RunPod
diff --git a/docs/docs/concepts/services.md b/docs/docs/concepts/services.md
index 5dd92f19c..a93cacf0d 100644
--- a/docs/docs/concepts/services.md
+++ b/docs/docs/concepts/services.md
@@ -679,6 +679,61 @@ utilization_policy:
[`max_price`](../reference/dstack.yml/service.md#max_price), and
among [others](../reference/dstack.yml/service.md).
+## Rolling deployment
+
+To deploy a new version of a service that is already `running`, use `dstack apply`. `dstack` will automatically detect changes and suggest a rolling deployment update.
+
+
+
+```shell
+$ dstack apply -f my-service.dstack.yml
+
+Active run my-service already exists. Detected changes that can be updated in-place:
+- Repo state (branch, commit, or other)
+- File archives
+- Configuration properties:
+ - env
+ - files
+
+Update the run? [y/n]:
+```
+
+
+
+If approved, `dstack` gradually updates the service replicas. To update a replica, `dstack` starts a new replica, waits for it to become `running`, then terminates the old replica. This process is repeated for each replica, one at a time.
+
+You can track the progress of rolling deployment in both `dstack apply` or `dstack ps`.
+Older replicas have lower `deployment` numbers; newer ones have higher.
+
+
+
+```shell
+$ dstack apply -f my-service.dstack.yml
+
+⠋ Launching my-service...
+ NAME BACKEND PRICE STATUS SUBMITTED
+ my-service deployment=1 running 11 mins ago
+ replica=0 job=0 deployment=0 aws (us-west-2) $0.0026 terminating 11 mins ago
+ replica=1 job=0 deployment=1 aws (us-west-2) $0.0026 running 1 min ago
+```
+
+The rolling deployment stops when all replicas are updated or when a new deployment is submitted.
+
+??? info "Supported properties"
+
+
+ Rolling deployment supports changes to the following properties: `port`, `resources`, `volumes`, `docker`, `files`, `image`, `user`, `privileged`, `entrypoint`, `working_dir`, `python`, `nvcc`, `single_branch`, `env`, `shell`, `commands`, as well as changes to [repo](repos.md) or [file](#files) contents.
+
+ Changes to `replicas` and `scaling` can be applied without redeploying replicas.
+
+ Changes to other properties require a full service restart.
+
+ To trigger a rolling deployment when no properties have changed (e.g., after updating [secrets](secrets.md) or to restart all replicas),
+ make a minor config change, such as adding a dummy [environment variable](#environment-variables).
+
--8<-- "docs/concepts/snippets/manage-runs.ext"
!!! info "What's next?"
diff --git a/docs/docs/guides/server-deployment.md b/docs/docs/guides/server-deployment.md
index 8ff4034e6..ee47f481a 100644
--- a/docs/docs/guides/server-deployment.md
+++ b/docs/docs/guides/server-deployment.md
@@ -124,7 +124,11 @@ Postgres has no such limitation and is recommended for production deployment.
### PostgreSQL
-To store the server state in Postgres, set the `DSTACK_DATABASE_URL` environment variable.
+To store the server state in Postgres, set the `DSTACK_DATABASE_URL` environment variable:
+
+```shell
+$ DSTACK_DATABASE_URL=postgresql+asyncpg://user:password@db-host:5432/dstack dstack server
+```
??? info "Migrate from SQLite to PostgreSQL"
You can migrate the existing state from SQLite to PostgreSQL using `pgloader`:
diff --git a/docs/docs/reference/environment-variables.md b/docs/docs/reference/environment-variables.md
index 4c5d44bd5..3c28ba333 100644
--- a/docs/docs/reference/environment-variables.md
+++ b/docs/docs/reference/environment-variables.md
@@ -123,6 +123,7 @@ For more details on the options below, refer to the [server deployment](../guide
- `DSTACK_DB_POOL_SIZE`{ #DSTACK_DB_POOL_SIZE } - The client DB connections pool size. Defaults to `20`,
- `DSTACK_DB_MAX_OVERFLOW`{ #DSTACK_DB_MAX_OVERFLOW } - The client DB connections pool allowed overflow. Defaults to `20`.
- `DSTACK_SERVER_BACKGROUND_PROCESSING_FACTOR`{ #DSTACK_SERVER_BACKGROUND_PROCESSING_FACTOR } - The number of background jobs for processing server resources. Increase if you need to process more resources per server replica quickly. Defaults to `1`.
+- `DSTACK_SERVER_BACKGROUND_PROCESSING_DISABLED`{ #DSTACK_SERVER_BACKGROUND_PROCESSING_DISABLED } - Disables background processing if set to any value. Useful to run only web frontend and API server.
??? info "Internal environment variables"
The following environment variables are intended for development purposes:
diff --git a/docs/overrides/header-2.html b/docs/overrides/header-2.html
index 2c7e67679..4f8542d38 100644
--- a/docs/overrides/header-2.html
+++ b/docs/overrides/header-2.html
@@ -62,7 +62,7 @@
{% if "navigation.tabs.sticky" in features %}
diff --git a/docs/overrides/home.html b/docs/overrides/home.html
index 8c306dd36..b7c3ee994 100644
--- a/docs/overrides/home.html
+++ b/docs/overrides/home.html
@@ -509,20 +509,20 @@ FAQ
-
+
@@ -539,8 +539,8 @@ dstack Enterprise
diff --git a/docs/overrides/main.html b/docs/overrides/main.html
index 4a74fb3a8..8db019610 100644
--- a/docs/overrides/main.html
+++ b/docs/overrides/main.html
@@ -102,38 +102,42 @@
+
+
+
+
+
+
+
+
+
+
+
+
{% endblock %}
diff --git a/frontend/src/pages/Runs/Details/Logs/index.tsx b/frontend/src/pages/Runs/Details/Logs/index.tsx
index aaffbf41c..6fc550103 100644
--- a/frontend/src/pages/Runs/Details/Logs/index.tsx
+++ b/frontend/src/pages/Runs/Details/Logs/index.tsx
@@ -31,7 +31,7 @@ export const Logs: React.FC = ({ className, projectName, runName, jobSub
const writeDataToTerminal = (logs: ILogItem[]) => {
logs.forEach((logItem) => {
- terminalInstance.current.write(logItem.message);
+ terminalInstance.current.write(logItem.message.replace(/(? {
const { data, isLoading, refreshList, isLoadingMore } = useInfiniteScroll({
useLazyQuery: useLazyGetRunsQuery,
- args: { ...filteringRequestParams, limit: DEFAULT_TABLE_PAGE_SIZE },
+ args: { ...filteringRequestParams, limit: DEFAULT_TABLE_PAGE_SIZE, job_submissions_limit: 1 },
getPaginationParams: (lastRun) => ({ prev_submitted_at: lastRun.submitted_at }),
});
diff --git a/frontend/src/services/project.ts b/frontend/src/services/project.ts
index 9e05e444a..c7559784e 100644
--- a/frontend/src/services/project.ts
+++ b/frontend/src/services/project.ts
@@ -4,6 +4,8 @@ import { createApi, fetchBaseQuery } from '@reduxjs/toolkit/query/react';
import { base64ToArrayBuffer } from 'libs';
import fetchBaseQueryHeaders from 'libs/fetchBaseQueryHeaders';
+const decoder = new TextDecoder('utf-8');
+
// Helper function to transform backend response to frontend format
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const transformProjectResponse = (project: any): IProject => ({
@@ -131,7 +133,7 @@ export const projectApi = createApi({
transformResponse: (response: { logs: ILogItem[]; next_token: string }) => {
const logs = response.logs.map((logItem) => ({
...logItem,
- message: base64ToArrayBuffer(logItem.message as string),
+ message: decoder.decode(base64ToArrayBuffer(logItem.message)),
}));
return {
diff --git a/frontend/src/types/log.d.ts b/frontend/src/types/log.d.ts
index eec182c1e..99e9532c8 100644
--- a/frontend/src/types/log.d.ts
+++ b/frontend/src/types/log.d.ts
@@ -1,7 +1,7 @@
declare interface ILogItem {
log_source: 'stdout' | 'stderr';
timestamp: string;
- message: string | Uint8Array;
+ message: string;
}
declare type TRequestLogsParams = {
diff --git a/frontend/src/types/run.d.ts b/frontend/src/types/run.d.ts
index eae9ebacc..2e613defb 100644
--- a/frontend/src/types/run.d.ts
+++ b/frontend/src/types/run.d.ts
@@ -7,6 +7,7 @@ declare type TRunsRequestParams = {
prev_run_id?: string;
limit?: number;
ascending?: boolean;
+ job_submissions_limit?: number;
};
declare type TDeleteRunsRequestParams = {
diff --git a/mkdocs.yml b/mkdocs.yml
index e6fac58ed..7d93d9623 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -172,7 +172,7 @@ markdown_extensions:
- pymdownx.tasklist:
custom_checkbox: true
- toc:
- toc_depth: 5
+ toc_depth: 3
permalink: true
- attr_list
- md_in_html
@@ -292,8 +292,9 @@ nav:
- TPU: examples/accelerators/tpu/index.md
- Intel Gaudi: examples/accelerators/intel/index.md
- Tenstorrent: examples/accelerators/tenstorrent/index.md
- - Benchmarks: blog/benchmarks.md
+ - Changelog: blog/changelog.md
- Case studies: blog/case-studies.md
+ - Benchmarks: blog/benchmarks.md
- Blog:
- blog/index.md
# - Discord: https://discord.gg/u8SmfwPpMd" target="_blank
diff --git a/pyproject.toml b/pyproject.toml
index 47353886b..736ba6768 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,6 +35,7 @@ dependencies = [
"gpuhunt==0.1.6",
"argcomplete>=3.5.0",
"ignore-python>=0.2.0",
+ "orjson",
]
[project.urls]
diff --git a/runner/go.mod b/runner/go.mod
index 22dad6466..850ea8253 100644
--- a/runner/go.mod
+++ b/runner/go.mod
@@ -1,6 +1,6 @@
module github.com/dstackai/dstack/runner
-go 1.23
+go 1.23.8
require (
github.com/alexellis/go-execute/v2 v2.2.1
@@ -10,6 +10,7 @@ require (
github.com/docker/docker v26.0.0+incompatible
github.com/docker/go-connections v0.5.0
github.com/docker/go-units v0.5.0
+ github.com/dstackai/ansistrip v0.0.6
github.com/go-git/go-git/v5 v5.12.0
github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f
github.com/gorilla/websocket v1.5.1
@@ -62,6 +63,7 @@ require (
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
github.com/skeema/knownhosts v1.2.2 // indirect
+ github.com/tidwall/btree v1.7.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/ulikunitz/xz v0.5.12 // indirect
@@ -77,10 +79,11 @@ require (
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.24.0 // indirect
golang.org/x/sync v0.7.0 // indirect
+ golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.20.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
- gotest.tools/v3 v3.5.0 // indirect
+ gotest.tools/v3 v3.5.1 // indirect
)
diff --git a/runner/go.sum b/runner/go.sum
index 41e133c46..1222fcac8 100644
--- a/runner/go.sum
+++ b/runner/go.sum
@@ -47,6 +47,8 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
+github.com/dstackai/ansistrip v0.0.6 h1:6qqeDNWt8NoqfkY1CxKUvdHpJzBl89LOE3wMwptVpaI=
+github.com/dstackai/ansistrip v0.0.6/go.mod h1:w3ejXI0twxDv6bPXhkOaPeYdbwz2nwcrcvFoZGqi9F0=
github.com/ebitengine/purego v0.8.1 h1:sdRKd6plj7KYW33EH5As6YKfe8m9zbN9JMrOjNVF/BE=
github.com/ebitengine/purego v0.8.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU=
@@ -171,6 +173,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI=
+github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@@ -281,8 +285,9 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/time v0.0.0-20170424234030-8be79e1e0910 h1:bCMaBn7ph495H+x72gEvgcv+mDRd9dElbzo/mVCMxX4=
golang.org/x/time v0.0.0-20170424234030-8be79e1e0910/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
+golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
@@ -318,5 +323,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
-gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
+gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
+gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
diff --git a/runner/internal/executor/base.go b/runner/internal/executor/base.go
index 2163ca920..554bd7646 100644
--- a/runner/internal/executor/base.go
+++ b/runner/internal/executor/base.go
@@ -10,7 +10,7 @@ import (
type Executor interface {
GetHistory(timestamp int64) *schemas.PullResponse
- GetJobLogsHistory() []schemas.LogEvent
+ GetJobWsLogsHistory() []schemas.LogEvent
GetRunnerState() string
Run(ctx context.Context) error
SetCodePath(codePath string)
diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go
index e14ce540d..d2ced5dc3 100644
--- a/runner/internal/executor/executor.go
+++ b/runner/internal/executor/executor.go
@@ -11,6 +11,7 @@ import (
"os/exec"
osuser "os/user"
"path/filepath"
+ "runtime"
"strconv"
"strings"
"sync"
@@ -18,6 +19,7 @@ import (
"time"
"github.com/creack/pty"
+ "github.com/dstackai/ansistrip"
"github.com/dstackai/dstack/runner/consts"
"github.com/dstackai/dstack/runner/internal/connections"
"github.com/dstackai/dstack/runner/internal/gerrors"
@@ -27,6 +29,24 @@ import (
"github.com/prometheus/procfs"
)
+// TODO: Tune these parameters for optimal experience/performance
+const (
+ // Output is flushed when the cursor doesn't move for this duration
+ AnsiStripFlushInterval = 500 * time.Millisecond
+
+ // Output is flushed regardless of cursor activity after this maximum delay
+ AnsiStripMaxDelay = 3 * time.Second
+
+ // Maximum buffer size for ansistrip
+ MaxBufferSize = 32 * 1024 // 32KB
+)
+
+type ConnectionTracker interface {
+ GetNoConnectionsSecs() int64
+ Track(ticker <-chan time.Time)
+ Stop()
+}
+
type RunExecutor struct {
tempDir string
homeDir string
@@ -47,13 +67,21 @@ type RunExecutor struct {
state string
jobStateHistory []schemas.JobStateEvent
jobLogs *appendWriter
+ jobWsLogs *appendWriter
runnerLogs *appendWriter
timestamp *MonotonicTimestamp
killDelay time.Duration
- connectionTracker *connections.ConnectionTracker
+ connectionTracker ConnectionTracker
}
+// stubConnectionTracker is a no-op implementation for when procfs is not available (only required for tests on darwin)
+type stubConnectionTracker struct{}
+
+func (s *stubConnectionTracker) GetNoConnectionsSecs() int64 { return 0 }
+func (s *stubConnectionTracker) Track(ticker <-chan time.Time) {}
+func (s *stubConnectionTracker) Stop() {}
+
func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort int) (*RunExecutor, error) {
mu := &sync.RWMutex{}
timestamp := NewMonotonicTimestamp()
@@ -65,15 +93,25 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort i
if err != nil {
return nil, fmt.Errorf("failed to parse current user uid: %w", err)
}
- proc, err := procfs.NewDefaultFS()
- if err != nil {
- return nil, fmt.Errorf("failed to initialize procfs: %w", err)
+
+ // Try to initialize procfs, but don't fail if it's not available (e.g., on macOS)
+ var connectionTracker ConnectionTracker
+
+ if runtime.GOOS == "linux" {
+ proc, err := procfs.NewDefaultFS()
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize procfs: %w", err)
+ }
+ connectionTracker = connections.NewConnectionTracker(connections.ConnectionTrackerConfig{
+ Port: uint64(sshPort),
+ MinConnDuration: 10 * time.Second, // shorter connections are likely from dstack-server
+ Procfs: proc,
+ })
+ } else {
+ // Use stub connection tracker (only required for tests on darwin)
+ connectionTracker = &stubConnectionTracker{}
}
- connectionTracker := connections.NewConnectionTracker(connections.ConnectionTrackerConfig{
- Port: uint64(sshPort),
- MinConnDuration: 10 * time.Second, // shorter connections are likely from dstack-server
- Procfs: proc,
- })
+
return &RunExecutor{
tempDir: tempDir,
homeDir: homeDir,
@@ -86,6 +124,7 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort i
state: WaitSubmit,
jobStateHistory: make([]schemas.JobStateEvent, 0),
jobLogs: newAppendWriter(mu, timestamp),
+ jobWsLogs: newAppendWriter(mu, timestamp),
runnerLogs: newAppendWriter(mu, timestamp),
timestamp: timestamp,
@@ -129,7 +168,9 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
}
}()
- logger := io.MultiWriter(runnerLogFile, os.Stdout, ex.runnerLogs)
+ stripper := ansistrip.NewWriter(ex.runnerLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize)
+ defer stripper.Close()
+ logger := io.MultiWriter(runnerLogFile, os.Stdout, stripper)
ctx = log.WithLogger(ctx, log.NewEntry(logger, int(log.DefaultEntry.Logger.Level))) // todo loglevel
log.Info(ctx, "Run job", "log_level", log.GetLogger(ctx).Logger.Level.String())
@@ -431,7 +472,9 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
defer func() { _ = ptm.Close() }()
defer func() { _ = cmd.Wait() }() // release resources if copy fails
- logger := io.MultiWriter(jobLogFile, ex.jobLogs)
+ stripper := ansistrip.NewWriter(ex.jobLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize)
+ defer stripper.Close()
+ logger := io.MultiWriter(jobLogFile, ex.jobWsLogs, stripper)
_, err = io.Copy(logger, ptm)
if err != nil && !isPtyError(err) {
return gerrors.Wrap(err)
diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go
index 8d275b137..e13184513 100644
--- a/runner/internal/executor/executor_test.go
+++ b/runner/internal/executor/executor_test.go
@@ -9,6 +9,7 @@ import (
"os"
"os/exec"
"path/filepath"
+ "strings"
"testing"
"time"
@@ -17,8 +18,6 @@ import (
"github.com/stretchr/testify/require"
)
-// todo test get history
-
func TestExecutor_WorkingDir_Current(t *testing.T) {
var b bytes.Buffer
ex := makeTestExecutor(t)
@@ -28,7 +27,8 @@ func TestExecutor_WorkingDir_Current(t *testing.T) {
err := ex.execJob(context.TODO(), io.Writer(&b))
assert.NoError(t, err)
- assert.Equal(t, ex.workingDir+"\r\n", b.String())
+ // Normalize line endings for cross-platform compatibility.
+ assert.Equal(t, ex.workingDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
}
func TestExecutor_WorkingDir_Nil(t *testing.T) {
@@ -39,7 +39,7 @@ func TestExecutor_WorkingDir_Nil(t *testing.T) {
err := ex.execJob(context.TODO(), io.Writer(&b))
assert.NoError(t, err)
- assert.Equal(t, ex.workingDir+"\r\n", b.String())
+ assert.Equal(t, ex.workingDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
}
func TestExecutor_HomeDir(t *testing.T) {
@@ -49,7 +49,7 @@ func TestExecutor_HomeDir(t *testing.T) {
err := ex.execJob(context.TODO(), io.Writer(&b))
assert.NoError(t, err)
- assert.Equal(t, ex.homeDir+"\r\n", b.String())
+ assert.Equal(t, ex.homeDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
}
func TestExecutor_NonZeroExit(t *testing.T) {
@@ -61,7 +61,7 @@ func TestExecutor_NonZeroExit(t *testing.T) {
assert.Error(t, err)
assert.NotEmpty(t, ex.jobStateHistory)
exitStatus := ex.jobStateHistory[len(ex.jobStateHistory)-1].ExitStatus
- assert.NotNil(t, exitStatus, ex.jobStateHistory)
+ assert.NotNil(t, exitStatus)
assert.Equal(t, 100, *exitStatus)
}
@@ -96,7 +96,7 @@ func TestExecutor_LocalRepo(t *testing.T) {
err = ex.execJob(context.TODO(), io.Writer(&b))
assert.NoError(t, err)
- assert.Equal(t, "bar\r\n", b.String())
+ assert.Equal(t, "bar\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
}
func TestExecutor_Recover(t *testing.T) {
@@ -148,8 +148,8 @@ func TestExecutor_RemoteRepo(t *testing.T) {
err = ex.execJob(context.TODO(), io.Writer(&b))
assert.NoError(t, err)
- expected := fmt.Sprintf("%s\r\n%s\r\n%s\r\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail)
- assert.Equal(t, expected, b.String())
+ expected := fmt.Sprintf("%s\n%s\n%s\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail)
+ assert.Equal(t, expected, strings.ReplaceAll(b.String(), "\r\n", "\n"))
}
/* Helpers */
@@ -236,3 +236,98 @@ func TestWriteDstackProfile(t *testing.T) {
assert.Equal(t, value, string(out))
}
}
+
+func TestExecutor_Logs(t *testing.T) {
+ var b bytes.Buffer
+ ex := makeTestExecutor(t)
+ // Use printf to generate ANSI control codes.
+ // \033[31m = red text, \033[1;32m = bold green text, \033[0m = reset
+ ex.jobSpec.Commands = append(ex.jobSpec.Commands, "printf '\\033[31mRed Hello World\\033[0m\\n' && printf '\\033[1;32mBold Green Line 2\\033[0m\\n' && printf 'Line 3\\n'")
+
+ err := ex.execJob(context.TODO(), io.Writer(&b))
+ assert.NoError(t, err)
+
+ logHistory := ex.GetHistory(0).JobLogs
+ assert.NotEmpty(t, logHistory)
+
+ logString := combineLogMessages(logHistory)
+ normalizedLogString := strings.ReplaceAll(logString, "\r\n", "\n")
+
+ expectedOutput := "Red Hello World\nBold Green Line 2\nLine 3\n"
+ assert.Equal(t, expectedOutput, normalizedLogString, "Should strip ANSI codes from regular logs")
+
+ // Verify timestamps are in order
+ assert.Greater(t, len(logHistory), 0)
+ for i := 1; i < len(logHistory); i++ {
+ assert.GreaterOrEqual(t, logHistory[i].Timestamp, logHistory[i-1].Timestamp)
+ }
+}
+
+func TestExecutor_LogsWithErrors(t *testing.T) {
+ var b bytes.Buffer
+ ex := makeTestExecutor(t)
+ ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo 'Success message' && echo 'Error message' >&2 && exit 1")
+
+ err := ex.execJob(context.TODO(), io.Writer(&b))
+ assert.Error(t, err)
+
+ logHistory := ex.GetHistory(0).JobLogs
+ assert.NotEmpty(t, logHistory)
+
+ logString := combineLogMessages(logHistory)
+ normalizedLogString := strings.ReplaceAll(logString, "\r\n", "\n")
+
+ expectedOutput := "Success message\nError message\n"
+ assert.Equal(t, expectedOutput, normalizedLogString)
+}
+
+func TestExecutor_LogsAnsiCodeHandling(t *testing.T) {
+ var b bytes.Buffer
+ ex := makeTestExecutor(t)
+
+ // Test a variety of ANSI escape sequences on stdout and stderr.
+ cmd := "printf '\\033[31mRed\\033[0m \\033[32mGreen\\033[0m\\n' && " +
+ "printf '\\033[1mBold\\033[0m \\033[4mUnderline\\033[0m\\n' && " +
+ "printf '\\033[s\\033[uPlain text\\n' >&2"
+
+ ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd)
+
+ err := ex.execJob(context.TODO(), io.Writer(&b))
+ assert.NoError(t, err)
+
+ // 1. Check WebSocket logs, which should preserve ANSI codes.
+ wsLogHistory := ex.GetJobWsLogsHistory()
+ assert.NotEmpty(t, wsLogHistory)
+ wsLogString := combineLogMessages(wsLogHistory)
+ normalizedWsLogString := strings.ReplaceAll(wsLogString, "\r\n", "\n")
+
+ expectedWsOutput := "\033[31mRed\033[0m \033[32mGreen\033[0m\n" +
+ "\033[1mBold\033[0m \033[4mUnderline\033[0m\n" +
+ "\033[s\033[uPlain text\n"
+ assert.Equal(t, expectedWsOutput, normalizedWsLogString, "Websocket logs should preserve ANSI codes")
+
+ // 2. Check regular job logs, which should have ANSI codes stripped.
+ regularLogHistory := ex.GetHistory(0).JobLogs
+ assert.NotEmpty(t, regularLogHistory)
+ regularLogString := combineLogMessages(regularLogHistory)
+ normalizedRegularLogString := strings.ReplaceAll(regularLogString, "\r\n", "\n")
+
+ expectedRegularOutput := "Red Green\n" +
+ "Bold Underline\n" +
+ "Plain text\n"
+ assert.Equal(t, expectedRegularOutput, normalizedRegularLogString, "Regular logs should have ANSI codes stripped")
+
+ // Verify timestamps are ordered for both log types.
+ assert.Greater(t, len(wsLogHistory), 0)
+ for i := 1; i < len(wsLogHistory); i++ {
+ assert.GreaterOrEqual(t, wsLogHistory[i].Timestamp, wsLogHistory[i-1].Timestamp)
+ }
+}
+
+func combineLogMessages(logHistory []schemas.LogEvent) string {
+ var logOutput bytes.Buffer
+ for _, logEvent := range logHistory {
+ logOutput.Write(logEvent.Message)
+ }
+ return logOutput.String()
+}
diff --git a/runner/internal/executor/query.go b/runner/internal/executor/query.go
index 1dff4e330..6678e5f8d 100644
--- a/runner/internal/executor/query.go
+++ b/runner/internal/executor/query.go
@@ -4,8 +4,8 @@ import (
"github.com/dstackai/dstack/runner/internal/schemas"
)
-func (ex *RunExecutor) GetJobLogsHistory() []schemas.LogEvent {
- return ex.jobLogs.history
+func (ex *RunExecutor) GetJobWsLogsHistory() []schemas.LogEvent {
+ return ex.jobWsLogs.history
}
func (ex *RunExecutor) GetHistory(timestamp int64) *schemas.PullResponse {
diff --git a/runner/internal/metrics/metrics_test.go b/runner/internal/metrics/metrics_test.go
index 7f280da25..d547e2e33 100644
--- a/runner/internal/metrics/metrics_test.go
+++ b/runner/internal/metrics/metrics_test.go
@@ -1,6 +1,7 @@
package metrics
import (
+ "runtime"
"testing"
"github.com/dstackai/dstack/runner/internal/schemas"
@@ -8,6 +9,9 @@ import (
)
func TestGetAMDGPUMetrics_OK(t *testing.T) {
+ if runtime.GOOS == "darwin" {
+ t.Skip("Skipping on macOS")
+ }
collector, err := NewMetricsCollector()
assert.NoError(t, err)
@@ -39,6 +43,9 @@ func TestGetAMDGPUMetrics_OK(t *testing.T) {
}
func TestGetAMDGPUMetrics_ErrorGPUUtilNA(t *testing.T) {
+ if runtime.GOOS == "darwin" {
+ t.Skip("Skipping on macOS")
+ }
collector, err := NewMetricsCollector()
assert.NoError(t, err)
metrics, err := collector.getAMDGPUMetrics("gpu,gfx,gfx_clock,vram_used,vram_total\n0,N/A,N/A,283,196300\n")
diff --git a/runner/internal/runner/api/ws.go b/runner/internal/runner/api/ws.go
index cade1170a..ebb0caea2 100644
--- a/runner/internal/runner/api/ws.go
+++ b/runner/internal/runner/api/ws.go
@@ -34,23 +34,23 @@ func (s *Server) streamJobLogs(conn *websocket.Conn) {
for {
s.executor.RLock()
- jobLogsHistory := s.executor.GetJobLogsHistory()
+ jobLogsWsHistory := s.executor.GetJobWsLogsHistory()
select {
case <-s.shutdownCh:
- if currentPos >= len(jobLogsHistory) {
+ if currentPos >= len(jobLogsWsHistory) {
s.executor.RUnlock()
close(s.wsDoneCh)
return
}
default:
- if currentPos >= len(jobLogsHistory) {
+ if currentPos >= len(jobLogsWsHistory) {
s.executor.RUnlock()
time.Sleep(100 * time.Millisecond)
continue
}
}
- for currentPos < len(jobLogsHistory) {
- if err := conn.WriteMessage(websocket.BinaryMessage, jobLogsHistory[currentPos].Message); err != nil {
+ for currentPos < len(jobLogsWsHistory) {
+ if err := conn.WriteMessage(websocket.BinaryMessage, jobLogsWsHistory[currentPos].Message); err != nil {
s.executor.RUnlock()
log.Error(context.TODO(), "Failed to write message", "err", err)
return
diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go
index 19462c9d7..1834188ae 100644
--- a/runner/internal/shim/docker.go
+++ b/runner/internal/shim/docker.go
@@ -274,6 +274,13 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
cfg := task.config
var err error
+ runnerDir, err := d.dockerParams.MakeRunnerDir(task.containerName)
+ if err != nil {
+ return tracerr.Wrap(err)
+ }
+ task.runnerDir = runnerDir
+ log.Debug(ctx, "runner dir", "task", task.ID, "path", runnerDir)
+
if cfg.GPU != 0 {
gpuIDs, err := d.gpuLock.Acquire(ctx, cfg.GPU)
if err != nil {
@@ -335,7 +342,10 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err := d.tasks.Update(task); err != nil {
return tracerr.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err)
}
- if err = pullImage(pullCtx, d.client, cfg); err != nil {
+ // Although it's called "runner dir", we also use it for shim task-related data.
+ // Maybe we should rename it to "task dir" (including the `/root/.dstack/runners` dir on the host).
+ pullLogPath := filepath.Join(runnerDir, "pull.log")
+ if err = pullImage(pullCtx, d.client, cfg, pullLogPath); err != nil {
errMessage := fmt.Sprintf("pullImage error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated(string(types.TerminationReasonCreatingContainerError), errMessage)
@@ -655,7 +665,7 @@ func mountDisk(ctx context.Context, deviceName, mountPoint string, fsRootPerms o
return nil
}
-func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConfig) error {
+func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConfig, logPath string) error {
if !strings.Contains(taskConfig.ImageName, ":") {
taskConfig.ImageName += ":latest"
}
@@ -685,51 +695,70 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf
if err != nil {
return tracerr.Wrap(err)
}
- defer func() { _ = reader.Close() }()
+ defer reader.Close()
+
+ logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
+ if err != nil {
+ return tracerr.Wrap(err)
+ }
+ defer logFile.Close()
+
+ teeReader := io.TeeReader(reader, logFile)
current := make(map[string]uint)
total := make(map[string]uint)
- type ProgressDetail struct {
- Current uint `json:"current"`
- Total uint `json:"total"`
- }
- type Progress struct {
- Id string `json:"id"`
- Status string `json:"status"`
- ProgressDetail ProgressDetail `json:"progressDetail"` //nolint:tagliatelle
- Error string `json:"error"`
+ // dockerd reports pulling progress as a stream of JSON Lines. The format of records is not documented in the API documentation,
+ // although it's occasionally mentioned, e.g., https://docs.docker.com/reference/api/engine/version-history/#v148-api-changes
+
+ // https://github.com/moby/moby/blob/e77ff99ede5ee5952b3a9227863552ae6e5b6fb1/pkg/jsonmessage/jsonmessage.go#L144
+ // All fields are optional
+ type PullMessage struct {
+ Id string `json:"id"` // layer id
+ Status string `json:"status"`
+ ProgressDetail struct {
+ Current uint `json:"current"` // bytes
+ Total uint `json:"total"` // bytes
+ } `json:"progressDetail"`
+ ErrorDetail struct {
+ Message string `json:"message"`
+ } `json:"errorDetail"`
}
- var status bool
+ var pullCompleted bool
pullErrors := make([]string, 0)
- scanner := bufio.NewScanner(reader)
+ scanner := bufio.NewScanner(teeReader)
for scanner.Scan() {
line := scanner.Bytes()
- var progressRow Progress
- if err := json.Unmarshal(line, &progressRow); err != nil {
+ var pullMessage PullMessage
+ if err := json.Unmarshal(line, &pullMessage); err != nil {
continue
}
- if progressRow.Status == "Downloading" {
- current[progressRow.Id] = progressRow.ProgressDetail.Current
- total[progressRow.Id] = progressRow.ProgressDetail.Total
+ if pullMessage.Status == "Downloading" {
+ current[pullMessage.Id] = pullMessage.ProgressDetail.Current
+ total[pullMessage.Id] = pullMessage.ProgressDetail.Total
}
- if progressRow.Status == "Download complete" {
- current[progressRow.Id] = total[progressRow.Id]
+ if pullMessage.Status == "Download complete" {
+ current[pullMessage.Id] = total[pullMessage.Id]
}
- if progressRow.Error != "" {
- log.Error(ctx, "error pulling image", "name", taskConfig.ImageName, "err", progressRow.Error)
- pullErrors = append(pullErrors, progressRow.Error)
+ if pullMessage.ErrorDetail.Message != "" {
+ log.Error(ctx, "error pulling image", "name", taskConfig.ImageName, "err", pullMessage.ErrorDetail.Message)
+ pullErrors = append(pullErrors, pullMessage.ErrorDetail.Message)
}
- if strings.HasPrefix(progressRow.Status, "Status:") {
- status = true
- log.Debug(ctx, progressRow.Status)
+ // If the pull is successful, the last two entries must be:
+ // "Digest: sha256:"
+ // "Status: "
+ // where is either "Downloaded newer image for " or "Image is up to date for ".
+ // See: https://github.com/moby/moby/blob/e77ff99ede5ee5952b3a9227863552ae6e5b6fb1/daemon/containerd/image_pull.go#L134-L152
+ // See: https://github.com/moby/moby/blob/e77ff99ede5ee5952b3a9227863552ae6e5b6fb1/daemon/containerd/image_pull.go#L257-L263
+ if strings.HasPrefix(pullMessage.Status, "Status:") {
+ pullCompleted = true
+ log.Debug(ctx, pullMessage.Status)
}
}
duration := time.Since(startTime)
-
var currentBytes uint
var totalBytes uint
for _, v := range current {
@@ -738,9 +767,13 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf
for _, v := range total {
totalBytes += v
}
-
speed := bytesize.New(float64(currentBytes) / duration.Seconds())
- if status && currentBytes == totalBytes {
+
+ if err := ctx.Err(); err != nil {
+ return tracerr.Errorf("image pull interrupted: downloaded %d bytes out of %d (%s/s): %w", currentBytes, totalBytes, speed, err)
+ }
+
+ if pullCompleted {
log.Debug(ctx, "image successfully pulled", "bytes", currentBytes, "bps", speed)
} else {
return tracerr.Errorf(
@@ -749,21 +782,11 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf
)
}
- err = ctx.Err()
- if err != nil {
- return tracerr.Errorf("imagepull interrupted: downloaded %d bytes out of %d (%s/s): %w", currentBytes, totalBytes, speed, err)
- }
return nil
}
func (d *DockerRunner) createContainer(ctx context.Context, task *Task) error {
- runnerDir, err := d.dockerParams.MakeRunnerDir(task.containerName)
- if err != nil {
- return tracerr.Wrap(err)
- }
- task.runnerDir = runnerDir
-
- mounts, err := d.dockerParams.DockerMounts(runnerDir)
+ mounts, err := d.dockerParams.DockerMounts(task.runnerDir)
if err != nil {
return tracerr.Wrap(err)
}
diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go
index 35c8cbab6..29a5e1afd 100644
--- a/runner/internal/shim/docker_test.go
+++ b/runner/internal/shim/docker_test.go
@@ -26,8 +26,9 @@ func TestDocker_SSHServer(t *testing.T) {
t.Parallel()
params := &dockerParametersMock{
- commands: []string{"echo 1"},
- sshPort: nextPort(),
+ commands: []string{"echo 1"},
+ sshPort: nextPort(),
+ runnerDir: t.TempDir(),
}
timeout := 180 // seconds
@@ -58,6 +59,7 @@ func TestDocker_SSHServerConnect(t *testing.T) {
commands: []string{"sleep 5"},
sshPort: nextPort(),
publicSSHKey: string(publicBytes),
+ runnerDir: t.TempDir(),
}
timeout := 180 // seconds
@@ -103,7 +105,8 @@ func TestDocker_ShmNoexecByDefault(t *testing.T) {
t.Parallel()
params := &dockerParametersMock{
- commands: []string{"mount | grep '/dev/shm .*size=65536k' | grep noexec"},
+ commands: []string{"mount | grep '/dev/shm .*size=65536k' | grep noexec"},
+ runnerDir: t.TempDir(),
}
timeout := 180 // seconds
@@ -125,7 +128,8 @@ func TestDocker_ShmExecIfSizeSpecified(t *testing.T) {
t.Parallel()
params := &dockerParametersMock{
- commands: []string{"mount | grep '/dev/shm .*size=1024k' | grep -v noexec"},
+ commands: []string{"mount | grep '/dev/shm .*size=1024k' | grep -v noexec"},
+ runnerDir: t.TempDir(),
}
timeout := 180 // seconds
@@ -148,6 +152,7 @@ type dockerParametersMock struct {
commands []string
sshPort int
publicSSHKey string
+ runnerDir string
}
func (c *dockerParametersMock) DockerPrivileged() bool {
@@ -184,7 +189,7 @@ func (c *dockerParametersMock) DockerMounts(string) ([]mount.Mount, error) {
}
func (c *dockerParametersMock) MakeRunnerDir(string) (string, error) {
- return "", nil
+ return c.runnerDir, nil
}
/* Utilities */
diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py
index b501f0b9d..2a7eeb4d5 100644
--- a/src/dstack/_internal/cli/services/configurators/fleet.py
+++ b/src/dstack/_internal/cli/services/configurators/fleet.py
@@ -25,6 +25,7 @@
ServerClientError,
URLNotFoundError,
)
+from dstack._internal.core.models.common import ApplyAction
from dstack._internal.core.models.configurations import ApplyConfigurationType
from dstack._internal.core.models.fleets import (
Fleet,
@@ -72,7 +73,104 @@ def apply_configuration(
spec=spec,
)
_print_plan_header(plan)
+ if plan.action is not None:
+ self._apply_plan(plan, command_args)
+ else:
+ # Old servers don't support spec update
+ self._apply_plan_on_old_server(plan, command_args)
+
+ def _apply_plan(self, plan: FleetPlan, command_args: argparse.Namespace):
+ delete_fleet_name: Optional[str] = None
+ action_message = ""
+ confirm_message = ""
+ if plan.current_resource is None:
+ if plan.spec.configuration.name is not None:
+ action_message += (
+ f"Fleet [code]{plan.spec.configuration.name}[/] does not exist yet."
+ )
+ confirm_message += "Create the fleet?"
+ else:
+ action_message += f"Found fleet [code]{plan.spec.configuration.name}[/]."
+ if plan.action == ApplyAction.CREATE:
+ delete_fleet_name = plan.current_resource.name
+ action_message += (
+ " Configuration changes detected. Cannot update the fleet in-place"
+ )
+ confirm_message += "Re-create the fleet?"
+ elif plan.current_resource.spec == plan.effective_spec:
+ if command_args.yes and not command_args.force:
+ # --force is required only with --yes,
+ # otherwise we may ask for force apply interactively.
+ console.print(
+ "No configuration changes detected. Use --force to apply anyway."
+ )
+ return
+ delete_fleet_name = plan.current_resource.name
+ action_message += " No configuration changes detected."
+ confirm_message += "Re-create the fleet?"
+ else:
+ action_message += " Configuration changes detected."
+ confirm_message += "Update the fleet in-place?"
+
+ console.print(action_message)
+ if not command_args.yes and not confirm_ask(confirm_message):
+ console.print("\nExiting...")
+ return
+
+ if delete_fleet_name is not None:
+ with console.status("Deleting existing fleet..."):
+ self.api.client.fleets.delete(
+ project_name=self.api.project, names=[delete_fleet_name]
+ )
+ # Fleet deletion is async. Wait for fleet to be deleted.
+ while True:
+ try:
+ self.api.client.fleets.get(
+ project_name=self.api.project, name=delete_fleet_name
+ )
+ except ResourceNotExistsError:
+ break
+ else:
+ time.sleep(1)
+
+ try:
+ with console.status("Applying plan..."):
+ fleet = self.api.client.fleets.apply_plan(project_name=self.api.project, plan=plan)
+ except ServerClientError as e:
+ raise CLIError(e.msg)
+ if command_args.detach:
+ console.print("Fleet configuration submitted. Exiting...")
+ return
+ try:
+ with MultiItemStatus(
+ f"Provisioning [code]{fleet.name}[/]...", console=console
+ ) as live:
+ while not _finished_provisioning(fleet):
+ table = get_fleets_table([fleet])
+ live.update(table)
+ time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS)
+ fleet = self.api.client.fleets.get(self.api.project, fleet.name)
+ except KeyboardInterrupt:
+ if confirm_ask("Delete the fleet before exiting?"):
+ with console.status("Deleting fleet..."):
+ self.api.client.fleets.delete(
+ project_name=self.api.project, names=[fleet.name]
+ )
+ else:
+ console.print("Exiting... Fleet provisioning will continue in the background.")
+ return
+ console.print(
+ get_fleets_table(
+ [fleet],
+ verbose=_failed_provisioning(fleet),
+ format_date=local_time,
+ )
+ )
+ if _failed_provisioning(fleet):
+ console.print("\n[error]Some instances failed. Check the table above for errors.[/]")
+ exit(1)
+ def _apply_plan_on_old_server(self, plan: FleetPlan, command_args: argparse.Namespace):
action_message = ""
confirm_message = ""
if plan.current_resource is None:
@@ -86,7 +184,7 @@ def apply_configuration(
diff = diff_models(
old=plan.current_resource.spec.configuration,
new=plan.spec.configuration,
- ignore={
+ reset={
"ssh_config": {
"ssh_key": True,
"proxy_jump": {"ssh_key"},
diff --git a/src/dstack/_internal/cli/services/profile.py b/src/dstack/_internal/cli/services/profile.py
index 23bbe55ad..d57ea2e13 100644
--- a/src/dstack/_internal/cli/services/profile.py
+++ b/src/dstack/_internal/cli/services/profile.py
@@ -159,7 +159,7 @@ def apply_profile_args(
if args.idle_duration is not None:
profile_settings.idle_duration = args.idle_duration
elif args.dont_destroy:
- profile_settings.idle_duration = False
+ profile_settings.idle_duration = "off"
if args.creation_policy_reuse:
profile_settings.creation_policy = CreationPolicy.REUSE
diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py
index 1deea7ee0..5eaf10ed2 100644
--- a/src/dstack/_internal/core/compatibility/runs.py
+++ b/src/dstack/_internal/core/compatibility/runs.py
@@ -3,7 +3,16 @@
from dstack._internal.core.models.common import IncludeExcludeDictType, IncludeExcludeSetType
from dstack._internal.core.models.configurations import ServiceConfiguration
from dstack._internal.core.models.runs import ApplyRunPlanInput, JobSpec, JobSubmission, RunSpec
-from dstack._internal.server.schemas.runs import GetRunPlanRequest
+from dstack._internal.server.schemas.runs import GetRunPlanRequest, ListRunsRequest
+
+
+def get_list_runs_excludes(list_runs_request: ListRunsRequest) -> IncludeExcludeSetType:
+ excludes = set()
+ if list_runs_request.include_jobs:
+ excludes.add("include_jobs")
+ if list_runs_request.job_submissions_limit is None:
+ excludes.add("job_submissions_limit")
+ return excludes
def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[IncludeExcludeDictType]:
@@ -139,6 +148,8 @@ def get_job_spec_excludes(job_specs: list[JobSpec]) -> IncludeExcludeDictType:
spec_excludes["repo_data"] = True
if all(not s.file_archives for s in job_specs):
spec_excludes["file_archives"] = True
+ if all(s.service_port is None for s in job_specs):
+ spec_excludes["service_port"] = True
return spec_excludes
diff --git a/src/dstack/_internal/core/compatibility/volumes.py b/src/dstack/_internal/core/compatibility/volumes.py
index 7395674f9..4b7be6bb0 100644
--- a/src/dstack/_internal/core/compatibility/volumes.py
+++ b/src/dstack/_internal/core/compatibility/volumes.py
@@ -30,4 +30,6 @@ def _get_volume_configuration_excludes(
configuration_excludes: IncludeExcludeDictType = {}
if configuration.tags is None:
configuration_excludes["tags"] = True
+ if configuration.auto_cleanup_duration is None:
+ configuration_excludes["auto_cleanup_duration"] = True
return configuration_excludes
diff --git a/src/dstack/_internal/core/models/common.py b/src/dstack/_internal/core/models/common.py
index c347cf0d3..a13922671 100644
--- a/src/dstack/_internal/core/models/common.py
+++ b/src/dstack/_internal/core/models/common.py
@@ -1,11 +1,14 @@
import re
from enum import Enum
-from typing import Union
+from typing import Any, Callable, Optional, Union
+import orjson
from pydantic import Field
from pydantic_duality import DualBaseModel
from typing_extensions import Annotated
+from dstack._internal.utils.json_utils import pydantic_orjson_dumps
+
IncludeExcludeFieldType = Union[int, str]
IncludeExcludeSetType = set[IncludeExcludeFieldType]
IncludeExcludeDictType = dict[
@@ -20,7 +23,40 @@
# This allows to use the same model both for a strict parsing of the user input and
# for a permissive parsing of the server responses.
class CoreModel(DualBaseModel):
- pass
+ class Config:
+ json_loads = orjson.loads
+ json_dumps = pydantic_orjson_dumps
+
+ def json(
+ self,
+ *,
+ include: Optional[IncludeExcludeType] = None,
+ exclude: Optional[IncludeExcludeType] = None,
+ by_alias: bool = False,
+ skip_defaults: Optional[bool] = None, # ignore as it's deprecated
+ exclude_unset: bool = False,
+ exclude_defaults: bool = False,
+ exclude_none: bool = False,
+ encoder: Optional[Callable[[Any], Any]] = None,
+ models_as_dict: bool = True, # does not seems to be needed by dstack or dependencies
+ **dumps_kwargs: Any,
+ ) -> str:
+ """
+ Override `json()` method so that it calls `dict()`.
+ Allows changing how models are serialized by overriding `dict()` only.
+ By default, `json()` won't call `dict()`, so changes applied in `dict()` won't take place.
+ """
+ data = self.dict(
+ by_alias=by_alias,
+ include=include,
+ exclude=exclude,
+ exclude_unset=exclude_unset,
+ exclude_defaults=exclude_defaults,
+ exclude_none=exclude_none,
+ )
+ if self.__custom_root_type__:
+ data = data["__root__"]
+ return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)
class Duration(int):
diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py
index 97be403ca..a22db9e36 100644
--- a/src/dstack/_internal/core/models/configurations.py
+++ b/src/dstack/_internal/core/models/configurations.py
@@ -4,6 +4,7 @@
from pathlib import PurePosixPath
from typing import Any, Dict, List, Optional, Union
+import orjson
from pydantic import Field, ValidationError, conint, constr, root_validator, validator
from typing_extensions import Annotated, Literal
@@ -18,6 +19,9 @@
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
from dstack._internal.core.models.unix import UnixUser
from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point
+from dstack._internal.utils.json_utils import (
+ pydantic_orjson_dumps_with_indent,
+)
CommandsList = List[str]
ValidPort = conint(gt=0, le=65536)
@@ -394,8 +398,9 @@ class TaskConfiguration(
class ServiceConfigurationParams(CoreModel):
port: Annotated[
+ # NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used.
Union[ValidPort, constr(regex=r"^[0-9]+:[0-9]+$"), PortMapping],
- Field(description="The port, that application listens on or the mapping"),
+ Field(description="The port the application listens on"),
]
gateway: Annotated[
Optional[Union[bool, str]],
@@ -573,6 +578,9 @@ class DstackConfiguration(CoreModel):
]
class Config:
+ json_loads = orjson.loads
+ json_dumps = pydantic_orjson_dumps_with_indent
+
@staticmethod
def schema_extra(schema: Dict[str, Any]):
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
diff --git a/src/dstack/_internal/core/models/fleets.py b/src/dstack/_internal/core/models/fleets.py
index 6cf970a95..fd616b754 100644
--- a/src/dstack/_internal/core/models/fleets.py
+++ b/src/dstack/_internal/core/models/fleets.py
@@ -8,7 +8,7 @@
from typing_extensions import Annotated, Literal
from dstack._internal.core.models.backends.base import BackendType
-from dstack._internal.core.models.common import CoreModel
+from dstack._internal.core.models.common import ApplyAction, CoreModel
from dstack._internal.core.models.envs import Env
from dstack._internal.core.models.instances import Instance, InstanceOfferWithAvailability, SSHKey
from dstack._internal.core.models.profiles import (
@@ -324,6 +324,7 @@ class FleetPlan(CoreModel):
offers: List[InstanceOfferWithAvailability]
total_offers: int
max_offer_price: Optional[float] = None
+ action: Optional[ApplyAction] = None # default value for backward compatibility
def get_effective_spec(self) -> FleetSpec:
if self.effective_spec is not None:
diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py
index 62997ce4e..78199608c 100644
--- a/src/dstack/_internal/core/models/profiles.py
+++ b/src/dstack/_internal/core/models/profiles.py
@@ -1,12 +1,14 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union, overload
+import orjson
from pydantic import Field, root_validator, validator
from typing_extensions import Annotated, Literal
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel, Duration
from dstack._internal.utils.common import list_enum_values_for_annotation
+from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent
from dstack._internal.utils.tags import tags_validator
DEFAULT_RETRY_DURATION = 3600
@@ -74,11 +76,9 @@ def parse_off_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str
return parse_duration(v)
-def parse_idle_duration(v: Optional[Union[int, str, bool]]) -> Optional[Union[str, int, bool]]:
- if v is False:
+def parse_idle_duration(v: Optional[Union[int, str]]) -> Optional[Union[str, int]]:
+ if v == "off" or v == -1:
return -1
- if v is True:
- return None
return parse_duration(v)
@@ -249,7 +249,7 @@ class ProfileParams(CoreModel):
),
] = None
idle_duration: Annotated[
- Optional[Union[Literal["off"], str, int, bool]],
+ Optional[Union[Literal["off"], str, int]],
Field(
description=(
"Time to wait before terminating idle instances."
@@ -343,6 +343,9 @@ class ProfilesConfig(CoreModel):
profiles: List[Profile]
class Config:
+ json_loads = orjson.loads
+ json_dumps = pydantic_orjson_dumps_with_indent
+
schema_extra = {"$schema": "http://json-schema.org/draft-07/schema#"}
def default(self) -> Optional[Profile]:
diff --git a/src/dstack/_internal/core/models/resources.py b/src/dstack/_internal/core/models/resources.py
index a0ecff953..15c80f716 100644
--- a/src/dstack/_internal/core/models/resources.py
+++ b/src/dstack/_internal/core/models/resources.py
@@ -382,14 +382,6 @@ def schema_extra(schema: Dict[str, Any]):
gpu: Annotated[Optional[GPUSpec], Field(description="The GPU requirements")] = None
disk: Annotated[Optional[DiskSpec], Field(description="The disk resources")] = DEFAULT_DISK
- # TODO: Remove in 0.20. Added for backward compatibility.
- @root_validator
- def _post_validate(cls, values):
- cpu = values.get("cpu")
- if isinstance(cpu, CPUSpec) and cpu.arch in [None, gpuhunt.CPUArchitecture.X86]:
- values["cpu"] = cpu.count
- return values
-
def pretty_format(self) -> str:
# TODO: Remove in 0.20. Use self.cpu directly
cpu = parse_obj_as(CPUSpec, self.cpu)
@@ -407,3 +399,18 @@ def pretty_format(self) -> str:
resources.update(disk_size=self.disk.size)
res = pretty_resources(**resources)
return res
+
+ def dict(self, *args, **kwargs) -> Dict:
+ # super() does not work with pydantic-duality
+ res = CoreModel.dict(self, *args, **kwargs)
+ self._update_serialized_cpu(res)
+ return res
+
+ # TODO: Remove in 0.20. Added for backward compatibility.
+ def _update_serialized_cpu(self, values: Dict):
+ cpu = values["cpu"]
+ if cpu:
+ arch = cpu.get("arch")
+ count = cpu.get("count")
+ if count and arch in [None, gpuhunt.CPUArchitecture.X86.value]:
+ values["cpu"] = count
diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py
index 49691eb50..3f67646ff 100644
--- a/src/dstack/_internal/core/models/runs.py
+++ b/src/dstack/_internal/core/models/runs.py
@@ -11,6 +11,7 @@
DEFAULT_REPO_DIR,
AnyRunConfiguration,
RunConfiguration,
+ ServiceConfiguration,
)
from dstack._internal.core.models.files import FileArchiveMapping
from dstack._internal.core.models.instances import (
@@ -101,6 +102,14 @@ def to_status(self) -> "RunStatus":
}
return mapping[self]
+ def to_error(self) -> Optional[str]:
+ if self == RunTerminationReason.RETRY_LIMIT_EXCEEDED:
+ return "retry limit exceeded"
+ elif self == RunTerminationReason.SERVER_ERROR:
+ return "server error"
+ else:
+ return None
+
class JobTerminationReason(str, Enum):
# Set by the server
@@ -162,6 +171,24 @@ def to_retry_event(self) -> Optional[RetryEvent]:
default = RetryEvent.ERROR if self.to_status() == JobStatus.FAILED else None
return mapping.get(self, default)
+ def to_error(self) -> Optional[str]:
+ # Should return None for values that are already
+ # handled and shown in status_message.
+ error_mapping = {
+ JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable",
+ JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded",
+ JobTerminationReason.VOLUME_ERROR: "volume error",
+ JobTerminationReason.GATEWAY_ERROR: "gateway error",
+ JobTerminationReason.SCALED_DOWN: "scaled down",
+ JobTerminationReason.INACTIVITY_DURATION_EXCEEDED: "inactivity duration exceeded",
+ JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY: "utilization policy",
+ JobTerminationReason.PORTS_BINDING_FAILED: "ports binding failed",
+ JobTerminationReason.CREATING_CONTAINER_ERROR: "runner error",
+ JobTerminationReason.EXECUTOR_ERROR: "executor error",
+ JobTerminationReason.MAX_DURATION_EXCEEDED: "max duration exceeded",
+ }
+ return error_mapping.get(self)
+
class Requirements(CoreModel):
# TODO: Make requirements' fields required
@@ -227,6 +254,8 @@ class JobSpec(CoreModel):
# TODO: drop this comment when supporting jobs submitted before 0.19.17 is no longer relevant.
repo_code_hash: Optional[str] = None
file_archives: list[FileArchiveMapping] = []
+ # None for non-services and pre-0.19.19 services. See `get_service_port`
+ service_port: Optional[int] = None
class JobProvisioningData(CoreModel):
@@ -305,13 +334,12 @@ class JobSubmission(CoreModel):
finished_at: Optional[datetime]
inactivity_secs: Optional[int]
status: JobStatus
+ status_message: str = "" # default for backward compatibility
termination_reason: Optional[JobTerminationReason]
termination_reason_message: Optional[str]
exit_status: Optional[int]
job_provisioning_data: Optional[JobProvisioningData]
job_runtime_data: Optional[JobRuntimeData]
- # TODO: make status_message and error a computed field after migrating to pydanticV2
- status_message: Optional[str] = None
error: Optional[str] = None
@property
@@ -325,71 +353,6 @@ def duration(self) -> timedelta:
end_time = self.finished_at
return end_time - self.submitted_at
- @root_validator
- def _status_message(cls, values) -> Dict:
- try:
- status = values["status"]
- termination_reason = values["termination_reason"]
- exit_code = values["exit_status"]
- except KeyError:
- return values
- values["status_message"] = JobSubmission._get_status_message(
- status=status,
- termination_reason=termination_reason,
- exit_status=exit_code,
- )
- return values
-
- @staticmethod
- def _get_status_message(
- status: JobStatus,
- termination_reason: Optional[JobTerminationReason],
- exit_status: Optional[int],
- ) -> str:
- if status == JobStatus.DONE:
- return "exited (0)"
- elif status == JobStatus.FAILED:
- if termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR:
- return f"exited ({exit_status})"
- elif termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY:
- return "no offers"
- elif termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY:
- return "interrupted"
- else:
- return "error"
- elif status == JobStatus.TERMINATED:
- if termination_reason == JobTerminationReason.TERMINATED_BY_USER:
- return "stopped"
- elif termination_reason == JobTerminationReason.ABORTED_BY_USER:
- return "aborted"
- return status.value
-
- @root_validator
- def _error(cls, values) -> Dict:
- try:
- termination_reason = values["termination_reason"]
- except KeyError:
- return values
- values["error"] = JobSubmission._get_error(termination_reason=termination_reason)
- return values
-
- @staticmethod
- def _get_error(termination_reason: Optional[JobTerminationReason]) -> Optional[str]:
- error_mapping = {
- JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable",
- JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded",
- JobTerminationReason.VOLUME_ERROR: "volume error",
- JobTerminationReason.GATEWAY_ERROR: "gateway error",
- JobTerminationReason.SCALED_DOWN: "scaled down",
- JobTerminationReason.INACTIVITY_DURATION_EXCEEDED: "inactivity duration exceeded",
- JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY: "utilization policy",
- JobTerminationReason.PORTS_BINDING_FAILED: "ports binding failed",
- JobTerminationReason.CREATING_CONTAINER_ERROR: "runner error",
- JobTerminationReason.EXECUTOR_ERROR: "executor error",
- JobTerminationReason.MAX_DURATION_EXCEEDED: "max duration exceeded",
- }
- return error_mapping.get(termination_reason)
-
class Job(CoreModel):
job_spec: JobSpec
@@ -524,85 +487,17 @@ class Run(CoreModel):
submitted_at: datetime
last_processed_at: datetime
status: RunStatus
- status_message: Optional[str] = None
- termination_reason: Optional[RunTerminationReason]
+ status_message: str = "" # default for backward compatibility
+ termination_reason: Optional[RunTerminationReason] = None
run_spec: RunSpec
jobs: List[Job]
- latest_job_submission: Optional[JobSubmission]
+ latest_job_submission: Optional[JobSubmission] = None
cost: float = 0
service: Optional[ServiceSpec] = None
deployment_num: int = 0 # default for compatibility with pre-0.19.14 servers
- # TODO: make error a computed field after migrating to pydanticV2
error: Optional[str] = None
deleted: Optional[bool] = None
- @root_validator
- def _error(cls, values) -> Dict:
- try:
- termination_reason = values["termination_reason"]
- except KeyError:
- return values
- values["error"] = Run._get_error(termination_reason=termination_reason)
- return values
-
- @staticmethod
- def _get_error(termination_reason: Optional[RunTerminationReason]) -> Optional[str]:
- if termination_reason == RunTerminationReason.RETRY_LIMIT_EXCEEDED:
- return "retry limit exceeded"
- elif termination_reason == RunTerminationReason.SERVER_ERROR:
- return "server error"
- else:
- return None
-
- @root_validator
- def _status_message(cls, values) -> Dict:
- try:
- status = values["status"]
- jobs: List[Job] = values["jobs"]
- retry_on_events = (
- jobs[0].job_spec.retry.on_events if jobs and jobs[0].job_spec.retry else []
- )
- job_status = (
- jobs[0].job_submissions[-1].status
- if len(jobs) == 1 and jobs[0].job_submissions
- else None
- )
- termination_reason = Run.get_last_termination_reason(jobs[0]) if jobs else None
- except KeyError:
- return values
- values["status_message"] = Run._get_status_message(
- status=status,
- job_status=job_status,
- retry_on_events=retry_on_events,
- termination_reason=termination_reason,
- )
- return values
-
- @staticmethod
- def get_last_termination_reason(job: "Job") -> Optional[JobTerminationReason]:
- for submission in reversed(job.job_submissions):
- if submission.termination_reason is not None:
- return submission.termination_reason
- return None
-
- @staticmethod
- def _get_status_message(
- status: RunStatus,
- job_status: Optional[JobStatus],
- retry_on_events: List[RetryEvent],
- termination_reason: Optional[JobTerminationReason],
- ) -> str:
- if job_status == JobStatus.PULLING:
- return "pulling"
- # Currently, `retrying` is shown only for `no-capacity` events
- if (
- status in [RunStatus.SUBMITTED, RunStatus.PENDING]
- and termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
- and RetryEvent.NO_CAPACITY in retry_on_events
- ):
- return "retrying"
- return status.value
-
def is_deployment_in_progress(self) -> bool:
return any(
not j.job_submissions[-1].status.is_finished()
@@ -658,3 +553,11 @@ def get_policy_map(spot_policy: Optional[SpotPolicy], default: SpotPolicy) -> Op
SpotPolicy.ONDEMAND: False,
}
return policy_map[spot_policy]
+
+
+def get_service_port(job_spec: JobSpec, configuration: ServiceConfiguration) -> int:
+ # Compatibility with pre-0.19.19 job specs that do not have the `service_port` property.
+ # TODO: drop when pre-0.19.19 jobs are no longer relevant.
+ if job_spec.service_port is None:
+ return configuration.port.container_port
+ return job_spec.service_port
diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py
index 773fd9429..0f89b770f 100644
--- a/src/dstack/_internal/core/models/volumes.py
+++ b/src/dstack/_internal/core/models/volumes.py
@@ -9,6 +9,7 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
+from dstack._internal.core.models.profiles import parse_idle_duration
from dstack._internal.core.models.resources import Memory
from dstack._internal.utils.common import get_or_error
from dstack._internal.utils.tags import tags_validator
@@ -44,6 +45,16 @@ class VolumeConfiguration(CoreModel):
Optional[str],
Field(description="The volume ID. Must be specified when registering external volumes"),
] = None
+ auto_cleanup_duration: Annotated[
+ Optional[Union[str, int]],
+ Field(
+ description=(
+ "Time to wait after volume is no longer used by any job before deleting it. "
+ "Defaults to keep the volume indefinitely. "
+ "Use the value 'off' or -1 to disable auto-cleanup."
+ )
+ ),
+ ] = None
tags: Annotated[
Optional[Dict[str, str]],
Field(
@@ -56,6 +67,9 @@ class VolumeConfiguration(CoreModel):
] = None
_validate_tags = validator("tags", pre=True, allow_reuse=True)(tags_validator)
+ _validate_auto_cleanup_duration = validator(
+ "auto_cleanup_duration", pre=True, allow_reuse=True
+ )(parse_idle_duration)
@property
def size_gb(self) -> int:
diff --git a/src/dstack/_internal/core/services/diff.py b/src/dstack/_internal/core/services/diff.py
index d50ab90e5..0d63cebc4 100644
--- a/src/dstack/_internal/core/services/diff.py
+++ b/src/dstack/_internal/core/services/diff.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, TypedDict
+from typing import Any, Optional, TypedDict, TypeVar
from pydantic import BaseModel
@@ -15,20 +15,19 @@ class ModelFieldDiff(TypedDict):
# TODO: calculate nested diffs
def diff_models(
- old: BaseModel, new: BaseModel, ignore: Optional[IncludeExcludeType] = None
+ old: BaseModel, new: BaseModel, reset: Optional[IncludeExcludeType] = None
) -> ModelDiff:
"""
Returns a diff of model instances fields.
- NOTE: `ignore` is implemented as `BaseModel.parse_obj(BaseModel.dict(exclude=ignore))`,
- that is, the "ignored" fields are actually not ignored but reset to the default values
- before comparison, meaning that 1) any field in `ignore` must have a default value,
- 2) the default value must be equal to itself (e.g. `math.nan` != `math.nan`).
+ The fields specified in the `reset` option are reset to their default values, effectively
+ excluding them from comparison (assuming that the default value is equal to itself, e.g,
+ `None == None`, `"task" == "task"`, but `math.nan != math.nan`).
Args:
old: The "old" model instance.
new: The "new" model instance.
- ignore: Optional fields to ignore.
+ reset: Fields to reset to their default values before comparison.
Returns:
A dict of changed fields in the form of
@@ -37,9 +36,9 @@ def diff_models(
if type(old) is not type(new):
raise TypeError("Both instances must be of the same Pydantic model class.")
- if ignore is not None:
- old = type(old).parse_obj(old.dict(exclude=ignore))
- new = type(new).parse_obj(new.dict(exclude=ignore))
+ if reset is not None:
+ old = copy_model(old, reset=reset)
+ new = copy_model(new, reset=reset)
changes: ModelDiff = {}
for field in old.__fields__:
@@ -49,3 +48,24 @@ def diff_models(
changes[field] = {"old": old_value, "new": new_value}
return changes
+
+
+M = TypeVar("M", bound=BaseModel)
+
+
+def copy_model(model: M, reset: Optional[IncludeExcludeType] = None) -> M:
+ """
+ Returns a deep copy of the model instance.
+
+ Implemented as `BaseModel.parse_obj(BaseModel.dict())`, thus,
+ unlike `BaseModel.copy(deep=True)`, runs all validations.
+
+ The fields specified in the `reset` option are reset to their default values.
+
+ Args:
+ reset: Fields to reset to their default values.
+
+ Returns:
+ A deep copy of the model instance.
+ """
+ return type(model).parse_obj(model.dict(exclude=reset))
diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py
index a095fafb2..d0ad4ac64 100644
--- a/src/dstack/_internal/core/services/ssh/attach.py
+++ b/src/dstack/_internal/core/services/ssh/attach.py
@@ -64,6 +64,7 @@ def __init__(
run_name: str,
dockerized: bool,
ssh_proxy: Optional[SSHConnectionParams] = None,
+ service_port: Optional[int] = None,
local_backend: bool = False,
bind_address: Optional[str] = None,
):
@@ -90,6 +91,7 @@ def __init__(
},
)
self.ssh_proxy = ssh_proxy
+ self.service_port = service_port
hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {}
self.hosts = hosts
diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py
index 7435ff8cb..8aff963c1 100644
--- a/src/dstack/_internal/server/app.py
+++ b/src/dstack/_internal/server/app.py
@@ -10,7 +10,7 @@
import sentry_sdk
from fastapi import FastAPI, Request, Response, status
from fastapi.datastructures import URL
-from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
+from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from prometheus_client import Counter, Histogram
@@ -56,6 +56,7 @@
)
from dstack._internal.server.utils.logging import configure_logging
from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
check_client_server_compatibility,
error_detail,
get_server_client_error_details,
@@ -90,7 +91,10 @@ def create_app() -> FastAPI:
profiles_sample_rate=settings.SENTRY_PROFILES_SAMPLE_RATE,
)
- app = FastAPI(docs_url="/api/docs", lifespan=lifespan)
+ app = FastAPI(
+ docs_url="/api/docs",
+ lifespan=lifespan,
+ )
app.state.proxy_dependency_injector = ServerProxyDependencyInjector()
return app
@@ -147,7 +151,10 @@ async def lifespan(app: FastAPI):
)
if settings.SERVER_S3_BUCKET is not None or settings.SERVER_GCS_BUCKET is not None:
init_default_storage()
- scheduler = start_background_tasks()
+ if settings.SERVER_BACKGROUND_PROCESSING_ENABLED:
+ scheduler = start_background_tasks()
+ else:
+ logger.info("Background processing is disabled")
dstack_version = DSTACK_VERSION if DSTACK_VERSION else "(no version)"
logger.info(f"The admin token is {admin.token.get_plaintext_or_error()}", {"show_path": False})
logger.info(
@@ -157,7 +164,8 @@ async def lifespan(app: FastAPI):
for func in _ON_STARTUP_HOOKS:
await func(app)
yield
- scheduler.shutdown()
+ if settings.SERVER_BACKGROUND_PROCESSING_ENABLED:
+ scheduler.shutdown()
await gateway_connections_pool.remove_all()
service_conn_pool = await get_injector_from_app(app).get_service_connection_pool()
await service_conn_pool.remove_all()
@@ -208,14 +216,14 @@ async def forbidden_error_handler(request: Request, exc: ForbiddenError):
msg = "Access denied"
if len(exc.args) > 0:
msg = exc.args[0]
- return JSONResponse(
+ return CustomORJSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content=error_detail(msg),
)
@app.exception_handler(ServerClientError)
async def server_client_error_handler(request: Request, exc: ServerClientError):
- return JSONResponse(
+ return CustomORJSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": get_server_client_error_details(exc)},
)
@@ -223,7 +231,7 @@ async def server_client_error_handler(request: Request, exc: ServerClientError):
@app.exception_handler(OSError)
async def os_error_handler(request, exc: OSError):
if exc.errno in [36, 63]:
- return JSONResponse(
+ return CustomORJSONResponse(
{"detail": "Filename too long"},
status_code=status.HTTP_400_BAD_REQUEST,
)
@@ -309,7 +317,7 @@ async def check_client_version(request: Request, call_next):
@app.get("/healthcheck")
async def healthcheck():
- return JSONResponse(content={"status": "running"})
+ return CustomORJSONResponse(content={"status": "running"})
if ui and Path(__file__).parent.joinpath("statics").exists():
app.mount(
@@ -323,7 +331,7 @@ async def custom_http_exception_handler(request, exc):
or _is_proxy_request(request)
or _is_prometheus_request(request)
):
- return JSONResponse(
+ return CustomORJSONResponse(
{"detail": exc.detail},
status_code=status.HTTP_404_NOT_FOUND,
)
diff --git a/src/dstack/_internal/server/background/__init__.py b/src/dstack/_internal/server/background/__init__.py
index 2dd410cd2..ec927ecfc 100644
--- a/src/dstack/_internal/server/background/__init__.py
+++ b/src/dstack/_internal/server/background/__init__.py
@@ -4,9 +4,10 @@
from dstack._internal.server import settings
from dstack._internal.server.background.tasks.process_fleets import process_fleets
from dstack._internal.server.background.tasks.process_gateways import (
+ process_gateways,
process_gateways_connections,
- process_submitted_gateways,
)
+from dstack._internal.server.background.tasks.process_idle_volumes import process_idle_volumes
from dstack._internal.server.background.tasks.process_instances import (
process_instances,
)
@@ -70,11 +71,12 @@ def start_background_tasks() -> AsyncIOScheduler:
)
_scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1)
_scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15))
+ _scheduler.add_job(process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5)
_scheduler.add_job(
- process_submitted_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5
+ process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5
)
_scheduler.add_job(
- process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5
+ process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1
)
_scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30, jitter=5))
for replica in range(settings.SERVER_BACKGROUND_PROCESSING_FACTOR):
diff --git a/src/dstack/_internal/server/background/tasks/process_gateways.py b/src/dstack/_internal/server/background/tasks/process_gateways.py
index e2a17aa15..cd6025a1a 100644
--- a/src/dstack/_internal/server/background/tasks/process_gateways.py
+++ b/src/dstack/_internal/server/background/tasks/process_gateways.py
@@ -16,6 +16,7 @@
gateway_connections_pool,
)
from dstack._internal.server.services.locking import advisory_lock_ctx, get_locker
+from dstack._internal.server.services.logging import fmt
from dstack._internal.utils.common import get_current_datetime
from dstack._internal.utils.logging import get_logger
@@ -27,14 +28,14 @@ async def process_gateways_connections():
await _process_active_connections()
-async def process_submitted_gateways():
+async def process_gateways():
lock, lockset = get_locker(get_db().dialect_name).get_lockset(GatewayModel.__tablename__)
async with get_session_ctx() as session:
async with lock:
res = await session.execute(
select(GatewayModel)
.where(
- GatewayModel.status == GatewayStatus.SUBMITTED,
+ GatewayModel.status.in_([GatewayStatus.SUBMITTED, GatewayStatus.PROVISIONING]),
GatewayModel.id.not_in(lockset),
)
.options(lazyload(GatewayModel.gateway_compute))
@@ -48,7 +49,25 @@ async def process_submitted_gateways():
lockset.add(gateway_model.id)
try:
gateway_model_id = gateway_model.id
- await _process_submitted_gateway(session=session, gateway_model=gateway_model)
+ initial_status = gateway_model.status
+ if initial_status == GatewayStatus.SUBMITTED:
+ await _process_submitted_gateway(session=session, gateway_model=gateway_model)
+ elif initial_status == GatewayStatus.PROVISIONING:
+ await _process_provisioning_gateway(session=session, gateway_model=gateway_model)
+ else:
+ logger.error(
+ "%s: unexpected gateway status %r", fmt(gateway_model), initial_status.upper()
+ )
+ if gateway_model.status != initial_status:
+ logger.info(
+ "%s: gateway status has changed %s -> %s%s",
+ fmt(gateway_model),
+ initial_status.upper(),
+ gateway_model.status.upper(),
+ f": {gateway_model.status_message}" if gateway_model.status_message else "",
+ )
+ gateway_model.last_processed_at = get_current_datetime()
+ await session.commit()
finally:
lockset.difference_update([gateway_model_id])
@@ -89,7 +108,7 @@ async def _process_connection(conn: GatewayConnection):
async def _process_submitted_gateway(session: AsyncSession, gateway_model: GatewayModel):
- logger.info("Started gateway %s provisioning", gateway_model.name)
+ logger.info("%s: started gateway provisioning", fmt(gateway_model))
# Refetch to load related attributes.
# joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
res = await session.execute(
@@ -110,8 +129,6 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
except BackendNotAvailable:
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Backend not available"
- gateway_model.last_processed_at = get_current_datetime()
- await session.commit()
return
try:
@@ -123,53 +140,54 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
)
session.add(gateway_model)
gateway_model.status = GatewayStatus.PROVISIONING
- await session.commit()
- await session.refresh(gateway_model)
except BackendError as e:
- logger.info(
- "Failed to create gateway compute for gateway %s: %s", gateway_model.name, repr(e)
- )
+ logger.info("%s: failed to create gateway compute: %r", fmt(gateway_model), e)
gateway_model.status = GatewayStatus.FAILED
status_message = f"Backend error: {repr(e)}"
if len(e.args) > 0:
status_message = str(e.args[0])
gateway_model.status_message = status_message
- gateway_model.last_processed_at = get_current_datetime()
- await session.commit()
- return
except Exception as e:
- logger.exception(
- "Got exception when creating gateway compute for gateway %s", gateway_model.name
- )
+ logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model))
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = f"Unexpected error: {repr(e)}"
- gateway_model.last_processed_at = get_current_datetime()
- await session.commit()
- return
+
+async def _process_provisioning_gateway(
+ session: AsyncSession, gateway_model: GatewayModel
+) -> None:
+ # Refetch to load related attributes.
+ # joinedload produces LEFT OUTER JOIN that can't be used with FOR UPDATE.
+ res = await session.execute(
+ select(GatewayModel)
+ .where(GatewayModel.id == gateway_model.id)
+ .execution_options(populate_existing=True)
+ )
+ gateway_model = res.unique().scalar_one()
+
+ # FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway:
+ # - cannot delete the gateway before it is provisioned because the DB model is locked
+ # - connection retry counter is reset on server restart
+ # - only one server replica is processing the gateway
+ # Easy to fix by doing only one connection/configuration attempt per processing iteration. The
+ # main challenge is applying the same provisioning model to the dstack Sky gateway to avoid
+ # maintaining a different model for Sky.
connection = await gateways_services.connect_to_gateway_with_retry(
gateway_model.gateway_compute
)
if connection is None:
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Failed to connect to gateway"
- gateway_model.last_processed_at = get_current_datetime()
gateway_model.gateway_compute.deleted = True
- await session.commit()
return
-
try:
await gateways_services.configure_gateway(connection)
except Exception:
- logger.exception("Failed to configure gateway %s", gateway_model.name)
+ logger.exception("%s: failed to configure gateway", fmt(gateway_model))
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Failed to configure gateway"
- gateway_model.last_processed_at = get_current_datetime()
await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
gateway_model.gateway_compute.active = False
- await session.commit()
return
gateway_model.status = GatewayStatus.RUNNING
- gateway_model.last_processed_at = get_current_datetime()
- await session.commit()
diff --git a/src/dstack/_internal/server/background/tasks/process_idle_volumes.py b/src/dstack/_internal/server/background/tasks/process_idle_volumes.py
new file mode 100644
index 000000000..33d9d5a9b
--- /dev/null
+++ b/src/dstack/_internal/server/background/tasks/process_idle_volumes.py
@@ -0,0 +1,139 @@
+import datetime
+from typing import List
+
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.orm import joinedload
+
+from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport
+from dstack._internal.core.errors import BackendNotAvailable
+from dstack._internal.core.models.profiles import parse_duration
+from dstack._internal.core.models.volumes import VolumeStatus
+from dstack._internal.server.db import get_db, get_session_ctx
+from dstack._internal.server.models import ProjectModel, VolumeModel
+from dstack._internal.server.services import backends as backends_services
+from dstack._internal.server.services.locking import get_locker
+from dstack._internal.server.services.volumes import (
+ get_volume_configuration,
+ volume_model_to_volume,
+)
+from dstack._internal.utils import common
+from dstack._internal.utils.common import get_current_datetime
+from dstack._internal.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+async def process_idle_volumes():
+ lock, lockset = get_locker(get_db().dialect_name).get_lockset(VolumeModel.__tablename__)
+ async with get_session_ctx() as session:
+ async with lock:
+ res = await session.execute(
+ select(VolumeModel.id)
+ .where(
+ VolumeModel.status == VolumeStatus.ACTIVE,
+ VolumeModel.deleted == False,
+ VolumeModel.id.not_in(lockset),
+ )
+ .order_by(VolumeModel.last_processed_at.asc())
+ .limit(10)
+ .with_for_update(skip_locked=True, key_share=True)
+ )
+ volume_ids = list(res.scalars().all())
+ if not volume_ids:
+ return
+ for volume_id in volume_ids:
+ lockset.add(volume_id)
+
+ res = await session.execute(
+ select(VolumeModel)
+ .where(VolumeModel.id.in_(volume_ids))
+ .options(joinedload(VolumeModel.project).joinedload(ProjectModel.backends))
+ .options(joinedload(VolumeModel.user))
+ .options(joinedload(VolumeModel.attachments))
+ .execution_options(populate_existing=True)
+ )
+ volume_models = list(res.unique().scalars().all())
+ try:
+ volumes_to_delete = [v for v in volume_models if _should_delete_volume(v)]
+ if not volumes_to_delete:
+ return
+ await _delete_idle_volumes(session, volumes_to_delete)
+ finally:
+ lockset.difference_update(volume_ids)
+
+
+def _should_delete_volume(volume: VolumeModel) -> bool:
+ if volume.attachments:
+ return False
+
+ config = get_volume_configuration(volume)
+ if not config.auto_cleanup_duration:
+ return False
+
+ duration_seconds = parse_duration(config.auto_cleanup_duration)
+ if not duration_seconds or duration_seconds <= 0:
+ return False
+
+ idle_time = _get_idle_time(volume)
+ threshold = datetime.timedelta(seconds=duration_seconds)
+ return idle_time > threshold
+
+
+def _get_idle_time(volume: VolumeModel) -> datetime.timedelta:
+ last_used = volume.last_job_processed_at or volume.created_at
+ last_used_utc = last_used.replace(tzinfo=datetime.timezone.utc)
+ idle_time = get_current_datetime() - last_used_utc
+ return max(idle_time, datetime.timedelta(0))
+
+
+async def _delete_idle_volumes(session: AsyncSession, volumes: List[VolumeModel]):
+ # Note: Multiple volumes are deleted in the same transaction,
+ # so long deletion of one volume may block processing other volumes.
+ for volume_model in volumes:
+ logger.info("Deleting idle volume %s", volume_model.name)
+ try:
+ await _delete_idle_volume(session, volume_model)
+ except Exception:
+ logger.exception("Error when deleting idle volume %s", volume_model.name)
+
+ volume_model.deleted = True
+ volume_model.deleted_at = get_current_datetime()
+
+ logger.info("Deleted idle volume %s", volume_model.name)
+
+ await session.commit()
+
+
+async def _delete_idle_volume(session: AsyncSession, volume_model: VolumeModel):
+ volume = volume_model_to_volume(volume_model)
+
+ if volume.provisioning_data is None:
+ logger.error(
+ f"Failed to delete volume {volume_model.name}. volume.provisioning_data is None."
+ )
+ return
+
+ if volume.provisioning_data.backend is None:
+ logger.error(
+ f"Failed to delete volume {volume_model.name}. volume.provisioning_data.backend is None."
+ )
+ return
+
+ try:
+ backend = await backends_services.get_project_backend_by_type_or_error(
+ project=volume_model.project,
+ backend_type=volume.provisioning_data.backend,
+ )
+ except BackendNotAvailable:
+ logger.error(
+ f"Failed to delete volume {volume_model.name}. Backend {volume.configuration.backend} not available."
+ )
+ return
+
+ compute = backend.compute()
+ assert isinstance(compute, ComputeWithVolumeSupport)
+ await common.run_async(
+ compute.delete_volume,
+ volume=volume,
+ )
diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
index b9c2f9c94..eba9549b6 100644
--- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
+++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
@@ -739,3 +739,5 @@ async def _attach_volume(
attachment_data=attachment_data.json(),
)
instance.volume_attachments.append(volume_attachment_model)
+
+ volume_model.last_job_processed_at = common_utils.get_current_datetime()
diff --git a/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py b/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py
index a0cac48af..e26206222 100644
--- a/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py
+++ b/src/dstack/_internal/server/migrations/versions/35e90e1b0d3e_add_rolling_deployment_fields.py
@@ -17,12 +17,6 @@
def upgrade() -> None:
- with op.batch_alter_table("jobs", schema=None) as batch_op:
- batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True))
- with op.batch_alter_table("jobs", schema=None) as batch_op:
- batch_op.execute("UPDATE jobs SET deployment_num = 0")
- batch_op.alter_column("deployment_num", nullable=False)
-
with op.batch_alter_table("runs", schema=None) as batch_op:
batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True))
batch_op.add_column(sa.Column("desired_replica_count", sa.Integer(), nullable=True))
@@ -32,6 +26,12 @@ def upgrade() -> None:
batch_op.alter_column("deployment_num", nullable=False)
batch_op.alter_column("desired_replica_count", nullable=False)
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("deployment_num", sa.Integer(), nullable=True))
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
+ batch_op.execute("UPDATE jobs SET deployment_num = 0")
+ batch_op.alter_column("deployment_num", nullable=False)
+
def downgrade() -> None:
with op.batch_alter_table("runs", schema=None) as batch_op:
diff --git a/src/dstack/_internal/server/migrations/versions/d5863798bf41_add_volumemodel_last_job_processed_at.py b/src/dstack/_internal/server/migrations/versions/d5863798bf41_add_volumemodel_last_job_processed_at.py
new file mode 100644
index 000000000..1dc883e05
--- /dev/null
+++ b/src/dstack/_internal/server/migrations/versions/d5863798bf41_add_volumemodel_last_job_processed_at.py
@@ -0,0 +1,40 @@
+"""Add VolumeModel.last_job_processed_at
+
+Revision ID: d5863798bf41
+Revises: 644b8a114187
+Create Date: 2025-07-15 14:26:22.981687
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+import dstack._internal.server.models
+
+# revision identifiers, used by Alembic.
+revision = "d5863798bf41"
+down_revision = "644b8a114187"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table("volumes", schema=None) as batch_op:
+ batch_op.add_column(
+ sa.Column(
+ "last_job_processed_at",
+ dstack._internal.server.models.NaiveDateTime(),
+ nullable=True,
+ )
+ )
+
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table("volumes", schema=None) as batch_op:
+ batch_op.drop_column("last_job_processed_at")
+
+ # ### end Alembic commands ###
diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py
index d39d07be1..c4dafe81e 100644
--- a/src/dstack/_internal/server/models.py
+++ b/src/dstack/_internal/server/models.py
@@ -645,6 +645,7 @@ class VolumeModel(BaseModel):
last_processed_at: Mapped[datetime] = mapped_column(
NaiveDateTime, default=get_current_datetime
)
+ last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
deleted_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)
diff --git a/src/dstack/_internal/server/routers/backends.py b/src/dstack/_internal/server/routers/backends.py
index 7b6056b92..b43463a90 100644
--- a/src/dstack/_internal/server/routers/backends.py
+++ b/src/dstack/_internal/server/routers/backends.py
@@ -27,7 +27,10 @@
get_backend_config_yaml,
update_backend_config_yaml,
)
-from dstack._internal.server.utils.routers import get_base_api_additional_responses
+from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
+ get_base_api_additional_responses,
+)
root_router = APIRouter(
prefix="/api/backends",
@@ -41,35 +44,37 @@
)
-@root_router.post("/list_types")
-async def list_backend_types() -> List[BackendType]:
- return dstack._internal.core.backends.configurators.list_available_backend_types()
+@root_router.post("/list_types", response_model=List[BackendType])
+async def list_backend_types():
+ return CustomORJSONResponse(
+ dstack._internal.core.backends.configurators.list_available_backend_types()
+ )
-@project_router.post("/create")
+@project_router.post("/create", response_model=AnyBackendConfigWithCreds)
async def create_backend(
body: AnyBackendConfigWithCreds,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
-) -> AnyBackendConfigWithCreds:
+):
_, project = user_project
config = await backends.create_backend(session=session, project=project, config=body)
if settings.SERVER_CONFIG_ENABLED:
await ServerConfigManager().sync_config(session=session)
- return config
+ return CustomORJSONResponse(config)
-@project_router.post("/update")
+@project_router.post("/update", response_model=AnyBackendConfigWithCreds)
async def update_backend(
body: AnyBackendConfigWithCreds,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
-) -> AnyBackendConfigWithCreds:
+):
_, project = user_project
config = await backends.update_backend(session=session, project=project, config=body)
if settings.SERVER_CONFIG_ENABLED:
await ServerConfigManager().sync_config(session=session)
- return config
+ return CustomORJSONResponse(config)
@project_router.post("/delete")
@@ -86,16 +91,16 @@ async def delete_backends(
await ServerConfigManager().sync_config(session=session)
-@project_router.post("/{backend_name}/config_info")
+@project_router.post("/{backend_name}/config_info", response_model=AnyBackendConfigWithCreds)
async def get_backend_config_info(
backend_name: BackendType,
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
-) -> AnyBackendConfigWithCreds:
+):
_, project = user_project
config = await backends.get_backend_config(project=project, backend_type=backend_name)
if config is None:
raise ResourceNotExistsError()
- return config
+ return CustomORJSONResponse(config)
@project_router.post("/create_yaml")
@@ -126,10 +131,12 @@ async def update_backend_yaml(
)
-@project_router.post("/{backend_name}/get_yaml")
+@project_router.post("/{backend_name}/get_yaml", response_model=BackendInfoYAML)
async def get_backend_yaml(
backend_name: BackendType,
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
-) -> BackendInfoYAML:
+):
_, project = user_project
- return await get_backend_config_yaml(project=project, backend_type=backend_name)
+ return CustomORJSONResponse(
+ await get_backend_config_yaml(project=project, backend_type=backend_name)
+ )
diff --git a/src/dstack/_internal/server/routers/files.py b/src/dstack/_internal/server/routers/files.py
index 574ef0177..ff7b2a3d5 100644
--- a/src/dstack/_internal/server/routers/files.py
+++ b/src/dstack/_internal/server/routers/files.py
@@ -12,6 +12,7 @@
from dstack._internal.server.services import files
from dstack._internal.server.settings import SERVER_CODE_UPLOAD_LIMIT
from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
get_base_api_additional_responses,
get_request_size,
)
@@ -24,12 +25,12 @@
)
-@router.post("/get_archive_by_hash")
+@router.post("/get_archive_by_hash", response_model=FileArchive)
async def get_archive_by_hash(
body: GetFileArchiveByHashRequest,
session: Annotated[AsyncSession, Depends(get_session)],
user: Annotated[UserModel, Depends(Authenticated())],
-) -> FileArchive:
+):
archive = await files.get_archive_by_hash(
session=session,
user=user,
@@ -37,16 +38,16 @@ async def get_archive_by_hash(
)
if archive is None:
raise ResourceNotExistsError()
- return archive
+ return CustomORJSONResponse(archive)
-@router.post("/upload_archive")
+@router.post("/upload_archive", response_model=FileArchive)
async def upload_archive(
request: Request,
file: UploadFile,
session: Annotated[AsyncSession, Depends(get_session)],
user: Annotated[UserModel, Depends(Authenticated())],
-) -> FileArchive:
+):
request_size = get_request_size(request)
if SERVER_CODE_UPLOAD_LIMIT > 0 and request_size > SERVER_CODE_UPLOAD_LIMIT:
diff_size_fmt = sizeof_fmt(request_size)
@@ -64,4 +65,4 @@ async def upload_archive(
user=user,
file=file,
)
- return archive
+ return CustomORJSONResponse(archive)
diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py
index 92aca9e18..3cbab9508 100644
--- a/src/dstack/_internal/server/routers/fleets.py
+++ b/src/dstack/_internal/server/routers/fleets.py
@@ -18,7 +18,10 @@
ListFleetsRequest,
)
from dstack._internal.server.security.permissions import Authenticated, ProjectMember
-from dstack._internal.server.utils.routers import get_base_api_additional_responses
+from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
+ get_base_api_additional_responses,
+)
root_router = APIRouter(
prefix="/api/fleets",
@@ -32,12 +35,12 @@
)
-@root_router.post("/list")
+@root_router.post("/list", response_model=List[Fleet])
async def list_fleets(
body: ListFleetsRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
-) -> List[Fleet]:
+):
"""
Returns all fleets and instances within them visible to user sorted by descending `created_at`.
`project_name` and `only_active` can be specified as filters.
@@ -45,36 +48,40 @@ async def list_fleets(
The results are paginated. To get the next page, pass `created_at` and `id` of
the last fleet from the previous page as `prev_created_at` and `prev_id`.
"""
- return await fleets_services.list_fleets(
- session=session,
- user=user,
- project_name=body.project_name,
- only_active=body.only_active,
- prev_created_at=body.prev_created_at,
- prev_id=body.prev_id,
- limit=body.limit,
- ascending=body.ascending,
+ return CustomORJSONResponse(
+ await fleets_services.list_fleets(
+ session=session,
+ user=user,
+ project_name=body.project_name,
+ only_active=body.only_active,
+ prev_created_at=body.prev_created_at,
+ prev_id=body.prev_id,
+ limit=body.limit,
+ ascending=body.ascending,
+ )
)
-@project_router.post("/list")
+@project_router.post("/list", response_model=List[Fleet])
async def list_project_fleets(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> List[Fleet]:
+):
"""
Returns all fleets in the project.
"""
_, project = user_project
- return await fleets_services.list_project_fleets(session=session, project=project)
+ return CustomORJSONResponse(
+ await fleets_services.list_project_fleets(session=session, project=project)
+ )
-@project_router.post("/get")
+@project_router.post("/get", response_model=Fleet)
async def get_fleet(
body: GetFleetRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> Fleet:
+):
"""
Returns a fleet given `name` or `id`.
If given `name`, does not return deleted fleets.
@@ -86,15 +93,15 @@ async def get_fleet(
)
if fleet is None:
raise ResourceNotExistsError()
- return fleet
+ return CustomORJSONResponse(fleet)
-@project_router.post("/get_plan")
+@project_router.post("/get_plan", response_model=FleetPlan)
async def get_plan(
body: GetFleetPlanRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> FleetPlan:
+):
"""
Returns a fleet plan for the given fleet configuration.
"""
@@ -105,45 +112,49 @@ async def get_plan(
user=user,
spec=body.spec,
)
- return plan
+ return CustomORJSONResponse(plan)
-@project_router.post("/apply")
+@project_router.post("/apply", response_model=Fleet)
async def apply_plan(
body: ApplyFleetPlanRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> Fleet:
+):
"""
Creates a new fleet or updates an existing fleet.
Errors if the expected current resource from the plan does not match the current resource.
Use `force: true` to apply even if the current resource does not match.
"""
user, project = user_project
- return await fleets_services.apply_plan(
- session=session,
- user=user,
- project=project,
- plan=body.plan,
- force=body.force,
+ return CustomORJSONResponse(
+ await fleets_services.apply_plan(
+ session=session,
+ user=user,
+ project=project,
+ plan=body.plan,
+ force=body.force,
+ )
)
-@project_router.post("/create")
+@project_router.post("/create", response_model=Fleet)
async def create_fleet(
body: CreateFleetRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> Fleet:
+):
"""
Creates a fleet given a fleet configuration.
"""
user, project = user_project
- return await fleets_services.create_fleet(
- session=session,
- project=project,
- user=user,
- spec=body.spec,
+ return CustomORJSONResponse(
+ await fleets_services.create_fleet(
+ session=session,
+ project=project,
+ user=user,
+ spec=body.spec,
+ )
)
diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py
index e0e0ad37d..fb03a3d69 100644
--- a/src/dstack/_internal/server/routers/gateways.py
+++ b/src/dstack/_internal/server/routers/gateways.py
@@ -13,7 +13,10 @@
ProjectAdmin,
ProjectMemberOrPublicAccess,
)
-from dstack._internal.server.utils.routers import get_base_api_additional_responses
+from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
+ get_base_api_additional_responses,
+)
router = APIRouter(
prefix="/api/project/{project_name}/gateways",
@@ -22,40 +25,44 @@
)
-@router.post("/list")
+@router.post("/list", response_model=List[models.Gateway])
async def list_gateways(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()),
-) -> List[models.Gateway]:
+):
_, project = user_project
- return await gateways.list_project_gateways(session=session, project=project)
+ return CustomORJSONResponse(
+ await gateways.list_project_gateways(session=session, project=project)
+ )
-@router.post("/get")
+@router.post("/get", response_model=models.Gateway)
async def get_gateway(
body: schemas.GetGatewayRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()),
-) -> models.Gateway:
+):
_, project = user_project
gateway = await gateways.get_gateway_by_name(session=session, project=project, name=body.name)
if gateway is None:
raise ResourceNotExistsError()
- return gateway
+ return CustomORJSONResponse(gateway)
-@router.post("/create")
+@router.post("/create", response_model=models.Gateway)
async def create_gateway(
body: schemas.CreateGatewayRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
-) -> models.Gateway:
+):
user, project = user_project
- return await gateways.create_gateway(
- session=session,
- user=user,
- project=project,
- configuration=body.configuration,
+ return CustomORJSONResponse(
+ await gateways.create_gateway(
+ session=session,
+ user=user,
+ project=project,
+ configuration=body.configuration,
+ )
)
@@ -83,13 +90,15 @@ async def set_default_gateway(
await gateways.set_default_gateway(session=session, project=project, name=body.name)
-@router.post("/set_wildcard_domain")
+@router.post("/set_wildcard_domain", response_model=models.Gateway)
async def set_gateway_wildcard_domain(
body: schemas.SetWildcardDomainRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
-) -> models.Gateway:
+):
_, project = user_project
- return await gateways.set_gateway_wildcard_domain(
- session=session, project=project, name=body.name, wildcard_domain=body.wildcard_domain
+ return CustomORJSONResponse(
+ await gateways.set_gateway_wildcard_domain(
+ session=session, project=project, name=body.name, wildcard_domain=body.wildcard_domain
+ )
)
diff --git a/src/dstack/_internal/server/routers/instances.py b/src/dstack/_internal/server/routers/instances.py
index 489b3bf1c..740c51fd6 100644
--- a/src/dstack/_internal/server/routers/instances.py
+++ b/src/dstack/_internal/server/routers/instances.py
@@ -9,7 +9,10 @@
from dstack._internal.server.models import UserModel
from dstack._internal.server.schemas.instances import ListInstancesRequest
from dstack._internal.server.security.permissions import Authenticated
-from dstack._internal.server.utils.routers import get_base_api_additional_responses
+from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
+ get_base_api_additional_responses,
+)
root_router = APIRouter(
prefix="/api/instances",
@@ -18,12 +21,12 @@
)
-@root_router.post("/list")
+@root_router.post("/list", response_model=List[Instance])
async def list_instances(
body: ListInstancesRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
-) -> List[Instance]:
+):
"""
Returns all instances visible to user sorted by descending `created_at`.
`project_names` and `fleet_ids` can be specified as filters.
@@ -31,14 +34,16 @@ async def list_instances(
The results are paginated. To get the next page, pass `created_at` and `id` of
the last instance from the previous page as `prev_created_at` and `prev_id`.
"""
- return await instances.list_user_instances(
- session=session,
- user=user,
- project_names=body.project_names,
- fleet_ids=body.fleet_ids,
- only_active=body.only_active,
- prev_created_at=body.prev_created_at,
- prev_id=body.prev_id,
- limit=body.limit,
- ascending=body.ascending,
+ return CustomORJSONResponse(
+ await instances.list_user_instances(
+ session=session,
+ user=user,
+ project_names=body.project_names,
+ fleet_ids=body.fleet_ids,
+ only_active=body.only_active,
+ prev_created_at=body.prev_created_at,
+ prev_id=body.prev_id,
+ limit=body.limit,
+ ascending=body.ascending,
+ )
)
diff --git a/src/dstack/_internal/server/routers/logs.py b/src/dstack/_internal/server/routers/logs.py
index a86424ee6..29685f6f5 100644
--- a/src/dstack/_internal/server/routers/logs.py
+++ b/src/dstack/_internal/server/routers/logs.py
@@ -7,7 +7,10 @@
from dstack._internal.server.schemas.logs import PollLogsRequest
from dstack._internal.server.security.permissions import ProjectMember
from dstack._internal.server.services import logs
-from dstack._internal.server.utils.routers import get_base_api_additional_responses
+from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
+ get_base_api_additional_responses,
+)
router = APIRouter(
prefix="/api/project/{project_name}/logs",
@@ -18,13 +21,14 @@
@router.post(
"/poll",
+ response_model=JobSubmissionLogs,
)
async def poll_logs(
body: PollLogsRequest,
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> JobSubmissionLogs:
+):
_, project = user_project
# The runner guarantees logs have different timestamps if throughput < 1k logs / sec.
# Otherwise, some logs with duplicated timestamps may be filtered out.
# This limitation is imposed by cloud log services that support up to millisecond timestamp resolution.
- return await logs.poll_logs_async(project=project, request=body)
+ return CustomORJSONResponse(await logs.poll_logs_async(project=project, request=body))
diff --git a/src/dstack/_internal/server/routers/metrics.py b/src/dstack/_internal/server/routers/metrics.py
index 1d4ffb1db..e61a0d9bf 100644
--- a/src/dstack/_internal/server/routers/metrics.py
+++ b/src/dstack/_internal/server/routers/metrics.py
@@ -11,7 +11,10 @@
from dstack._internal.server.security.permissions import ProjectMember
from dstack._internal.server.services import metrics
from dstack._internal.server.services.jobs import get_run_job_model
-from dstack._internal.server.utils.routers import get_base_api_additional_responses
+from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
+ get_base_api_additional_responses,
+)
router = APIRouter(
prefix="/api/project/{project_name}/metrics",
@@ -22,6 +25,7 @@
@router.get(
"/job/{run_name}",
+ response_model=JobMetrics,
)
async def get_job_metrics(
run_name: str,
@@ -32,7 +36,7 @@ async def get_job_metrics(
before: Optional[datetime] = None,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> JobMetrics:
+):
"""
Returns job-level metrics such as hardware utilization
given `run_name`, `replica_num`, and `job_num`.
@@ -63,10 +67,12 @@ async def get_job_metrics(
if job_model is None:
raise ResourceNotExistsError("Found no job with given parameters")
- return await metrics.get_job_metrics(
- session=session,
- job_model=job_model,
- limit=limit,
- after=after,
- before=before,
+ return CustomORJSONResponse(
+ await metrics.get_job_metrics(
+ session=session,
+ job_model=job_model,
+ limit=limit,
+ after=after,
+ before=before,
+ )
)
diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py
index 1d967c6c8..56d41b6ca 100644
--- a/src/dstack/_internal/server/routers/projects.py
+++ b/src/dstack/_internal/server/routers/projects.py
@@ -23,7 +23,10 @@
ProjectMemberOrPublicAccess,
)
from dstack._internal.server.services import projects
-from dstack._internal.server.utils.routers import get_base_api_additional_responses
+from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
+ get_base_api_additional_responses,
+)
router = APIRouter(
prefix="/api/projects",
@@ -32,30 +35,34 @@
)
-@router.post("/list")
+@router.post("/list", response_model=List[Project])
async def list_projects(
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
-) -> List[Project]:
+):
"""
Returns all projects visible to user sorted by descending `created_at`.
`members` and `backends` are always empty - call `/api/projects/{project_name}/get` to retrieve them.
"""
- return await projects.list_user_accessible_projects(session=session, user=user)
+ return CustomORJSONResponse(
+ await projects.list_user_accessible_projects(session=session, user=user)
+ )
-@router.post("/create")
+@router.post("/create", response_model=Project)
async def create_project(
body: CreateProjectRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
-) -> Project:
- return await projects.create_project(
- session=session,
- user=user,
- project_name=body.project_name,
- is_public=body.is_public,
+):
+ return CustomORJSONResponse(
+ await projects.create_project(
+ session=session,
+ user=user,
+ project_name=body.project_name,
+ is_public=body.is_public,
+ )
)
@@ -72,23 +79,24 @@ async def delete_projects(
)
-@router.post("/{project_name}/get")
+@router.post("/{project_name}/get", response_model=Project)
async def get_project(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()),
-) -> Project:
+):
_, project = user_project
- return projects.project_model_to_project(project)
+ return CustomORJSONResponse(projects.project_model_to_project(project))
@router.post(
"/{project_name}/set_members",
+ response_model=Project,
)
async def set_project_members(
body: SetProjectMembersRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectManager()),
-) -> Project:
+):
user, project = user_project
await projects.set_project_members(
session=session,
@@ -97,17 +105,18 @@ async def set_project_members(
members=body.members,
)
await session.refresh(project)
- return projects.project_model_to_project(project)
+ return CustomORJSONResponse(projects.project_model_to_project(project))
@router.post(
"/{project_name}/add_members",
+ response_model=Project,
)
async def add_project_members(
body: AddProjectMemberRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectManagerOrPublicProject()),
-) -> Project:
+):
user, project = user_project
await projects.add_project_members(
session=session,
@@ -116,17 +125,18 @@ async def add_project_members(
members=body.members,
)
await session.refresh(project)
- return projects.project_model_to_project(project)
+ return CustomORJSONResponse(projects.project_model_to_project(project))
@router.post(
"/{project_name}/remove_members",
+ response_model=Project,
)
async def remove_project_members(
body: RemoveProjectMemberRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectManagerOrSelfLeave()),
-) -> Project:
+):
user, project = user_project
await projects.remove_project_members(
session=session,
@@ -135,17 +145,18 @@ async def remove_project_members(
usernames=body.usernames,
)
await session.refresh(project)
- return projects.project_model_to_project(project)
+ return CustomORJSONResponse(projects.project_model_to_project(project))
@router.post(
"/{project_name}/update",
+ response_model=Project,
)
async def update_project(
body: UpdateProjectRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
-) -> Project:
+):
user, project = user_project
await projects.update_project(
session=session,
@@ -154,4 +165,4 @@ async def update_project(
is_public=body.is_public,
)
await session.refresh(project)
- return projects.project_model_to_project(project)
+ return CustomORJSONResponse(projects.project_model_to_project(project))
diff --git a/src/dstack/_internal/server/routers/repos.py b/src/dstack/_internal/server/routers/repos.py
index 32e59f631..202732f4f 100644
--- a/src/dstack/_internal/server/routers/repos.py
+++ b/src/dstack/_internal/server/routers/repos.py
@@ -16,6 +16,7 @@
from dstack._internal.server.services import repos
from dstack._internal.server.settings import SERVER_CODE_UPLOAD_LIMIT
from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
get_base_api_additional_responses,
get_request_size,
)
@@ -28,21 +29,21 @@
)
-@router.post("/list")
+@router.post("/list", response_model=List[RepoHead])
async def list_repos(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> List[RepoHead]:
+):
_, project = user_project
- return await repos.list_repos(session=session, project=project)
+ return CustomORJSONResponse(await repos.list_repos(session=session, project=project))
-@router.post("/get")
+@router.post("/get", response_model=RepoHeadWithCreds)
async def get_repo(
body: GetRepoRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> RepoHeadWithCreds:
+):
user, project = user_project
repo = await repos.get_repo(
session=session,
@@ -53,7 +54,7 @@ async def get_repo(
)
if repo is None:
raise ResourceNotExistsError()
- return repo
+ return CustomORJSONResponse(repo)
@router.post("/init")
diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py
index c6a4b60f8..8f3909503 100644
--- a/src/dstack/_internal/server/routers/runs.py
+++ b/src/dstack/_internal/server/routers/runs.py
@@ -18,7 +18,10 @@
)
from dstack._internal.server.security.permissions import Authenticated, ProjectMember
from dstack._internal.server.services import runs
-from dstack._internal.server.utils.routers import get_base_api_additional_responses
+from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
+ get_base_api_additional_responses,
+)
root_router = APIRouter(
prefix="/api/runs",
@@ -32,12 +35,15 @@
)
-@root_router.post("/list")
+@root_router.post(
+ "/list",
+ response_model=List[Run],
+)
async def list_runs(
body: ListRunsRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
-) -> List[Run]:
+):
"""
Returns all runs visible to user sorted by descending `submitted_at`.
`project_name`, `repo_id`, `username`, and `only_active` can be specified as filters.
@@ -47,26 +53,33 @@ async def list_runs(
The results are paginated. To get the next page, pass `submitted_at` and `id` of
the last run from the previous page as `prev_submitted_at` and `prev_run_id`.
"""
- return await runs.list_user_runs(
- session=session,
- user=user,
- project_name=body.project_name,
- repo_id=body.repo_id,
- username=body.username,
- only_active=body.only_active,
- prev_submitted_at=body.prev_submitted_at,
- prev_run_id=body.prev_run_id,
- limit=body.limit,
- ascending=body.ascending,
+ return CustomORJSONResponse(
+ await runs.list_user_runs(
+ session=session,
+ user=user,
+ project_name=body.project_name,
+ repo_id=body.repo_id,
+ username=body.username,
+ only_active=body.only_active,
+ include_jobs=body.include_jobs,
+ job_submissions_limit=body.job_submissions_limit,
+ prev_submitted_at=body.prev_submitted_at,
+ prev_run_id=body.prev_run_id,
+ limit=body.limit,
+ ascending=body.ascending,
+ )
)
-@project_router.post("/get")
+@project_router.post(
+ "/get",
+ response_model=Run,
+)
async def get_run(
body: GetRunRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> Run:
+):
"""
Returns a run given `run_name` or `id`.
If given `run_name`, does not return deleted runs.
@@ -81,15 +94,18 @@ async def get_run(
)
if run is None:
raise ResourceNotExistsError("Run not found")
- return run
+ return CustomORJSONResponse(run)
-@project_router.post("/get_plan")
+@project_router.post(
+ "/get_plan",
+ response_model=RunPlan,
+)
async def get_plan(
body: GetRunPlanRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> RunPlan:
+):
"""
Returns a run plan for the given run spec.
This is an optional step before calling `/apply`.
@@ -102,15 +118,18 @@ async def get_plan(
run_spec=body.run_spec,
max_offers=body.max_offers,
)
- return run_plan
+ return CustomORJSONResponse(run_plan)
-@project_router.post("/apply")
+@project_router.post(
+ "/apply",
+ response_model=Run,
+)
async def apply_plan(
body: ApplyRunPlanRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> Run:
+):
"""
Creates a new run or updates an existing run.
Errors if the expected current resource from the plan does not match the current resource.
@@ -118,12 +137,14 @@ async def apply_plan(
If the existing run is active and cannot be updated, it must be stopped first.
"""
user, project = user_project
- return await runs.apply_plan(
- session=session,
- user=user,
- project=project,
- plan=body.plan,
- force=body.force,
+ return CustomORJSONResponse(
+ await runs.apply_plan(
+ session=session,
+ user=user,
+ project=project,
+ plan=body.plan,
+ force=body.force,
+ )
)
diff --git a/src/dstack/_internal/server/routers/secrets.py b/src/dstack/_internal/server/routers/secrets.py
index bbfa26be9..c19f15bcc 100644
--- a/src/dstack/_internal/server/routers/secrets.py
+++ b/src/dstack/_internal/server/routers/secrets.py
@@ -14,6 +14,7 @@
)
from dstack._internal.server.security.permissions import ProjectAdmin
from dstack._internal.server.services import secrets as secrets_services
+from dstack._internal.server.utils.routers import CustomORJSONResponse
router = APIRouter(
prefix="/api/project/{project_name}/secrets",
@@ -21,24 +22,26 @@
)
-@router.post("/list")
+@router.post("/list", response_model=List[Secret])
async def list_secrets(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
-) -> List[Secret]:
+):
_, project = user_project
- return await secrets_services.list_secrets(
- session=session,
- project=project,
+ return CustomORJSONResponse(
+ await secrets_services.list_secrets(
+ session=session,
+ project=project,
+ )
)
-@router.post("/get")
+@router.post("/get", response_model=Secret)
async def get_secret(
body: GetSecretRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
-) -> Secret:
+):
_, project = user_project
secret = await secrets_services.get_secret(
session=session,
@@ -47,21 +50,23 @@ async def get_secret(
)
if secret is None:
raise ResourceNotExistsError()
- return secret
+ return CustomORJSONResponse(secret)
-@router.post("/create_or_update")
+@router.post("/create_or_update", response_model=Secret)
async def create_or_update_secret(
body: CreateOrUpdateSecretRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
-) -> Secret:
+):
_, project = user_project
- return await secrets_services.create_or_update_secret(
- session=session,
- project=project,
- name=body.name,
- value=body.value,
+ return CustomORJSONResponse(
+ await secrets_services.create_or_update_secret(
+ session=session,
+ project=project,
+ name=body.name,
+ value=body.value,
+ )
)
diff --git a/src/dstack/_internal/server/routers/server.py b/src/dstack/_internal/server/routers/server.py
index 31e1e04c9..28c742772 100644
--- a/src/dstack/_internal/server/routers/server.py
+++ b/src/dstack/_internal/server/routers/server.py
@@ -2,6 +2,7 @@
from dstack._internal import settings
from dstack._internal.core.models.server import ServerInfo
+from dstack._internal.server.utils.routers import CustomORJSONResponse
router = APIRouter(
prefix="/api/server",
@@ -9,8 +10,10 @@
)
-@router.post("/get_info")
-async def get_server_info() -> ServerInfo:
- return ServerInfo(
- server_version=settings.DSTACK_VERSION,
+@router.post("/get_info", response_model=ServerInfo)
+async def get_server_info():
+ return CustomORJSONResponse(
+ ServerInfo(
+ server_version=settings.DSTACK_VERSION,
+ )
)
diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py
index 670f9f0a5..abb672914 100644
--- a/src/dstack/_internal/server/routers/users.py
+++ b/src/dstack/_internal/server/routers/users.py
@@ -16,7 +16,10 @@
)
from dstack._internal.server.security.permissions import Authenticated, GlobalAdmin
from dstack._internal.server.services import users
-from dstack._internal.server.utils.routers import get_base_api_additional_responses
+from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
+ get_base_api_additional_responses,
+)
router = APIRouter(
prefix="/api/users",
@@ -25,41 +28,41 @@
)
-@router.post("/list")
+@router.post("/list", response_model=List[User])
async def list_users(
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
-) -> List[User]:
- return await users.list_users_for_user(session=session, user=user)
+):
+ return CustomORJSONResponse(await users.list_users_for_user(session=session, user=user))
-@router.post("/get_my_user")
+@router.post("/get_my_user", response_model=User)
async def get_my_user(
user: UserModel = Depends(Authenticated()),
-) -> User:
- return users.user_model_to_user(user)
+):
+ return CustomORJSONResponse(users.user_model_to_user(user))
-@router.post("/get_user")
+@router.post("/get_user", response_model=UserWithCreds)
async def get_user(
body: GetUserRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
-) -> UserWithCreds:
+):
res = await users.get_user_with_creds_by_name(
session=session, current_user=user, username=body.username
)
if res is None:
raise ResourceNotExistsError()
- return res
+ return CustomORJSONResponse(res)
-@router.post("/create")
+@router.post("/create", response_model=User)
async def create_user(
body: CreateUserRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(GlobalAdmin()),
-) -> User:
+):
res = await users.create_user(
session=session,
username=body.username,
@@ -67,15 +70,15 @@ async def create_user(
email=body.email,
active=body.active,
)
- return users.user_model_to_user(res)
+ return CustomORJSONResponse(users.user_model_to_user(res))
-@router.post("/update")
+@router.post("/update", response_model=User)
async def update_user(
body: UpdateUserRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(GlobalAdmin()),
-) -> User:
+):
res = await users.update_user(
session=session,
username=body.username,
@@ -85,19 +88,19 @@ async def update_user(
)
if res is None:
raise ResourceNotExistsError()
- return users.user_model_to_user(res)
+ return CustomORJSONResponse(users.user_model_to_user(res))
-@router.post("/refresh_token")
+@router.post("/refresh_token", response_model=UserWithCreds)
async def refresh_token(
body: RefreshTokenRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
-) -> UserWithCreds:
+):
res = await users.refresh_user_token(session=session, user=user, username=body.username)
if res is None:
raise ResourceNotExistsError()
- return users.user_model_to_user_with_creds(res)
+ return CustomORJSONResponse(users.user_model_to_user_with_creds(res))
@router.post("/delete")
diff --git a/src/dstack/_internal/server/routers/volumes.py b/src/dstack/_internal/server/routers/volumes.py
index d4137099f..2ac503470 100644
--- a/src/dstack/_internal/server/routers/volumes.py
+++ b/src/dstack/_internal/server/routers/volumes.py
@@ -15,7 +15,10 @@
ListVolumesRequest,
)
from dstack._internal.server.security.permissions import Authenticated, ProjectMember
-from dstack._internal.server.utils.routers import get_base_api_additional_responses
+from dstack._internal.server.utils.routers import (
+ CustomORJSONResponse,
+ get_base_api_additional_responses,
+)
root_router = APIRouter(
prefix="/api/volumes",
@@ -25,12 +28,12 @@
project_router = APIRouter(prefix="/api/project/{project_name}/volumes", tags=["volumes"])
-@root_router.post("/list")
+@root_router.post("/list", response_model=List[Volume])
async def list_volumes(
body: ListVolumesRequest,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
-) -> List[Volume]:
+):
"""
Returns all volumes visible to user sorted by descending `created_at`.
`project_name` and `only_active` can be specified as filters.
@@ -38,36 +41,40 @@ async def list_volumes(
The results are paginated. To get the next page, pass `created_at` and `id` of
the last fleet from the previous page as `prev_created_at` and `prev_id`.
"""
- return await volumes_services.list_volumes(
- session=session,
- user=user,
- project_name=body.project_name,
- only_active=body.only_active,
- prev_created_at=body.prev_created_at,
- prev_id=body.prev_id,
- limit=body.limit,
- ascending=body.ascending,
+ return CustomORJSONResponse(
+ await volumes_services.list_volumes(
+ session=session,
+ user=user,
+ project_name=body.project_name,
+ only_active=body.only_active,
+ prev_created_at=body.prev_created_at,
+ prev_id=body.prev_id,
+ limit=body.limit,
+ ascending=body.ascending,
+ )
)
-@project_router.post("/list")
+@project_router.post("/list", response_model=List[Volume])
async def list_project_volumes(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> List[Volume]:
+):
"""
Returns all volumes in the project.
"""
_, project = user_project
- return await volumes_services.list_project_volumes(session=session, project=project)
+ return CustomORJSONResponse(
+ await volumes_services.list_project_volumes(session=session, project=project)
+ )
-@project_router.post("/get")
+@project_router.post("/get", response_model=Volume)
async def get_volume(
body: GetVolumeRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> Volume:
+):
"""
Returns a volume given a volume name.
"""
@@ -77,24 +84,26 @@ async def get_volume(
)
if volume is None:
raise ResourceNotExistsError()
- return volume
+ return CustomORJSONResponse(volume)
-@project_router.post("/create")
+@project_router.post("/create", response_model=Volume)
async def create_volume(
body: CreateVolumeRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
-) -> Volume:
+):
"""
Creates a volume given a volume configuration.
"""
user, project = user_project
- return await volumes_services.create_volume(
- session=session,
- project=project,
- user=user,
- configuration=body.configuration,
+ return CustomORJSONResponse(
+ await volumes_services.create_volume(
+ session=session,
+ project=project,
+ user=user,
+ configuration=body.configuration,
+ )
)
diff --git a/src/dstack/_internal/server/schemas/logs.py b/src/dstack/_internal/server/schemas/logs.py
index 267f5612f..f97d4fde3 100644
--- a/src/dstack/_internal/server/schemas/logs.py
+++ b/src/dstack/_internal/server/schemas/logs.py
@@ -9,8 +9,8 @@
class PollLogsRequest(CoreModel):
run_name: str
job_submission_id: UUID4
- start_time: Optional[datetime]
- end_time: Optional[datetime]
+ start_time: Optional[datetime] = None
+ end_time: Optional[datetime] = None
descending: bool = False
next_token: Optional[str] = None
limit: int = Field(100, ge=0, le=1000)
diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py
index 8ae875df0..844724371 100644
--- a/src/dstack/_internal/server/schemas/runs.py
+++ b/src/dstack/_internal/server/schemas/runs.py
@@ -9,12 +9,24 @@
class ListRunsRequest(CoreModel):
- project_name: Optional[str]
- repo_id: Optional[str]
- username: Optional[str]
+ project_name: Optional[str] = None
+ repo_id: Optional[str] = None
+ username: Optional[str] = None
only_active: bool = False
- prev_submitted_at: Optional[datetime]
- prev_run_id: Optional[UUID]
+ include_jobs: bool = Field(
+ True,
+ description=("Whether to include `jobs` in the response"),
+ )
+ job_submissions_limit: Optional[int] = Field(
+ None,
+ ge=0,
+ description=(
+ "Limit number of job submissions returned per job to avoid large responses."
+ "Drops older job submissions. No effect with `include_jobs: false`"
+ ),
+ )
+ prev_submitted_at: Optional[datetime] = None
+ prev_run_id: Optional[UUID] = None
limit: int = Field(100, ge=0, le=100)
ascending: bool = False
diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py
index 9925483e4..93ed23b86 100644
--- a/src/dstack/_internal/server/services/fleets.py
+++ b/src/dstack/_internal/server/services/fleets.py
@@ -1,6 +1,8 @@
import uuid
+from collections.abc import Callable
from datetime import datetime, timezone
-from typing import List, Literal, Optional, Tuple, Union, cast
+from functools import wraps
+from typing import List, Literal, Optional, Tuple, TypeVar, Union, cast
from sqlalchemy import and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,10 +15,12 @@
ResourceExistsError,
ServerClientError,
)
+from dstack._internal.core.models.common import ApplyAction, CoreModel
from dstack._internal.core.models.envs import Env
from dstack._internal.core.models.fleets import (
ApplyFleetPlanInput,
Fleet,
+ FleetConfiguration,
FleetPlan,
FleetSpec,
FleetStatus,
@@ -40,6 +44,7 @@
from dstack._internal.core.models.runs import Requirements, get_policy_map
from dstack._internal.core.models.users import GlobalRole
from dstack._internal.core.services import validate_dstack_resource_name
+from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models
from dstack._internal.server.db import get_db
from dstack._internal.server.models import (
FleetModel,
@@ -49,7 +54,10 @@
)
from dstack._internal.server.services import instances as instances_services
from dstack._internal.server.services import offers as offers_services
-from dstack._internal.server.services.instances import list_active_remote_instances
+from dstack._internal.server.services.instances import (
+ get_instance_remote_connection_info,
+ list_active_remote_instances,
+)
from dstack._internal.server.services.locking import (
get_locker,
string_to_lock_id,
@@ -178,8 +186,9 @@ async def list_project_fleet_models(
async def get_fleet(
session: AsyncSession,
project: ProjectModel,
- name: Optional[str],
- fleet_id: Optional[uuid.UUID],
+ name: Optional[str] = None,
+ fleet_id: Optional[uuid.UUID] = None,
+ include_sensitive: bool = False,
) -> Optional[Fleet]:
if fleet_id is not None:
fleet_model = await get_project_fleet_model_by_id(
@@ -193,7 +202,7 @@ async def get_fleet(
raise ServerClientError("name or id must be specified")
if fleet_model is None:
return None
- return fleet_model_to_fleet(fleet_model)
+ return fleet_model_to_fleet(fleet_model, include_sensitive=include_sensitive)
async def get_project_fleet_model_by_id(
@@ -236,23 +245,32 @@ async def get_plan(
spec: FleetSpec,
) -> FleetPlan:
# Spec must be copied by parsing to calculate merged_profile
- effective_spec = FleetSpec.parse_obj(spec.dict())
+ effective_spec = copy_model(spec)
effective_spec = await apply_plugin_policies(
user=user.name,
project=project.name,
spec=effective_spec,
)
- effective_spec = FleetSpec.parse_obj(effective_spec.dict())
- _validate_fleet_spec_and_set_defaults(spec)
+ # Spec must be copied by parsing to calculate merged_profile
+ effective_spec = copy_model(effective_spec)
+ _validate_fleet_spec_and_set_defaults(effective_spec)
+
+ action = ApplyAction.CREATE
current_fleet: Optional[Fleet] = None
current_fleet_id: Optional[uuid.UUID] = None
+
if effective_spec.configuration.name is not None:
- current_fleet_model = await get_project_fleet_model_by_name(
- session=session, project=project, name=effective_spec.configuration.name
+ current_fleet = await get_fleet(
+ session=session,
+ project=project,
+ name=effective_spec.configuration.name,
+ include_sensitive=True,
)
- if current_fleet_model is not None:
- current_fleet = fleet_model_to_fleet(current_fleet_model)
- current_fleet_id = current_fleet_model.id
+ if current_fleet is not None:
+ _set_fleet_spec_defaults(current_fleet.spec)
+ if _can_update_fleet_spec(current_fleet.spec, effective_spec):
+ action = ApplyAction.UPDATE
+ current_fleet_id = current_fleet.id
await _check_ssh_hosts_not_yet_added(session, effective_spec, current_fleet_id)
offers = []
@@ -265,7 +283,10 @@ async def get_plan(
blocks=effective_spec.configuration.blocks,
)
offers = [offer for _, offer in offers_with_backends]
+
_remove_fleet_spec_sensitive_info(effective_spec)
+ if current_fleet is not None:
+ _remove_fleet_spec_sensitive_info(current_fleet.spec)
plan = FleetPlan(
project_name=project.name,
user=user.name,
@@ -275,6 +296,7 @@ async def get_plan(
offers=offers[:50],
total_offers=len(offers),
max_offer_price=max((offer.price for offer in offers), default=None),
+ action=action,
)
return plan
@@ -327,11 +349,77 @@ async def apply_plan(
plan: ApplyFleetPlanInput,
force: bool,
) -> Fleet:
- return await create_fleet(
+ spec = await apply_plugin_policies(
+ user=user.name,
+ project=project.name,
+ spec=plan.spec,
+ )
+ # Spec must be copied by parsing to calculate merged_profile
+ spec = copy_model(spec)
+ _validate_fleet_spec_and_set_defaults(spec)
+
+ if spec.configuration.ssh_config is not None:
+ _check_can_manage_ssh_fleets(user=user, project=project)
+
+ configuration = spec.configuration
+ if configuration.name is None:
+ return await _create_fleet(
+ session=session,
+ project=project,
+ user=user,
+ spec=spec,
+ )
+
+ fleet_model = await get_project_fleet_model_by_name(
+ session=session,
+ project=project,
+ name=configuration.name,
+ )
+ if fleet_model is None:
+ return await _create_fleet(
+ session=session,
+ project=project,
+ user=user,
+ spec=spec,
+ )
+
+ instances_ids = sorted(i.id for i in fleet_model.instances if not i.deleted)
+ await session.commit()
+ async with (
+ get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, [fleet_model.id]),
+ get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids),
+ ):
+ # Refetch after lock
+ # TODO: Lock instances with FOR UPDATE?
+ res = await session.execute(
+ select(FleetModel)
+ .where(
+ FleetModel.project_id == project.id,
+ FleetModel.id == fleet_model.id,
+ FleetModel.deleted == False,
+ )
+ .options(selectinload(FleetModel.instances))
+ .options(selectinload(FleetModel.runs))
+ .execution_options(populate_existing=True)
+ .order_by(FleetModel.id) # take locks in order
+ .with_for_update(key_share=True)
+ )
+ fleet_model = res.scalars().unique().one_or_none()
+ if fleet_model is not None:
+ return await _update_fleet(
+ session=session,
+ project=project,
+ spec=spec,
+ current_resource=plan.current_resource,
+ force=force,
+ fleet_model=fleet_model,
+ )
+
+ return await _create_fleet(
session=session,
project=project,
user=user,
- spec=plan.spec,
+ spec=spec,
)
@@ -341,73 +429,19 @@ async def create_fleet(
user: UserModel,
spec: FleetSpec,
) -> Fleet:
- # Spec must be copied by parsing to calculate merged_profile
spec = await apply_plugin_policies(
user=user.name,
project=project.name,
spec=spec,
)
- spec = FleetSpec.parse_obj(spec.dict())
+ # Spec must be copied by parsing to calculate merged_profile
+ spec = copy_model(spec)
_validate_fleet_spec_and_set_defaults(spec)
if spec.configuration.ssh_config is not None:
_check_can_manage_ssh_fleets(user=user, project=project)
- lock_namespace = f"fleet_names_{project.name}"
- if get_db().dialect_name == "sqlite":
- # Start new transaction to see committed changes after lock
- await session.commit()
- elif get_db().dialect_name == "postgresql":
- await session.execute(
- select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
- )
-
- lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
- async with lock:
- if spec.configuration.name is not None:
- fleet_model = await get_project_fleet_model_by_name(
- session=session,
- project=project,
- name=spec.configuration.name,
- )
- if fleet_model is not None:
- raise ResourceExistsError()
- else:
- spec.configuration.name = await generate_fleet_name(session=session, project=project)
-
- fleet_model = FleetModel(
- id=uuid.uuid4(),
- name=spec.configuration.name,
- project=project,
- status=FleetStatus.ACTIVE,
- spec=spec.json(),
- instances=[],
- )
- session.add(fleet_model)
- if spec.configuration.ssh_config is not None:
- for i, host in enumerate(spec.configuration.ssh_config.hosts):
- instances_model = await create_fleet_ssh_instance_model(
- project=project,
- spec=spec,
- ssh_params=spec.configuration.ssh_config,
- env=spec.configuration.env,
- instance_num=i,
- host=host,
- )
- fleet_model.instances.append(instances_model)
- else:
- for i in range(_get_fleet_nodes_to_provision(spec)):
- instance_model = await create_fleet_instance_model(
- session=session,
- project=project,
- user=user,
- spec=spec,
- reservation=spec.configuration.reservation,
- instance_num=i,
- )
- fleet_model.instances.append(instance_model)
- await session.commit()
- return fleet_model_to_fleet(fleet_model)
+ return await _create_fleet(session=session, project=project, user=user, spec=spec)
async def create_fleet_instance_model(
@@ -600,6 +634,235 @@ def is_fleet_empty(fleet_model: FleetModel) -> bool:
return len(active_instances) == 0
+async def _create_fleet(
+ session: AsyncSession,
+ project: ProjectModel,
+ user: UserModel,
+ spec: FleetSpec,
+) -> Fleet:
+ lock_namespace = f"fleet_names_{project.name}"
+ if get_db().dialect_name == "sqlite":
+ # Start new transaction to see committed changes after lock
+ await session.commit()
+ elif get_db().dialect_name == "postgresql":
+ await session.execute(
+ select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
+ )
+
+ lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
+ async with lock:
+ if spec.configuration.name is not None:
+ fleet_model = await get_project_fleet_model_by_name(
+ session=session,
+ project=project,
+ name=spec.configuration.name,
+ )
+ if fleet_model is not None:
+ raise ResourceExistsError()
+ else:
+ spec.configuration.name = await generate_fleet_name(session=session, project=project)
+
+ fleet_model = FleetModel(
+ id=uuid.uuid4(),
+ name=spec.configuration.name,
+ project=project,
+ status=FleetStatus.ACTIVE,
+ spec=spec.json(),
+ instances=[],
+ )
+ session.add(fleet_model)
+ if spec.configuration.ssh_config is not None:
+ for i, host in enumerate(spec.configuration.ssh_config.hosts):
+ instances_model = await create_fleet_ssh_instance_model(
+ project=project,
+ spec=spec,
+ ssh_params=spec.configuration.ssh_config,
+ env=spec.configuration.env,
+ instance_num=i,
+ host=host,
+ )
+ fleet_model.instances.append(instances_model)
+ else:
+ for i in range(_get_fleet_nodes_to_provision(spec)):
+ instance_model = await create_fleet_instance_model(
+ session=session,
+ project=project,
+ user=user,
+ spec=spec,
+ reservation=spec.configuration.reservation,
+ instance_num=i,
+ )
+ fleet_model.instances.append(instance_model)
+ await session.commit()
+ return fleet_model_to_fleet(fleet_model)
+
+
+async def _update_fleet(
+ session: AsyncSession,
+ project: ProjectModel,
+ spec: FleetSpec,
+ current_resource: Optional[Fleet],
+ force: bool,
+ fleet_model: FleetModel,
+) -> Fleet:
+ fleet = fleet_model_to_fleet(fleet_model)
+ _set_fleet_spec_defaults(fleet.spec)
+ fleet_sensitive = fleet_model_to_fleet(fleet_model, include_sensitive=True)
+ _set_fleet_spec_defaults(fleet_sensitive.spec)
+
+ if not force:
+ if current_resource is not None:
+ _set_fleet_spec_defaults(current_resource.spec)
+ if (
+ current_resource is None
+ or current_resource.id != fleet.id
+ or current_resource.spec != fleet.spec
+ ):
+ raise ServerClientError(
+ "Failed to apply plan. Resource has been changed. Try again or use force apply."
+ )
+
+ _check_can_update_fleet_spec(fleet_sensitive.spec, spec)
+
+ spec_json = spec.json()
+ fleet_model.spec = spec_json
+
+ if (
+ fleet_sensitive.spec.configuration.ssh_config is not None
+ and spec.configuration.ssh_config is not None
+ ):
+ added_hosts, removed_hosts, changed_hosts = _calculate_ssh_hosts_changes(
+ current=fleet_sensitive.spec.configuration.ssh_config.hosts,
+ new=spec.configuration.ssh_config.hosts,
+ )
+ # `_check_can_update_fleet_spec` ensures hosts are not changed
+ assert not changed_hosts, changed_hosts
+ active_instance_nums: set[int] = set()
+ removed_instance_nums: list[int] = []
+ if removed_hosts or added_hosts:
+ for instance_model in fleet_model.instances:
+ if instance_model.deleted:
+ continue
+ active_instance_nums.add(instance_model.instance_num)
+ rci = get_instance_remote_connection_info(instance_model)
+ if rci is None:
+ logger.error(
+ "Cloud instance %s in SSH fleet %s",
+ instance_model.id,
+ fleet_model.id,
+ )
+ continue
+ if rci.host in removed_hosts:
+ removed_instance_nums.append(instance_model.instance_num)
+ if added_hosts:
+ await _check_ssh_hosts_not_yet_added(session, spec, fleet.id)
+ for host in added_hosts.values():
+ instance_num = _get_next_instance_num(active_instance_nums)
+ instance_model = await create_fleet_ssh_instance_model(
+ project=project,
+ spec=spec,
+ ssh_params=spec.configuration.ssh_config,
+ env=spec.configuration.env,
+ instance_num=instance_num,
+ host=host,
+ )
+ fleet_model.instances.append(instance_model)
+ active_instance_nums.add(instance_num)
+ if removed_instance_nums:
+ _terminate_fleet_instances(fleet_model, removed_instance_nums)
+
+ await session.commit()
+ return fleet_model_to_fleet(fleet_model)
+
+
+def _can_update_fleet_spec(current_fleet_spec: FleetSpec, new_fleet_spec: FleetSpec) -> bool:
+ try:
+ _check_can_update_fleet_spec(current_fleet_spec, new_fleet_spec)
+ except ServerClientError as e:
+ logger.debug("Run cannot be updated: %s", repr(e))
+ return False
+ return True
+
+
+M = TypeVar("M", bound=CoreModel)
+
+
+def _check_can_update(*updatable_fields: str):
+ def decorator(fn: Callable[[M, M, ModelDiff], None]) -> Callable[[M, M], None]:
+ @wraps(fn)
+ def inner(current: M, new: M):
+ diff = _check_can_update_inner(current, new, updatable_fields)
+ fn(current, new, diff)
+
+ return inner
+
+ return decorator
+
+
+def _check_can_update_inner(current: M, new: M, updatable_fields: tuple[str, ...]) -> ModelDiff:
+ diff = diff_models(current, new)
+ changed_fields = diff.keys()
+ if not (changed_fields <= set(updatable_fields)):
+ raise ServerClientError(
+ f"Failed to update fields {list(changed_fields)}."
+ f" Can only update {list(updatable_fields)}."
+ )
+ return diff
+
+
+@_check_can_update("configuration", "configuration_path")
+def _check_can_update_fleet_spec(current: FleetSpec, new: FleetSpec, diff: ModelDiff):
+ if "configuration" in diff:
+ _check_can_update_fleet_configuration(current.configuration, new.configuration)
+
+
+@_check_can_update("ssh_config")
+def _check_can_update_fleet_configuration(
+ current: FleetConfiguration, new: FleetConfiguration, diff: ModelDiff
+):
+ if "ssh_config" in diff:
+ current_ssh_config = current.ssh_config
+ new_ssh_config = new.ssh_config
+ if current_ssh_config is None:
+ if new_ssh_config is not None:
+ raise ServerClientError("Fleet type changed from Cloud to SSH, cannot update")
+ elif new_ssh_config is None:
+ raise ServerClientError("Fleet type changed from SSH to Cloud, cannot update")
+ else:
+ _check_can_update_ssh_config(current_ssh_config, new_ssh_config)
+
+
+@_check_can_update("hosts")
+def _check_can_update_ssh_config(current: SSHParams, new: SSHParams, diff: ModelDiff):
+ if "hosts" in diff:
+ _, _, changed_hosts = _calculate_ssh_hosts_changes(current.hosts, new.hosts)
+ if changed_hosts:
+ raise ServerClientError(
+ f"Hosts configuration changed, cannot update: {list(changed_hosts)}"
+ )
+
+
+def _calculate_ssh_hosts_changes(
+ current: list[Union[SSHHostParams, str]], new: list[Union[SSHHostParams, str]]
+) -> tuple[dict[str, Union[SSHHostParams, str]], set[str], set[str]]:
+ current_hosts = {h if isinstance(h, str) else h.hostname: h for h in current}
+ new_hosts = {h if isinstance(h, str) else h.hostname: h for h in new}
+ added_hosts = {h: new_hosts[h] for h in new_hosts.keys() - current_hosts}
+ removed_hosts = current_hosts.keys() - new_hosts
+ changed_hosts: set[str] = set()
+ for host in current_hosts.keys() & new_hosts:
+ current_host = current_hosts[host]
+ new_host = new_hosts[host]
+ if isinstance(current_host, str) or isinstance(new_host, str):
+ if current_host != new_host:
+ changed_hosts.add(host)
+ elif diff_models(
+ current_host, new_host, reset={"identity_file": True, "proxy_jump": {"identity_file"}}
+ ):
+ changed_hosts.add(host)
+ return added_hosts, removed_hosts, changed_hosts
+
+
def _check_can_manage_ssh_fleets(user: UserModel, project: ProjectModel):
if user.global_role == GlobalRole.ADMIN:
return
@@ -654,6 +917,8 @@ def _validate_fleet_spec_and_set_defaults(spec: FleetSpec):
validate_dstack_resource_name(spec.configuration.name)
if spec.configuration.ssh_config is None and spec.configuration.nodes is None:
raise ServerClientError("No ssh_config or nodes specified")
+ if spec.configuration.ssh_config is not None and spec.configuration.nodes is not None:
+ raise ServerClientError("ssh_config and nodes are mutually exclusive")
if spec.configuration.ssh_config is not None:
_validate_all_ssh_params_specified(spec.configuration.ssh_config)
if spec.configuration.ssh_config.ssh_key is not None:
@@ -662,6 +927,10 @@ def _validate_fleet_spec_and_set_defaults(spec: FleetSpec):
if isinstance(host, SSHHostParams) and host.ssh_key is not None:
_validate_ssh_key(host.ssh_key)
_validate_internal_ips(spec.configuration.ssh_config)
+ _set_fleet_spec_defaults(spec)
+
+
+def _set_fleet_spec_defaults(spec: FleetSpec):
if spec.configuration.resources is not None:
set_resources_defaults(spec.configuration.resources)
@@ -734,3 +1003,16 @@ def _get_fleet_requirements(fleet_spec: FleetSpec) -> Requirements:
reservation=fleet_spec.configuration.reservation,
)
return requirements
+
+
+def _get_next_instance_num(instance_nums: set[int]) -> int:
+ if not instance_nums:
+ return 0
+ min_instance_num = min(instance_nums)
+ if min_instance_num > 0:
+ return 0
+ instance_num = min_instance_num + 1
+ while True:
+ if instance_num not in instance_nums:
+ return instance_num
+ instance_num += 1
diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py
index 156459257..fb3839316 100644
--- a/src/dstack/_internal/server/services/gateways/__init__.py
+++ b/src/dstack/_internal/server/services/gateways/__init__.py
@@ -2,6 +2,7 @@
import datetime
import uuid
from datetime import timedelta, timezone
+from functools import partial
from typing import List, Optional, Sequence
import httpx
@@ -186,6 +187,7 @@ async def create_gateway(
return gateway_model_to_gateway(gateway)
+# NOTE: dstack Sky imports and uses this function
async def connect_to_gateway_with_retry(
gateway_compute: GatewayComputeModel,
) -> Optional[GatewayConnection]:
@@ -380,6 +382,8 @@ async def get_or_add_gateway_connection(
async def init_gateways(session: AsyncSession):
res = await session.execute(
select(GatewayComputeModel).where(
+ # FIXME: should not include computes related to gateways in the `provisioning` status.
+ # Causes warnings and delays when restarting the server during gateway provisioning.
GatewayComputeModel.active == True,
GatewayComputeModel.deleted == False,
)
@@ -421,7 +425,8 @@ async def init_gateways(session: AsyncSession):
for gateway_compute, error in await gather_map_async(
await gateway_connections_pool.all(),
- configure_gateway,
+ # Need several attempts to handle short gateway downtime after update
+ partial(configure_gateway, attempts=7),
return_exceptions=True,
):
if isinstance(error, Exception):
@@ -461,7 +466,11 @@ def _recently_updated(gateway_compute_model: GatewayComputeModel) -> bool:
) > get_current_datetime() - timedelta(seconds=60)
-async def configure_gateway(connection: GatewayConnection) -> None:
+# NOTE: dstack Sky imports and uses this function
+async def configure_gateway(
+ connection: GatewayConnection,
+ attempts: int = GATEWAY_CONFIGURE_ATTEMPTS,
+) -> None:
"""
Try submitting gateway config several times in case gateway's HTTP server is not
running yet
@@ -469,7 +478,7 @@ async def configure_gateway(connection: GatewayConnection) -> None:
logger.debug("Configuring gateway %s", connection.ip_address)
- for attempt in range(GATEWAY_CONFIGURE_ATTEMPTS - 1):
+ for attempt in range(attempts - 1):
try:
async with connection.client() as client:
await client.submit_gateway_config()
@@ -478,7 +487,7 @@ async def configure_gateway(connection: GatewayConnection) -> None:
logger.debug(
"Failed attempt %s/%s at configuring gateway %s: %r",
attempt + 1,
- GATEWAY_CONFIGURE_ATTEMPTS,
+ attempts,
connection.ip_address,
e,
)
diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py
index ed87bc84f..aa4b4823c 100644
--- a/src/dstack/_internal/server/services/gateways/client.py
+++ b/src/dstack/_internal/server/services/gateways/client.py
@@ -7,9 +7,9 @@
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
from dstack._internal.core.errors import GatewayError
-from dstack._internal.core.models.configurations import RateLimit
+from dstack._internal.core.models.configurations import RateLimit, ServiceConfiguration
from dstack._internal.core.models.instances import SSHConnectionParams
-from dstack._internal.core.models.runs import JobSubmission, Run
+from dstack._internal.core.models.runs import JobSpec, JobSubmission, Run, get_service_port
from dstack._internal.proxy.gateway.schemas.stats import ServiceStats
from dstack._internal.server import settings
@@ -80,13 +80,15 @@ async def unregister_service(self, project: str, run_name: str):
async def register_replica(
self,
run: Run,
+ job_spec: JobSpec,
job_submission: JobSubmission,
ssh_head_proxy: Optional[SSHConnectionParams],
ssh_head_proxy_private_key: Optional[str],
):
+ assert isinstance(run.run_spec.configuration, ServiceConfiguration)
payload = {
"job_id": job_submission.id.hex,
- "app_port": run.run_spec.configuration.port.container_port,
+ "app_port": get_service_port(job_spec, run.run_spec.configuration),
"ssh_head_proxy": ssh_head_proxy.dict() if ssh_head_proxy is not None else None,
"ssh_head_proxy_private_key": ssh_head_proxy_private_key,
}
diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py
index 53ca55396..c2dd5c2a1 100644
--- a/src/dstack/_internal/server/services/instances.py
+++ b/src/dstack/_internal/server/services/instances.py
@@ -106,6 +106,14 @@ def get_instance_requirements(instance_model: InstanceModel) -> Requirements:
return Requirements.__response__.parse_raw(instance_model.requirements)
+def get_instance_remote_connection_info(
+ instance_model: InstanceModel,
+) -> Optional[RemoteConnectionInfo]:
+ if instance_model.remote_connection_info is None:
+ return None
+ return RemoteConnectionInfo.__response__.parse_raw(instance_model.remote_connection_info)
+
+
def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, Optional[str]]:
"""
Returns a pair of SSH private keys: host key and optional proxy jump key.
diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py
index 41aa496be..8343f4069 100644
--- a/src/dstack/_internal/server/services/jobs/__init__.py
+++ b/src/dstack/_internal/server/services/jobs/__init__.py
@@ -134,6 +134,8 @@ def job_model_to_job_submission(job_model: JobModel) -> JobSubmission:
finished_at = None
if job_model.status.is_finished():
finished_at = last_processed_at
+ status_message = _get_job_status_message(job_model)
+ error = _get_job_error(job_model)
return JobSubmission(
id=job_model.id,
submission_num=job_model.submission_num,
@@ -143,11 +145,13 @@ def job_model_to_job_submission(job_model: JobModel) -> JobSubmission:
finished_at=finished_at,
inactivity_secs=job_model.inactivity_secs,
status=job_model.status,
+ status_message=status_message,
termination_reason=job_model.termination_reason,
termination_reason_message=job_model.termination_reason_message,
exit_status=job_model.exit_status,
job_provisioning_data=job_provisioning_data,
job_runtime_data=get_job_runtime_data(job_model),
+ error=error,
)
@@ -289,6 +293,19 @@ async def process_terminating_job(
# so that stuck volumes don't prevent the instance from terminating.
job_model.instance_id = None
instance_model.last_job_processed_at = common.get_current_datetime()
+
+ volume_names = (
+ jrd.volume_names
+ if jrd and jrd.volume_names
+ else [va.volume.name for va in instance_model.volume_attachments]
+ )
+ if volume_names:
+ volumes = await list_project_volume_models(
+ session=session, project=instance_model.project, names=volume_names
+ )
+ for volume in volumes:
+ volume.last_job_processed_at = common.get_current_datetime()
+
logger.info(
"%s: instance '%s' has been released, new status is %s",
fmt(job_model),
@@ -693,3 +710,31 @@ def _get_job_mount_point_attached_volume(
continue
return volume
raise ServerClientError("Failed to find an eligible volume for the mount point")
+
+
+def _get_job_status_message(job_model: JobModel) -> str:
+ if job_model.status == JobStatus.DONE:
+ return "exited (0)"
+ elif job_model.status == JobStatus.FAILED:
+ if job_model.termination_reason == JobTerminationReason.CONTAINER_EXITED_WITH_ERROR:
+ return f"exited ({job_model.exit_status})"
+ elif (
+ job_model.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
+ ):
+ return "no offers"
+ elif job_model.termination_reason == JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY:
+ return "interrupted"
+ else:
+ return "error"
+ elif job_model.status == JobStatus.TERMINATED:
+ if job_model.termination_reason == JobTerminationReason.TERMINATED_BY_USER:
+ return "stopped"
+ elif job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER:
+ return "aborted"
+ return job_model.status.value
+
+
+def _get_job_error(job_model: JobModel) -> Optional[str]:
+ if job_model.termination_reason is None:
+ return None
+ return job_model.termination_reason.to_error()
diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py
index 079e47f0b..e1fcee597 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/base.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/base.py
@@ -15,6 +15,7 @@
PortMapping,
PythonVersion,
RunConfigurationType,
+ ServiceConfiguration,
)
from dstack._internal.core.models.profiles import (
DEFAULT_STOP_DURATION,
@@ -153,6 +154,7 @@ async def _get_job_spec(
repo_data=self.run_spec.repo_data,
repo_code_hash=self.run_spec.repo_code_hash,
file_archives=self.run_spec.file_archives,
+ service_port=self._service_port(),
)
return job_spec
@@ -306,6 +308,11 @@ def _ssh_key(self, jobs_per_replica: int) -> Optional[JobSSHKey]:
)
return self._job_ssh_key
+ def _service_port(self) -> Optional[int]:
+ if isinstance(self.run_spec.configuration, ServiceConfiguration):
+ return self.run_spec.configuration.port.container_port
+ return None
+
def interpolate_job_volumes(
run_volumes: List[Union[MountPoint, str]],
diff --git a/src/dstack/_internal/server/services/locking.py b/src/dstack/_internal/server/services/locking.py
index 37807b37a..4c3b7f938 100644
--- a/src/dstack/_internal/server/services/locking.py
+++ b/src/dstack/_internal/server/services/locking.py
@@ -172,7 +172,7 @@ async def _wait_to_lock_many(
The keys must be sorted to prevent deadlock.
"""
left_to_lock = keys.copy()
- while len(left_to_lock) > 0:
+ while True:
async with lock:
locked_now_num = 0
for key in left_to_lock:
@@ -182,4 +182,6 @@ async def _wait_to_lock_many(
locked.add(key)
locked_now_num += 1
left_to_lock = left_to_lock[locked_now_num:]
+ if not left_to_lock:
+ return
await asyncio.sleep(delay)
diff --git a/src/dstack/_internal/server/services/logging.py b/src/dstack/_internal/server/services/logging.py
index 1f2d106a5..545067d6a 100644
--- a/src/dstack/_internal/server/services/logging.py
+++ b/src/dstack/_internal/server/services/logging.py
@@ -1,12 +1,14 @@
from typing import Union
-from dstack._internal.server.models import JobModel, RunModel
+from dstack._internal.server.models import GatewayModel, JobModel, RunModel
-def fmt(model: Union[RunModel, JobModel]) -> str:
+def fmt(model: Union[RunModel, JobModel, GatewayModel]) -> str:
"""Consistent string representation of a model for logging."""
if isinstance(model, RunModel):
return f"run({model.id.hex[:6]}){model.run_name}"
if isinstance(model, JobModel):
return f"job({model.id.hex[:6]}){model.job_name}"
+ if isinstance(model, GatewayModel):
+ return f"gateway({model.id.hex[:6]}){model.name}"
return str(model)
diff --git a/src/dstack/_internal/server/services/logs/__init__.py b/src/dstack/_internal/server/services/logs/__init__.py
index a3623a19d..b38264980 100644
--- a/src/dstack/_internal/server/services/logs/__init__.py
+++ b/src/dstack/_internal/server/services/logs/__init__.py
@@ -8,7 +8,11 @@
from dstack._internal.server.schemas.logs import PollLogsRequest
from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent
from dstack._internal.server.services.logs.aws import BOTO_AVAILABLE, CloudWatchLogStorage
-from dstack._internal.server.services.logs.base import LogStorage, LogStorageError
+from dstack._internal.server.services.logs.base import (
+ LogStorage,
+ LogStorageError,
+ b64encode_raw_message,
+)
from dstack._internal.server.services.logs.filelog import FileLogStorage
from dstack._internal.server.services.logs.gcp import GCP_LOGGING_AVAILABLE, GCPLogStorage
from dstack._internal.utils.common import run_async
@@ -75,4 +79,13 @@ def write_logs(
async def poll_logs_async(project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
- return await run_async(get_log_storage().poll_logs, project=project, request=request)
+ job_submission_logs = await run_async(
+ get_log_storage().poll_logs, project=project, request=request
+ )
+ # Logs are stored in plaintext but transmitted in base64 for API/CLI backward compatibility.
+ # Old logs stored in base64 are encoded twice for transmission and shown as base64 in CLI/UI.
+ # We live with that.
+ # TODO: Drop base64 encoding in 0.20.
+ for log_event in job_submission_logs.logs:
+ log_event.message = b64encode_raw_message(log_event.message.encode())
+ return job_submission_logs
diff --git a/src/dstack/_internal/server/services/logs/aws.py b/src/dstack/_internal/server/services/logs/aws.py
index 92155763c..616db94db 100644
--- a/src/dstack/_internal/server/services/logs/aws.py
+++ b/src/dstack/_internal/server/services/logs/aws.py
@@ -17,7 +17,6 @@
from dstack._internal.server.services.logs.base import (
LogStorage,
LogStorageError,
- b64encode_raw_message,
datetime_to_unix_time_ms,
unix_time_ms_to_datetime,
)
@@ -238,8 +237,7 @@ def _get_next_batch(
skipped_future_events += 1
continue
cw_event = self._runner_log_event_to_cloudwatch_event(event)
- # as message is base64-encoded, length in bytes = length in code points.
- message_size = len(cw_event["message"]) + self.MESSAGE_OVERHEAD_SIZE
+ message_size = len(event.message) + self.MESSAGE_OVERHEAD_SIZE
if message_size > self.MESSAGE_MAX_SIZE:
# we should never hit this limit, as we use `io.Copy` to copy from pty to logs,
# which under the hood uses 32KiB buffer, see runner/internal/executor/executor.go,
@@ -271,7 +269,7 @@ def _runner_log_event_to_cloudwatch_event(
) -> _CloudWatchLogEvent:
return {
"timestamp": runner_log_event.timestamp,
- "message": b64encode_raw_message(runner_log_event.message),
+ "message": runner_log_event.message.decode(errors="replace"),
}
@contextmanager
diff --git a/src/dstack/_internal/server/services/logs/filelog.py b/src/dstack/_internal/server/services/logs/filelog.py
index 905ee3527..10cbe3d1a 100644
--- a/src/dstack/_internal/server/services/logs/filelog.py
+++ b/src/dstack/_internal/server/services/logs/filelog.py
@@ -2,6 +2,7 @@
from typing import List, Union
from uuid import UUID
+from dstack._internal.core.errors import ServerClientError
from dstack._internal.core.models.logs import (
JobSubmissionLogs,
LogEvent,
@@ -14,8 +15,6 @@
from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent
from dstack._internal.server.services.logs.base import (
LogStorage,
- LogStorageError,
- b64encode_raw_message,
unix_time_ms_to_datetime,
)
@@ -30,9 +29,6 @@ def __init__(self, root: Union[Path, str, None] = None) -> None:
self.root = Path(root)
def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
- if request.descending:
- raise LogStorageError("descending: true is not supported")
-
log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB
log_file_path = self._get_log_file_path(
project_name=project.name,
@@ -46,11 +42,11 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi
try:
start_line = int(request.next_token)
if start_line < 0:
- raise LogStorageError(
+ raise ServerClientError(
f"Invalid next_token: {request.next_token}. Must be a non-negative integer."
)
except ValueError:
- raise LogStorageError(
+ raise ServerClientError(
f"Invalid next_token: {request.next_token}. Must be a valid integer."
)
@@ -60,31 +56,41 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi
try:
with open(log_file_path) as f:
- lines = f.readlines()
-
- for i, line in enumerate(lines):
- if current_line < start_line:
+ # Skip to start_line if needed
+ for _ in range(start_line):
+ if f.readline() == "":
+ # File is shorter than start_line
+ return JobSubmissionLogs(logs=logs, next_token=next_token)
current_line += 1
- continue
- log_event = LogEvent.__response__.parse_raw(line)
- current_line += 1
+ # Read lines one by one
+ while True:
+ line = f.readline()
+ if line == "": # EOF
+ break
+
+ current_line += 1
- if request.start_time and log_event.timestamp <= request.start_time:
- continue
- if request.end_time is not None and log_event.timestamp >= request.end_time:
- break
+ try:
+ log_event = LogEvent.__response__.parse_raw(line)
+ except Exception:
+ # Skip malformed lines
+ continue
- logs.append(log_event)
+ if request.start_time and log_event.timestamp <= request.start_time:
+ continue
+ if request.end_time is not None and log_event.timestamp >= request.end_time:
+ break
- if len(logs) >= request.limit:
- # Only set next_token if there are more lines to read
- if current_line < len(lines):
- next_token = str(current_line)
- break
+ logs.append(log_event)
- except IOError as e:
- raise LogStorageError(f"Failed to read log file {log_file_path}: {e}")
+ if len(logs) >= request.limit:
+ # Check if there are more lines to read
+ if f.readline() != "":
+ next_token = str(current_line)
+ break
+ except FileNotFoundError:
+ pass
return JobSubmissionLogs(logs=logs, next_token=next_token)
@@ -140,5 +146,5 @@ def _runner_log_event_to_log_event(self, runner_log_event: RunnerLogEvent) -> Lo
return LogEvent(
timestamp=unix_time_ms_to_datetime(runner_log_event.timestamp),
log_source=LogEventSource.STDOUT,
- message=b64encode_raw_message(runner_log_event.message),
+ message=runner_log_event.message.decode(errors="replace"),
)
diff --git a/src/dstack/_internal/server/services/logs/gcp.py b/src/dstack/_internal/server/services/logs/gcp.py
index ac228e19e..6e9314df2 100644
--- a/src/dstack/_internal/server/services/logs/gcp.py
+++ b/src/dstack/_internal/server/services/logs/gcp.py
@@ -14,7 +14,6 @@
from dstack._internal.server.services.logs.base import (
LogStorage,
LogStorageError,
- b64encode_raw_message,
unix_time_ms_to_datetime,
)
from dstack._internal.utils.common import batched
@@ -137,15 +136,14 @@ def _write_logs_to_stream(self, stream_name: str, logs: List[RunnerLogEvent]):
with self.logger.batch() as batcher:
for batch in batched(logs, self.MAX_BATCH_SIZE):
for log in batch:
- message = b64encode_raw_message(log.message)
+ message = log.message.decode(errors="replace")
timestamp = unix_time_ms_to_datetime(log.timestamp)
- # as message is base64-encoded, length in bytes = length in code points
- if len(message) > self.MAX_RUNNER_MESSAGE_SIZE:
+ if len(log.message) > self.MAX_RUNNER_MESSAGE_SIZE:
logger.error(
"Stream %s: skipping event at %s, message exceeds max size: %d > %d",
stream_name,
timestamp.isoformat(),
- len(message),
+ len(log.message),
self.MAX_RUNNER_MESSAGE_SIZE,
)
continue
diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py
index 4578cb56f..7c21e840b 100644
--- a/src/dstack/_internal/server/services/proxy/repo.py
+++ b/src/dstack/_internal/server/services/proxy/repo.py
@@ -12,10 +12,12 @@
from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
from dstack._internal.core.models.runs import (
JobProvisioningData,
+ JobSpec,
JobStatus,
RunSpec,
RunStatus,
ServiceSpec,
+ get_service_port,
)
from dstack._internal.core.models.services import AnyModel
from dstack._internal.proxy.lib.models import (
@@ -97,9 +99,10 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
if rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
+ job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
replica = Replica(
id=job.id.hex,
- app_port=run_spec.configuration.port.container_port,
+ app_port=get_service_port(job_spec, run_spec.configuration),
ssh_destination=ssh_destination,
ssh_port=ssh_port,
ssh_proxy=ssh_proxy,
diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py
index 79d9e0a20..33b7c8299 100644
--- a/src/dstack/_internal/server/services/runs.py
+++ b/src/dstack/_internal/server/services/runs.py
@@ -24,6 +24,7 @@
)
from dstack._internal.core.models.profiles import (
CreationPolicy,
+ RetryEvent,
)
from dstack._internal.core.models.repos.virtual import DEFAULT_VIRTUAL_REPO_ID, VirtualRunRepoData
from dstack._internal.core.models.runs import (
@@ -105,6 +106,8 @@ async def list_user_runs(
repo_id: Optional[str],
username: Optional[str],
only_active: bool,
+ include_jobs: bool,
+ job_submissions_limit: Optional[int],
prev_submitted_at: Optional[datetime],
prev_run_id: Optional[uuid.UUID],
limit: int,
@@ -148,7 +151,14 @@ async def list_user_runs(
runs = []
for r in run_models:
try:
- runs.append(run_model_to_run(r, return_in_api=True))
+ runs.append(
+ run_model_to_run(
+ r,
+ return_in_api=True,
+ include_jobs=include_jobs,
+ job_submissions_limit=job_submissions_limit,
+ )
+ )
except pydantic.ValidationError:
pass
if len(run_models) > len(runs):
@@ -652,51 +662,33 @@ async def delete_runs(
def run_model_to_run(
run_model: RunModel,
- include_job_submissions: bool = True,
+ include_jobs: bool = True,
+ job_submissions_limit: Optional[int] = None,
return_in_api: bool = False,
include_sensitive: bool = False,
) -> Run:
jobs: List[Job] = []
- run_jobs = sorted(run_model.jobs, key=lambda j: (j.replica_num, j.job_num, j.submission_num))
- for replica_num, replica_submissions in itertools.groupby(
- run_jobs, key=lambda j: j.replica_num
- ):
- for job_num, job_submissions in itertools.groupby(
- replica_submissions, key=lambda j: j.job_num
- ):
- submissions = []
- job_model = None
- for job_model in job_submissions:
- if include_job_submissions:
- job_submission = job_model_to_job_submission(job_model)
- if return_in_api:
- # Set default non-None values for 0.18 backward-compatibility
- # Remove in 0.19
- if job_submission.job_provisioning_data is not None:
- if job_submission.job_provisioning_data.hostname is None:
- job_submission.job_provisioning_data.hostname = ""
- if job_submission.job_provisioning_data.ssh_port is None:
- job_submission.job_provisioning_data.ssh_port = 22
- submissions.append(job_submission)
- if job_model is not None:
- # Use the spec from the latest submission. Submissions can have different specs
- job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
- if not include_sensitive:
- _remove_job_spec_sensitive_info(job_spec)
- jobs.append(Job(job_spec=job_spec, job_submissions=submissions))
+ if include_jobs:
+ jobs = _get_run_jobs_with_submissions(
+ run_model=run_model,
+ job_submissions_limit=job_submissions_limit,
+ return_in_api=return_in_api,
+ include_sensitive=include_sensitive,
+ )
run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
latest_job_submission = None
- if include_job_submissions:
+ if len(jobs) > 0 and len(jobs[0].job_submissions) > 0:
# TODO(egor-s): does it make sense with replicas and multi-node?
- if jobs:
- latest_job_submission = jobs[0].job_submissions[-1]
+ latest_job_submission = jobs[0].job_submissions[-1]
service_spec = None
if run_model.service_spec is not None:
service_spec = ServiceSpec.__response__.parse_raw(run_model.service_spec)
+ status_message = _get_run_status_message(run_model)
+ error = _get_run_error(run_model)
run = Run(
id=run_model.id,
project_name=run_model.project.name,
@@ -704,18 +696,107 @@ def run_model_to_run(
submitted_at=run_model.submitted_at.replace(tzinfo=timezone.utc),
last_processed_at=run_model.last_processed_at.replace(tzinfo=timezone.utc),
status=run_model.status,
+ status_message=status_message,
termination_reason=run_model.termination_reason,
run_spec=run_spec,
jobs=jobs,
latest_job_submission=latest_job_submission,
service=service_spec,
deployment_num=run_model.deployment_num,
+ error=error,
deleted=run_model.deleted,
)
run.cost = _get_run_cost(run)
return run
+def _get_run_jobs_with_submissions(
+ run_model: RunModel,
+ job_submissions_limit: Optional[int],
+ return_in_api: bool = False,
+ include_sensitive: bool = False,
+) -> List[Job]:
+ jobs: List[Job] = []
+ run_jobs = sorted(run_model.jobs, key=lambda j: (j.replica_num, j.job_num, j.submission_num))
+ for replica_num, replica_submissions in itertools.groupby(
+ run_jobs, key=lambda j: j.replica_num
+ ):
+ for job_num, job_models in itertools.groupby(replica_submissions, key=lambda j: j.job_num):
+ submissions = []
+ job_model = None
+ if job_submissions_limit is not None:
+ if job_submissions_limit == 0:
+ # Take latest job submission to return its job_spec
+ job_models = list(job_models)[-1:]
+ else:
+ job_models = list(job_models)[-job_submissions_limit:]
+ for job_model in job_models:
+ if job_submissions_limit != 0:
+ job_submission = job_model_to_job_submission(job_model)
+ if return_in_api:
+ # Set default non-None values for 0.18 backward-compatibility
+ # Remove in 0.19
+ if job_submission.job_provisioning_data is not None:
+ if job_submission.job_provisioning_data.hostname is None:
+ job_submission.job_provisioning_data.hostname = ""
+ if job_submission.job_provisioning_data.ssh_port is None:
+ job_submission.job_provisioning_data.ssh_port = 22
+ submissions.append(job_submission)
+ if job_model is not None:
+ # Use the spec from the latest submission. Submissions can have different specs
+ job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
+ if not include_sensitive:
+ _remove_job_spec_sensitive_info(job_spec)
+ jobs.append(Job(job_spec=job_spec, job_submissions=submissions))
+ return jobs
+
+
+def _get_run_status_message(run_model: RunModel) -> str:
+ if len(run_model.jobs) == 0:
+ return run_model.status.value
+
+ sorted_job_models = sorted(
+ run_model.jobs, key=lambda j: (j.replica_num, j.job_num, j.submission_num)
+ )
+ job_models_grouped_by_job = list(
+ list(jm)
+ for _, jm in itertools.groupby(sorted_job_models, key=lambda j: (j.replica_num, j.job_num))
+ )
+
+ if all(job_models[-1].status == JobStatus.PULLING for job_models in job_models_grouped_by_job):
+ # Show `pulling`` if last job submission of all jobs is pulling
+ return "pulling"
+
+ if run_model.status in [RunStatus.SUBMITTED, RunStatus.PENDING]:
+ # Show `retrying` if any job caused the run to retry
+ for job_models in job_models_grouped_by_job:
+ last_job_spec = JobSpec.__response__.parse_raw(job_models[-1].job_spec_data)
+ retry_on_events = last_job_spec.retry.on_events if last_job_spec.retry else []
+ last_job_termination_reason = _get_last_job_termination_reason(job_models)
+ if (
+ last_job_termination_reason
+ == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY
+ and RetryEvent.NO_CAPACITY in retry_on_events
+ ):
+ # TODO: Show `retrying` for other retry events
+ return "retrying"
+
+ return run_model.status.value
+
+
+def _get_last_job_termination_reason(job_models: List[JobModel]) -> Optional[JobTerminationReason]:
+ for job_model in reversed(job_models):
+ if job_model.termination_reason is not None:
+ return job_model.termination_reason
+ return None
+
+
+def _get_run_error(run_model: RunModel) -> Optional[str]:
+ if run_model.termination_reason is None:
+ return None
+ return run_model.termination_reason.to_error()
+
+
async def _get_pool_offers(
session: AsyncSession,
project: ProjectModel,
@@ -914,6 +995,8 @@ def _validate_run_spec_and_set_defaults(run_spec: RunSpec):
"replicas",
"scaling",
# rolling deployment
+ # NOTE: keep this list in sync with the "Rolling deployment" section in services.md
+ "port",
"resources",
"volumes",
"docker",
diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py
index 062390b9a..ba3168c59 100644
--- a/src/dstack/_internal/server/services/services/__init__.py
+++ b/src/dstack/_internal/server/services/services/__init__.py
@@ -22,7 +22,7 @@
from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT, ServiceConfiguration
from dstack._internal.core.models.gateways import GatewayConfiguration, GatewayStatus
from dstack._internal.core.models.instances import SSHConnectionParams
-from dstack._internal.core.models.runs import Run, RunSpec, ServiceModelSpec, ServiceSpec
+from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec
from dstack._internal.server import settings
from dstack._internal.server.models import GatewayModel, JobModel, ProjectModel, RunModel
from dstack._internal.server.services.gateways import (
@@ -179,6 +179,7 @@ async def register_replica(
async with conn.client() as client:
await client.register_replica(
run=run,
+ job_spec=JobSpec.__response__.parse_raw(job_model.job_spec_data),
job_submission=job_submission,
ssh_head_proxy=ssh_head_proxy,
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py
index 7aaf4b979..8504f2af7 100644
--- a/src/dstack/_internal/server/services/users.py
+++ b/src/dstack/_internal/server/services/users.py
@@ -44,7 +44,9 @@ async def list_users_for_user(
session: AsyncSession,
user: UserModel,
) -> List[User]:
- return await list_all_users(session=session)
+ if user.global_role == GlobalRole.ADMIN:
+ return await list_all_users(session=session)
+ return [user_model_to_user(user)]
async def list_all_users(
diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py
index 9a7ed53d3..be43e02ce 100644
--- a/src/dstack/_internal/server/services/volumes.py
+++ b/src/dstack/_internal/server/services/volumes.py
@@ -401,6 +401,19 @@ def _validate_volume_configuration(configuration: VolumeConfiguration):
if configuration.name is not None:
validate_dstack_resource_name(configuration.name)
+ if configuration.volume_id is not None and configuration.auto_cleanup_duration is not None:
+ if (
+ isinstance(configuration.auto_cleanup_duration, int)
+ and configuration.auto_cleanup_duration > 0
+ ) or (
+ isinstance(configuration.auto_cleanup_duration, str)
+ and configuration.auto_cleanup_duration not in ("off", "-1")
+ ):
+ raise ServerClientError(
+ "External volumes (with volume_id) do not support auto_cleanup_duration. "
+ "Auto-cleanup only works for volumes created and managed by dstack."
+ )
+
async def _delete_volume(session: AsyncSession, project: ProjectModel, volume_model: VolumeModel):
volume = volume_model_to_volume(volume_model)
diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py
index a60fc5e65..f025d7164 100644
--- a/src/dstack/_internal/server/settings.py
+++ b/src/dstack/_internal/server/settings.py
@@ -42,6 +42,11 @@
os.getenv("DSTACK_SERVER_BACKGROUND_PROCESSING_FACTOR", 1)
)
+SERVER_BACKGROUND_PROCESSING_DISABLED = (
+ os.getenv("DSTACK_SERVER_BACKGROUND_PROCESSING_DISABLED") is not None
+)
+SERVER_BACKGROUND_PROCESSING_ENABLED = not SERVER_BACKGROUND_PROCESSING_DISABLED
+
SERVER_EXECUTOR_MAX_WORKERS = int(os.getenv("DSTACK_SERVER_EXECUTOR_MAX_WORKERS", 128))
MAX_OFFERS_TRIED = int(os.getenv("DSTACK_SERVER_MAX_OFFERS_TRIED", 25))
@@ -113,5 +118,5 @@
UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_UPDATE_DEFAULT_PROJECT") is not None
DO_NOT_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None
-SKIP_GATEWAY_UPDATE = os.getenv("DSTACK_SKIP_GATEWAY_UPDATE", None) is not None
-ENABLE_PROMETHEUS_METRICS = os.getenv("DSTACK_ENABLE_PROMETHEUS_METRICS", None) is not None
+SKIP_GATEWAY_UPDATE = os.getenv("DSTACK_SKIP_GATEWAY_UPDATE") is not None
+ENABLE_PROMETHEUS_METRICS = os.getenv("DSTACK_ENABLE_PROMETHEUS_METRICS") is not None
diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py
index 047adb5c1..02d3ac242 100644
--- a/src/dstack/_internal/server/testing/common.py
+++ b/src/dstack/_internal/server/testing/common.py
@@ -31,6 +31,8 @@
FleetSpec,
FleetStatus,
InstanceGroupPlacement,
+ SSHHostParams,
+ SSHParams,
)
from dstack._internal.core.models.gateways import GatewayComputeConfiguration, GatewayStatus
from dstack._internal.core.models.instances import (
@@ -378,6 +380,7 @@ def get_job_provisioning_data(
hostname: str = "127.0.0.4",
internal_ip: Optional[str] = "127.0.0.4",
price: float = 10.5,
+ instance_type: Optional[InstanceType] = None,
) -> JobProvisioningData:
gpus = [
Gpu(
@@ -386,14 +389,16 @@ def get_job_provisioning_data(
vendor=gpuhunt.AcceleratorVendor.NVIDIA,
)
] * gpu_count
- return JobProvisioningData(
- backend=backend,
- instance_type=InstanceType(
+ if instance_type is None:
+ instance_type = InstanceType(
name="instance",
resources=Resources(
cpus=cpu_count, memory_mib=int(memory_gib * 1024), spot=spot, gpus=gpus
),
- ),
+ )
+ return JobProvisioningData(
+ backend=backend,
+ instance_type=instance_type,
instance_id="instance_id",
hostname=hostname,
internal_ip=internal_ip,
@@ -549,6 +554,31 @@ def get_fleet_configuration(
)
+def get_ssh_fleet_configuration(
+ name: str = "test-fleet",
+ user: str = "ubuntu",
+ ssh_key: Optional[SSHKey] = None,
+ hosts: Optional[list[Union[SSHHostParams, str]]] = None,
+ network: Optional[str] = None,
+ placement: Optional[InstanceGroupPlacement] = None,
+) -> FleetConfiguration:
+ if ssh_key is None:
+ ssh_key = SSHKey(public="", private=get_private_key_string())
+ if hosts is None:
+ hosts = ["10.0.0.100"]
+ ssh_config = SSHParams(
+ user=user,
+ ssh_key=ssh_key,
+ hosts=hosts,
+ network=network,
+ )
+ return FleetConfiguration(
+ name=name,
+ ssh_config=ssh_config,
+ placement=placement,
+ )
+
+
async def create_instance(
session: AsyncSession,
project: ProjectModel,
@@ -590,7 +620,9 @@ async def create_instance(
internal_ip=None,
)
if offer == "auto":
- offer = get_instance_offer_with_availability(backend=backend, region=region, spot=spot)
+ offer = get_instance_offer_with_availability(
+ backend=backend, region=region, spot=spot, price=price
+ )
if profile is None:
profile = Profile(name="test_name")
@@ -742,6 +774,7 @@ async def create_volume(
status: VolumeStatus = VolumeStatus.SUBMITTED,
created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
last_processed_at: Optional[datetime] = None,
+ last_job_processed_at: Optional[datetime] = None,
configuration: Optional[VolumeConfiguration] = None,
volume_provisioning_data: Optional[VolumeProvisioningData] = None,
deleted_at: Optional[datetime] = None,
@@ -759,6 +792,7 @@ async def create_volume(
status=status,
created_at=created_at,
last_processed_at=last_processed_at,
+ last_job_processed_at=last_job_processed_at,
configuration=configuration.json(),
volume_provisioning_data=volume_provisioning_data.json()
if volume_provisioning_data
@@ -820,6 +854,7 @@ def get_volume_configuration(
region: str = "eu-west-1",
size: Optional[Memory] = Memory(100),
volume_id: Optional[str] = None,
+ auto_cleanup_duration: Optional[Union[str, int]] = None,
) -> VolumeConfiguration:
return VolumeConfiguration(
name=name,
@@ -827,6 +862,7 @@ def get_volume_configuration(
region=region,
size=size,
volume_id=volume_id,
+ auto_cleanup_duration=auto_cleanup_duration,
)
diff --git a/src/dstack/_internal/server/utils/routers.py b/src/dstack/_internal/server/utils/routers.py
index f8ca21004..131ec5cc3 100644
--- a/src/dstack/_internal/server/utils/routers.py
+++ b/src/dstack/_internal/server/utils/routers.py
@@ -1,11 +1,34 @@
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
-from fastapi import HTTPException, Request, status
-from fastapi.responses import JSONResponse
+import orjson
+from fastapi import HTTPException, Request, Response, status
from packaging import version
from dstack._internal.core.errors import ServerClientError, ServerClientErrorCode
from dstack._internal.core.models.common import CoreModel
+from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default
+
+
+class CustomORJSONResponse(Response):
+ """
+ Custom JSONResponse that uses orjson for serialization.
+
+ It's recommended to return this class from routers directly instead of
+ returning pydantic models to avoid the FastAPI's jsonable_encoder overhead.
+ See https://fastapi.tiangolo.com/advanced/custom-response/#use-orjsonresponse.
+
+ Beware that FastAPI skips model validation when responses are returned directly.
+ If serialization needs to be modified, override `dict()` instead of adding validators.
+ """
+
+ media_type = "application/json"
+
+ def render(self, content: Any) -> bytes:
+ return orjson.dumps(
+ content,
+ option=get_orjson_default_options(),
+ default=orjson_default,
+ )
class BadRequestDetailsModel(CoreModel):
@@ -30,7 +53,7 @@ def get_base_api_additional_responses() -> Dict:
"""
Returns additional responses for the OpenAPI docs relevant to all API endpoints.
The endpoints may override responses to make them as specific as possible.
- E.g. an enpoint may specify which error codes it may return in `code`.
+ E.g. an endpoint may specify which error codes it may return in `code`.
"""
return {
400: get_bad_request_additional_response(),
@@ -102,7 +125,7 @@ def get_request_size(request: Request) -> int:
def check_client_server_compatibility(
client_version: Optional[str],
server_version: Optional[str],
-) -> Optional[JSONResponse]:
+) -> Optional[CustomORJSONResponse]:
"""
Returns `JSONResponse` with error if client/server versions are incompatible.
Returns `None` otherwise.
@@ -116,7 +139,7 @@ def check_client_server_compatibility(
try:
parsed_client_version = version.parse(client_version)
except version.InvalidVersion:
- return JSONResponse(
+ return CustomORJSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"detail": get_server_client_error_details(
@@ -138,11 +161,11 @@ def error_incompatible_versions(
client_version: Optional[str],
server_version: str,
ask_cli_update: bool,
-) -> JSONResponse:
+) -> CustomORJSONResponse:
msg = f"The client/CLI version ({client_version}) is incompatible with the server version ({server_version})."
if ask_cli_update:
msg += f" Update the dstack CLI: `pip install dstack=={server_version}`."
- return JSONResponse(
+ return CustomORJSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": get_server_client_error_details(ServerClientError(msg=msg))},
)
diff --git a/src/dstack/_internal/utils/json_utils.py b/src/dstack/_internal/utils/json_utils.py
new file mode 100644
index 000000000..9017e94c3
--- /dev/null
+++ b/src/dstack/_internal/utils/json_utils.py
@@ -0,0 +1,54 @@
+from typing import Any
+
+import orjson
+from pydantic import BaseModel
+
+FREEZEGUN = True
+try:
+ from freezegun.api import FakeDatetime
+except ImportError:
+ FREEZEGUN = False
+
+
+ASYNCPG = True
+try:
+ import asyncpg.pgproto.pgproto
+except ImportError:
+ ASYNCPG = False
+
+
+def pydantic_orjson_dumps(v: Any, *, default: Any) -> str:
+ return orjson.dumps(
+ v,
+ option=get_orjson_default_options(),
+ default=orjson_default,
+ ).decode()
+
+
+def pydantic_orjson_dumps_with_indent(v: Any, *, default: Any) -> str:
+ return orjson.dumps(
+ v,
+ option=get_orjson_default_options() | orjson.OPT_INDENT_2,
+ default=orjson_default,
+ ).decode()
+
+
+def orjson_default(obj):
+ if isinstance(obj, float):
+ # orjson does not convert float subclasses be default
+ return float(obj)
+ if isinstance(obj, BaseModel):
+ # Allows calling orjson.dumps() on pydantic models
+ # (e.g. to return from the API)
+ return obj.dict()
+ if ASYNCPG:
+ if isinstance(obj, asyncpg.pgproto.pgproto.UUID):
+ return str(obj)
+ if FREEZEGUN:
+ if isinstance(obj, FakeDatetime):
+ return obj.isoformat()
+ raise TypeError
+
+
+def get_orjson_default_options() -> int:
+ return orjson.OPT_NON_STR_KEYS
diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py
index 1a4e0e1e2..e1992068d 100644
--- a/src/dstack/api/_public/runs.py
+++ b/src/dstack/api/_public/runs.py
@@ -18,7 +18,11 @@
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_RUNNER_SSH_PORT
from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError
from dstack._internal.core.models.backends.base import BackendType
-from dstack._internal.core.models.configurations import AnyRunConfiguration, PortMapping
+from dstack._internal.core.models.configurations import (
+ AnyRunConfiguration,
+ PortMapping,
+ ServiceConfiguration,
+)
from dstack._internal.core.models.files import FileArchiveMapping, FilePathMapping
from dstack._internal.core.models.profiles import (
CreationPolicy,
@@ -38,6 +42,7 @@
RunPlan,
RunSpec,
RunStatus,
+ get_service_port,
)
from dstack._internal.core.models.runs import Run as RunModel
from dstack._internal.core.services.logs import URLReplacer
@@ -163,7 +168,7 @@ def ws_thread():
service_port = 443 if secure else 80
ports = {
**ports,
- self._run.run_spec.configuration.port.container_port: service_port,
+ get_or_error(get_or_error(self._ssh_attach).service_port): service_port,
}
path_prefix = url.path
replace_urls = URLReplacer(
@@ -338,6 +343,10 @@ def attach(
else:
container_user = "root"
+ service_port = None
+ if isinstance(self._run.run_spec.configuration, ServiceConfiguration):
+ service_port = get_service_port(job.job_spec, self._run.run_spec.configuration)
+
self._ssh_attach = SSHAttach(
hostname=provisioning_data.hostname,
ssh_port=provisioning_data.ssh_port,
@@ -349,6 +358,7 @@ def attach(
run_name=name,
dockerized=provisioning_data.dockerized,
ssh_proxy=provisioning_data.ssh_proxy,
+ service_port=service_port,
local_backend=provisioning_data.backend == BackendType.LOCAL,
bind_address=bind_address,
)
@@ -748,6 +758,7 @@ def list(self, all: bool = False, limit: Optional[int] = None) -> List[Run]:
repo_id=None,
only_active=only_active,
limit=limit or 100,
+ # TODO: Pass job_submissions_limit=1 in 0.20
)
if only_active and len(runs) == 0:
runs = self._api_client.runs.list(
diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py
index 2c85792eb..745ce9c78 100644
--- a/src/dstack/api/server/_runs.py
+++ b/src/dstack/api/server/_runs.py
@@ -4,7 +4,11 @@
from pydantic import parse_obj_as
-from dstack._internal.core.compatibility.runs import get_apply_plan_excludes, get_get_plan_excludes
+from dstack._internal.core.compatibility.runs import (
+ get_apply_plan_excludes,
+ get_get_plan_excludes,
+ get_list_runs_excludes,
+)
from dstack._internal.core.models.runs import (
ApplyRunPlanInput,
Run,
@@ -33,18 +37,24 @@ def list(
prev_run_id: Optional[UUID] = None,
limit: int = 100,
ascending: bool = False,
+ include_jobs: bool = True,
+ job_submissions_limit: Optional[int] = None,
) -> List[Run]:
body = ListRunsRequest(
project_name=project_name,
repo_id=repo_id,
username=username,
only_active=only_active,
+ include_jobs=include_jobs,
+ job_submissions_limit=job_submissions_limit,
prev_submitted_at=prev_submitted_at,
prev_run_id=prev_run_id,
limit=limit,
ascending=ascending,
)
- resp = self._request("/api/runs/list", body=body.json())
+ resp = self._request(
+ "/api/runs/list", body=body.json(exclude=get_list_runs_excludes(body))
+ )
return parse_obj_as(List[Run.__response__], resp.json())
def get(self, project_name: str, run_name: str) -> Run:
diff --git a/src/tests/_internal/core/models/test_runs.py b/src/tests/_internal/core/models/test_runs.py
index 851cba9e3..23b27c018 100644
--- a/src/tests/_internal/core/models/test_runs.py
+++ b/src/tests/_internal/core/models/test_runs.py
@@ -1,9 +1,7 @@
from dstack._internal.core.models.profiles import RetryEvent
from dstack._internal.core.models.runs import (
JobStatus,
- JobSubmission,
JobTerminationReason,
- Run,
RunStatus,
RunTerminationReason,
)
@@ -33,8 +31,9 @@ def test_job_termination_reason_to_retry_event_works_with_all_enum_variants():
assert retry_event is None or isinstance(retry_event, RetryEvent)
-# Will fail if JobTerminationReason value is added without updaing JobSubmission._get_error
+# Will fail if JobTerminationReason value is added without updating JobSubmission._get_error
def test_get_error_returns_expected_messages():
+ # already handled and shown in status_message
no_error_reasons = [
JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY,
JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY,
@@ -47,7 +46,7 @@ def test_get_error_returns_expected_messages():
]
for reason in JobTerminationReason:
- if JobSubmission._get_error(reason) is None:
+ if reason.to_error() is None:
# Fail no-error reason is not in the list
assert reason in no_error_reasons
@@ -62,6 +61,6 @@ def test_run_get_error_returns_none_for_specific_reasons():
]
for reason in RunTerminationReason:
- if Run._get_error(reason) is None:
+ if reason.to_error() is None:
# Fail no-error reason is not in the list
assert reason in no_error_reasons
diff --git a/src/tests/_internal/server/background/tasks/test_process_gateways.py b/src/tests/_internal/server/background/tasks/test_process_gateways.py
index 159547af4..3460f18cb 100644
--- a/src/tests/_internal/server/background/tasks/test_process_gateways.py
+++ b/src/tests/_internal/server/background/tasks/test_process_gateways.py
@@ -5,56 +5,48 @@
from dstack._internal.core.errors import BackendError
from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus
-from dstack._internal.server.background.tasks.process_gateways import process_submitted_gateways
+from dstack._internal.server.background.tasks.process_gateways import process_gateways
from dstack._internal.server.testing.common import (
AsyncContextManager,
ComputeMockSpec,
create_backend,
create_gateway,
+ create_gateway_compute,
create_project,
)
+@pytest.mark.asyncio
+@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
class TestProcessSubmittedGateways:
- @pytest.mark.asyncio
- @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
- async def test_provisions_gateway(self, test_db, session: AsyncSession):
+ async def test_submitted_to_provisioning(self, test_db, session: AsyncSession):
project = await create_project(session=session)
backend = await create_backend(session=session, project_id=project.id)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
+ status=GatewayStatus.SUBMITTED,
)
- with (
- patch(
- "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error"
- ) as m,
- patch(
- "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add"
- ) as pool_add,
- ):
+ with patch(
+ "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error"
+ ) as m:
aws = Mock()
m.return_value = (backend, aws)
- pool_add.return_value = MagicMock()
- pool_add.return_value.client.return_value = MagicMock(AsyncContextManager())
aws.compute.return_value = Mock(spec=ComputeMockSpec)
aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData(
instance_id="i-1234567890",
ip_address="2.2.2.2",
region="us",
)
- await process_submitted_gateways()
+ await process_gateways()
m.assert_called_once()
aws.compute.return_value.create_gateway.assert_called_once()
- pool_add.assert_called_once()
await session.refresh(gateway)
- assert gateway.status == GatewayStatus.RUNNING
+ assert gateway.status == GatewayStatus.PROVISIONING
assert gateway.gateway_compute is not None
assert gateway.gateway_compute.ip_address == "2.2.2.2"
- @pytest.mark.asyncio
- @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_marks_gateway_as_failed_if_gateway_creation_errors(
self, test_db, session: AsyncSession
):
@@ -64,6 +56,7 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors(
session=session,
project_id=project.id,
backend_id=backend.id,
+ status=GatewayStatus.SUBMITTED,
)
with patch(
"dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error"
@@ -72,47 +65,57 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors(
m.return_value = (backend, aws)
aws.compute.return_value = Mock(spec=ComputeMockSpec)
aws.compute.return_value.create_gateway.side_effect = BackendError("Some error")
- await process_submitted_gateways()
+ await process_gateways()
m.assert_called_once()
aws.compute.return_value.create_gateway.assert_called_once()
await session.refresh(gateway)
assert gateway.status == GatewayStatus.FAILED
assert gateway.status_message == "Some error"
- @pytest.mark.asyncio
- @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+class TestProcessProvisioningGateways:
+ async def test_provisioning_to_running(self, test_db, session: AsyncSession):
+ project = await create_project(session=session)
+ backend = await create_backend(session=session, project_id=project.id)
+ gateway_compute = await create_gateway_compute(session)
+ gateway = await create_gateway(
+ session=session,
+ project_id=project.id,
+ backend_id=backend.id,
+ gateway_compute_id=gateway_compute.id,
+ status=GatewayStatus.PROVISIONING,
+ )
+ with patch(
+ "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add"
+ ) as pool_add:
+ pool_add.return_value = MagicMock()
+ pool_add.return_value.client.return_value = MagicMock(AsyncContextManager())
+ await process_gateways()
+ pool_add.assert_called_once()
+ await session.refresh(gateway)
+ assert gateway.status == GatewayStatus.RUNNING
+
async def test_marks_gateway_as_failed_if_fails_to_connect(
self, test_db, session: AsyncSession
):
project = await create_project(session=session)
backend = await create_backend(session=session, project_id=project.id)
+ gateway_compute = await create_gateway_compute(session)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
+ gateway_compute_id=gateway_compute.id,
+ status=GatewayStatus.PROVISIONING,
)
- with (
- patch(
- "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error"
- ) as m,
- patch(
- "dstack._internal.server.services.gateways.connect_to_gateway_with_retry"
- ) as connect_to_gateway_with_retry_mock,
- ):
- aws = Mock()
- m.return_value = (backend, aws)
+ with patch(
+ "dstack._internal.server.services.gateways.connect_to_gateway_with_retry"
+ ) as connect_to_gateway_with_retry_mock:
connect_to_gateway_with_retry_mock.return_value = None
- aws.compute.return_value = Mock(spec=ComputeMockSpec)
- aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData(
- instance_id="i-1234567890",
- ip_address="2.2.2.2",
- region="us",
- )
- await process_submitted_gateways()
- m.assert_called_once()
- aws.compute.return_value.create_gateway.assert_called_once()
+ await process_gateways()
connect_to_gateway_with_retry_mock.assert_called_once()
await session.refresh(gateway)
assert gateway.status == GatewayStatus.FAILED
- assert gateway.gateway_compute is not None
- assert gateway.gateway_compute is not None
+ assert gateway.status_message == "Failed to connect to gateway"
diff --git a/src/tests/_internal/server/background/tasks/test_process_idle_volumes.py b/src/tests/_internal/server/background/tasks/test_process_idle_volumes.py
new file mode 100644
index 000000000..8f6621186
--- /dev/null
+++ b/src/tests/_internal/server/background/tasks/test_process_idle_volumes.py
@@ -0,0 +1,190 @@
+import datetime
+from unittest.mock import Mock, patch
+
+import pytest
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.volumes import VolumeStatus
+from dstack._internal.server.background.tasks.process_idle_volumes import (
+ _get_idle_time,
+ _should_delete_volume,
+ process_idle_volumes,
+)
+from dstack._internal.server.models import VolumeAttachmentModel
+from dstack._internal.server.testing.common import (
+ ComputeMockSpec,
+ create_instance,
+ create_project,
+ create_user,
+ create_volume,
+ get_volume_configuration,
+ get_volume_provisioning_data,
+)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+class TestProcessIdleVolumes:
+ async def test_deletes_idle_volumes(self, test_db, session: AsyncSession):
+ project = await create_project(session=session)
+ user = await create_user(session=session)
+
+ config1 = get_volume_configuration(
+ name="test-volume",
+ auto_cleanup_duration="1h",
+ )
+ config2 = get_volume_configuration(
+ name="test-volume",
+ auto_cleanup_duration="3h",
+ )
+ volume1 = await create_volume(
+ session=session,
+ project=project,
+ user=user,
+ status=VolumeStatus.ACTIVE,
+ backend=BackendType.AWS,
+ configuration=config1,
+ volume_provisioning_data=get_volume_provisioning_data(),
+ last_job_processed_at=datetime.datetime.now(datetime.timezone.utc)
+ - datetime.timedelta(hours=2),
+ )
+ volume2 = await create_volume(
+ session=session,
+ project=project,
+ user=user,
+ status=VolumeStatus.ACTIVE,
+ backend=BackendType.AWS,
+ configuration=config2,
+ volume_provisioning_data=get_volume_provisioning_data(),
+ last_job_processed_at=datetime.datetime.now(datetime.timezone.utc)
+ - datetime.timedelta(hours=2),
+ )
+ await session.commit()
+
+ with patch(
+ "dstack._internal.server.services.backends.get_project_backend_by_type_or_error"
+ ) as m:
+ aws_mock = Mock()
+ m.return_value = aws_mock
+ aws_mock.compute.return_value = Mock(spec=ComputeMockSpec)
+ await process_idle_volumes()
+
+ await session.refresh(volume1)
+ await session.refresh(volume2)
+ assert volume1.deleted
+ assert volume1.deleted_at is not None
+ assert not volume2.deleted
+ assert volume2.deleted_at is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+class TestShouldDeleteVolume:
+ async def test_no_idle_duration(self, test_db, session: AsyncSession):
+ project = await create_project(session=session)
+ user = await create_user(session=session)
+
+ volume = await create_volume(
+ session=session,
+ project=project,
+ user=user,
+ status=VolumeStatus.ACTIVE,
+ backend=BackendType.AWS,
+ configuration=get_volume_configuration(name="test-volume"),
+ volume_provisioning_data=get_volume_provisioning_data(),
+ )
+
+ assert not _should_delete_volume(volume)
+
+ async def test_idle_duration_disabled(self, test_db, session: AsyncSession):
+ project = await create_project(session=session)
+ user = await create_user(session=session)
+
+ config = get_volume_configuration(name="test-volume")
+ config.auto_cleanup_duration = -1
+
+ volume = await create_volume(
+ session=session,
+ project=project,
+ user=user,
+ status=VolumeStatus.ACTIVE,
+ backend=BackendType.AWS,
+ configuration=config,
+ volume_provisioning_data=get_volume_provisioning_data(),
+ )
+
+ assert not _should_delete_volume(volume)
+
+ async def test_volume_attached(self, test_db, session: AsyncSession):
+ project = await create_project(session=session)
+ user = await create_user(session=session)
+
+ config = get_volume_configuration(name="test-volume")
+ config.auto_cleanup_duration = "1h"
+
+ volume = await create_volume(
+ session=session,
+ project=project,
+ user=user,
+ status=VolumeStatus.ACTIVE,
+ backend=BackendType.AWS,
+ configuration=config,
+ volume_provisioning_data=get_volume_provisioning_data(),
+ )
+
+ instance = await create_instance(session=session, project=project)
+ volume.attachments.append(
+ VolumeAttachmentModel(volume_id=volume.id, instance_id=instance.id)
+ )
+ await session.commit()
+
+ assert not _should_delete_volume(volume)
+
+ async def test_idle_duration_threshold(self, test_db, session: AsyncSession):
+ project = await create_project(session=session)
+ user = await create_user(session=session)
+
+ config = get_volume_configuration(name="test-volume")
+ config.auto_cleanup_duration = "1h"
+
+ volume = await create_volume(
+ session=session,
+ project=project,
+ user=user,
+ status=VolumeStatus.ACTIVE,
+ backend=BackendType.AWS,
+ configuration=config,
+ volume_provisioning_data=get_volume_provisioning_data(),
+ )
+
+ # Not exceeded - 30 minutes ago
+ volume.last_job_processed_at = (
+ datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(minutes=30)
+ ).replace(tzinfo=None)
+ assert not _should_delete_volume(volume)
+
+ # Exceeded - 2 hours ago
+ volume.last_job_processed_at = (
+ datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=2)
+ ).replace(tzinfo=None)
+ assert _should_delete_volume(volume)
+
+ async def test_never_used_volume(self, test_db, session: AsyncSession):
+ project = await create_project(session=session)
+ user = await create_user(session=session)
+
+ volume = await create_volume(
+ session=session,
+ project=project,
+ user=user,
+ status=VolumeStatus.ACTIVE,
+ backend=BackendType.AWS,
+ configuration=get_volume_configuration(name="test-volume"),
+ volume_provisioning_data=get_volume_provisioning_data(),
+ created_at=datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=2),
+ )
+
+ volume.last_job_processed_at = None
+ idle_time = _get_idle_time(volume)
+ assert idle_time.total_seconds() >= 7000
diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py
index 87a970c73..a0ae8ce72 100644
--- a/src/tests/_internal/server/routers/test_fleets.py
+++ b/src/tests/_internal/server/routers/test_fleets.py
@@ -10,7 +10,12 @@
from sqlalchemy.ext.asyncio import AsyncSession
from dstack._internal.core.models.backends.base import BackendType
-from dstack._internal.core.models.fleets import FleetConfiguration, FleetStatus, SSHParams
+from dstack._internal.core.models.fleets import (
+ FleetConfiguration,
+ FleetStatus,
+ InstanceGroupPlacement,
+ SSHParams,
+)
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceOfferWithAvailability,
@@ -21,6 +26,7 @@
)
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server.models import FleetModel, InstanceModel
+from dstack._internal.server.services.fleets import fleet_model_to_fleet
from dstack._internal.server.services.permissions import DefaultPermissions
from dstack._internal.server.services.projects import add_project_member
from dstack._internal.server.testing.common import (
@@ -35,7 +41,11 @@
get_auth_headers,
get_fleet_configuration,
get_fleet_spec,
+ get_instance_offer_with_availability,
+ get_job_provisioning_data,
get_private_key_string,
+ get_remote_connection_info,
+ get_ssh_fleet_configuration,
)
pytestmark = pytest.mark.usefixtures("image_config_mock")
@@ -415,17 +425,14 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
- spec = get_fleet_spec(
- conf=FleetConfiguration(
- name="test-ssh-fleet",
- ssh_config=SSHParams(
- user="ubuntu",
- ssh_key=SSHKey(public="", private=get_private_key_string()),
- hosts=["1.1.1.1"],
- network=None,
- ),
- )
+ conf = get_ssh_fleet_configuration(
+ name="test-ssh-fleet",
+ user="ubuntu",
+ ssh_key=SSHKey(public="", private=get_private_key_string()),
+ hosts=["1.1.1.1"],
+ network=None,
)
+ spec = get_fleet_spec(conf=conf)
with patch("uuid.uuid4") as m:
m.return_value = UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e")
response = await client.post(
@@ -541,6 +548,212 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A
instance = res.unique().scalar_one()
assert instance.remote_connection_info is not None
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ @freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), real_asyncio=True)
+ async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: AsyncClient):
+ user = await create_user(session, global_role=GlobalRole.USER)
+ project = await create_project(session)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.USER
+ )
+ current_conf = get_ssh_fleet_configuration(
+ name="test-ssh-fleet",
+ user="ubuntu",
+ ssh_key=SSHKey(public="", private=get_private_key_string()),
+ hosts=["10.0.0.100"],
+ network=None,
+ )
+ current_spec = get_fleet_spec(conf=current_conf)
+ spec = current_spec.copy(deep=True)
+ # 10.0.0.100 removed, 10.0.0.101 added
+ spec.configuration.ssh_config.hosts = ["10.0.0.101"]
+
+ fleet = await create_fleet(session=session, project=project, spec=current_spec)
+ instance_type = InstanceType(
+ name="ssh",
+ resources=Resources(cpus=2, memory_mib=8, gpus=[], spot=False),
+ )
+ instance = await create_instance(
+ session=session,
+ project=project,
+ fleet=fleet,
+ backend=BackendType.REMOTE,
+ name="test-ssh-fleet-0",
+ region="remote",
+ price=0.0,
+ status=InstanceStatus.IDLE,
+ offer=get_instance_offer_with_availability(
+ backend=BackendType.REMOTE,
+ region="remote",
+ price=0.0,
+ ),
+ job_provisioning_data=get_job_provisioning_data(
+ instance_type=instance_type,
+ hostname="10.0.0.100",
+ ),
+ remote_connection_info=get_remote_connection_info(host="10.0.0.100"),
+ )
+
+ with patch("uuid.uuid4") as m:
+ m.return_value = UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e")
+ response = await client.post(
+ f"/api/project/{project.name}/fleets/apply",
+ headers=get_auth_headers(user.token),
+ json={
+ "plan": {
+ "spec": spec.dict(),
+ "current_resource": _fleet_model_to_json_dict(fleet),
+ },
+ "force": False,
+ },
+ )
+
+ assert response.status_code == 200, response.json()
+ assert response.json() == {
+ "id": str(fleet.id),
+ "name": spec.configuration.name,
+ "project_name": project.name,
+ "spec": {
+ "configuration_path": spec.configuration_path,
+ "configuration": {
+ "env": {},
+ "ssh_config": {
+ "user": "ubuntu",
+ "port": None,
+ "identity_file": None,
+ "ssh_key": None, # should not return ssh_key
+ "proxy_jump": None,
+ "hosts": ["10.0.0.101"],
+ "network": None,
+ },
+ "nodes": None,
+ "placement": None,
+ "resources": {
+ "cpu": {"min": 2, "max": None},
+ "memory": {"min": 8.0, "max": None},
+ "shm_size": None,
+ "gpu": None,
+ "disk": {"size": {"min": 100.0, "max": None}},
+ },
+ "backends": None,
+ "regions": None,
+ "availability_zones": None,
+ "instance_types": None,
+ "spot_policy": None,
+ "retry": None,
+ "max_price": None,
+ "idle_duration": None,
+ "type": "fleet",
+ "name": spec.configuration.name,
+ "reservation": None,
+ "blocks": 1,
+ "tags": None,
+ },
+ "profile": {
+ "backends": None,
+ "regions": None,
+ "availability_zones": None,
+ "instance_types": None,
+ "spot_policy": None,
+ "retry": None,
+ "max_duration": None,
+ "stop_duration": None,
+ "max_price": None,
+ "creation_policy": None,
+ "idle_duration": None,
+ "utilization_policy": None,
+ "startup_order": None,
+ "stop_criteria": None,
+ "name": "",
+ "default": False,
+ "reservation": None,
+ "fleets": None,
+ "tags": None,
+ },
+ "autocreated": False,
+ },
+ "created_at": "2023-01-02T03:04:00+00:00",
+ "status": "active",
+ "status_message": None,
+ "instances": [
+ {
+ "id": str(instance.id),
+ "project_name": project.name,
+ "backend": "remote",
+ "instance_type": {
+ "name": "ssh",
+ "resources": {
+ "cpu_arch": None,
+ "cpus": 2,
+ "memory_mib": 8,
+ "gpus": [],
+ "spot": False,
+ "disk": {"size_mib": 102400},
+ "description": "cpu=2 mem=0GB disk=100GB",
+ },
+ },
+ "name": "test-ssh-fleet-0",
+ "fleet_id": str(fleet.id),
+ "fleet_name": "test-ssh-fleet",
+ "instance_num": 0,
+ "job_name": None,
+ "hostname": "10.0.0.100",
+ "status": "terminating",
+ "unreachable": False,
+ "termination_reason": None,
+ "created": "2023-01-02T03:04:00+00:00",
+ "region": "remote",
+ "availability_zone": None,
+ "price": 0.0,
+ "total_blocks": 1,
+ "busy_blocks": 0,
+ },
+ {
+ "id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e",
+ "project_name": project.name,
+ "backend": "remote",
+ "instance_type": {
+ "name": "ssh",
+ "resources": {
+ "cpu_arch": None,
+ "cpus": 2,
+ "memory_mib": 8,
+ "gpus": [],
+ "spot": False,
+ "disk": {"size_mib": 102400},
+ "description": "cpu=2 mem=0GB disk=100GB",
+ },
+ },
+ "name": "test-ssh-fleet-1",
+ "fleet_id": str(fleet.id),
+ "fleet_name": "test-ssh-fleet",
+ "instance_num": 1,
+ "job_name": None,
+ "hostname": "10.0.0.101",
+ "status": "pending",
+ "unreachable": False,
+ "termination_reason": None,
+ "created": "2023-01-02T03:04:00+00:00",
+ "region": "remote",
+ "availability_zone": None,
+ "price": 0.0,
+ "total_blocks": 1,
+ "busy_blocks": 0,
+ },
+ ],
+ }
+ res = await session.execute(select(FleetModel))
+ assert res.scalar_one()
+ await session.refresh(instance)
+ assert instance.status == InstanceStatus.TERMINATING
+ res = await session.execute(
+ select(InstanceModel).where(InstanceModel.id == "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e")
+ )
+ instance = res.unique().scalar_one()
+ assert instance.status == InstanceStatus.PENDING
+ assert instance.remote_connection_info is not None
+
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc))
@@ -820,7 +1033,9 @@ async def test_returns_40x_if_not_authenticated(
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
- async def test_returns_plan(self, test_db, session: AsyncSession, client: AsyncClient):
+ async def test_returns_create_plan_for_new_fleet(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
@@ -855,10 +1070,91 @@ async def test_returns_plan(self, test_db, session: AsyncSession, client: AsyncC
assert response.json() == {
"project_name": project.name,
"user": user.name,
- "spec": spec.dict(),
- "effective_spec": spec.dict(),
+ "spec": json.loads(spec.json()),
+ "effective_spec": json.loads(spec.json()),
"current_resource": None,
"offers": [json.loads(o.json()) for o in offers],
"total_offers": len(offers),
"max_offer_price": 1.0,
+ "action": "create",
}
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_returns_update_plan_for_existing_fleet(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.USER
+ )
+ conf = get_ssh_fleet_configuration(hosts=["10.0.0.100"])
+ spec = get_fleet_spec(conf=conf)
+ effective_spec = spec.copy(deep=True)
+ effective_spec.configuration.ssh_config.ssh_key = None
+ current_spec = spec.copy(deep=True)
+ # `hosts` can be updated in-place
+ current_spec.configuration.ssh_config.hosts = ["10.0.0.100", "10.0.0.101"]
+ fleet = await create_fleet(session=session, project=project, spec=current_spec)
+
+ response = await client.post(
+ f"/api/project/{project.name}/fleets/get_plan",
+ headers=get_auth_headers(user.token),
+ json={"spec": spec.dict()},
+ )
+
+ assert response.status_code == 200
+ assert response.json() == {
+ "project_name": project.name,
+ "user": user.name,
+ "spec": spec.dict(),
+ "effective_spec": effective_spec.dict(),
+ "current_resource": _fleet_model_to_json_dict(fleet),
+ "offers": [],
+ "total_offers": 0,
+ "max_offer_price": None,
+ "action": "update",
+ }
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_returns_create_plan_for_existing_fleet(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.USER
+ )
+ conf = get_ssh_fleet_configuration(placement=InstanceGroupPlacement.ANY)
+ spec = get_fleet_spec(conf=conf)
+ effective_spec = spec.copy(deep=True)
+ effective_spec.configuration.ssh_config.ssh_key = None
+ current_spec = spec.copy(deep=True)
+ # `placement` cannot be updated in-place
+ current_spec.configuration.placement = InstanceGroupPlacement.CLUSTER
+ fleet = await create_fleet(session=session, project=project, spec=current_spec)
+
+ response = await client.post(
+ f"/api/project/{project.name}/fleets/get_plan",
+ headers=get_auth_headers(user.token),
+ json={"spec": spec.dict()},
+ )
+
+ assert response.status_code == 200
+ assert response.json() == {
+ "project_name": project.name,
+ "user": user.name,
+ "spec": spec.dict(),
+ "effective_spec": effective_spec.dict(),
+ "current_resource": _fleet_model_to_json_dict(fleet),
+ "offers": [],
+ "total_offers": 0,
+ "max_offer_price": None,
+ "action": "create",
+ }
+
+
+def _fleet_model_to_json_dict(fleet: FleetModel) -> dict:
+ return json.loads(fleet_model_to_fleet(fleet).json())
diff --git a/src/tests/_internal/server/routers/test_logs.py b/src/tests/_internal/server/routers/test_logs.py
index 11f0da8da..0364edee4 100644
--- a/src/tests/_internal/server/routers/test_logs.py
+++ b/src/tests/_internal/server/routers/test_logs.py
@@ -62,17 +62,17 @@ async def test_returns_logs(
{
"timestamp": "2023-10-06T10:01:53.234234+00:00",
"log_source": "stdout",
- "message": "Hello",
+ "message": "SGVsbG8=",
},
{
"timestamp": "2023-10-06T10:01:53.234235+00:00",
"log_source": "stdout",
- "message": "World",
+ "message": "V29ybGQ=",
},
{
"timestamp": "2023-10-06T10:01:53.234236+00:00",
"log_source": "stdout",
- "message": "!",
+ "message": "IQ==",
},
],
"next_token": None,
@@ -93,7 +93,7 @@ async def test_returns_logs(
{
"timestamp": "2023-10-06T10:01:53.234236+00:00",
"log_source": "stdout",
- "message": "!",
+ "message": "IQ==",
},
],
"next_token": None,
diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py
index 96adb6499..4ec0c4225 100644
--- a/src/tests/_internal/server/routers/test_runs.py
+++ b/src/tests/_internal/server/routers/test_runs.py
@@ -246,6 +246,7 @@ def get_dev_env_run_plan_dict(
"repo_code_hash": None,
"repo_data": {"repo_dir": "/repo", "repo_type": "local"},
"file_archives": [],
+ "service_port": None,
},
"offers": [json.loads(o.json()) for o in offers],
"total_offers": total_offers,
@@ -441,6 +442,7 @@ def get_dev_env_run_dict(
"repo_code_hash": None,
"repo_data": {"repo_dir": "/repo", "repo_type": "local"},
"file_archives": [],
+ "service_port": None,
},
"job_submissions": [
{
@@ -707,6 +709,108 @@ async def test_lists_runs_pagination(
assert len(response2_json) == 1
assert response2_json[0]["id"] == str(run2.id)
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_limits_job_submissions(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session=session, global_role=GlobalRole.USER)
+ project = await create_project(session=session, owner=user)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.USER
+ )
+ repo = await create_repo(
+ session=session,
+ project_id=project.id,
+ )
+ run_submitted_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)
+ run = await create_run(
+ session=session,
+ project=project,
+ repo=repo,
+ user=user,
+ submitted_at=run_submitted_at,
+ )
+ run_spec = RunSpec.parse_raw(run.run_spec)
+ await create_job(
+ session=session,
+ run=run,
+ submitted_at=run_submitted_at,
+ last_processed_at=run_submitted_at,
+ )
+ job2 = await create_job(
+ session=session,
+ run=run,
+ submitted_at=run_submitted_at,
+ last_processed_at=run_submitted_at,
+ )
+ job2_spec = JobSpec.parse_raw(job2.job_spec_data)
+ response = await client.post(
+ "/api/runs/list",
+ headers=get_auth_headers(user.token),
+ json={"job_submissions_limit": 1},
+ )
+ assert response.status_code == 200, response.json()
+ assert response.json() == [
+ {
+ "id": str(run.id),
+ "project_name": project.name,
+ "user": user.name,
+ "submitted_at": run_submitted_at.isoformat(),
+ "last_processed_at": run_submitted_at.isoformat(),
+ "status": "submitted",
+ "status_message": "submitted",
+ "run_spec": run_spec.dict(),
+ "jobs": [
+ {
+ "job_spec": job2_spec.dict(),
+ "job_submissions": [
+ {
+ "id": str(job2.id),
+ "submission_num": 0,
+ "deployment_num": 0,
+ "submitted_at": run_submitted_at.isoformat(),
+ "last_processed_at": run_submitted_at.isoformat(),
+ "finished_at": None,
+ "inactivity_secs": None,
+ "status": "submitted",
+ "status_message": "submitted",
+ "termination_reason": None,
+ "termination_reason_message": None,
+ "error": None,
+ "exit_status": None,
+ "job_provisioning_data": None,
+ "job_runtime_data": None,
+ }
+ ],
+ }
+ ],
+ "latest_job_submission": {
+ "id": str(job2.id),
+ "submission_num": 0,
+ "deployment_num": 0,
+ "submitted_at": run_submitted_at.isoformat(),
+ "last_processed_at": run_submitted_at.isoformat(),
+ "finished_at": None,
+ "inactivity_secs": None,
+ "status": "submitted",
+ "status_message": "submitted",
+ "termination_reason_message": None,
+ "termination_reason": None,
+ "error": None,
+ "exit_status": None,
+ "job_provisioning_data": None,
+ "job_runtime_data": None,
+ },
+ "cost": 0,
+ "service": None,
+ "deployment_num": 0,
+ "termination_reason": None,
+ "error": None,
+ "deleted": False,
+ },
+ ]
+
class TestGetRun:
@pytest.mark.asyncio
@@ -1074,12 +1178,14 @@ async def test_returns_run_plan_instance_volumes(
ServiceConfiguration(
commands=["one", "two"],
port=80,
+ gateway=None,
replicas=1,
scaling=None,
),
ServiceConfiguration(
commands=["one", "two"],
- port=8080, # not updatable
+ port=8080,
+ gateway="test-gateway", # not updatable
replicas="2..4",
scaling=ScalingSpec(metric="rps", target=5),
),
diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py
index d19290749..df4a8f3c0 100644
--- a/src/tests/_internal/server/routers/test_users.py
+++ b/src/tests/_internal/server/routers/test_users.py
@@ -22,19 +22,71 @@ async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClie
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
- async def test_returns_users(self, test_db, session: AsyncSession, client: AsyncClient):
- user = await create_user(
+ async def test_admins_see_all_users(self, test_db, session: AsyncSession, client: AsyncClient):
+ admin = await create_user(
+ session=session,
+ name="admin",
+ created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
+ global_role=GlobalRole.ADMIN,
+ )
+ other_user = await create_user(
+ session=session,
+ name="other_user",
+ created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
+ global_role=GlobalRole.USER,
+ )
+ response = await client.post("/api/users/list", headers=get_auth_headers(admin.token))
+ assert response.status_code in [200]
+ assert response.json() == [
+ {
+ "id": str(admin.id),
+ "username": admin.name,
+ "created_at": "2023-01-02T03:04:00+00:00",
+ "global_role": admin.global_role,
+ "email": None,
+ "active": True,
+ "permissions": {
+ "can_create_projects": True,
+ },
+ },
+ {
+ "id": str(other_user.id),
+ "username": other_user.name,
+ "created_at": "2023-01-02T03:04:00+00:00",
+ "global_role": other_user.global_role,
+ "email": None,
+ "active": True,
+ "permissions": {
+ "can_create_projects": True,
+ },
+ },
+ ]
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_non_admins_see_only_themselves(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ await create_user(
+ session=session,
+ name="admin",
+ created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
+ global_role=GlobalRole.ADMIN,
+ )
+ other_user = await create_user(
session=session,
+ name="other_user",
created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
+ global_role=GlobalRole.USER,
)
- response = await client.post("/api/users/list", headers=get_auth_headers(user.token))
+ response = await client.post("/api/users/list", headers=get_auth_headers(other_user.token))
assert response.status_code in [200]
assert response.json() == [
{
- "id": str(user.id),
- "username": user.name,
+ "id": str(other_user.id),
+ "username": other_user.name,
"created_at": "2023-01-02T03:04:00+00:00",
- "global_role": user.global_role,
+ "global_role": other_user.global_role,
"email": None,
"active": True,
"permissions": {
diff --git a/src/tests/_internal/server/services/test_logs.py b/src/tests/_internal/server/services/test_logs.py
index 19769a360..0b9420917 100644
--- a/src/tests/_internal/server/services/test_logs.py
+++ b/src/tests/_internal/server/services/test_logs.py
@@ -1,4 +1,3 @@
-import base64
import logging
from datetime import datetime, timedelta, timezone
from pathlib import Path
@@ -13,6 +12,7 @@
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
+from dstack._internal.core.errors import ServerClientError
from dstack._internal.core.models.logs import LogEvent, LogEventSource
from dstack._internal.server.models import ProjectModel
from dstack._internal.server.schemas.logs import PollLogsRequest
@@ -51,8 +51,8 @@ async def test_writes_logs(self, test_db, session: AsyncSession, tmp_path: Path)
/ "runner.log"
)
assert runner_log_path.read_text() == (
- '{"timestamp": "2023-10-06T10:01:53.234000+00:00", "log_source": "stdout", "message": "SGVsbG8="}\n'
- '{"timestamp": "2023-10-06T10:01:53.235000+00:00", "log_source": "stdout", "message": "V29ybGQ="}\n'
+ '{"timestamp":"2023-10-06T10:01:53.234000+00:00","log_source":"stdout","message":"Hello"}\n'
+ '{"timestamp":"2023-10-06T10:01:53.235000+00:00","log_source":"stdout","message":"World"}\n'
)
@pytest.mark.asyncio
@@ -119,12 +119,8 @@ async def test_poll_logs_with_next_token_pagination(
job_submission_logs = log_storage.poll_logs(project, poll_request)
assert len(job_submission_logs.logs) == 2
- assert job_submission_logs.logs[0].message == base64.b64encode(
- "Log1".encode("utf-8")
- ).decode("utf-8")
- assert job_submission_logs.logs[1].message == base64.b64encode(
- "Log2".encode("utf-8")
- ).decode("utf-8")
+ assert job_submission_logs.logs[0].message == "Log1"
+ assert job_submission_logs.logs[1].message == "Log2"
assert job_submission_logs.next_token == "2" # Next line to read
# Second page: use next_token
@@ -132,12 +128,8 @@ async def test_poll_logs_with_next_token_pagination(
job_submission_logs = log_storage.poll_logs(project, poll_request)
assert len(job_submission_logs.logs) == 2
- assert job_submission_logs.logs[0].message == base64.b64encode(
- "Log3".encode("utf-8")
- ).decode("utf-8")
- assert job_submission_logs.logs[1].message == base64.b64encode(
- "Log4".encode("utf-8")
- ).decode("utf-8")
+ assert job_submission_logs.logs[0].message == "Log3"
+ assert job_submission_logs.logs[1].message == "Log4"
assert job_submission_logs.next_token == "4" # Next line to read
# Third page: get remaining log
@@ -145,9 +137,7 @@ async def test_poll_logs_with_next_token_pagination(
job_submission_logs = log_storage.poll_logs(project, poll_request)
assert len(job_submission_logs.logs) == 1
- assert job_submission_logs.logs[0].message == base64.b64encode(
- "Log5".encode("utf-8")
- ).decode("utf-8")
+ assert job_submission_logs.logs[0].message == "Log5"
assert job_submission_logs.next_token is None # No more logs
@pytest.mark.asyncio
@@ -182,12 +172,8 @@ async def test_poll_logs_with_start_from_specific_line(
job_submission_logs = log_storage.poll_logs(project, poll_request)
assert len(job_submission_logs.logs) == 2
- assert job_submission_logs.logs[0].message == base64.b64encode(
- "Log2".encode("utf-8")
- ).decode("utf-8")
- assert job_submission_logs.logs[1].message == base64.b64encode(
- "Log3".encode("utf-8")
- ).decode("utf-8")
+ assert job_submission_logs.logs[0].message == "Log2"
+ assert job_submission_logs.logs[1].message == "Log3"
assert job_submission_logs.next_token is None
@pytest.mark.asyncio
@@ -206,49 +192,22 @@ async def test_poll_logs_invalid_next_token_raises_error(
limit=10,
diagnose=True,
)
- with pytest.raises(
- LogStorageError, match="Invalid next_token: invalid. Must be a valid integer."
- ):
+ with pytest.raises(ServerClientError):
log_storage.poll_logs(project, poll_request)
# Test with negative next_token
poll_request.next_token = "-1"
- with pytest.raises(
- LogStorageError, match="Invalid next_token: -1. Must be a non-negative integer."
- ):
+ with pytest.raises(ServerClientError):
log_storage.poll_logs(project, poll_request)
# Test with float next_token
poll_request.next_token = "1.5"
- with pytest.raises(
- LogStorageError, match="Invalid next_token: 1.5. Must be a valid integer."
- ):
+ with pytest.raises(ServerClientError):
log_storage.poll_logs(project, poll_request)
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
- async def test_poll_logs_descending_raises_error(
- self, test_db, session: AsyncSession, tmp_path: Path
- ):
- project = await create_project(session=session)
- log_storage = FileLogStorage(tmp_path)
-
- # Test that descending=True raises LogStorageError
- poll_request = PollLogsRequest(
- run_name="test_run",
- job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"),
- limit=10,
- diagnose=True,
- # Note: This bypasses schema validation for testing the implementation
- )
- poll_request.descending = True # Set directly to bypass validation
-
- with pytest.raises(LogStorageError, match="descending: true is not supported"):
- log_storage.poll_logs(project, poll_request)
-
- @pytest.mark.asyncio
- @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
- async def test_poll_logs_file_not_found_raises_error(
+ async def test_poll_logs_file_not_found_raises_no_error(
self, test_db, session: AsyncSession, tmp_path: Path
):
project = await create_project(session=session)
@@ -261,11 +220,7 @@ async def test_poll_logs_file_not_found_raises_error(
limit=10,
diagnose=True,
)
-
- with pytest.raises(
- LogStorageError, match="Failed to read log file .* No such file or directory"
- ):
- log_storage.poll_logs(project, poll_request)
+ log_storage.poll_logs(project, poll_request)
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@@ -309,9 +264,7 @@ async def test_poll_logs_with_time_filtering_and_pagination(
# Should get Log3 first (timestamp > 235)
assert len(job_submission_logs.logs) == 1
- assert job_submission_logs.logs[0].message == base64.b64encode(
- "Log3".encode("utf-8")
- ).decode("utf-8")
+ assert job_submission_logs.logs[0].message == "Log3"
assert job_submission_logs.next_token == "3"
# Get next page
@@ -320,9 +273,7 @@ async def test_poll_logs_with_time_filtering_and_pagination(
# Should get Log4
assert len(job_submission_logs.logs) == 1
- assert job_submission_logs.logs[0].message == base64.b64encode(
- "Log4".encode("utf-8")
- ).decode("utf-8")
+ assert job_submission_logs.logs[0].message == "Log4"
# Should not have next_token since we reached end of file
assert job_submission_logs.next_token is None
@@ -389,9 +340,9 @@ async def test_next_token_pagination_complete_workflow(
page1 = log_storage.poll_logs(project, poll_request)
assert len(page1.logs) == 3
- assert page1.logs[0].message == base64.b64encode("Log1".encode()).decode()
- assert page1.logs[1].message == base64.b64encode("Log2".encode()).decode()
- assert page1.logs[2].message == base64.b64encode("Log3".encode()).decode()
+ assert page1.logs[0].message == "Log1"
+ assert page1.logs[1].message == "Log2"
+ assert page1.logs[2].message == "Log3"
assert page1.next_token == "3" # Next line to read
# Second page: use next_token
@@ -399,9 +350,9 @@ async def test_next_token_pagination_complete_workflow(
page2 = log_storage.poll_logs(project, poll_request)
assert len(page2.logs) == 3
- assert page2.logs[0].message == base64.b64encode("Log4".encode()).decode()
- assert page2.logs[1].message == base64.b64encode("Log5".encode()).decode()
- assert page2.logs[2].message == base64.b64encode("Log6".encode()).decode()
+ assert page2.logs[0].message == "Log4"
+ assert page2.logs[1].message == "Log5"
+ assert page2.logs[2].message == "Log6"
assert page2.next_token == "6"
# Third page: get more logs
@@ -409,9 +360,9 @@ async def test_next_token_pagination_complete_workflow(
page3 = log_storage.poll_logs(project, poll_request)
assert len(page3.logs) == 3
- assert page3.logs[0].message == base64.b64encode("Log7".encode()).decode()
- assert page3.logs[1].message == base64.b64encode("Log8".encode()).decode()
- assert page3.logs[2].message == base64.b64encode("Log9".encode()).decode()
+ assert page3.logs[0].message == "Log7"
+ assert page3.logs[1].message == "Log8"
+ assert page3.logs[2].message == "Log9"
assert page3.next_token == "9"
# Fourth page: get last log
@@ -419,8 +370,8 @@ async def test_next_token_pagination_complete_workflow(
page4 = log_storage.poll_logs(project, poll_request)
assert len(page4.logs) == 1
- assert page4.logs[0].message == base64.b64encode("Log10".encode()).decode()
- assert page4.next_token is None # No more logs
+ assert page4.logs[0].message == "Log10"
+ assert page4.next_token is None
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@@ -458,15 +409,15 @@ async def test_next_token_with_time_filtering(
page1 = log_storage.poll_logs(project, poll_request)
assert len(page1.logs) == 2
- assert page1.logs[0].message == base64.b64encode("Log3".encode()).decode()
- assert page1.logs[1].message == base64.b64encode("Log4".encode()).decode()
+ assert page1.logs[0].message == "Log3"
+ assert page1.logs[1].message == "Log4"
assert page1.next_token == "4"
# Get next page
poll_request.next_token = page1.next_token
page2 = log_storage.poll_logs(project, poll_request)
assert len(page2.logs) == 1
- assert page2.logs[0].message == base64.b64encode("Log5".encode()).decode()
+ assert page2.logs[0].message == "Log5"
assert page2.next_token is None
@pytest.mark.asyncio
@@ -497,16 +448,16 @@ async def test_next_token_edge_cases(self, test_db, session: AsyncSession, tmp_p
result = log_storage.poll_logs(project, poll_request)
assert len(result.logs) == 1
- assert result.logs[0].message == base64.b64encode("OnlyLog".encode()).decode()
- assert result.next_token is None # No more logs available
+ assert result.logs[0].message == "OnlyLog"
+ assert result.next_token is None
# Request with limit equal to available logs
poll_request.limit = 1
result = log_storage.poll_logs(project, poll_request)
assert len(result.logs) == 1
- assert result.logs[0].message == base64.b64encode("OnlyLog".encode()).decode()
- assert result.next_token is None # No more logs available
+ assert result.logs[0].message == "OnlyLog"
+ assert result.next_token is None
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@@ -607,15 +558,9 @@ async def test_poll_logs_with_limit(self, test_db, session: AsyncSession, tmp_pa
# Should return only the first 3 logs and provide next_token
assert len(job_submission_logs.logs) == 3
- assert job_submission_logs.logs[0].message == base64.b64encode(
- "Log1".encode("utf-8")
- ).decode("utf-8")
- assert job_submission_logs.logs[1].message == base64.b64encode(
- "Log2".encode("utf-8")
- ).decode("utf-8")
- assert job_submission_logs.logs[2].message == base64.b64encode(
- "Log3".encode("utf-8")
- ).decode("utf-8")
+ assert job_submission_logs.logs[0].message == "Log1"
+ assert job_submission_logs.logs[1].message == "Log2"
+ assert job_submission_logs.logs[2].message == "Log3"
# Should have next_token pointing to line 3 (fourth log)
assert job_submission_logs.next_token == "3"
@@ -624,9 +569,7 @@ async def test_poll_logs_with_limit(self, test_db, session: AsyncSession, tmp_pa
poll_request.start_time = logs[3].timestamp
job_submission_logs = log_storage.poll_logs(project, poll_request)
assert len(job_submission_logs.logs) == 1
- assert job_submission_logs.logs[0].message == base64.b64encode(
- "Log5".encode("utf-8")
- ).decode("utf-8")
+ assert job_submission_logs.logs[0].message == "Log5"
# Should not have next_token since we reached end of file
assert job_submission_logs.next_token is None
@@ -951,14 +894,14 @@ async def test_write_logs(
logGroupName="test-group",
logStreamName=expected_runner_stream,
logEvents=[
- {"timestamp": 1696586513234, "message": "SGVsbG8="},
+ {"timestamp": 1696586513234, "message": "Hello"},
],
),
call(
logGroupName="test-group",
logStreamName=expected_job_stream,
logEvents=[
- {"timestamp": 1696586513235, "message": "V29ybGQ="},
+ {"timestamp": 1696586513235, "message": "World"},
],
),
]
@@ -1056,11 +999,11 @@ async def test_write_logs_not_in_chronological_order(
logGroupName="test-group",
logStreamName="test-proj/test-run/1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e/runner",
logEvents=[
- {"timestamp": 1696586513235, "message": "MQ=="},
- {"timestamp": 1696586513236, "message": "Mg=="},
- {"timestamp": 1696586513237, "message": "Mw=="},
- {"timestamp": 1696586513237, "message": "NA=="},
- {"timestamp": 1696586513237, "message": "NQ=="},
+ {"timestamp": 1696586513235, "message": "1"},
+ {"timestamp": 1696586513236, "message": "2"},
+ {"timestamp": 1696586513237, "message": "3"},
+ {"timestamp": 1696586513237, "message": "4"},
+ {"timestamp": 1696586513237, "message": "5"},
],
)
assert "events are not in chronological order" in caplog.text
@@ -1098,7 +1041,7 @@ def _delta_ms(**kwargs: int) -> int:
assert "skipping 1 past event(s)" in caplog.text
assert "skipping 2 future event(s)" in caplog.text
actual = [
- base64.b64decode(e["message"]).decode()
+ e["message"]
for c in mock_client.put_log_events.call_args_list
for e in c.kwargs["logEvents"]
]
@@ -1143,8 +1086,8 @@ async def test_write_logs_batching_by_size(
messages: List[str],
expected: List[List[str]],
):
- # maximum 6 bytes: 12 (in base64) + 26 (overhead) = 34
- monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 34)
+ # maximum 6 bytes: 6 (raw bytes) + 26 (overhead) = 32
+ monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 32)
monkeypatch.setattr(CloudWatchLogStorage, "BATCH_MAX_SIZE", 60)
log_storage.write_logs(
project=project,
@@ -1158,7 +1101,7 @@ async def test_write_logs_batching_by_size(
)
assert mock_client.put_log_events.call_count == len(expected)
actual = [
- [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]]
+ [e["message"] for e in c.kwargs["logEvents"]]
for c in mock_client.put_log_events.call_args_list
]
assert actual == expected
@@ -1173,7 +1116,7 @@ async def test_write_logs_batching_by_size(
[["111", "111", "111"], ["222"]],
],
[
- ["111", "111", "111"] + ["222", "222", "toolong", "", "222222"],
+ ["111", "111", "111"] + ["222", "222", "toolongtoolong", "", "222222"],
[["111", "111", "111"], ["222", "222", "222222"]],
],
],
@@ -1190,8 +1133,8 @@ async def test_write_logs_batching_by_count(
messages: List[str],
expected: List[List[str]],
):
- # maximum 6 bytes: 12 (in base64) + 26 (overhead) = 34
- monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 34)
+ # maximum 6 bytes: 6 (raw bytes) + 26 (overhead) = 32
+ monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 32)
monkeypatch.setattr(CloudWatchLogStorage, "EVENT_MAX_COUNT_IN_BATCH", 3)
log_storage.write_logs(
project=project,
@@ -1205,7 +1148,7 @@ async def test_write_logs_batching_by_count(
)
assert mock_client.put_log_events.call_count == len(expected)
actual = [
- [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]]
+ [e["message"] for e in c.kwargs["logEvents"]]
for c in mock_client.put_log_events.call_args_list
]
assert actual == expected
@@ -1248,7 +1191,7 @@ def _delta_ms(**kwargs: int) -> int:
expected = [["1", "2", "3"], ["4", "5", "6"], ["7"]]
assert mock_client.put_log_events.call_count == len(expected)
actual = [
- [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]]
+ [e["message"] for e in c.kwargs["logEvents"]]
for c in mock_client.put_log_events.call_args_list
]
assert actual == expected
@@ -1262,8 +1205,8 @@ async def test_poll_logs_non_empty_response(
poll_logs_request: PollLogsRequest,
):
mock_client.get_log_events.return_value["events"] = [
- {"timestamp": 1696586513234, "message": "SGVsbG8="},
- {"timestamp": 1696586513235, "message": "V29ybGQ="},
+ {"timestamp": 1696586513234, "message": "Hello"},
+ {"timestamp": 1696586513235, "message": "World"},
]
poll_logs_request.limit = 2
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)
@@ -1272,12 +1215,12 @@ async def test_poll_logs_non_empty_response(
LogEvent(
timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc),
log_source=LogEventSource.STDOUT,
- message="SGVsbG8=",
+ message="Hello",
),
LogEvent(
timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc),
log_source=LogEventSource.STDOUT,
- message="V29ybGQ=",
+ message="World",
),
]
@@ -1290,8 +1233,8 @@ async def test_poll_logs_descending_non_empty_response_on_first_call(
poll_logs_request: PollLogsRequest,
):
mock_client.get_log_events.return_value["events"] = [
- {"timestamp": 1696586513234, "message": "SGVsbG8="},
- {"timestamp": 1696586513235, "message": "V29ybGQ="},
+ {"timestamp": 1696586513234, "message": "Hello"},
+ {"timestamp": 1696586513235, "message": "World"},
]
poll_logs_request.descending = True
poll_logs_request.limit = 2
@@ -1301,12 +1244,12 @@ async def test_poll_logs_descending_non_empty_response_on_first_call(
LogEvent(
timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc),
log_source=LogEventSource.STDOUT,
- message="V29ybGQ=",
+ message="World",
),
LogEvent(
timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc),
log_source=LogEventSource.STDOUT,
- message="SGVsbG8=",
+ message="Hello",
),
]
@@ -1322,8 +1265,8 @@ async def test_next_token_ascending_pagination(
# Setup response with nextForwardToken
mock_client.get_log_events.return_value = {
"events": [
- {"timestamp": 1696586513234, "message": "SGVsbG8="},
- {"timestamp": 1696586513235, "message": "V29ybGQ="},
+ {"timestamp": 1696586513234, "message": "Hello"},
+ {"timestamp": 1696586513235, "message": "World"},
],
"nextBackwardToken": "bwd",
"nextForwardToken": "fwd123",
@@ -1357,8 +1300,8 @@ async def test_next_token_descending_pagination(
# Setup response with nextBackwardToken
mock_client.get_log_events.return_value = {
"events": [
- {"timestamp": 1696586513234, "message": "SGVsbG8="},
- {"timestamp": 1696586513235, "message": "V29ybGQ="},
+ {"timestamp": 1696586513234, "message": "Hello"},
+ {"timestamp": 1696586513235, "message": "World"},
],
"nextBackwardToken": "bwd456",
"nextForwardToken": "fwd",
@@ -1370,8 +1313,8 @@ async def test_next_token_descending_pagination(
assert len(result.logs) == 2
# Events should be reversed for descending order
- assert result.logs[0].message == "V29ybGQ="
- assert result.logs[1].message == "SGVsbG8="
+ assert result.logs[0].message == "World"
+ assert result.logs[1].message == "Hello"
assert result.next_token == "bwd456" # Should return nextBackwardToken
# Verify API was called with correct parameters
@@ -1393,7 +1336,7 @@ async def test_next_token_provided_in_request(
"""Test that provided next_token is passed to CloudWatch API"""
mock_client.get_log_events.return_value = {
"events": [
- {"timestamp": 1696586513234, "message": "SGVsbG8="},
+ {"timestamp": 1696586513234, "message": "Hello"},
],
"nextBackwardToken": "bwd",
"nextForwardToken": "new_fwd",
@@ -1449,7 +1392,7 @@ async def test_next_token_with_time_filtering(
"""Test next_token behavior with time filtering"""
mock_client.get_log_events.return_value = {
"events": [
- {"timestamp": 1696586513234, "message": "SGVsbG8="},
+ {"timestamp": 1696586513234, "message": "Hello"},
],
"nextBackwardToken": "bwd_with_time",
"nextForwardToken": "fwd_with_time",
@@ -1487,7 +1430,7 @@ async def test_next_token_missing_in_cloudwatch_response(
"""Test behavior when CloudWatch doesn't return next tokens"""
mock_client.get_log_events.return_value = {
"events": [
- {"timestamp": 1696586513234, "message": "SGVsbG8="},
+ {"timestamp": 1696586513234, "message": "Hello"},
],
# No nextBackwardToken or nextForwardToken in response
}
@@ -1509,7 +1452,7 @@ async def test_next_token_empty_string_in_cloudwatch_response(
"""Test behavior when CloudWatch returns empty string tokens"""
mock_client.get_log_events.return_value = {
"events": [
- {"timestamp": 1696586513234, "message": "SGVsbG8="},
+ {"timestamp": 1696586513234, "message": "Hello"},
],
"nextBackwardToken": "",
"nextForwardToken": "",
@@ -1534,8 +1477,8 @@ async def test_next_token_pagination_workflow(
mock_client.get_log_events.side_effect = [
{
"events": [
- {"timestamp": 1696586513234, "message": "SGVsbG8="},
- {"timestamp": 1696586513235, "message": "V29ybGQ="},
+ {"timestamp": 1696586513234, "message": "Hello"},
+ {"timestamp": 1696586513235, "message": "World"},
],
"nextBackwardToken": "bwd",
"nextForwardToken": "token_page2",
@@ -1543,7 +1486,7 @@ async def test_next_token_pagination_workflow(
# Second call - returns final logs without next_token
{
"events": [
- {"timestamp": 1696586513236, "message": "IQ=="},
+ {"timestamp": 1696586513236, "message": "!"},
],
"nextBackwardToken": "final_bwd",
"nextForwardToken": "final_fwd",
@@ -1556,8 +1499,8 @@ async def test_next_token_pagination_workflow(
page1 = log_storage.poll_logs(project, poll_logs_request)
assert len(page1.logs) == 2
- assert page1.logs[0].message == "SGVsbG8="
- assert page1.logs[1].message == "V29ybGQ="
+ assert page1.logs[0].message == "Hello"
+ assert page1.logs[1].message == "World"
assert page1.next_token == "token_page2"
# Second page using next_token
@@ -1565,7 +1508,7 @@ async def test_next_token_pagination_workflow(
page2 = log_storage.poll_logs(project, poll_logs_request)
assert len(page2.logs) == 1
- assert page2.logs[0].message == "IQ=="
+ assert page2.logs[0].message == "!"
assert page2.next_token == "final_fwd"
# Verify both API calls
diff --git a/src/tests/_internal/server/services/test_volumes.py b/src/tests/_internal/server/services/test_volumes.py
index c2ba555a1..4de9c3f05 100644
--- a/src/tests/_internal/server/services/test_volumes.py
+++ b/src/tests/_internal/server/services/test_volumes.py
@@ -3,11 +3,78 @@
import pytest
from freezegun import freeze_time
-from dstack._internal.core.models.volumes import VolumeStatus
-from dstack._internal.server.services.volumes import _get_volume_cost
+from dstack._internal.core.errors import ServerClientError
+from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.volumes import VolumeConfiguration, VolumeStatus
+from dstack._internal.server.services.volumes import (
+ _get_volume_cost,
+ _validate_volume_configuration,
+)
from dstack._internal.server.testing.common import get_volume, get_volume_provisioning_data
+class TestValidateVolumeConfiguration:
+ def test_external_volume_with_auto_cleanup_duration_raises_error(self):
+ """External volumes (with volume_id) should not allow auto_cleanup_duration"""
+ config = VolumeConfiguration(
+ backend=BackendType.AWS,
+ region="us-east-1",
+ volume_id="vol-123456",
+ auto_cleanup_duration="1h",
+ )
+ with pytest.raises(
+ ServerClientError, match="External volumes.*do not support auto_cleanup_duration"
+ ):
+ _validate_volume_configuration(config)
+
+ def test_external_volume_with_auto_cleanup_duration_int_raises_error(self):
+ """External volumes with integer auto_cleanup_duration should also raise error"""
+ config = VolumeConfiguration(
+ backend=BackendType.AWS,
+ region="us-east-1",
+ volume_id="vol-123456",
+ auto_cleanup_duration=3600,
+ )
+ with pytest.raises(
+ ServerClientError, match="External volumes.*do not support auto_cleanup_duration"
+ ):
+ _validate_volume_configuration(config)
+
+ def test_external_volume_with_auto_cleanup_disabled_succeeds(self):
+ """External volumes with auto_cleanup_duration='off' or -1 should be allowed"""
+ config1 = VolumeConfiguration(
+ backend=BackendType.AWS,
+ region="us-east-1",
+ volume_id="vol-123456",
+ auto_cleanup_duration="off",
+ )
+ config2 = VolumeConfiguration(
+ backend=BackendType.AWS,
+ region="us-east-1",
+ volume_id="vol-123456",
+ auto_cleanup_duration=-1,
+ )
+ # Should not raise any errors
+ _validate_volume_configuration(config1)
+ _validate_volume_configuration(config2)
+
+ def test_external_volume_without_auto_cleanup_succeeds(self):
+ """External volumes without auto_cleanup_duration should be allowed"""
+ config = VolumeConfiguration(
+ backend=BackendType.AWS, region="us-east-1", volume_id="vol-123456"
+ )
+ # Should not raise any errors
+ _validate_volume_configuration(config)
+
+ def test_new_volume_with_auto_cleanup_duration_succeeds(self):
+ """New volumes (without volume_id) with auto_cleanup_duration should be allowed"""
+ config = VolumeConfiguration(
+ backend=BackendType.AWS, region="us-east-1", size=100, auto_cleanup_duration="1h"
+ )
+ # Should not raise any errors
+ _validate_volume_configuration(config)
+
+
class TestGetVolumeCost:
def test_returns_0_when_no_provisioning_data(self):
volume = get_volume(provisioning_data=None)